forked from mindspore-Ecosystem/mindspore
!6728 [Ascend][DynamicShape] Dynamic shape feature
Merge pull request !6728 from caifubi/dynamic_shape_share_2
This commit is contained in:
commit
c951d42c2c
|
@ -18,6 +18,7 @@ import os
|
||||||
import sys
|
import sys
|
||||||
from te.platform.cce_conf import te_set_version
|
from te.platform.cce_conf import te_set_version
|
||||||
from te.platform.fusion_util import fusion_op
|
from te.platform.fusion_util import fusion_op
|
||||||
|
import te
|
||||||
from common import check_kernel_info, get_args, get_build_in_impl_path
|
from common import check_kernel_info, get_args, get_build_in_impl_path
|
||||||
|
|
||||||
build_in_impl_path = get_build_in_impl_path()
|
build_in_impl_path = get_build_in_impl_path()
|
||||||
|
@ -38,6 +39,16 @@ def _initialize(impl_path):
|
||||||
|
|
||||||
sys.path.insert(0, op_module_name)
|
sys.path.insert(0, op_module_name)
|
||||||
|
|
||||||
|
def _replace_range(args):
|
||||||
|
for arg in args:
|
||||||
|
if not arg.__contains__('range'):
|
||||||
|
continue
|
||||||
|
shape_range = arg["range"]
|
||||||
|
for range_item in shape_range:
|
||||||
|
for index, value in enumerate(range_item):
|
||||||
|
if value < 0:
|
||||||
|
range_item[index] = None
|
||||||
|
|
||||||
def build_op(build_type, json_str):
|
def build_op(build_type, json_str):
|
||||||
"""
|
"""
|
||||||
call op functions with function name and input args json_str
|
call op functions with function name and input args json_str
|
||||||
|
@ -71,9 +82,16 @@ def build_op(build_type, json_str):
|
||||||
outputs_args = get_args(kernel_info['op_info'], 'outputs')
|
outputs_args = get_args(kernel_info['op_info'], 'outputs')
|
||||||
attrs_args = get_args(kernel_info['op_info'], 'attrs')
|
attrs_args = get_args(kernel_info['op_info'], 'attrs')
|
||||||
kernel_name = kernel_info['op_info']['kernel_name']
|
kernel_name = kernel_info['op_info']['kernel_name']
|
||||||
|
is_dynamic_shape = kernel_info['op_info']['is_dynamic_shape']
|
||||||
|
if is_dynamic_shape:
|
||||||
|
_replace_range(inputs_args)
|
||||||
|
_replace_range(outputs_args)
|
||||||
|
|
||||||
if custom_flag:
|
if custom_flag:
|
||||||
op_module = __import__(op_name)
|
op_module = __import__(op_name)
|
||||||
|
else:
|
||||||
|
if is_dynamic_shape:
|
||||||
|
op_module = __import__("impl.dynamic."+op_name, globals(), locals(), [op_name], 0)
|
||||||
else:
|
else:
|
||||||
op_module = __import__("impl."+op_name, globals(), locals(), [op_name], 0)
|
op_module = __import__("impl."+op_name, globals(), locals(), [op_name], 0)
|
||||||
# get function
|
# get function
|
||||||
|
@ -92,6 +110,11 @@ def build_op(build_type, json_str):
|
||||||
if kernel_name[0:19] == "bounding_box_encode":
|
if kernel_name[0:19] == "bounding_box_encode":
|
||||||
return op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name_val=kernel_name)
|
return op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name_val=kernel_name)
|
||||||
|
|
||||||
|
if is_dynamic_shape:
|
||||||
|
with te.op.dynamic():
|
||||||
|
op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
|
||||||
|
return te.op.get_compile_info()
|
||||||
|
else:
|
||||||
return op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
|
return op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
@ -78,6 +78,7 @@ def _check_supported(kernel_info):
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
op_name = kernel_info['op_info']['name']
|
op_name = kernel_info['op_info']['name']
|
||||||
|
is_dynamic_shape = kernel_info['op_info']['is_dynamic_shape']
|
||||||
impl_path = build_in_impl_path
|
impl_path = build_in_impl_path
|
||||||
custom_flag = False
|
custom_flag = False
|
||||||
if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
|
if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
|
||||||
|
@ -92,8 +93,11 @@ def _check_supported(kernel_info):
|
||||||
|
|
||||||
if custom_flag:
|
if custom_flag:
|
||||||
op_module = __import__(op_name)
|
op_module = __import__(op_name)
|
||||||
|
elif is_dynamic_shape:
|
||||||
|
op_module = __import__("impl.dynamic." + op_name, globals(), locals(), [op_name], 0)
|
||||||
else:
|
else:
|
||||||
op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
|
op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
|
||||||
|
|
||||||
# get function
|
# get function
|
||||||
if not hasattr(op_module, "check_supported"):
|
if not hasattr(op_module, "check_supported"):
|
||||||
return ""
|
return ""
|
||||||
|
|
|
@ -219,6 +219,7 @@ if (ENABLE_D)
|
||||||
set(ASCEND_DRIVER_PATH ${ASCEND_PATH}/driver/lib64/common)
|
set(ASCEND_DRIVER_PATH ${ASCEND_PATH}/driver/lib64/common)
|
||||||
set(ASCEND_DRIVER_BACK_PATH ${ASCEND_PATH}/driver/lib64/driver)
|
set(ASCEND_DRIVER_BACK_PATH ${ASCEND_PATH}/driver/lib64/driver)
|
||||||
set(ASCEND_RUNTIME_PATH ${ASCEND_PATH}/fwkacllib/lib64)
|
set(ASCEND_RUNTIME_PATH ${ASCEND_PATH}/fwkacllib/lib64)
|
||||||
|
set(ASCEND_OPP_PATH ${ASCEND_PATH}/opp/op_impl/built-in/ai_core/tbe/op_tiling)
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
MESSAGE("USE DAV LIB PATH: ${ASCEND_PATH}")
|
MESSAGE("USE DAV LIB PATH: ${ASCEND_PATH}")
|
||||||
|
@ -228,7 +229,8 @@ if (ENABLE_D)
|
||||||
find_library(TSDCLIENT tsdclient HINTS ${ASCEND_RUNTIME_PATH} ${ASCEND_DRIVER_BACK_PATH})
|
find_library(TSDCLIENT tsdclient HINTS ${ASCEND_RUNTIME_PATH} ${ASCEND_DRIVER_BACK_PATH})
|
||||||
find_library(DATATRANSFER datatransfer HINTS ${ASCEND_RUNTIME_PATH} ${ASCEND_DRIVER_BACK_PATH})
|
find_library(DATATRANSFER datatransfer HINTS ${ASCEND_RUNTIME_PATH} ${ASCEND_DRIVER_BACK_PATH})
|
||||||
find_library(PROFILING msprofiler ${ASCEND_RUNTIME_PATH})
|
find_library(PROFILING msprofiler ${ASCEND_RUNTIME_PATH})
|
||||||
target_link_libraries(mindspore ge_runtime ${CCE_LIB} ${RUNTIME_LIB} ${TSDCLIENT} ${HCCL} ${DATATRANSFER})
|
find_library(OPTILING optiling ${ASCEND_OPP_PATH})
|
||||||
|
target_link_libraries(mindspore ge_runtime ${CCE_LIB} ${RUNTIME_LIB} ${TSDCLIENT} ${HCCL} ${DATATRANSFER} ${OPTILING})
|
||||||
target_link_libraries(mindspore -Wl,--start-group proto_input ${PROFILING} mindspore::protobuf -Wl,--end-group)
|
target_link_libraries(mindspore -Wl,--start-group proto_input ${PROFILING} mindspore::protobuf -Wl,--end-group)
|
||||||
elseif (CMAKE_SYSTEM_NAME MATCHES "Windows")
|
elseif (CMAKE_SYSTEM_NAME MATCHES "Windows")
|
||||||
target_link_libraries(mindspore -Wl,--start-group proto_input mindspore::protobuf mindspore::sentencepiece -Wl,--end-group)
|
target_link_libraries(mindspore -Wl,--start-group proto_input mindspore::protobuf mindspore::sentencepiece -Wl,--end-group)
|
||||||
|
@ -258,6 +260,7 @@ if (ENABLE_D)
|
||||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/fwkacllib/lib64)
|
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/ascend-toolkit/latest/fwkacllib/lib64)
|
||||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/fwkacllib/lib64)
|
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/fwkacllib/lib64)
|
||||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/add-ons)
|
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/add-ons)
|
||||||
|
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/Ascend/opp/op_impl/built-in/ai_core/tbe/op_tiling)
|
||||||
elseif (ENABLE_GPU)
|
elseif (ENABLE_GPU)
|
||||||
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/cuda/lib64)
|
set(MINDSPORE_RPATH ${MINDSPORE_RPATH}:/usr/local/cuda/lib64)
|
||||||
endif ()
|
endif ()
|
||||||
|
@ -315,6 +318,8 @@ add_library(inference SHARED
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/backend/session/infer_session.cc
|
${CMAKE_CURRENT_SOURCE_DIR}/backend/session/infer_session.cc
|
||||||
${LOAD_ONNX_SRC}
|
${LOAD_ONNX_SRC}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
set_target_properties(inference PROPERTIES INSTALL_RPATH ${MINDSPORE_RPATH})
|
||||||
target_link_libraries(inference PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
|
target_link_libraries(inference PRIVATE ${PYTHON_LIBRARIES} ${SECUREC_LIBRARY}
|
||||||
-Wl,--whole-archive mindspore proto_input -Wl,--no-whole-archive mindspore_gvar)
|
-Wl,--whole-archive mindspore proto_input -Wl,--no-whole-archive mindspore_gvar)
|
||||||
|
|
||||||
|
|
|
@ -15,6 +15,7 @@ if (ENABLE_D)
|
||||||
"akg/akg_kernel_attrs_process.cc"
|
"akg/akg_kernel_attrs_process.cc"
|
||||||
"akg/akg_kernel_metadata.cc"
|
"akg/akg_kernel_metadata.cc"
|
||||||
"tbe/*.cc"
|
"tbe/*.cc"
|
||||||
|
"host/*.cc"
|
||||||
"aicpu/*.cc"
|
"aicpu/*.cc"
|
||||||
"rts/*.cc"
|
"rts/*.cc"
|
||||||
"hccl/*.cc"
|
"hccl/*.cc"
|
||||||
|
|
|
@ -289,51 +289,25 @@ bool CreateNodeDefBytes(const std::shared_ptr<AnfNode> &anf_node,
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CreateExtInfo(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<AicpuOpKernelMod> &kernel_mod_ptr) {
|
uint64_t SetExtInfoShapeType(char *ext_info_buf, uint64_t ext_info_offset) {
|
||||||
if (!anf_node->isa<CNode>()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!AnfAlgo::IsDynamicShape(anf_node)) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
MS_LOG(INFO) << "CreateExtInfo start, " << anf_node->fullname_with_scope();
|
|
||||||
|
|
||||||
int32_t unknown_shape_type = UnknowShapeOpType::DEPEND_COMPUTE;
|
|
||||||
uint64_t ext_info_head_len = kExtInfoHeadSize;
|
|
||||||
std::string ext_info;
|
|
||||||
size_t input_num = AnfAlgo::GetInputTensorNum(anf_node);
|
|
||||||
size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node);
|
|
||||||
|
|
||||||
// 1.addr:unknown shape type
|
|
||||||
uint64_t ext_info_len = ext_info.size();
|
|
||||||
ext_info_len += ext_info_head_len + sizeof(int32_t);
|
|
||||||
|
|
||||||
// 2.addr:input ShapeAndType
|
|
||||||
ext_info_len += ext_info_head_len + input_num * sizeof(ShapeAndType);
|
|
||||||
|
|
||||||
// 3.addr:output ShapeAndType
|
|
||||||
ext_info_len += ext_info_head_len + output_num * sizeof(ShapeAndType);
|
|
||||||
|
|
||||||
uint64_t ext_info_offset = ext_info.size();
|
|
||||||
ext_info.resize(ext_info_len, 0);
|
|
||||||
char *ext_info_buf = ext_info.data();
|
|
||||||
|
|
||||||
// deal1: unknown shape type
|
// deal1: unknown shape type
|
||||||
ExtInfo *info = reinterpret_cast<ExtInfo *>(ext_info_buf + ext_info_offset);
|
ExtInfo *info = reinterpret_cast<ExtInfo *>(ext_info_buf + ext_info_offset);
|
||||||
info->infoType = FWK_ADPT_EXT_SHAPE_TYPE;
|
info->infoType = FWK_ADPT_EXT_SHAPE_TYPE;
|
||||||
info->infoLen = sizeof(int32_t);
|
info->infoLen = sizeof(int32_t);
|
||||||
ext_info_offset += ext_info_head_len;
|
ext_info_offset += kExtInfoHeadSize;
|
||||||
int32_t *shape_type = reinterpret_cast<int32_t *>(ext_info_buf + ext_info_offset);
|
int32_t *shape_type = reinterpret_cast<int32_t *>(ext_info_buf + ext_info_offset);
|
||||||
*shape_type = unknown_shape_type;
|
*shape_type = UnknowShapeOpType::DEPEND_COMPUTE;
|
||||||
ext_info_offset += info->infoLen;
|
ext_info_offset += info->infoLen;
|
||||||
|
return ext_info_offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t SetExtInfoInputShapeType(char *ext_info_buf, uint64_t ext_info_offset,
|
||||||
|
const std::shared_ptr<AnfNode> &anf_node, size_t input_num) {
|
||||||
// deal2:input ShapeAndType
|
// deal2:input ShapeAndType
|
||||||
info = reinterpret_cast<ExtInfo *>(ext_info_buf + ext_info_offset);
|
ExtInfo *info = reinterpret_cast<ExtInfo *>(ext_info_buf + ext_info_offset);
|
||||||
info->infoType = FWK_ADPT_EXT_INPUT_SHAPE;
|
info->infoType = FWK_ADPT_EXT_INPUT_SHAPE;
|
||||||
info->infoLen = input_num * sizeof(ShapeAndType);
|
info->infoLen = input_num * sizeof(ShapeAndType);
|
||||||
ext_info_offset += ext_info_head_len;
|
ext_info_offset += kExtInfoHeadSize;
|
||||||
|
|
||||||
ShapeAndType *inputs = reinterpret_cast<ShapeAndType *>(ext_info_buf + ext_info_offset);
|
ShapeAndType *inputs = reinterpret_cast<ShapeAndType *>(ext_info_buf + ext_info_offset);
|
||||||
for (size_t input_index = 0; input_index < input_num; input_index++) {
|
for (size_t input_index = 0; input_index < input_num; input_index++) {
|
||||||
|
@ -364,12 +338,16 @@ bool CreateExtInfo(const std::shared_ptr<AnfNode> &anf_node, const std::shared_p
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ext_info_offset += info->infoLen;
|
ext_info_offset += info->infoLen;
|
||||||
|
return ext_info_offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
uint64_t SetExtInfoOutputShapeType(char *ext_info_buf, uint64_t ext_info_offset,
|
||||||
|
const std::shared_ptr<AnfNode> &anf_node, size_t output_num) {
|
||||||
// deal3:output ShapeAndType
|
// deal3:output ShapeAndType
|
||||||
info = reinterpret_cast<ExtInfo *>(ext_info_buf + ext_info_offset);
|
ExtInfo *info = reinterpret_cast<ExtInfo *>(ext_info_buf + ext_info_offset);
|
||||||
info->infoType = FWK_ADPT_EXT_OUTPUT_SHAPE;
|
info->infoType = FWK_ADPT_EXT_OUTPUT_SHAPE;
|
||||||
info->infoLen = output_num * sizeof(ShapeAndType);
|
info->infoLen = output_num * sizeof(ShapeAndType);
|
||||||
ext_info_offset += ext_info_head_len;
|
ext_info_offset += kExtInfoHeadSize;
|
||||||
|
|
||||||
ShapeAndType *outputs = reinterpret_cast<ShapeAndType *>(ext_info_buf + ext_info_offset);
|
ShapeAndType *outputs = reinterpret_cast<ShapeAndType *>(ext_info_buf + ext_info_offset);
|
||||||
for (size_t output_index = 0; output_index < output_num; output_index++) {
|
for (size_t output_index = 0; output_index < output_num; output_index++) {
|
||||||
|
@ -387,6 +365,47 @@ bool CreateExtInfo(const std::shared_ptr<AnfNode> &anf_node, const std::shared_p
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ext_info_offset += info->infoLen;
|
||||||
|
return ext_info_offset;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CreateExtInfo(const std::shared_ptr<AnfNode> &anf_node, const std::shared_ptr<AicpuOpKernelMod> &kernel_mod_ptr) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_mod_ptr);
|
||||||
|
if (!anf_node->isa<CNode>()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!AnfAlgo::IsDynamicShape(anf_node)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "CreateExtInfo start, " << anf_node->fullname_with_scope();
|
||||||
|
|
||||||
|
uint64_t ext_info_head_len = kExtInfoHeadSize;
|
||||||
|
std::string ext_info;
|
||||||
|
size_t input_num = AnfAlgo::GetInputTensorNum(anf_node);
|
||||||
|
size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node);
|
||||||
|
|
||||||
|
// 1.addr:unknown shape type
|
||||||
|
uint64_t ext_info_len = ext_info.size();
|
||||||
|
ext_info_len += ext_info_head_len + sizeof(int32_t);
|
||||||
|
|
||||||
|
// 2.addr:input ShapeAndType
|
||||||
|
ext_info_len += ext_info_head_len + input_num * sizeof(ShapeAndType);
|
||||||
|
|
||||||
|
// 3.addr:output ShapeAndType
|
||||||
|
ext_info_len += ext_info_head_len + output_num * sizeof(ShapeAndType);
|
||||||
|
|
||||||
|
uint64_t ext_info_offset = ext_info.size();
|
||||||
|
ext_info.resize(ext_info_len, 0);
|
||||||
|
char *ext_info_buf = ext_info.data();
|
||||||
|
|
||||||
|
ext_info_offset = SetExtInfoShapeType(ext_info_buf, ext_info_offset);
|
||||||
|
ext_info_offset = SetExtInfoInputShapeType(ext_info_buf, ext_info_offset, anf_node, input_num);
|
||||||
|
ext_info_offset = SetExtInfoOutputShapeType(ext_info_buf, ext_info_offset, anf_node, output_num);
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Check ext_info_len:" << ext_info_len << " ext_info_offset:" << ext_info_offset;
|
||||||
// set ext info
|
// set ext info
|
||||||
kernel_mod_ptr->SetExtInfo(ext_info);
|
kernel_mod_ptr->SetExtInfo(ext_info);
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -26,8 +26,13 @@
|
||||||
#include "utils/convert_utils.h"
|
#include "utils/convert_utils.h"
|
||||||
#include "backend/kernel_compiler/aicpu/aicpu_util.h"
|
#include "backend/kernel_compiler/aicpu/aicpu_util.h"
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
|
#include "runtime/device/ascend/executor/ai_cpu_dynamic_kernel.h"
|
||||||
|
#include "runtime/device/kernel_runtime.h"
|
||||||
|
#include "runtime/device/ascend/executor/host_dynamic_kernel.h"
|
||||||
|
|
||||||
using AicpuTaskInfoPtr = std::shared_ptr<ge::model_runner::AicpuTaskInfo>;
|
using AicpuTaskInfoPtr = std::shared_ptr<ge::model_runner::AicpuTaskInfo>;
|
||||||
|
using AicpuDynamicKernel = mindspore::device::ascend::AiCpuDynamicKernel;
|
||||||
|
using HostDynamicKernel = mindspore::device::ascend::HostDynamicKernel;
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
@ -93,7 +98,7 @@ void AicpuOpKernelMod::CreateCpuKernelInfo(const std::vector<AddressPtr> &inputs
|
||||||
param_len += node_def_len;
|
param_len += node_def_len;
|
||||||
param_len += sizeof(uint32_t);
|
param_len += sizeof(uint32_t);
|
||||||
|
|
||||||
AicpuParamHead aicpu_param_head;
|
AicpuParamHead aicpu_param_head{};
|
||||||
aicpu_param_head.length = param_len;
|
aicpu_param_head.length = param_len;
|
||||||
aicpu_param_head.ioAddrNum = io_addrs_num;
|
aicpu_param_head.ioAddrNum = io_addrs_num;
|
||||||
|
|
||||||
|
@ -178,5 +183,15 @@ std::vector<TaskInfoPtr> AicpuOpKernelMod::GenTask(const std::vector<AddressPtr>
|
||||||
MS_LOG(INFO) << "AicpuOpKernelMod GenTask end";
|
MS_LOG(INFO) << "AicpuOpKernelMod GenTask end";
|
||||||
return {task_info_ptr};
|
return {task_info_ptr};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
device::DynamicKernelPtr AicpuOpKernelMod::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) {
|
||||||
|
AddressPtrList kernel_inputs;
|
||||||
|
AddressPtrList kernel_workspaces;
|
||||||
|
AddressPtrList kernel_outputs;
|
||||||
|
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||||
|
|
||||||
|
CreateCpuKernelInfo(kernel_inputs, kernel_outputs);
|
||||||
|
return std::make_shared<AicpuDynamicKernel>(stream_ptr, cnode_ptr, args_, ext_info_, node_so_, node_name_);
|
||||||
|
}
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -31,6 +31,7 @@ class AicpuOpKernelMod : public AscendKernelMod {
|
||||||
|
|
||||||
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs, uint32_t stream_id) override;
|
const std::vector<AddressPtr> &outputs, uint32_t stream_id) override;
|
||||||
|
device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) override;
|
||||||
|
|
||||||
void SetInputList(const std::vector<int64_t> &inputList);
|
void SetInputList(const std::vector<int64_t> &inputList);
|
||||||
void SetOutputList(const std::vector<int64_t> &outputList);
|
void SetOutputList(const std::vector<int64_t> &outputList);
|
||||||
|
|
|
@ -20,7 +20,7 @@
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
static std::map<int32_t, int32_t> MS_PROTO_DATA_TYPE_MAP = {
|
static const std::map<int32_t, int32_t> kMsProtoDataTypeMap = {
|
||||||
{mindspore::TypeId::kTypeUnknown, mindspore::DataType::MS_UNKNOWN},
|
{mindspore::TypeId::kTypeUnknown, mindspore::DataType::MS_UNKNOWN},
|
||||||
{mindspore::TypeId::kNumberTypeBool, mindspore::DataType::MS_BOOL},
|
{mindspore::TypeId::kNumberTypeBool, mindspore::DataType::MS_BOOL},
|
||||||
{mindspore::TypeId::kNumberTypeInt, mindspore::DataType::MS_INT32},
|
{mindspore::TypeId::kNumberTypeInt, mindspore::DataType::MS_INT32},
|
||||||
|
@ -39,14 +39,38 @@ static std::map<int32_t, int32_t> MS_PROTO_DATA_TYPE_MAP = {
|
||||||
{mindspore::TypeId::kNumberTypeFloat64, mindspore::DataType::MS_FLOAT64},
|
{mindspore::TypeId::kNumberTypeFloat64, mindspore::DataType::MS_FLOAT64},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static const std::map<int32_t, int32_t> kProtoDataTypeToMsDataTypeMap = {
|
||||||
|
{mindspore::DataType::MS_UNKNOWN, mindspore::TypeId::kTypeUnknown},
|
||||||
|
{mindspore::DataType::MS_BOOL, mindspore::TypeId::kNumberTypeBool},
|
||||||
|
{mindspore::DataType::MS_INT32, mindspore::TypeId::kNumberTypeInt32},
|
||||||
|
{mindspore::DataType::MS_INT8, mindspore::TypeId::kNumberTypeInt8},
|
||||||
|
{mindspore::DataType::MS_INT16, mindspore::TypeId::kNumberTypeInt16},
|
||||||
|
{mindspore::DataType::MS_INT64, mindspore::TypeId::kNumberTypeInt64},
|
||||||
|
{mindspore::DataType::MS_UINT8, mindspore::TypeId::kNumberTypeUInt8},
|
||||||
|
{mindspore::DataType::MS_UINT16, mindspore::TypeId::kNumberTypeUInt16},
|
||||||
|
{mindspore::DataType::MS_UINT32, mindspore::TypeId::kNumberTypeUInt32},
|
||||||
|
{mindspore::DataType::MS_UINT64, mindspore::TypeId::kNumberTypeUInt64},
|
||||||
|
{mindspore::DataType::MS_FLOAT16, mindspore::TypeId::kNumberTypeFloat16},
|
||||||
|
{mindspore::DataType::MS_FLOAT32, mindspore::TypeId::kNumberTypeFloat32},
|
||||||
|
{mindspore::DataType::MS_FLOAT64, mindspore::TypeId::kNumberTypeFloat64},
|
||||||
|
};
|
||||||
|
|
||||||
int AicpuOpUtil::MsTypeToProtoType(TypeId ms_type) {
|
int AicpuOpUtil::MsTypeToProtoType(TypeId ms_type) {
|
||||||
auto iter = MS_PROTO_DATA_TYPE_MAP.find(ms_type);
|
auto iter = kMsProtoDataTypeMap.find(ms_type);
|
||||||
if (iter != MS_PROTO_DATA_TYPE_MAP.end()) {
|
if (iter == kMsProtoDataTypeMap.end()) {
|
||||||
return MS_PROTO_DATA_TYPE_MAP[ms_type];
|
|
||||||
} else {
|
|
||||||
MS_LOG(ERROR) << "UnSupported ms_type value" << static_cast<int>(ms_type);
|
MS_LOG(ERROR) << "UnSupported ms_type value" << static_cast<int>(ms_type);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
int AicpuOpUtil::ProtoTypeToMsType(int proto_type) {
|
||||||
|
auto iter = kProtoDataTypeToMsDataTypeMap.find(proto_type);
|
||||||
|
if (iter == kProtoDataTypeToMsDataTypeMap.end()) {
|
||||||
|
MS_LOG(ERROR) << "UnSupported proto_type value:" << proto_type;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
return iter->second;
|
||||||
}
|
}
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -55,13 +55,6 @@ struct AicpuParamHead {
|
||||||
uint64_t extInfoAddr; // extInfo address
|
uint64_t extInfoAddr; // extInfo address
|
||||||
} __attribute__((packed));
|
} __attribute__((packed));
|
||||||
|
|
||||||
const uint32_t kExtInfoHeadSize = 8;
|
|
||||||
struct ExtInfo {
|
|
||||||
int32_t infoType; // extend type
|
|
||||||
uint32_t infoLen; // length for infoMsg
|
|
||||||
char infoMsg[0]; // extend value
|
|
||||||
} __attribute__((packed));
|
|
||||||
|
|
||||||
// Extent info ShapeAndType
|
// Extent info ShapeAndType
|
||||||
const uint32_t kMaxShapeDims = 8;
|
const uint32_t kMaxShapeDims = 8;
|
||||||
struct ShapeAndType {
|
struct ShapeAndType {
|
||||||
|
@ -69,6 +62,14 @@ struct ShapeAndType {
|
||||||
int64_t dims[kMaxShapeDims];
|
int64_t dims[kMaxShapeDims];
|
||||||
} __attribute__((packed));
|
} __attribute__((packed));
|
||||||
|
|
||||||
|
// Extend info structure for extInfoAddr
|
||||||
|
const uint32_t kExtInfoHeadSize = 8;
|
||||||
|
struct ExtInfo {
|
||||||
|
int32_t infoType; // extend type
|
||||||
|
uint32_t infoLen; // length for infoMsg
|
||||||
|
char infoMsg[0]; // extend value
|
||||||
|
} __attribute__((packed));
|
||||||
|
|
||||||
// Extend Info type for task
|
// Extend Info type for task
|
||||||
enum FWKTaskExtInfoType {
|
enum FWKTaskExtInfoType {
|
||||||
FWK_ADPT_EXT_SHAPE_TYPE = 0,
|
FWK_ADPT_EXT_SHAPE_TYPE = 0,
|
||||||
|
@ -88,6 +89,7 @@ enum UnknowShapeOpType {
|
||||||
class AicpuOpUtil {
|
class AicpuOpUtil {
|
||||||
public:
|
public:
|
||||||
static int MsTypeToProtoType(TypeId ms_type);
|
static int MsTypeToProtoType(TypeId ms_type);
|
||||||
|
static int ProtoTypeToMsType(int proto_type);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// kernel id
|
// kernel id
|
||||||
|
|
|
@ -15,15 +15,34 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "backend/kernel_compiler/hccl/hccl_kernel.h"
|
#include "backend/kernel_compiler/hccl/hccl_kernel.h"
|
||||||
|
|
||||||
|
#include <map>
|
||||||
#include "runtime/device/ascend/tasksink/runtime_utils.h"
|
#include "runtime/device/ascend/tasksink/runtime_utils.h"
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
|
#include "runtime/device/kernel_runtime.h"
|
||||||
|
#include "runtime/device/ascend/executor/hccl_dynamic_kernel.h"
|
||||||
|
|
||||||
using HcclTaskInfoPtr = std::shared_ptr<ge::model_runner::HcclTaskInfo>;
|
using HcclTaskInfoPtr = std::shared_ptr<ge::model_runner::HcclTaskInfo>;
|
||||||
using ge::model_runner::HcclTaskInfo;
|
using ge::model_runner::HcclTaskInfo;
|
||||||
using mindspore::device::ascend::tasksink::RuntimeUtils;
|
using mindspore::device::ascend::tasksink::RuntimeUtils;
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
static std::map<std::string, std::string> kMsOpNameToHcomHcclType = {
|
||||||
|
{mindspore::kAllReduceOpName, mindspore::kHcomOpTypeAllReduce},
|
||||||
|
{mindspore::kAllGatherOpName, mindspore::kHcomOpTypeAllGather},
|
||||||
|
{mindspore::kBroadcastOpName, mindspore::kHcomOpTypeBroadcast},
|
||||||
|
{mindspore::kReduceScatterOpName, mindspore::kHcomOpTypeReduceScatter}};
|
||||||
|
std::string MsOpNameToHcomOpType(const std::string &ms_op_type) {
|
||||||
|
auto iter = kMsOpNameToHcomHcclType.find(ms_op_type);
|
||||||
|
if (iter == kMsOpNameToHcomHcclType.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid MsOpType:" << ms_op_type;
|
||||||
|
}
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
void HcclKernelFactory::Registe(const std::string &name, HcclKernelCreater &&fun) {
|
void HcclKernelFactory::Registe(const std::string &name, HcclKernelCreater &&fun) {
|
||||||
|
@ -156,5 +175,30 @@ std::vector<TaskInfoPtr> HcclKernel::GenTask(const std::vector<AddressPtr> &inpu
|
||||||
MS_EXCEPTION_IF_NULL(task_info_ptr);
|
MS_EXCEPTION_IF_NULL(task_info_ptr);
|
||||||
return {task_info_ptr};
|
return {task_info_ptr};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
device::DynamicKernelPtr HcclKernel::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) {
|
||||||
|
AddressPtrList inputs;
|
||||||
|
AddressPtrList workspaces;
|
||||||
|
AddressPtrList outputs;
|
||||||
|
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &inputs, &workspaces, &outputs);
|
||||||
|
|
||||||
|
std::string hccl_type = MsOpNameToHcomOpType(AnfAlgo::GetCNodeName(anf_node_));
|
||||||
|
|
||||||
|
if (inputs.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Hccl kernel input is empty";
|
||||||
|
}
|
||||||
|
if (hccl_data_type_list_.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Hccl data type list is empty";
|
||||||
|
}
|
||||||
|
MS_EXCEPTION_IF_NULL(inputs.at(0));
|
||||||
|
auto input_data_addr = inputs.at(0)->addr;
|
||||||
|
MS_EXCEPTION_IF_NULL(outputs.at(0));
|
||||||
|
auto output_data_addr = outputs.at(0)->addr;
|
||||||
|
HcclDataType data_type = hccl_data_type_list_[0];
|
||||||
|
|
||||||
|
auto executor = std::make_shared<device::ascend::HcclDynamicKernel>(
|
||||||
|
hccl_type, input_data_addr, output_data_addr, hccl_count_, data_type, op_type_, root_id_, stream_ptr, cnode_ptr);
|
||||||
|
return executor;
|
||||||
|
}
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -41,6 +41,7 @@ class HcclKernel : public AscendKernelMod {
|
||||||
const std::vector<size_t> &GetWorkspaceSizeList() const override;
|
const std::vector<size_t> &GetWorkspaceSizeList() const override;
|
||||||
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs, uint32_t stream_id) override;
|
const std::vector<AddressPtr> &outputs, uint32_t stream_id) override;
|
||||||
|
device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) override;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
std::vector<std::vector<size_t>> hccl_kernel_input_shape_list_;
|
std::vector<std::vector<size_t>> hccl_kernel_input_shape_list_;
|
||||||
|
|
|
@ -0,0 +1,52 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "backend/kernel_compiler/host/dynamic_shape_kernel.h"
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
void DynamicShapeKernel::Execute() {
|
||||||
|
MS_LOG(INFO) << "Execute DynamicShapeKernel Start";
|
||||||
|
auto input_num = AnfAlgo::GetInputTensorNum(cnode_ptr_);
|
||||||
|
if (input_num != 1) {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid Input Num:" << input_num;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto prev_output_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode_ptr_, 0);
|
||||||
|
auto output_shape = std::vector<int>(SizeToInt(prev_output_shape.size()));
|
||||||
|
|
||||||
|
auto output_type = TypeId::kNumberTypeInt32;
|
||||||
|
|
||||||
|
auto output_tensor_for_sync = std::make_shared<tensor::Tensor>(output_type, output_shape);
|
||||||
|
auto data_ptr = static_cast<int32_t *>(output_tensor_for_sync->data_c());
|
||||||
|
for (size_t i = 0; i < prev_output_shape.size(); ++i) {
|
||||||
|
MS_LOG(INFO) << "DEBUG prev_output_shape[" << i << "]:" << prev_output_shape[i];
|
||||||
|
*(data_ptr + i) = prev_output_shape[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
auto output_addr = AnfAlgo::GetOutputAddr(cnode_ptr_, 0);
|
||||||
|
MS_EXCEPTION_IF_NULL(output_addr);
|
||||||
|
output_addr->SyncHostToDevice(output_shape, LongToSize(output_tensor_for_sync->data().nbytes()),
|
||||||
|
output_tensor_for_sync->data_type(), output_tensor_for_sync->data_c());
|
||||||
|
MS_LOG(INFO) << "Execute DynamicShapeKernel End";
|
||||||
|
}
|
||||||
|
|
||||||
|
device::DynamicKernelPtr DynamicShapeKernelMod::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) {
|
||||||
|
return std::make_shared<DynamicShapeKernel>(stream_ptr, cnode_ptr);
|
||||||
|
}
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,43 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_SHAPE_KERNEL_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_SHAPE_KERNEL_H_
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include "runtime/device/ascend/executor/host_dynamic_kernel.h"
|
||||||
|
#include "backend/kernel_compiler/host/host_kernel_mod.h"
|
||||||
|
using HostDynamicKernel = mindspore::device::ascend::HostDynamicKernel;
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
class DynamicShapeKernel : public HostDynamicKernel {
|
||||||
|
public:
|
||||||
|
DynamicShapeKernel(void *stream, const CNodePtr &cnode_ptr) : HostDynamicKernel(stream, cnode_ptr) {}
|
||||||
|
~DynamicShapeKernel() override = default;
|
||||||
|
void Execute() override;
|
||||||
|
};
|
||||||
|
|
||||||
|
class DynamicShapeKernelMod : public HostKernelMod {
|
||||||
|
public:
|
||||||
|
DynamicShapeKernelMod() = default;
|
||||||
|
~DynamicShapeKernelMod() override = default;
|
||||||
|
device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) override;
|
||||||
|
};
|
||||||
|
MS_HOST_REG_KERNEL(DynamicShape, DynamicShapeKernelMod);
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_DYNAMIC_SHAPE_KERNEL_H_
|
|
@ -0,0 +1,42 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#include "backend/kernel_compiler/host/host_kernel_build.h"
|
||||||
|
#include <string>
|
||||||
|
#include "runtime/device/kernel_runtime.h"
|
||||||
|
#include "backend/kernel_compiler/host/host_kernel_mod.h"
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
#include "backend/session/kernel_graph.h"
|
||||||
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
KernelModPtr HostOpBuild(const std::shared_ptr<AnfNode> &anf_node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
std::string opname = AnfAlgo::GetCNodeName(anf_node);
|
||||||
|
MS_LOG(INFO) << "Host op [" << opname << "]";
|
||||||
|
auto kerPtr = HostKernelFactory::Get(opname);
|
||||||
|
if (kerPtr == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Host can't find Kernel[" << opname << "]";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
if (!kerPtr->Init(anf_node)) {
|
||||||
|
MS_LOG(ERROR) << "Host Kernel initialize failed!";
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
return kerPtr;
|
||||||
|
}
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,27 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_BUILD_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_BUILD_H_
|
||||||
|
#include <memory>
|
||||||
|
#include "backend/kernel_compiler/kernel.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
KernelModPtr HostOpBuild(const std::shared_ptr<AnfNode> &anf_node);
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_BUILD_H_
|
|
@ -0,0 +1,59 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "backend/kernel_compiler/host/host_kernel_metadata.h"
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||||
|
#include "backend/kernel_compiler/common_utils.h"
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
constexpr auto kDynamicShape = "DynamicShape";
|
||||||
|
|
||||||
|
void HostMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
|
||||||
|
MS_LOG(INFO) << "HostMetadataInfo.";
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_info_list);
|
||||||
|
std::string op_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||||
|
if (op_name != kDynamicShape) {
|
||||||
|
MS_LOG(DEBUG) << "Host does not have op [" << op_name << "]";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::string> inputs_format{};
|
||||||
|
std::vector<TypeId> inputs_type{};
|
||||||
|
for (size_t input_index = 0; input_index < AnfAlgo::GetInputTensorNum(kernel_node); ++input_index) {
|
||||||
|
inputs_format.emplace_back(kOpFormat_DEFAULT);
|
||||||
|
inputs_type.push_back(AnfAlgo::GetPrevNodeOutputInferDataType(kernel_node, input_index));
|
||||||
|
}
|
||||||
|
std::vector<std::string> outputs_format;
|
||||||
|
std::vector<TypeId> outputs_type;
|
||||||
|
for (size_t output_index = 0; output_index < AnfAlgo::GetOutputTensorNum(kernel_node); ++output_index) {
|
||||||
|
outputs_format.emplace_back(kOpFormat_DEFAULT);
|
||||||
|
outputs_type.push_back(AnfAlgo::GetOutputInferDataType(kernel_node, output_index));
|
||||||
|
}
|
||||||
|
auto builder = KernelBuildInfo::KernelBuildInfoBuilder();
|
||||||
|
builder.SetInputsFormat(inputs_format);
|
||||||
|
builder.SetInputsDeviceType(inputs_type);
|
||||||
|
builder.SetOutputsFormat(outputs_format);
|
||||||
|
builder.SetOutputsDeviceType(outputs_type);
|
||||||
|
builder.SetKernelType(HOST_KERNEL);
|
||||||
|
kernel_info_list->push_back(builder.Build());
|
||||||
|
}
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,30 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_META_DATA_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_META_DATA_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include "backend/kernel_compiler/kernel_build_info.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
void HostMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list);
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_META_DATA_H_
|
|
@ -0,0 +1,98 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "backend/kernel_compiler/host/host_kernel_mod.h"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include "runtime/mem.h"
|
||||||
|
#include "utils/ms_context.h"
|
||||||
|
#include "runtime/device/kernel_runtime.h"
|
||||||
|
#include "runtime/device/ascend/executor/host_dynamic_kernel.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
void HostKernelFactory::Registe(const std::string &name, HostKernelCreater &&fun) {
|
||||||
|
hostKernelMap_.emplace(name, std::move(fun));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<HostKernelMod> HostKernelFactory::Get(const std::string &name) {
|
||||||
|
const auto &map = Get().hostKernelMap_;
|
||||||
|
auto it = map.find(name);
|
||||||
|
if (it != map.end() && it->second) {
|
||||||
|
return (it->second)();
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
HostKernelFactory &HostKernelFactory::Get() {
|
||||||
|
static HostKernelFactory instance;
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<size_t> &HostKernelMod::GetInputSizeList() const { return input_size_list_; }
|
||||||
|
const std::vector<size_t> &HostKernelMod::GetOutputSizeList() const { return output_size_list_; }
|
||||||
|
const std::vector<size_t> &HostKernelMod::GetWorkspaceSizeList() const { return workspace_size_list_; }
|
||||||
|
bool HostKernelMod::Init(const AnfNodePtr &anf_node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
size_t input_num = AnfAlgo::GetInputTensorNum(anf_node);
|
||||||
|
size_t output_num = AnfAlgo::GetOutputTensorNum(anf_node);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < input_num; i++) {
|
||||||
|
std::vector<size_t> shape_i = AnfAlgo::GetInputDeviceShape(anf_node, i);
|
||||||
|
TypePtr type_ptr = TypeIdToType(AnfAlgo::GetInputDeviceDataType(anf_node, i));
|
||||||
|
MS_EXCEPTION_IF_NULL(type_ptr);
|
||||||
|
int64_t size_i = 1;
|
||||||
|
for (size_t j = 0; j < shape_i.size(); j++) {
|
||||||
|
size_i = LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]));
|
||||||
|
}
|
||||||
|
size_t type_byte = GetTypeByte(type_ptr);
|
||||||
|
if (type_byte == 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte));
|
||||||
|
input_size_list_.push_back(LongToSize(size_i));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < output_num; i++) {
|
||||||
|
std::vector<size_t> shape_i = AnfAlgo::GetOutputDeviceShape(anf_node, i);
|
||||||
|
TypePtr type_ptr = TypeIdToType(AnfAlgo::GetOutputDeviceDataType(anf_node, i));
|
||||||
|
MS_EXCEPTION_IF_NULL(type_ptr);
|
||||||
|
int64_t size_i = 1;
|
||||||
|
for (size_t j = 0; j < shape_i.size(); j++) {
|
||||||
|
size_i = LongMulWithOverflowCheck(size_i, static_cast<int>(shape_i[j]));
|
||||||
|
}
|
||||||
|
size_t type_byte = GetTypeByte(type_ptr);
|
||||||
|
if (type_byte == 0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
size_i = LongMulWithOverflowCheck(size_i, SizeToInt(type_byte));
|
||||||
|
output_size_list_.push_back(LongToSize(size_i));
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
bool HostKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
std::vector<TaskInfoPtr> HostKernelMod::GenTask(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||||
|
const std::vector<AddressPtr> &, uint32_t) {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,86 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_MOD_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_MOD_H_
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <map>
|
||||||
|
#include <utility>
|
||||||
|
#include "backend/kernel_compiler/ascend_kernel_mod.h"
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
class HostKernelMod : public AscendKernelMod {
|
||||||
|
public:
|
||||||
|
HostKernelMod() = default;
|
||||||
|
~HostKernelMod() override = default;
|
||||||
|
const std::vector<size_t> &GetInputSizeList() const override;
|
||||||
|
const std::vector<size_t> &GetOutputSizeList() const override;
|
||||||
|
const std::vector<size_t> &GetWorkspaceSizeList() const override;
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||||
|
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
|
||||||
|
const std::vector<AddressPtr> &, uint32_t) override;
|
||||||
|
device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) override = 0;
|
||||||
|
bool Init(const AnfNodePtr &anf_node);
|
||||||
|
|
||||||
|
protected:
|
||||||
|
AnfNodePtr anf_node_;
|
||||||
|
std::string op_name_;
|
||||||
|
std::vector<size_t> input_size_list_;
|
||||||
|
std::vector<size_t> output_size_list_;
|
||||||
|
std::vector<size_t> workspace_size_list_;
|
||||||
|
};
|
||||||
|
|
||||||
|
using HostKernelModPtr = std::shared_ptr<HostKernelMod>;
|
||||||
|
using HostKernelModPtrList = std::vector<HostKernelModPtr>;
|
||||||
|
using HostKernelCreater = std::function<std::shared_ptr<HostKernelMod>()>;
|
||||||
|
|
||||||
|
class HostKernelFactory {
|
||||||
|
HostKernelFactory() = default;
|
||||||
|
~HostKernelFactory() = default;
|
||||||
|
|
||||||
|
public:
|
||||||
|
static HostKernelFactory &Get();
|
||||||
|
void Registe(const string &name, HostKernelCreater &&fun);
|
||||||
|
static std::shared_ptr<HostKernelMod> Get(const string &name);
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::map<string, HostKernelCreater> hostKernelMap_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class _HostKernelRegister {
|
||||||
|
public:
|
||||||
|
_HostKernelRegister(const string &name, HostKernelCreater &&fun) {
|
||||||
|
HostKernelFactory::Get().Registe(name, std::move(fun));
|
||||||
|
}
|
||||||
|
~_HostKernelRegister() = default;
|
||||||
|
};
|
||||||
|
|
||||||
|
#define _MS_HOST_REG_KERNEL_REG(KNAME, clazz) \
|
||||||
|
static_assert(std::is_base_of<HostKernelMod, clazz>::value, " must be base of HostKernelMod"); \
|
||||||
|
static const _HostKernelRegister g_##KNAME##_##_kernel_reg(#KNAME, []() { \
|
||||||
|
std::shared_ptr<clazz> ptr = nullptr; \
|
||||||
|
ptr = std::make_shared<clazz>(); \
|
||||||
|
MS_EXCEPTION_IF_NULL(ptr); \
|
||||||
|
return ptr; \
|
||||||
|
});
|
||||||
|
|
||||||
|
#define MS_HOST_REG_KERNEL(KNAME, clazz) _MS_HOST_REG_KERNEL_REG(KNAME, clazz)
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_HOST_HOST_KERNEL_MOD_H_
|
|
@ -174,6 +174,9 @@ void KernelPack::ParseKernelJson(const nlohmann::json &js) {
|
||||||
kernel_json_info_.block_dim = js["blockDim"];
|
kernel_json_info_.block_dim = js["blockDim"];
|
||||||
kernel_json_info_.kernel_name = js["kernelName"];
|
kernel_json_info_.kernel_name = js["kernelName"];
|
||||||
kernel_json_info_.magic = js["magic"];
|
kernel_json_info_.magic = js["magic"];
|
||||||
|
if (js.contains("opParaSize")) {
|
||||||
|
kernel_json_info_.op_para_size = js["opParaSize"];
|
||||||
|
}
|
||||||
if (js.find("parameters") != js.end()) {
|
if (js.find("parameters") != js.end()) {
|
||||||
if (!js.at("parameters").is_array()) {
|
if (!js.at("parameters").is_array()) {
|
||||||
MS_LOG(DEBUG) << "Format error!,parameters should be array.";
|
MS_LOG(DEBUG) << "Format error!,parameters should be array.";
|
||||||
|
|
|
@ -25,9 +25,18 @@
|
||||||
#include "ir/tensor.h"
|
#include "ir/tensor.h"
|
||||||
#include "abstract/dshape.h"
|
#include "abstract/dshape.h"
|
||||||
#include "utils/log_adapter.h"
|
#include "utils/log_adapter.h"
|
||||||
|
#include "runtime/device/executor/dynamic_kernel.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
enum KernelType : int { UNKNOWN_KERNEL_TYPE = 0, AKG_KERNEL, AICPU_KERNEL, RT_KERNEL, HCCL_KERNEL, TBE_KERNEL };
|
enum KernelType : int {
|
||||||
|
UNKNOWN_KERNEL_TYPE = 0,
|
||||||
|
AKG_KERNEL,
|
||||||
|
AICPU_KERNEL,
|
||||||
|
RT_KERNEL,
|
||||||
|
HCCL_KERNEL,
|
||||||
|
TBE_KERNEL,
|
||||||
|
HOST_KERNEL
|
||||||
|
};
|
||||||
|
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
// Supported fusion type
|
// Supported fusion type
|
||||||
|
@ -69,7 +78,8 @@ struct KernelJsonInfo {
|
||||||
std::vector<size_t> parameters;
|
std::vector<size_t> parameters;
|
||||||
std::string sha256;
|
std::string sha256;
|
||||||
std::vector<size_t> workspaces;
|
std::vector<size_t> workspaces;
|
||||||
KernelJsonInfo() : block_dim(0) {}
|
uint32_t op_para_size;
|
||||||
|
KernelJsonInfo() : block_dim(0), op_para_size(0) {}
|
||||||
};
|
};
|
||||||
|
|
||||||
class KernelPack {
|
class KernelPack {
|
||||||
|
@ -118,6 +128,7 @@ class KernelMod {
|
||||||
virtual const std::vector<size_t> &GetWorkspaceSizeList() const = 0;
|
virtual const std::vector<size_t> &GetWorkspaceSizeList() const = 0;
|
||||||
virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) = 0;
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) = 0;
|
||||||
|
virtual device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) { return nullptr; }
|
||||||
virtual std::vector<size_t> GenParameters() { return {}; }
|
virtual std::vector<size_t> GenParameters() { return {}; }
|
||||||
virtual void ReleaseResource() {}
|
virtual void ReleaseResource() {}
|
||||||
|
|
||||||
|
|
|
@ -83,8 +83,8 @@ std::map<int32_t, KernelModPtr> KernelFusion(const std::vector<FusionScopeInfo>
|
||||||
while (!build_manger->IsAllTaskFinish()) {
|
while (!build_manger->IsAllTaskFinish()) {
|
||||||
int task_id = -1;
|
int task_id = -1;
|
||||||
std::string task_result;
|
std::string task_result;
|
||||||
std::string pre_build_result;
|
std::string build_result;
|
||||||
auto ret = build_manger->WaitOne(&task_id, &task_result, &pre_build_result);
|
auto ret = build_manger->WaitOne(&task_id, &task_result, &build_result);
|
||||||
if (!ret) {
|
if (!ret) {
|
||||||
MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id;
|
MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id;
|
||||||
}
|
}
|
||||||
|
@ -94,7 +94,7 @@ std::map<int32_t, KernelModPtr> KernelFusion(const std::vector<FusionScopeInfo>
|
||||||
<< " change to single op build.";
|
<< " change to single op build.";
|
||||||
build_failed_num++;
|
build_failed_num++;
|
||||||
}
|
}
|
||||||
auto kernel_mod_item = build_manger->TaskFinishProcess(task_id, false);
|
auto kernel_mod_item = build_manger->TaskFinishProcess(task_id, build_result, false);
|
||||||
if (kernel_mod_item.second != nullptr) {
|
if (kernel_mod_item.second != nullptr) {
|
||||||
(void)kernel_mod_ret.emplace(kernel_mod_item);
|
(void)kernel_mod_ret.emplace(kernel_mod_item);
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,6 +18,7 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include "backend/kernel_compiler/aicpu/aicpu_kernel_metadata.h"
|
#include "backend/kernel_compiler/aicpu/aicpu_kernel_metadata.h"
|
||||||
|
#include "backend/kernel_compiler/host/host_kernel_metadata.h"
|
||||||
#include "backend/kernel_compiler/rts/rt_kernel_info.h"
|
#include "backend/kernel_compiler/rts/rt_kernel_info.h"
|
||||||
#include "backend/kernel_compiler/hccl/hccl_kernel_metadata.h"
|
#include "backend/kernel_compiler/hccl/hccl_kernel_metadata.h"
|
||||||
#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h"
|
#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.h"
|
||||||
|
@ -86,6 +87,9 @@ void KernelQueryAll(const CNodePtr &kernel_node,
|
||||||
if (kernel_info_list->empty()) {
|
if (kernel_info_list->empty()) {
|
||||||
HcclMetadataInfo(kernel_node, kernel_info_list);
|
HcclMetadataInfo(kernel_node, kernel_info_list);
|
||||||
}
|
}
|
||||||
|
if (kernel_info_list->empty()) {
|
||||||
|
HostMetadataInfo(kernel_node, kernel_info_list);
|
||||||
|
}
|
||||||
if (kernel_info_list->empty()) {
|
if (kernel_info_list->empty()) {
|
||||||
MS_EXCEPTION(NotExistsError)
|
MS_EXCEPTION(NotExistsError)
|
||||||
<< "Failed to obtain operator info, Please check whether the operator info is registered, Op full name:"
|
<< "Failed to obtain operator info, Please check whether the operator info is registered, Op full name:"
|
||||||
|
|
|
@ -102,6 +102,7 @@ class OpInfo {
|
||||||
kernel_name_ = opinfo.kernel_name();
|
kernel_name_ = opinfo.kernel_name();
|
||||||
partial_flag_ = opinfo.partial_flag_;
|
partial_flag_ = opinfo.partial_flag_;
|
||||||
dynamic_format_ = opinfo.dynamic_format_;
|
dynamic_format_ = opinfo.dynamic_format_;
|
||||||
|
dynamic_shape_ = opinfo.dynamic_shape_;
|
||||||
op_pattern_ = opinfo.op_pattern();
|
op_pattern_ = opinfo.op_pattern();
|
||||||
processor_ = opinfo.processor_;
|
processor_ = opinfo.processor_;
|
||||||
for (const auto &attr : opinfo.attrs_ptr()) {
|
for (const auto &attr : opinfo.attrs_ptr()) {
|
||||||
|
@ -122,12 +123,14 @@ class OpInfo {
|
||||||
std::string fusion_type() const { return fusion_type_; }
|
std::string fusion_type() const { return fusion_type_; }
|
||||||
std::string kernel_name() const { return kernel_name_; }
|
std::string kernel_name() const { return kernel_name_; }
|
||||||
OpPattern op_pattern() const { return op_pattern_; }
|
OpPattern op_pattern() const { return op_pattern_; }
|
||||||
|
bool dynamic_shape() const { return dynamic_shape_; }
|
||||||
std::string processor() const { return processor_; }
|
std::string processor() const { return processor_; }
|
||||||
std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; }
|
std::vector<std::shared_ptr<OpAttr>> attrs_ptr() const { return attrs_ptr_; }
|
||||||
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; }
|
std::vector<std::shared_ptr<OpIOInfo>> inputs_ptr() const { return inputs_ptr_; }
|
||||||
std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; }
|
std::vector<std::shared_ptr<OpIOInfo>> outputs_ptr() const { return outputs_ptr_; }
|
||||||
const std::unordered_map<size_t, size_t> &ref_infos() const { return ref_infos_; }
|
const std::unordered_map<size_t, size_t> &ref_infos() const { return ref_infos_; }
|
||||||
|
|
||||||
|
void set_dynamic_shape(bool dynamic_shape) { dynamic_shape_ = dynamic_shape; }
|
||||||
void set_op_name(const std::string &op_name) { op_name_ = op_name; }
|
void set_op_name(const std::string &op_name) { op_name_ = op_name; }
|
||||||
void set_imply_type(const OpImplyType imply_type) { imply_type_ = imply_type; }
|
void set_imply_type(const OpImplyType imply_type) { imply_type_ = imply_type; }
|
||||||
void set_impl_path(const std::string &impl_path) { impl_path_ = impl_path; }
|
void set_impl_path(const std::string &impl_path) { impl_path_ = impl_path; }
|
||||||
|
@ -149,7 +152,8 @@ class OpInfo {
|
||||||
void ClearOutputs() { (void)outputs_ptr_.clear(); }
|
void ClearOutputs() { (void)outputs_ptr_.clear(); }
|
||||||
bool equals_to(const std::shared_ptr<OpInfo> &other_info) const {
|
bool equals_to(const std::shared_ptr<OpInfo> &other_info) const {
|
||||||
return this->op_name_ == other_info->op_name_ && this->imply_type_ == other_info->imply_type_ &&
|
return this->op_name_ == other_info->op_name_ && this->imply_type_ == other_info->imply_type_ &&
|
||||||
this->processor_ == other_info->processor_;
|
this->processor_ == other_info->processor_ && this->op_pattern_ == other_info->op_pattern_ &&
|
||||||
|
this->dynamic_shape_ == other_info->dynamic_shape_;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -163,6 +167,7 @@ class OpInfo {
|
||||||
std::string kernel_name_;
|
std::string kernel_name_;
|
||||||
bool partial_flag_ = false;
|
bool partial_flag_ = false;
|
||||||
bool dynamic_format_ = false;
|
bool dynamic_format_ = false;
|
||||||
|
bool dynamic_shape_ = false;
|
||||||
OpPattern op_pattern_ = kCommonPattern;
|
OpPattern op_pattern_ = kCommonPattern;
|
||||||
std::string processor_;
|
std::string processor_;
|
||||||
std::vector<std::shared_ptr<OpAttr>> attrs_ptr_;
|
std::vector<std::shared_ptr<OpAttr>> attrs_ptr_;
|
||||||
|
|
|
@ -38,6 +38,7 @@ constexpr auto kDynamicFormat = "dynamicFormat";
|
||||||
constexpr auto kFormatAgnostic = "formatAgnostic";
|
constexpr auto kFormatAgnostic = "formatAgnostic";
|
||||||
constexpr auto kBroadcast = "broadcast";
|
constexpr auto kBroadcast = "broadcast";
|
||||||
constexpr auto kReduce = "reduce";
|
constexpr auto kReduce = "reduce";
|
||||||
|
constexpr auto kDynamicShape = "dynamic_shape";
|
||||||
constexpr auto kDtypeFormat = "dtype_format";
|
constexpr auto kDtypeFormat = "dtype_format";
|
||||||
constexpr auto kAttr = "attr";
|
constexpr auto kAttr = "attr";
|
||||||
constexpr auto kIputs = "inputs";
|
constexpr auto kIputs = "inputs";
|
||||||
|
@ -111,6 +112,10 @@ void OpLib::DecodeTBESpecificInfo(const nlohmann::json &obj, const std::shared_p
|
||||||
op_info->set_kernel_name(obj.at(kKernelName));
|
op_info->set_kernel_name(obj.at(kKernelName));
|
||||||
op_info->set_partial_flag(obj.at(kPartialFlag));
|
op_info->set_partial_flag(obj.at(kPartialFlag));
|
||||||
|
|
||||||
|
if (obj.find(kDynamicShape) != obj.end()) {
|
||||||
|
op_info->set_dynamic_shape(obj.at(kDynamicShape));
|
||||||
|
}
|
||||||
|
|
||||||
if (obj.find(kOpPattern) != obj.end()) {
|
if (obj.find(kOpPattern) != obj.end()) {
|
||||||
std::string op_pattern = obj.at(kOpPattern);
|
std::string op_pattern = obj.at(kOpPattern);
|
||||||
auto find_iter = kOpPatternMap.find(op_pattern);
|
auto find_iter = kOpPatternMap.find(op_pattern);
|
||||||
|
@ -322,7 +327,7 @@ bool OpLib::DecodeInputOutput(const nlohmann::json &obj, const OpImplyType imply
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType imply_type) {
|
std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType imply_type, bool is_dynamic_shape) {
|
||||||
if (!OpLib::RegOpFromLocalInfo()) {
|
if (!OpLib::RegOpFromLocalInfo()) {
|
||||||
MS_LOG(INFO) << "Warning reg local op info failed.";
|
MS_LOG(INFO) << "Warning reg local op info failed.";
|
||||||
}
|
}
|
||||||
|
@ -338,16 +343,20 @@ std::shared_ptr<OpInfo> OpLib::FindOp(const std::string &op_name, OpImplyType im
|
||||||
for (auto [iter, end] = op_info_.equal_range(op_name); iter != end; ++iter) {
|
for (auto [iter, end] = op_info_.equal_range(op_name); iter != end; ++iter) {
|
||||||
auto &op_info = iter->second;
|
auto &op_info = iter->second;
|
||||||
MS_EXCEPTION_IF_NULL(op_info);
|
MS_EXCEPTION_IF_NULL(op_info);
|
||||||
|
|
||||||
if (op_info->imply_type() != imply_type) {
|
if (op_info->imply_type() != imply_type) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if (imply_type == kAKG && op_info->processor() != target_processor) {
|
if (imply_type == kAKG && op_info->processor() != target_processor) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
if (is_dynamic_shape && !op_info->dynamic_shape()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
return op_info;
|
return op_info;
|
||||||
}
|
}
|
||||||
MS_LOG(INFO) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type)
|
MS_LOG(INFO) << "FindOp failed: opname: " << op_name << ", imply_type: " << ImplTypeToStr(imply_type)
|
||||||
<< ", current op num: " << op_info_.size();
|
<< ", current op num: " << op_info_.size() << " is_dynamic_shape:" << is_dynamic_shape;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -32,7 +32,8 @@ class OpLib {
|
||||||
virtual ~OpLib() = default;
|
virtual ~OpLib() = default;
|
||||||
static bool RegOp(const std::string &json_string, const std::string &impl_path);
|
static bool RegOp(const std::string &json_string, const std::string &impl_path);
|
||||||
static void RegOpInfo(const std::shared_ptr<OpInfo> &opinfo) { op_info_.emplace(opinfo->op_name(), opinfo); }
|
static void RegOpInfo(const std::shared_ptr<OpInfo> &opinfo) { op_info_.emplace(opinfo->op_name(), opinfo); }
|
||||||
static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, OpImplyType imply_type);
|
static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, OpImplyType imply_type,
|
||||||
|
bool is_dynamic_shape = false);
|
||||||
static const std::multimap<std::string, std::shared_ptr<OpInfo>> &GetAllOpsInfo() { return op_info_; }
|
static const std::multimap<std::string, std::shared_ptr<OpInfo>> &GetAllOpsInfo() { return op_info_; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
|
|
|
@ -21,9 +21,14 @@
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
#include "common/trans.h"
|
#include "common/trans.h"
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
|
#include "runtime/device/kernel_runtime.h"
|
||||||
|
#include "runtime/device/ascend/executor/rts/memcpy_rts_dynamic_kernel.h"
|
||||||
|
|
||||||
using ge::model_runner::MemcpyAsyncTaskInfo;
|
using ge::model_runner::MemcpyAsyncTaskInfo;
|
||||||
using MemcpyAsyncTaskInfoPtr = std::shared_ptr<MemcpyAsyncTaskInfo>;
|
using MemcpyAsyncTaskInfoPtr = std::shared_ptr<MemcpyAsyncTaskInfo>;
|
||||||
|
using AddressPtrList = std::vector<mindspore::kernel::AddressPtr>;
|
||||||
|
using mindspore::device::ascend::MemcpyRtsDynamicKernel;
|
||||||
|
using MemcpyRtsDynamicKernelPtr = std::shared_ptr<MemcpyRtsDynamicKernel>;
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
@ -122,6 +127,32 @@ std::vector<TaskInfoPtr> MemCpyAsyncKernel::GenTask(const std::vector<AddressPtr
|
||||||
MS_EXCEPTION_IF_NULL(task_info_ptr);
|
MS_EXCEPTION_IF_NULL(task_info_ptr);
|
||||||
return {task_info_ptr};
|
return {task_info_ptr};
|
||||||
}
|
}
|
||||||
|
device::DynamicKernelPtr MemCpyAsyncKernel::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) {
|
||||||
|
AddressPtrList kernel_inputs;
|
||||||
|
AddressPtrList kernel_workspaces;
|
||||||
|
AddressPtrList kernel_outputs;
|
||||||
|
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||||
|
|
||||||
|
if (kernel_inputs.size() != 1) {
|
||||||
|
MS_LOG(EXCEPTION) << "MemCpyAsync op inputs is not one";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (kernel_outputs.size() != 1) {
|
||||||
|
MS_LOG(EXCEPTION) << "MemCpyAsync op output is not one";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (kernel_outputs[0]->size < kernel_inputs[0]->size) {
|
||||||
|
MS_LOG(EXCEPTION) << "Check rtMemcpyAsync destMax < src size";
|
||||||
|
}
|
||||||
|
// input x -> memcpy_async -> AllReduce
|
||||||
|
if (kernel_outputs[0]->size > kernel_inputs[0]->size) {
|
||||||
|
MS_LOG(WARNING) << "Check rtMemcpyAsync destMax > src size";
|
||||||
|
}
|
||||||
|
|
||||||
|
return std::make_shared<MemcpyRtsDynamicKernel>(stream_ptr, cnode_ptr, kernel_outputs[0]->addr,
|
||||||
|
kernel_outputs[0]->size, kernel_inputs[0]->addr,
|
||||||
|
kernel_inputs[0]->size);
|
||||||
|
}
|
||||||
|
|
||||||
const std::vector<TypeId> data_type_list{kNumberTypeInt, kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32,
|
const std::vector<TypeId> data_type_list{kNumberTypeInt, kNumberTypeInt8, kNumberTypeInt16, kNumberTypeInt32,
|
||||||
kNumberTypeInt64, kNumberTypeUInt, kNumberTypeUInt8, kNumberTypeUInt16,
|
kNumberTypeInt64, kNumberTypeUInt, kNumberTypeUInt8, kNumberTypeUInt16,
|
||||||
|
|
|
@ -34,6 +34,7 @@ class MemCpyAsyncKernel : public RtKernel {
|
||||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||||
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs, uint32_t stream_id) override;
|
const std::vector<AddressPtr> &outputs, uint32_t stream_id) override;
|
||||||
|
device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void GetInputOutputDataType(const AnfNodePtr &anf_node);
|
void GetInputOutputDataType(const AnfNodePtr &anf_node);
|
||||||
|
|
|
@ -21,8 +21,10 @@
|
||||||
#include "framework/ge_runtime/task_info.h"
|
#include "framework/ge_runtime/task_info.h"
|
||||||
#include "runtime/device/ascend/profiling/profiling_utils.h"
|
#include "runtime/device/ascend/profiling/profiling_utils.h"
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
#include "runtime/device/ascend/executor/rts/profiling_rts_dynamic_kernel.h"
|
||||||
|
|
||||||
using ProfilerTraceTaskInfo = ge::model_runner::ProfilerTraceTaskInfo;
|
using ProfilerTraceTaskInfo = ge::model_runner::ProfilerTraceTaskInfo;
|
||||||
|
using mindspore::device::ascend::ProfilingRtsDynamicKernel;
|
||||||
using mindspore::device::ascend::ProfilingUtils;
|
using mindspore::device::ascend::ProfilingUtils;
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -64,5 +66,9 @@ std::vector<TaskInfoPtr> ProfilingKernelMod::GenTask(const std::vector<AddressPt
|
||||||
std::make_shared<ProfilerTraceTaskInfo>(kernel_name_, stream_id, log_id_, notify_, flags_);
|
std::make_shared<ProfilerTraceTaskInfo>(kernel_name_, stream_id, log_id_, notify_, flags_);
|
||||||
return {task_info_ptr};
|
return {task_info_ptr};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
device::DynamicKernelPtr ProfilingKernelMod::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) {
|
||||||
|
return std::make_shared<ProfilingRtsDynamicKernel>(stream_ptr, cnode_ptr, log_id_, notify_, flags_);
|
||||||
|
}
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -27,6 +27,7 @@ class ProfilingKernelMod : public RtKernel {
|
||||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||||
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
const std::vector<AddressPtr> &outputs, uint32_t stream_id) override;
|
const std::vector<AddressPtr> &outputs, uint32_t stream_id) override;
|
||||||
|
device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) override;
|
||||||
bool Init(const AnfNodePtr &anf_node) override;
|
bool Init(const AnfNodePtr &anf_node) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -29,157 +29,6 @@
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
namespace tbe {
|
namespace tbe {
|
||||||
static std::map<string, string> tbe_func_adapter_map = {
|
|
||||||
{"softmax", "softmax_v2"},
|
|
||||||
{"log_softmax", "log_softmax_v2"},
|
|
||||||
{"apply_momentum", "apply_momentum_d"},
|
|
||||||
{"apply_ftrl", "apply_ftrl_d"},
|
|
||||||
{"re_lu6", "relu6"},
|
|
||||||
{"re_lu6_grad", "relu6_grad"},
|
|
||||||
{"re_lu", "relu"},
|
|
||||||
{"reverse_v2", "reverse_v2_d"},
|
|
||||||
{"re_luv2", "relu_v2"},
|
|
||||||
{"p_re_lu", "prelu"},
|
|
||||||
{"p_re_lu_grad", "prelu_grad"},
|
|
||||||
{"tensor_add", "add"},
|
|
||||||
{"reduce_mean", "reduce_mean_d"},
|
|
||||||
{"reduce_max", "reduce_max_d"},
|
|
||||||
{"reduce_min", "reduce_min_d"},
|
|
||||||
{"avg_pool_grad", "avg_pool_grad_d"},
|
|
||||||
{"avg_pool_grad_vm", "avg_pool_grad_d"},
|
|
||||||
{"conv2d_backprop_filter", "conv2d_backprop_filter_d"},
|
|
||||||
{"conv2d_backprop_input", "conv2d_backprop_input_d"},
|
|
||||||
{"depthwise_conv2d_native", "depthwise_conv2d"},
|
|
||||||
{"depthwise_conv2d_native_backprop_filter", "depthwise_conv2d_backprop_filter_d"},
|
|
||||||
{"depthwise_conv2d_native_backprop_input", "depthwise_conv2d_backprop_input_d"},
|
|
||||||
{"scatter_nd", "scatter_nd_d"},
|
|
||||||
{"tile", "tile_d"},
|
|
||||||
{"gather_v2", "gather_v2_d"},
|
|
||||||
{"sparse_gather_v2", "gather_v2_d"},
|
|
||||||
{"batch_mat_mul", "batch_matmul"},
|
|
||||||
{"b_n_training_reduce", "bn_training_reduce"},
|
|
||||||
{"b_n_training_update", "bn_training_update"},
|
|
||||||
{"b_n_training_update_v2", "bn_training_update_v2"},
|
|
||||||
{"b_n_training_update_v3", "bn_training_update_v3"},
|
|
||||||
{"b_n_training_reduce_grad", "bn_training_reduce_grad"},
|
|
||||||
{"b_n_training_update_grad", "bn_training_update_grad"},
|
|
||||||
{"b_n_infer", "bn_infer"},
|
|
||||||
{"b_n_infer_grad", "bn_infer_grad"},
|
|
||||||
{"b_n_inference", "bninference_d"},
|
|
||||||
{"n_pu_clear_float_status", "n_p_u_clear_float_status"},
|
|
||||||
{"n_pu_get_float_status", "n_p_u_get_float_status"},
|
|
||||||
{"n_pu_alloc_float_status", "n_p_u_alloc_float_status"},
|
|
||||||
{"dropout_do_mask", "drop_out_do_mask"},
|
|
||||||
{"strided_slice", "strided_slice_d"},
|
|
||||||
{"strided_slice_grad", "strided_slice_grad_d"},
|
|
||||||
{"sparse_apply_ftrl", "sparse_apply_ftrl_d"},
|
|
||||||
{"sparse_apply_ftrl_v2", "sparse_apply_ftrl_v2_d"},
|
|
||||||
{"apply_ada_max", "apply_ada_max_d"},
|
|
||||||
{"apply_adadelta", "apply_adadelta_d"},
|
|
||||||
{"apply_adagrad", "apply_adagrad_d"},
|
|
||||||
{"apply_adagrad_v2", "apply_adagradv2_d"},
|
|
||||||
{"sparse_apply_adagrad", "sparse_apply_adagrad_d"},
|
|
||||||
{"sparse_apply_adagrad_v2", "sparse_apply_adagrad_v2_d"},
|
|
||||||
{"apply_proximal_adagrad", "apply_proximal_adagrad_d"},
|
|
||||||
{"sparse_apply_proximal_adagrad", "sparse_apply_proximal_adagrad_d"},
|
|
||||||
{"apply_add_sign", "apply_add_sign_d"},
|
|
||||||
{"apply_power_sign", "apply_power_sign_d"},
|
|
||||||
{"apply_centered_rms_prop", "apply_centered_rms_prop_d"},
|
|
||||||
{"transpose", "transpose_d"},
|
|
||||||
{"fill", "fill_d"},
|
|
||||||
{"unsorted_segment_sum", "unsorted_segment_sum_d"},
|
|
||||||
{"unsorted_segment_prod", "unsorted_segment_prod_d"},
|
|
||||||
{"concat", "concat_d"},
|
|
||||||
{"slice", "slice_d"},
|
|
||||||
{"reduce_sum", "reduce_sum_d"},
|
|
||||||
{"inplace_add", "inplace_add_d"},
|
|
||||||
{"inplace_sub", "inplace_sub_d"},
|
|
||||||
{"one_hot", "one_hot_d"},
|
|
||||||
{"sum", "reduce_sum_d"},
|
|
||||||
{"lamb_next_mv_with_decay", "lamb_next_m_v_with_decay"},
|
|
||||||
{"lamb_next_mv", "lamb_next_m_v"},
|
|
||||||
{"split", "split_d"},
|
|
||||||
{"split_v", "split_v_d"},
|
|
||||||
{"resize_nearest_neighbor", "resize_nearest_neighbor_v2_d"},
|
|
||||||
{"resize_nearest_neighbor_grad", "resize_nearest_neighbor_v2_grad_d"},
|
|
||||||
{"pad", "pad_d"},
|
|
||||||
{"argmax", "arg_max_d"},
|
|
||||||
{"argmin", "arg_min_d"},
|
|
||||||
{"space_to_batch", "space_to_batch_d"},
|
|
||||||
{"batch_to_space", "batch_to_space_d"},
|
|
||||||
{"space_to_batch_nd", "space_to_batch_nd_d"},
|
|
||||||
{"batch_to_space_nd", "batch_to_space_nd_d"},
|
|
||||||
{"resize_bilinear", "resize_bilinear_v2_d"},
|
|
||||||
{"resize_bilinear_grad", "resize_bilinear_v2_grad"},
|
|
||||||
{"adam", "apply_adam_d"},
|
|
||||||
{"r_oi_align", "roi_align"},
|
|
||||||
{"r_oi_align_grad", "roi_align_grad"},
|
|
||||||
{"i_ou", "iou"},
|
|
||||||
{"s_gd", "sgd"},
|
|
||||||
{"l_rn", "lrn"},
|
|
||||||
{"l_rn_grad", "lrn_grad"},
|
|
||||||
{"l_ars_update", "lars_v2_update"},
|
|
||||||
{"n_ms_with_mask", "nms_with_mask"},
|
|
||||||
{"square_sum_all", "square_sum_all"},
|
|
||||||
{"cum_sum", "cumsum_d"},
|
|
||||||
{"range", "range_d"},
|
|
||||||
{"lin_space", "lin_space_d"},
|
|
||||||
{"inv_grad", "inv_grad"},
|
|
||||||
{"apply_rms_prop", "apply_rms_prop_d"},
|
|
||||||
{"cum_prod", "cumprod_d"},
|
|
||||||
{"reduce_all", "reduce_all_d"},
|
|
||||||
{"reduce_any", "reduce_any_d"},
|
|
||||||
{"sparse_apply_adagrad", "sparse_apply_adagrad_d"},
|
|
||||||
{"unsorted_segment_min", "unsorted_segment_min_d"},
|
|
||||||
{"reduce_prod", "reduce_prod_d"},
|
|
||||||
{"a_cos", "acos"},
|
|
||||||
{"a_cos_grad", "acos_grad"},
|
|
||||||
{"histogram_fixed_width", "histogram_fixed_width_d"},
|
|
||||||
{"broadcast_to", "broadcast_to_d"},
|
|
||||||
{"inplace_update", "inplace_update_d"},
|
|
||||||
{"i_fmr", "ifmr"},
|
|
||||||
{"matrix_diag", "matrix_diag_d"},
|
|
||||||
{"matrix_diag_part", "matrix_diag_part_d"},
|
|
||||||
{"matrix_set_diag", "matrix_set_diag_d"},
|
|
||||||
{"l_stm_input_grad", "lstm_input_grad"}};
|
|
||||||
|
|
||||||
void TbeAdapter::NormalizeFuncName(std::string *func_name) {
|
|
||||||
if (func_name == nullptr) {
|
|
||||||
MS_LOG(EXCEPTION) << "func_name is null";
|
|
||||||
}
|
|
||||||
std::string name_tmp;
|
|
||||||
bool sub_head = false;
|
|
||||||
for (string::iterator iter = func_name->begin(); iter != func_name->end(); ++iter) {
|
|
||||||
if (islower(*iter)) {
|
|
||||||
sub_head = false;
|
|
||||||
}
|
|
||||||
if (isdigit(*iter)) {
|
|
||||||
sub_head = true;
|
|
||||||
}
|
|
||||||
if (isupper(*iter) && iter != func_name->begin()) {
|
|
||||||
if (!sub_head) {
|
|
||||||
(void)name_tmp.insert(name_tmp.end(), '_');
|
|
||||||
sub_head = true;
|
|
||||||
} else {
|
|
||||||
string::iterator iter_next = iter + 1;
|
|
||||||
if (iter_next != func_name->end()) {
|
|
||||||
if (islower(*iter_next)) {
|
|
||||||
(void)name_tmp.insert(name_tmp.end(), '_');
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
(void)name_tmp.insert(name_tmp.end(), *iter);
|
|
||||||
}
|
|
||||||
(void)transform(name_tmp.begin(), name_tmp.end(), name_tmp.begin(), ::tolower);
|
|
||||||
*func_name = name_tmp;
|
|
||||||
auto iter = tbe_func_adapter_map.find(*func_name);
|
|
||||||
if (iter != tbe_func_adapter_map.end()) {
|
|
||||||
MS_LOG(INFO) << "Map actual op from me: " << *func_name << " to tbe op: " << iter->second;
|
|
||||||
*func_name = iter->second;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
std::unordered_set<std::string> input_order_adjusted_ops = {
|
std::unordered_set<std::string> input_order_adjusted_ops = {
|
||||||
"Conv2DBackpropInput", "Conv2DBackpropFilter", "LogSoftmaxGrad", "LayerNormGrad", "LayerNormXBackprop",
|
"Conv2DBackpropInput", "Conv2DBackpropFilter", "LogSoftmaxGrad", "LayerNormGrad", "LayerNormXBackprop",
|
||||||
"LayerNormBetaGammaBackprop", "MinimumGrad", "MaximumGrad", "ApplyCenteredRMSProp"};
|
"LayerNormBetaGammaBackprop", "MinimumGrad", "MaximumGrad", "ApplyCenteredRMSProp"};
|
||||||
|
|
|
@ -35,7 +35,6 @@ class TbeAdapter {
|
||||||
public:
|
public:
|
||||||
TbeAdapter() = default;
|
TbeAdapter() = default;
|
||||||
~TbeAdapter() = default;
|
~TbeAdapter() = default;
|
||||||
static void NormalizeFuncName(std::string *func_name);
|
|
||||||
|
|
||||||
static void InputOrderPass(const std::string &op_name, std::vector<std::vector<nlohmann::json>> const &inputs_list,
|
static void InputOrderPass(const std::string &op_name, std::vector<std::vector<nlohmann::json>> const &inputs_list,
|
||||||
nlohmann::json *inputs_json);
|
nlohmann::json *inputs_json);
|
||||||
|
|
|
@ -0,0 +1,139 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
namespace tbe {
|
||||||
|
|
||||||
|
bool TbeDynamicShapeUtil::IsDynamicShapeNode(const CNodePtr &cnode) {
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
auto input_num = AnfAlgo ::GetInputTensorNum(cnode);
|
||||||
|
for (size_t i = 0; i < input_num; ++i) {
|
||||||
|
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(cnode, i);
|
||||||
|
if (std::any_of(input_shape.begin(), input_shape.end(), [](const size_t &dim) { return dim < 0; })) {
|
||||||
|
MS_LOG(INFO) << "Node(" << cnode->fullname_with_scope() << ") is dynamic shape node.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto output_num = AnfAlgo ::GetOutputTensorNum(cnode);
|
||||||
|
for (size_t i = 0; i < output_num; ++i) {
|
||||||
|
auto output_shape = AnfAlgo::GetOutputInferShape(cnode, i);
|
||||||
|
if (std::any_of(output_shape.begin(), output_shape.end(), [](const size_t &dim) { return dim < 0; })) {
|
||||||
|
MS_LOG(INFO) << "Node(" << cnode->fullname_with_scope() << ") is dynamic shape node.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TbeDynamicShapeUtil::IsDynamicShapeNode(const AnfNodePtr &anf_node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
if (anf_node->isa<CNode>()) {
|
||||||
|
auto cnode = anf_node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
return IsDynamicShapeNode(cnode);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void TbeDynamicShapeUtil::SetDynamicShapeAttr(const CNodePtr &cnode) {
|
||||||
|
auto is_dyanmic_shape = IsDynamicShapeNode(cnode);
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(is_dyanmic_shape), cnode);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TbeDynamicShapeUtil::GetDynamicShapeAttr(const AnfNodePtr &anf_node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
if (anf_node->isa<CNode>()) {
|
||||||
|
auto cnode = anf_node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
return GetDynamicShapeAttr(cnode);
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool TbeDynamicShapeUtil::GetDynamicShapeAttr(const CNodePtr &cnode) {
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
auto is_dynamic_shape = AnfAlgo::HasNodeAttr(kAttrIsDynamicShape, cnode);
|
||||||
|
if (!is_dynamic_shape) {
|
||||||
|
MS_LOG(INFO) << "Node(" << cnode->fullname_with_scope() << ") does not has is_dynamic_shape attribute.";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
is_dynamic_shape = AnfAlgo::GetNodeAttr<bool>(cnode, kAttrIsDynamicShape);
|
||||||
|
return is_dynamic_shape;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<OpInfo> TbeDynamicShapeUtil::FindOp(const std::string &op_name, const AnfNodePtr &anf_node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
if (anf_node->isa<CNode>()) {
|
||||||
|
auto cnode = anf_node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
return FindOp(op_name, cnode);
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<OpInfo> TbeDynamicShapeUtil::FindOp(const std::string &op_name, const CNodePtr &cnode) {
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
auto is_dynamic_shape = GetDynamicShapeAttr(cnode);
|
||||||
|
return mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE, is_dynamic_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<int, int>> TbeDynamicShapeUtil::GetInputDynamicRange(const AnfNodePtr &anf_node, size_t index) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
auto input_range_min = AnfAlgo::GetInputMinShape(anf_node, index);
|
||||||
|
auto input_range_max = AnfAlgo::GetInputMaxShape(anf_node, index);
|
||||||
|
if (input_range_min.size() != input_range_max.size()) {
|
||||||
|
MS_EXCEPTION(ArgumentError) << "Input range size is not equal, min size: " << input_range_min.size()
|
||||||
|
<< "max size: " << input_range_max.size();
|
||||||
|
}
|
||||||
|
if (input_range_min.empty() && input_range_max.empty()) {
|
||||||
|
return {{1, 1}};
|
||||||
|
}
|
||||||
|
std::vector<std::pair<int, int>> ret;
|
||||||
|
for (size_t i = 0; i < input_range_min.size(); ++i) {
|
||||||
|
ret.emplace_back(input_range_min[i], input_range_max[i]);
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<int, int>> TbeDynamicShapeUtil::GetOutputDynamicRange(const AnfNodePtr &anf_node, size_t index) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
auto output_range_min = AnfAlgo::GetOutputMinShape(anf_node, index);
|
||||||
|
auto output_range_max = AnfAlgo::GetOutputMaxShape(anf_node, index);
|
||||||
|
if (output_range_min.size() != output_range_max.size()) {
|
||||||
|
MS_EXCEPTION(ArgumentError) << "Onput range size is not equal, min size: " << output_range_min.size()
|
||||||
|
<< "max size: " << output_range_max.size();
|
||||||
|
}
|
||||||
|
if (output_range_max.empty() && output_range_min.empty()) {
|
||||||
|
return {{1, 1}};
|
||||||
|
}
|
||||||
|
std::vector<std::pair<int, int>> ret;
|
||||||
|
for (size_t i = 0; i < output_range_min.size(); ++i) {
|
||||||
|
ret.emplace_back(output_range_min[i], output_range_max[i]);
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace tbe
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,49 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_DYNAMINC_SHAPE_UTIL_H
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_DYNAMINC_SHAPE_UTIL_H
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
#include "mindspore/core/ir/anf.h"
|
||||||
|
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
namespace tbe {
|
||||||
|
|
||||||
|
class TbeDynamicShapeUtil {
|
||||||
|
public:
|
||||||
|
TbeDynamicShapeUtil() = default;
|
||||||
|
~TbeDynamicShapeUtil() = default;
|
||||||
|
static bool IsDynamicShapeNode(const CNodePtr &cnode);
|
||||||
|
static bool IsDynamicShapeNode(const AnfNodePtr &anf_node);
|
||||||
|
static void SetDynamicShapeAttr(const CNodePtr &cnode);
|
||||||
|
static bool GetDynamicShapeAttr(const CNodePtr &cnode);
|
||||||
|
static bool GetDynamicShapeAttr(const AnfNodePtr &anf_node);
|
||||||
|
static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, const AnfNodePtr &anf_node);
|
||||||
|
static std::shared_ptr<OpInfo> FindOp(const std::string &op_name, const CNodePtr &cnode);
|
||||||
|
static std::vector<std::pair<int, int>> GetInputDynamicRange(const AnfNodePtr &anf_node, size_t index);
|
||||||
|
static std::vector<std::pair<int, int>> GetOutputDynamicRange(const AnfNodePtr &anf_node, size_t index);
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace tbe
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_TBE_TBE_DYNAMINC_SHAPE_UTIL_H
|
|
@ -23,6 +23,7 @@
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
#include "backend/kernel_compiler/tbe/tbe_adapter.h"
|
#include "backend/kernel_compiler/tbe/tbe_adapter.h"
|
||||||
#include "backend/kernel_compiler/tbe/tbe_convert_utils.h"
|
#include "backend/kernel_compiler/tbe/tbe_convert_utils.h"
|
||||||
|
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
|
||||||
#include "backend/kernel_compiler/tbe/tbe_utils.h"
|
#include "backend/kernel_compiler/tbe/tbe_utils.h"
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
#include "runtime/dev.h"
|
#include "runtime/dev.h"
|
||||||
|
@ -61,6 +62,7 @@ constexpr auto kJDataType = "data_type";
|
||||||
constexpr auto kJOutputIndex = "output_index";
|
constexpr auto kJOutputIndex = "output_index";
|
||||||
constexpr auto kJOutputDesc = "output_desc";
|
constexpr auto kJOutputDesc = "output_desc";
|
||||||
constexpr auto kJInputDesc = "input_desc";
|
constexpr auto kJInputDesc = "input_desc";
|
||||||
|
constexpr auto kJRange = "range";
|
||||||
constexpr auto kVTypeInt = "int";
|
constexpr auto kVTypeInt = "int";
|
||||||
constexpr auto kVTypeStr = "str";
|
constexpr auto kVTypeStr = "str";
|
||||||
constexpr auto kVTypeBool = "bool";
|
constexpr auto kVTypeBool = "bool";
|
||||||
|
@ -89,24 +91,21 @@ constexpr auto kJKwdArgs = "kwds_args";
|
||||||
constexpr auto kJListArgs = "list_args";
|
constexpr auto kJListArgs = "list_args";
|
||||||
constexpr auto kJSocVersion = "socVersion";
|
constexpr auto kJSocVersion = "socVersion";
|
||||||
constexpr auto kSOC_VERSION = "SOC_VERSION";
|
constexpr auto kSOC_VERSION = "SOC_VERSION";
|
||||||
|
constexpr auto kJIsDynamicShape = "is_dynamic_shape";
|
||||||
|
|
||||||
bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr<mindspore::AnfNode> &anf_node,
|
bool TbeKernelJsonCreator::GenTbeSingleKernelJson(const std::shared_ptr<mindspore::AnfNode> &anf_node,
|
||||||
nlohmann::json *kernel_json) {
|
nlohmann::json *kernel_json) {
|
||||||
MS_EXCEPTION_IF_NULL(anf_node);
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
MS_EXCEPTION_IF_NULL(kernel_json);
|
MS_EXCEPTION_IF_NULL(kernel_json);
|
||||||
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
|
std::string op_name = AnfAlgo::GetCNodeName(anf_node);
|
||||||
auto op_info_ptr = mindspore::kernel::OpLib::FindOp(op_name, OpImplyType::kTBE);
|
auto op_info_ptr = mindspore::kernel::tbe::TbeDynamicShapeUtil::FindOp(op_name, anf_node);
|
||||||
MS_EXCEPTION_IF_NULL(op_info_ptr);
|
MS_EXCEPTION_IF_NULL(op_info_ptr);
|
||||||
(*kernel_json)[kPlatform] = kPlatTBE;
|
(*kernel_json)[kPlatform] = kPlatTBE;
|
||||||
(*kernel_json)[kGenModel] = kSingle;
|
(*kernel_json)[kGenModel] = kSingle;
|
||||||
(*kernel_json)[kImplPath] = op_info_ptr->impl_path();
|
(*kernel_json)[kImplPath] = op_info_ptr->impl_path();
|
||||||
nlohmann::json op_info_json;
|
nlohmann::json op_info_json;
|
||||||
if (op_info_ptr->impl_path().empty()) {
|
op_info_json[kJIsDynamicShape] = tbe::TbeDynamicShapeUtil::GetDynamicShapeAttr(anf_node->cast<CNodePtr>());
|
||||||
tbe::TbeAdapter::NormalizeFuncName(&op_name);
|
op_info_json[kJName] = op_info_ptr->kernel_name();
|
||||||
} else {
|
|
||||||
op_name = op_info_ptr->kernel_name();
|
|
||||||
}
|
|
||||||
op_info_json[kJName] = op_name;
|
|
||||||
// generate inputs json
|
// generate inputs json
|
||||||
nlohmann::json inputs_json;
|
nlohmann::json inputs_json;
|
||||||
if (!GenTbeInputsJson(anf_node, op_info_ptr, &inputs_json)) {
|
if (!GenTbeInputsJson(anf_node, op_info_ptr, &inputs_json)) {
|
||||||
|
@ -180,6 +179,7 @@ bool TbeKernelJsonCreator::GenInputDescJson(const std::shared_ptr<AnfNode> &anf_
|
||||||
input_desc_json[kJFormat] = format;
|
input_desc_json[kJFormat] = format;
|
||||||
input_desc_json[kJValid] = value;
|
input_desc_json[kJValid] = value;
|
||||||
input_desc_json[kJParamType] = input_ptr->param_type();
|
input_desc_json[kJParamType] = input_ptr->param_type();
|
||||||
|
input_desc_json[kJRange] = tbe::TbeDynamicShapeUtil::GetInputDynamicRange(anf_node, real_input_index);
|
||||||
input_list->emplace_back(input_desc_json);
|
input_list->emplace_back(input_desc_json);
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
@ -359,8 +359,13 @@ void TbeKernelJsonCreator::GenOutputList(const std::shared_ptr<AnfNode> &anf_nod
|
||||||
for (size_t i = 0; i < output_obj_num; i++) {
|
for (size_t i = 0; i < output_obj_num; i++) {
|
||||||
auto dtype = GetDeviceOutputType(anf_node, *output_idx);
|
auto dtype = GetDeviceOutputType(anf_node, *output_idx);
|
||||||
auto format = GetDeviceOutputFormat(anf_node, *output_idx);
|
auto format = GetDeviceOutputFormat(anf_node, *output_idx);
|
||||||
auto shape = GetDeviceOutputShape(anf_node, *output_idx);
|
|
||||||
std::vector<size_t> ori_shape = AnfAlgo::GetOutputInferShape(anf_node, *output_idx);
|
std::vector<int64_t> shape;
|
||||||
|
AnfAlgo::GetRealDynamicShape(GetDeviceOutputShape(anf_node, *output_idx), NOT_NULL(&shape));
|
||||||
|
|
||||||
|
std::vector<int64_t> ori_shape;
|
||||||
|
AnfAlgo::GetRealDynamicShape(AnfAlgo::GetOutputInferShape(anf_node, *output_idx), NOT_NULL(&ori_shape));
|
||||||
|
// std::vector<size_t> ori_shape = AnfAlgo::GetOutputInferShape(anf_node, *output_idx);
|
||||||
if (ori_shape.empty()) {
|
if (ori_shape.empty()) {
|
||||||
ori_shape.emplace_back(1);
|
ori_shape.emplace_back(1);
|
||||||
}
|
}
|
||||||
|
@ -373,6 +378,7 @@ void TbeKernelJsonCreator::GenOutputList(const std::shared_ptr<AnfNode> &anf_nod
|
||||||
output_obj[kJName] = output_ptr->name();
|
output_obj[kJName] = output_ptr->name();
|
||||||
output_obj[kJValid] = true;
|
output_obj[kJValid] = true;
|
||||||
output_obj[kJParamType] = output_ptr->param_type();
|
output_obj[kJParamType] = output_ptr->param_type();
|
||||||
|
output_obj[kJRange] = tbe::TbeDynamicShapeUtil::GetOutputDynamicRange(anf_node, *output_idx);
|
||||||
output_list->emplace_back(output_obj);
|
output_list->emplace_back(output_obj);
|
||||||
(*output_idx)++;
|
(*output_idx)++;
|
||||||
}
|
}
|
||||||
|
@ -575,48 +581,76 @@ std::string TbeKernelJsonCreator::GetDeviceOutputFormat(const AnfNodePtr &anf_no
|
||||||
return format;
|
return format;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void GetInputSizeList(const nlohmann::json &input_json, std::vector<size_t> *input_size_list,
|
||||||
|
const AnfNodePtr &anf_node) {
|
||||||
|
for (size_t i = 0; i < input_json.size(); i++) {
|
||||||
|
for (size_t m = 0; m < input_json[i].size(); m++) {
|
||||||
|
size_t size_i = 1;
|
||||||
|
if (input_json[i][m][kJValid] == false) {
|
||||||
|
std::string input_name = input_json[i][m][kJName];
|
||||||
|
MS_LOG(INFO) << "Input name:" << input_name << "is optional, valid is false.";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (size_t j = 0; j < input_json[i][m][kJShape].size(); ++j) {
|
||||||
|
if (input_json[i][m][kJShape][j] == -1) {
|
||||||
|
auto input_max_shape = AnfAlgo::GetInputMaxShape(anf_node, i);
|
||||||
|
if (j >= input_max_shape.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid Dynamic Shape Max Shape";
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Change -1 Shape to Max Shape:" << input_max_shape[j];
|
||||||
|
size_i *= input_max_shape[j];
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
size_i *= static_cast<size_t>(input_json[i][m][kJShape][j]);
|
||||||
|
}
|
||||||
|
std::string dtype = input_json[i][m][kJDtype];
|
||||||
|
size_t nbyte = tbe::GetDtypeNbyte(dtype);
|
||||||
|
size_i *= nbyte;
|
||||||
|
input_size_list->push_back(size_i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void GetOutputSizeList(const nlohmann::json &output_json, std::vector<size_t> *output_size_list,
|
||||||
|
const AnfNodePtr &anf_node) {
|
||||||
|
for (size_t i = 0; i < output_json.size(); i++) {
|
||||||
|
for (size_t m = 0; m < output_json[i].size(); m++) {
|
||||||
|
size_t size_i = 1;
|
||||||
|
if (output_json[i][m][kJValid] == false) {
|
||||||
|
std::string output_name = output_json[i][m][kJName];
|
||||||
|
MS_LOG(INFO) << "Output name:" << output_name << " is optional, valid is false.";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
for (size_t j = 0; j < output_json[i][m][kJShape].size(); ++j) {
|
||||||
|
if (output_json[i][m][kJShape][j] == -1) {
|
||||||
|
auto output_max_shape = AnfAlgo::GetOutputMaxShape(anf_node, i);
|
||||||
|
if (j >= output_max_shape.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid Dynamic Shape Max Shape";
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Change -1 Shape to Max Shape:" << output_max_shape[j];
|
||||||
|
size_i *= output_max_shape[j];
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
size_i *= static_cast<size_t>(output_json[i][m][kJShape][j]);
|
||||||
|
}
|
||||||
|
std::string dtype = output_json[i][m][kJDtype];
|
||||||
|
size_t nbyte = tbe::GetDtypeNbyte(dtype);
|
||||||
|
size_i *= nbyte;
|
||||||
|
output_size_list->push_back(size_i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector<size_t> *input_size_list,
|
bool TbeKernelBuild::GetIOSize(const nlohmann::json &kernel_json, std::vector<size_t> *input_size_list,
|
||||||
std::vector<size_t> *output_size_list) {
|
std::vector<size_t> *output_size_list, const AnfNodePtr &anf_node) {
|
||||||
if (input_size_list == nullptr || output_size_list == nullptr) {
|
if (input_size_list == nullptr || output_size_list == nullptr) {
|
||||||
MS_LOG(ERROR) << "Input size or output size is nullptr";
|
MS_LOG(ERROR) << "Input size or output size is nullptr";
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
input_size_list->clear();
|
input_size_list->clear();
|
||||||
output_size_list->clear();
|
output_size_list->clear();
|
||||||
for (size_t i = 0; i < kernel_json[kJOpInfo][kJInputs].size(); i++) {
|
GetInputSizeList(kernel_json[kJOpInfo][kJInputs], input_size_list, anf_node);
|
||||||
for (size_t m = 0; m < kernel_json[kJOpInfo][kJInputs][i].size(); m++) {
|
GetOutputSizeList(kernel_json[kJOpInfo][kJOutputs], output_size_list, anf_node);
|
||||||
size_t size_i = 1;
|
|
||||||
if (kernel_json[kJOpInfo][kJInputs][i][m][kJValid] == false) {
|
|
||||||
std::string input_name = kernel_json[kJOpInfo][kJInputs][i][m][kJName];
|
|
||||||
MS_LOG(INFO) << "Input name:" << input_name << "is optional, valid is false.";
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
for (const auto &j : kernel_json[kJOpInfo][kJInputs][i][m][kJShape]) {
|
|
||||||
size_i *= static_cast<size_t>(j);
|
|
||||||
}
|
|
||||||
std::string dtype = kernel_json[kJOpInfo][kJInputs][i][m][kJDtype];
|
|
||||||
size_t nbyte = tbe::GetDtypeNbyte(dtype);
|
|
||||||
size_i *= nbyte;
|
|
||||||
input_size_list->push_back(size_i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i < kernel_json[kJOpInfo][kJOutputs].size(); i++) {
|
|
||||||
for (size_t m = 0; m < kernel_json[kJOpInfo][kJOutputs][i].size(); m++) {
|
|
||||||
size_t size_i = 1;
|
|
||||||
if (kernel_json[kJOpInfo][kJOutputs][i][m][kJValid] == false) {
|
|
||||||
std::string output_name = kernel_json[kJOpInfo][kJOutputs][i][m][kJName];
|
|
||||||
MS_LOG(INFO) << "Output name:" << output_name << " is optional, valid is false.";
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
for (const auto &j : kernel_json[kJOpInfo][kJOutputs][i][m][kJShape]) {
|
|
||||||
size_i *= static_cast<size_t>(j);
|
|
||||||
}
|
|
||||||
std::string dtype = kernel_json[kJOpInfo][kJOutputs][i][m][kJDtype];
|
|
||||||
size_t nbyte = tbe::GetDtypeNbyte(dtype);
|
|
||||||
size_i *= nbyte;
|
|
||||||
output_size_list->push_back(size_i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -678,17 +712,18 @@ void TbeKernelBuild::GenFusionComputeCommonJson(const mindspore::CNodePtr &cnode
|
||||||
MS_EXCEPTION_IF_NULL(fusion_kernel_name);
|
MS_EXCEPTION_IF_NULL(fusion_kernel_name);
|
||||||
// gen others
|
// gen others
|
||||||
auto origin_type = AnfAlgo::GetCNodeName(cnode);
|
auto origin_type = AnfAlgo::GetCNodeName(cnode);
|
||||||
|
auto op_info_ptr = tbe::TbeDynamicShapeUtil::FindOp(origin_type, cnode);
|
||||||
// replace special op type for buffer fusion op
|
// replace special op type for buffer fusion op
|
||||||
auto type = GetRealOpType(origin_type);
|
auto type = GetRealOpType(origin_type);
|
||||||
(*compute_op_str)[kJtype] = type;
|
(*compute_op_str)[kJtype] = type;
|
||||||
tbe::TbeAdapter::NormalizeFuncName(&type);
|
auto kernel_name = op_info_ptr->kernel_name();
|
||||||
(*compute_op_str)[kJFuncName] = type;
|
(*compute_op_str)[kJFuncName] = kernel_name;
|
||||||
(*compute_op_str)[kJModuleName] = std::string("impl.") + type;
|
(*compute_op_str)[kJModuleName] = std::string("impl.") + type;
|
||||||
(*compute_op_str)[kJName] = cnode->fullname_with_scope();
|
(*compute_op_str)[kJName] = cnode->fullname_with_scope();
|
||||||
(*compute_op_str)[kJPattern] = GetNodeFusionType(cnode);
|
(*compute_op_str)[kJPattern] = GetNodeFusionType(cnode);
|
||||||
(*compute_op_str)[kJPyModulePath] = "/usr/local/Ascend/opp/op_impl/build_in/ai_core/tbe";
|
(*compute_op_str)[kJPyModulePath] = "/usr/local/Ascend/opp/op_impl/build_in/ai_core/tbe";
|
||||||
(void)(*fusion_kernel_name).append("_");
|
(void)(*fusion_kernel_name).append("_");
|
||||||
(void)(*fusion_kernel_name).append(type);
|
(void)(*fusion_kernel_name).append(kernel_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
void TbeKernelBuild::GenFusionComputePreBuildJson(const mindspore::CNodePtr &cnode, nlohmann::json *compute_op_str) {
|
void TbeKernelBuild::GenFusionComputePreBuildJson(const mindspore::CNodePtr &cnode, nlohmann::json *compute_op_str) {
|
||||||
|
@ -952,7 +987,7 @@ size_t TbeKernelBuild::GetOptionalInput(const mindspore::CNodePtr &cnode, bool i
|
||||||
}
|
}
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
auto node_name = AnfAlgo::GetCNodeName(cnode);
|
auto node_name = AnfAlgo::GetCNodeName(cnode);
|
||||||
auto op_info = OpLib::FindOp(node_name, kTBE);
|
auto op_info = tbe::TbeDynamicShapeUtil::FindOp(node_name, cnode);
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
if (op_info->inputs_ptr().size() < (cnode->inputs().size() - 1)) {
|
if (op_info->inputs_ptr().size() < (cnode->inputs().size() - 1)) {
|
||||||
MS_EXCEPTION(ArgumentError) << "op info error, node name:" << cnode->fullname_with_scope();
|
MS_EXCEPTION(ArgumentError) << "op info error, node name:" << cnode->fullname_with_scope();
|
||||||
|
|
|
@ -38,7 +38,7 @@ class TbeKernelBuild {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static bool GetIOSize(const nlohmann::json &kernel_json, std::vector<size_t> *input_size_list,
|
static bool GetIOSize(const nlohmann::json &kernel_json, std::vector<size_t> *input_size_list,
|
||||||
std::vector<size_t> *output_size_list);
|
std::vector<size_t> *output_size_list, const AnfNodePtr &anf_node);
|
||||||
// Ub Fuison
|
// Ub Fuison
|
||||||
static bool GenFusionScopeJson(const std::vector<AnfNodePtr> &input_nodes,
|
static bool GenFusionScopeJson(const std::vector<AnfNodePtr> &input_nodes,
|
||||||
const std::vector<AnfNodePtr> &compute_nodes, nlohmann::json *fusion_json,
|
const std::vector<AnfNodePtr> &compute_nodes, nlohmann::json *fusion_json,
|
||||||
|
|
|
@ -19,11 +19,14 @@
|
||||||
#include "runtime/rt.h"
|
#include "runtime/rt.h"
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
#include "graphengine/inc/framework/ge_runtime/task_info.h"
|
#include "graphengine/inc/framework/ge_runtime/task_info.h"
|
||||||
|
#include "runtime/device/ascend/executor/ai_core_dynamic_kernel.h"
|
||||||
|
#include "runtime/device/kernel_runtime.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
using TbeTaskInfoPtr = std::shared_ptr<ge::model_runner::TbeTaskInfo>;
|
using TbeTaskInfoPtr = std::shared_ptr<ge::model_runner::TbeTaskInfo>;
|
||||||
using tbe::KernelManager;
|
using tbe::KernelManager;
|
||||||
|
using AddressPtrList = std::vector<mindspore::kernel::AddressPtr>;
|
||||||
bool TbeKernelMod::Launch(const std::vector<mindspore::kernel::AddressPtr> &inputs,
|
bool TbeKernelMod::Launch(const std::vector<mindspore::kernel::AddressPtr> &inputs,
|
||||||
const std::vector<mindspore::kernel::AddressPtr> &workspace,
|
const std::vector<mindspore::kernel::AddressPtr> &workspace,
|
||||||
const std::vector<mindspore::kernel::AddressPtr> &outputs, void *stream_ptr) {
|
const std::vector<mindspore::kernel::AddressPtr> &outputs, void *stream_ptr) {
|
||||||
|
@ -105,6 +108,49 @@ std::vector<TaskInfoPtr> TbeKernelMod::GenTask(const std::vector<AddressPtr> &in
|
||||||
return {task_info_ptr};
|
return {task_info_ptr};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
device::DynamicKernelPtr TbeKernelMod::GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) {
|
||||||
|
AddressPtrList kernel_inputs;
|
||||||
|
AddressPtrList kernel_workspaces;
|
||||||
|
AddressPtrList kernel_outputs;
|
||||||
|
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||||
|
|
||||||
|
// Get para_size from json
|
||||||
|
auto kernel_json_info = kernel_pack_->kernel_json_info();
|
||||||
|
auto op_para_size = kernel_json_info.op_para_size;
|
||||||
|
|
||||||
|
// Get stub_function
|
||||||
|
uint32_t block_dim = 1; // default blockdim equal to 1.
|
||||||
|
auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim);
|
||||||
|
if (func_stub == 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "GenFuncStub failed.";
|
||||||
|
}
|
||||||
|
const void *stub_func_ptr = reinterpret_cast<void *>(func_stub);
|
||||||
|
|
||||||
|
// Generate args
|
||||||
|
std::vector<void *> runtime_args;
|
||||||
|
(void)std::transform(std::begin(kernel_inputs), std::end(kernel_inputs), std::back_inserter(runtime_args),
|
||||||
|
[](const AddressPtr &input) -> void * { return input->addr; });
|
||||||
|
(void)std::transform(std::begin(kernel_outputs), std::end(kernel_outputs), std::back_inserter(runtime_args),
|
||||||
|
[](const AddressPtr &output) -> void * { return output->addr; });
|
||||||
|
if (!kernel_workspaces.empty()) {
|
||||||
|
(void)std::transform(std::begin(kernel_workspaces), std::end(kernel_workspaces), std::back_inserter(runtime_args),
|
||||||
|
[](const AddressPtr &addr) -> void * { return addr->addr; });
|
||||||
|
}
|
||||||
|
|
||||||
|
void *tiling_data_ptr = nullptr;
|
||||||
|
if (op_para_size > 0) {
|
||||||
|
auto ret = rtMalloc(&tiling_data_ptr, op_para_size, RT_MEMORY_HBM);
|
||||||
|
if (ret != RT_ERROR_NONE) {
|
||||||
|
MS_LOG(EXCEPTION) << "rtMalloc tiling data failed";
|
||||||
|
}
|
||||||
|
runtime_args.push_back(tiling_data_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto executor = std::make_shared<device::ascend::AiCoreDynamicKernel>(
|
||||||
|
stub_func_ptr, block_dim, tiling_data_ptr, op_para_size, stream_ptr, cnode_ptr, runtime_args);
|
||||||
|
return executor;
|
||||||
|
}
|
||||||
|
|
||||||
vector<size_t> TbeKernelMod::GenParameters() {
|
vector<size_t> TbeKernelMod::GenParameters() {
|
||||||
auto kernel_json_info = kernel_pack_->kernel_json_info();
|
auto kernel_json_info = kernel_pack_->kernel_json_info();
|
||||||
return kernel_json_info.parameters;
|
return kernel_json_info.parameters;
|
||||||
|
|
|
@ -42,6 +42,7 @@ class TbeKernelMod : public AscendKernelMod {
|
||||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||||
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
|
std::vector<TaskInfoPtr> GenTask(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspaces,
|
||||||
const std::vector<AddressPtr> &outputs, uint32_t stream_id) override;
|
const std::vector<AddressPtr> &outputs, uint32_t stream_id) override;
|
||||||
|
device::DynamicKernelPtr GenDynamicKernel(const CNodePtr &cnode_ptr, void *stream_ptr) override;
|
||||||
std::vector<size_t> GenParameters() override;
|
std::vector<size_t> GenParameters() override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
@ -15,13 +15,11 @@
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h"
|
#include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
#include "backend/kernel_compiler/tbe/tbe_adapter.h"
|
#include "backend/kernel_compiler/tbe/tbe_adapter.h"
|
||||||
#include "backend/kernel_compiler/tbe/tbe_kernel_build.h"
|
#include "backend/kernel_compiler/tbe/tbe_kernel_build.h"
|
||||||
|
@ -29,6 +27,7 @@
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
#include "backend/kernel_compiler/tbe/tbe_convert_utils.h"
|
#include "backend/kernel_compiler/tbe/tbe_convert_utils.h"
|
||||||
#include "backend/kernel_compiler/tbe/tbe_utils.h"
|
#include "backend/kernel_compiler/tbe/tbe_utils.h"
|
||||||
|
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace kernel {
|
namespace kernel {
|
||||||
|
@ -52,15 +51,18 @@ bool TbeOpParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) {
|
||||||
// get size
|
// get size
|
||||||
std::vector<size_t> input_size_list;
|
std::vector<size_t> input_size_list;
|
||||||
std::vector<size_t> output_size_list;
|
std::vector<size_t> output_size_list;
|
||||||
(void)TbeKernelBuild::GetIOSize(kernel_json, &input_size_list, &output_size_list);
|
(void)TbeKernelBuild::GetIOSize(kernel_json, &input_size_list, &output_size_list, anf_node);
|
||||||
// search cache
|
// search cache
|
||||||
const std::string &json_name = creator.json_name();
|
const std::string &json_name = creator.json_name();
|
||||||
if (build_manger->SearchInCache(json_name, processor, input_size_list, output_size_list, anf_node.get())) {
|
auto IsDynamicShape = tbe::TbeDynamicShapeUtil::GetDynamicShapeAttr(anf_node);
|
||||||
MS_LOG(INFO) << "Use cached kernel, kernel json name:." << json_name;
|
if (build_manger->SearchInCache(json_name, processor, input_size_list, output_size_list, anf_node.get()) &&
|
||||||
|
!IsDynamicShape) {
|
||||||
|
MS_LOG(INFO) << "Node:" << anf_node->fullname_with_scope() << " Use cached kernel, kernel json name:."
|
||||||
|
<< json_name;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
// same op not need build, but need wait build finish to set kernel mode
|
// same op not need build, but need wait build finish to set kernel mode
|
||||||
if (processed_kernel.find(json_name) != processed_kernel.end()) {
|
if (processed_kernel.find(json_name) != processed_kernel.end() && !IsDynamicShape) {
|
||||||
build_manger->SaveSameOpInfo(anf_node, json_name, input_size_list, output_size_list);
|
build_manger->SaveSameOpInfo(anf_node, json_name, input_size_list, output_size_list);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
@ -72,8 +74,8 @@ bool TbeOpParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) {
|
||||||
while (!build_manger->IsAllTaskFinish()) {
|
while (!build_manger->IsAllTaskFinish()) {
|
||||||
int task_id = -1;
|
int task_id = -1;
|
||||||
std::string task_result;
|
std::string task_result;
|
||||||
std::string pre_build_result;
|
std::string build_result;
|
||||||
auto ret = build_manger->WaitOne(&task_id, &task_result, &pre_build_result);
|
auto ret = build_manger->WaitOne(&task_id, &task_result, &build_result);
|
||||||
if (!ret) {
|
if (!ret) {
|
||||||
MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id;
|
MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id;
|
||||||
}
|
}
|
||||||
|
@ -81,7 +83,7 @@ bool TbeOpParallelBuild(const std::vector<AnfNodePtr> &anf_nodes) {
|
||||||
if (task_result != "Success") {
|
if (task_result != "Success") {
|
||||||
MS_EXCEPTION(ArgumentError) << "task compile Failed, task id:" << task_id << ", cause:" << task_result;
|
MS_EXCEPTION(ArgumentError) << "task compile Failed, task id:" << task_id << ", cause:" << task_result;
|
||||||
}
|
}
|
||||||
(void)build_manger->TaskFinishProcess(task_id);
|
(void)build_manger->TaskFinishProcess(task_id, build_result);
|
||||||
}
|
}
|
||||||
return build_manger->GenSameOpKernelMod();
|
return build_manger->GenSameOpKernelMod();
|
||||||
}
|
}
|
||||||
|
@ -93,7 +95,7 @@ void ParallelBuildManager::SaveTaskInfo(int32_t task_id, const mindspore::AnfNod
|
||||||
const std::vector<size_t> &output_size_list, int32_t scope_id) {
|
const std::vector<size_t> &output_size_list, int32_t scope_id) {
|
||||||
MS_LOG(INFO) << "SaveTaskInfo, task id: " << task_id;
|
MS_LOG(INFO) << "SaveTaskInfo, task id: " << task_id;
|
||||||
struct KernelBuildTaskInfo task_info;
|
struct KernelBuildTaskInfo task_info;
|
||||||
task_info.node = anf_node.get();
|
task_info.node = anf_node;
|
||||||
task_info.json_name = json_name;
|
task_info.json_name = json_name;
|
||||||
if (anf_node == nullptr) {
|
if (anf_node == nullptr) {
|
||||||
task_info.processor = tbe::kProcessorAiCore;
|
task_info.processor = tbe::kProcessorAiCore;
|
||||||
|
@ -111,7 +113,38 @@ bool ParallelBuildManager::IsAllTaskFinish() const {
|
||||||
return task_map_.empty();
|
return task_map_.empty();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<int32_t, KernelModPtr> ParallelBuildManager::TaskFinishProcess(int32_t task_id, bool set_kernel_mod) {
|
void ParallelBuildManager::PreTaskFinishProcess(int32_t task_id, const std::string &pre_build_result) {
|
||||||
|
auto task_iter = pre_task_map_.find(task_id);
|
||||||
|
if (task_iter == pre_task_map_.end()) {
|
||||||
|
MS_EXCEPTION(ArgumentError) << "can find pre task_id:" << task_id;
|
||||||
|
}
|
||||||
|
auto node = task_iter->second;
|
||||||
|
auto builder =
|
||||||
|
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
|
||||||
|
std::string start_flag = "fusion_pattern_start";
|
||||||
|
std::string end_flag = "fusion_pattern_end";
|
||||||
|
int start = pre_build_result.find(start_flag);
|
||||||
|
int end = pre_build_result.find(end_flag);
|
||||||
|
if (start != -1 && end != -1 && end >= start) {
|
||||||
|
std::string result = pre_build_result.substr(start + start_flag.size(), end - start - start_flag.size());
|
||||||
|
if (result.empty()) {
|
||||||
|
(void)pre_task_map_.erase(task_iter);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
transform(result.begin(), result.end(), result.begin(), ::toupper);
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrFusionType, MakeValue(result), node);
|
||||||
|
FusionType fusion_type = tbe::GetFusionType(result);
|
||||||
|
builder->SetFusionType(fusion_type);
|
||||||
|
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), node.get());
|
||||||
|
}
|
||||||
|
(void)pre_task_map_.erase(task_iter);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::pair<int32_t, KernelModPtr> ParallelBuildManager::TaskFinishProcess(int32_t task_id, const std::string &build_ret,
|
||||||
|
bool set_kernel_mod) {
|
||||||
|
auto compile_info = ProcessBuildRetStr(build_ret);
|
||||||
|
MS_LOG(DEBUG) << "Tbe build ret:" << compile_info;
|
||||||
|
|
||||||
auto task_iter = task_map_.find(task_id);
|
auto task_iter = task_map_.find(task_id);
|
||||||
if (task_iter == task_map_.end()) {
|
if (task_iter == task_map_.end()) {
|
||||||
MS_EXCEPTION(ArgumentError) << "can find task_id:" << task_id;
|
MS_EXCEPTION(ArgumentError) << "can find task_id:" << task_id;
|
||||||
|
@ -133,7 +166,9 @@ std::pair<int32_t, KernelModPtr> ParallelBuildManager::TaskFinishProcess(int32_t
|
||||||
task_iter->second.output_size_list, kernel_pack);
|
task_iter->second.output_size_list, kernel_pack);
|
||||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||||
if (set_kernel_mod) {
|
if (set_kernel_mod) {
|
||||||
AnfAlgo::SetKernelMod(kernel_mod, task_iter->second.node);
|
AnfAlgo::SetKernelMod(kernel_mod, task_iter->second.node.get());
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrCompileInfo, MakeValue(compile_info), task_iter->second.node);
|
||||||
|
MS_LOG(DEBUG) << "Set Node Attr compile_info:" << compile_info;
|
||||||
}
|
}
|
||||||
auto ret = std::make_pair(task_iter->second.scope_id, kernel_mod);
|
auto ret = std::make_pair(task_iter->second.scope_id, kernel_mod);
|
||||||
(void)task_map_.erase(task_iter);
|
(void)task_map_.erase(task_iter);
|
||||||
|
@ -145,7 +180,7 @@ void ParallelBuildManager::SaveSameOpInfo(const mindspore::AnfNodePtr &anf_node,
|
||||||
const std::vector<size_t> &input_size_list,
|
const std::vector<size_t> &input_size_list,
|
||||||
const std::vector<size_t> &output_size_list) {
|
const std::vector<size_t> &output_size_list) {
|
||||||
struct KernelBuildTaskInfo task_info;
|
struct KernelBuildTaskInfo task_info;
|
||||||
task_info.node = anf_node.get();
|
task_info.node = anf_node;
|
||||||
task_info.json_name = json_name;
|
task_info.json_name = json_name;
|
||||||
task_info.processor = tbe::GetProcessor(anf_node);
|
task_info.processor = tbe::GetProcessor(anf_node);
|
||||||
task_info.input_size_list.assign(input_size_list.begin(), input_size_list.end());
|
task_info.input_size_list.assign(input_size_list.begin(), input_size_list.end());
|
||||||
|
@ -156,7 +191,7 @@ void ParallelBuildManager::SaveSameOpInfo(const mindspore::AnfNodePtr &anf_node,
|
||||||
bool ParallelBuildManager::GenSameOpKernelMod() const {
|
bool ParallelBuildManager::GenSameOpKernelMod() const {
|
||||||
for (const auto &task_info : same_op_list_) {
|
for (const auto &task_info : same_op_list_) {
|
||||||
bool ret = SearchInCache(task_info.json_name, task_info.processor, task_info.input_size_list,
|
bool ret = SearchInCache(task_info.json_name, task_info.processor, task_info.input_size_list,
|
||||||
task_info.output_size_list, task_info.node);
|
task_info.output_size_list, task_info.node.get());
|
||||||
if (!ret) {
|
if (!ret) {
|
||||||
MS_LOG(INFO) << "can't find " << task_info.json_name << " in cache.";
|
MS_LOG(INFO) << "can't find " << task_info.json_name << " in cache.";
|
||||||
return false;
|
return false;
|
||||||
|
@ -212,5 +247,20 @@ void ParallelBuildManager::ResetTaskInfo() {
|
||||||
same_op_list_.clear();
|
same_op_list_.clear();
|
||||||
AscendKernelBuildClient::Instance().TbeReset();
|
AscendKernelBuildClient::Instance().TbeReset();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::string ParallelBuildManager::ProcessBuildRetStr(const std::string &build_result) {
|
||||||
|
std::string start_flag = "fusion_pattern_start";
|
||||||
|
std::string end_flag = "fusion_pattern_end";
|
||||||
|
int start = build_result.find(start_flag);
|
||||||
|
int end = build_result.find(end_flag);
|
||||||
|
if (start != -1 && end != -1 && end >= start) {
|
||||||
|
std::string result = build_result.substr(start + start_flag.size(), end - start - start_flag.size());
|
||||||
|
if (!result.empty()) {
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "";
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace kernel
|
} // namespace kernel
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -31,7 +31,7 @@ namespace kernel {
|
||||||
bool TbeOpParallelBuild(const std::vector<AnfNodePtr> &anf_nodes);
|
bool TbeOpParallelBuild(const std::vector<AnfNodePtr> &anf_nodes);
|
||||||
|
|
||||||
struct KernelBuildTaskInfo {
|
struct KernelBuildTaskInfo {
|
||||||
AnfNode *node;
|
AnfNodePtr node;
|
||||||
std::string processor;
|
std::string processor;
|
||||||
std::string json_name;
|
std::string json_name;
|
||||||
std::vector<size_t> input_size_list;
|
std::vector<size_t> input_size_list;
|
||||||
|
@ -53,16 +53,21 @@ class ParallelBuildManager {
|
||||||
const std::vector<size_t> &input_size_list, const std::vector<size_t> &output_size_list,
|
const std::vector<size_t> &input_size_list, const std::vector<size_t> &output_size_list,
|
||||||
AnfNode *node) const;
|
AnfNode *node) const;
|
||||||
bool IsAllTaskFinish() const;
|
bool IsAllTaskFinish() const;
|
||||||
std::pair<int32_t, KernelModPtr> TaskFinishProcess(int32_t task_id, bool set_kernel_mod = true);
|
void PreTaskFinishProcess(int32_t task_id, const std::string &pre_build_result);
|
||||||
|
std::pair<int32_t, KernelModPtr> TaskFinishProcess(int32_t task_id, const std::string &build_ret,
|
||||||
|
bool set_kernel_mod = true);
|
||||||
KernelModPtr GenKernelMod(const string &json_name, const string &processor,
|
KernelModPtr GenKernelMod(const string &json_name, const string &processor,
|
||||||
const std::vector<size_t> &input_size_list, const std::vector<size_t> &output_size_list,
|
const std::vector<size_t> &input_size_list, const std::vector<size_t> &output_size_list,
|
||||||
const KernelPackPtr &kernel_pack) const;
|
const KernelPackPtr &kernel_pack) const;
|
||||||
|
|
||||||
// Interactive with real backend, who could be implemented by Python.
|
// Interactive with real backend, who could be implemented by Python.
|
||||||
int StartCompileOp(const nlohmann::json &kernel_json);
|
static int StartCompileOp(const nlohmann::json &kernel_json);
|
||||||
bool WaitOne(int *task_id, std::string *task_result, std::string *pre_build_result);
|
static bool WaitOne(int *task_id, std::string *task_result, std::string *build_result);
|
||||||
void ResetTaskInfo();
|
void ResetTaskInfo();
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string ProcessBuildRetStr(const std::string &build_result);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::map<int32_t, AnfNodePtr> pre_task_map_;
|
std::map<int32_t, AnfNodePtr> pre_task_map_;
|
||||||
std::map<int32_t, KernelBuildTaskInfo> task_map_;
|
std::map<int32_t, KernelBuildTaskInfo> task_map_;
|
||||||
|
|
|
@ -30,6 +30,7 @@
|
||||||
#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h"
|
#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_reduce_selecter.h"
|
||||||
#include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h"
|
#include "backend/kernel_compiler/tbe/tbe_kernel_select/common_utils.h"
|
||||||
#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.h"
|
#include "backend/kernel_compiler/tbe/tbe_kernel_select/tbe_property_checker.h"
|
||||||
|
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
|
||||||
#include "backend/session/kernel_build_client.h"
|
#include "backend/session/kernel_build_client.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
|
@ -54,7 +55,8 @@ void TbeKernelSelect::TbeMetadataInfoEx() {
|
||||||
MS_EXCEPTION_IF_NULL(cnode_ptr_);
|
MS_EXCEPTION_IF_NULL(cnode_ptr_);
|
||||||
MS_EXCEPTION_IF_NULL(kernel_info_list_);
|
MS_EXCEPTION_IF_NULL(kernel_info_list_);
|
||||||
node_name_ = AnfAlgo::GetCNodeName(cnode_ptr_);
|
node_name_ = AnfAlgo::GetCNodeName(cnode_ptr_);
|
||||||
auto op_info_ptr = OpLib::FindOp(node_name_, kTBE);
|
|
||||||
|
auto op_info_ptr = tbe::TbeDynamicShapeUtil::FindOp(node_name_, cnode_ptr_);
|
||||||
if (!op_info_ptr) {
|
if (!op_info_ptr) {
|
||||||
MS_LOG(INFO) << "Warning: Cann't find tbe core opinfo, node type: " << node_name_;
|
MS_LOG(INFO) << "Warning: Cann't find tbe core opinfo, node type: " << node_name_;
|
||||||
return;
|
return;
|
||||||
|
@ -81,6 +83,7 @@ void TbeKernelSelect::TbeMetadataInfoEx() {
|
||||||
}
|
}
|
||||||
// check support
|
// check support
|
||||||
FilterInVaildKernelInfo();
|
FilterInVaildKernelInfo();
|
||||||
|
MS_LOG(INFO) << "End get kernel build info size: " << kernel_info_list_->size() << ", after tbe select.";
|
||||||
}
|
}
|
||||||
|
|
||||||
void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) {
|
void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) {
|
||||||
|
|
|
@ -23,6 +23,7 @@
|
||||||
#include "backend/kernel_compiler/kernel_query.h"
|
#include "backend/kernel_compiler/kernel_query.h"
|
||||||
#include "backend/kernel_compiler/oplib/oplib.h"
|
#include "backend/kernel_compiler/oplib/oplib.h"
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
@ -62,7 +63,7 @@ class KernelQuery {
|
||||||
if (!node->isa<CNode>()) {
|
if (!node->isa<CNode>()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
auto op_info = mindspore::kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(node), kernel::kTBE);
|
auto op_info = mindspore::kernel::tbe::TbeDynamicShapeUtil::FindOp(AnfAlgo::GetCNodeName(node), node);
|
||||||
if (op_info != nullptr) {
|
if (op_info != nullptr) {
|
||||||
return op_info->is_ref();
|
return op_info->is_ref();
|
||||||
}
|
}
|
||||||
|
@ -75,8 +76,8 @@ class OpFinder {
|
||||||
public:
|
public:
|
||||||
OpFinder() = default;
|
OpFinder() = default;
|
||||||
virtual ~OpFinder() = default;
|
virtual ~OpFinder() = default;
|
||||||
virtual int GetOpRegisteredOutputNum(const std::string &op_name) {
|
virtual int GetOpRegisteredOutputNum(const std::string &op_name, const CNodePtr &cnode) {
|
||||||
auto op_info = kernel::OpLib::FindOp(op_name, kernel::kTBE);
|
auto op_info = kernel::tbe::TbeDynamicShapeUtil::FindOp(op_name, cnode);
|
||||||
if (op_info == nullptr) {
|
if (op_info == nullptr) {
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,6 +46,9 @@ const AnfNodePtr InsertPadForNMSWithMask::Process(const FuncGraphPtr &func_graph
|
||||||
const EquivPtr &) const {
|
const EquivPtr &) const {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (AnfAlgo::IsDynamicShape(node)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
|
||||||
|
|
|
@ -36,7 +36,7 @@ session::KernelWithIndex FindRefOriginNode(const AnfNodePtr &node) {
|
||||||
auto cnode = cur_node->cast<CNodePtr>();
|
auto cnode = cur_node->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
std::string op_name = AnfAlgo::GetCNodeName(cnode);
|
std::string op_name = AnfAlgo::GetCNodeName(cnode);
|
||||||
auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE);
|
auto op_info = mindspore::kernel::tbe::TbeDynamicShapeUtil::FindOp(op_name, cnode);
|
||||||
// deal ref op
|
// deal ref op
|
||||||
if (op_info != nullptr && op_info->is_ref()) {
|
if (op_info != nullptr && op_info->is_ref()) {
|
||||||
auto ref_infos = op_info->ref_infos();
|
auto ref_infos = op_info->ref_infos();
|
||||||
|
@ -223,7 +223,7 @@ const AnfNodePtr DealRefTransAndCast::Process(const FuncGraphPtr &graph, const A
|
||||||
DealBroadCastAsRef(graph, cnode);
|
DealBroadCastAsRef(graph, cnode);
|
||||||
|
|
||||||
auto op_name = AnfAlgo::GetCNodeName(cnode);
|
auto op_name = AnfAlgo::GetCNodeName(cnode);
|
||||||
auto op_info = mindspore::kernel::OpLib::FindOp(op_name, kernel::kTBE);
|
auto op_info = mindspore::kernel::tbe::TbeDynamicShapeUtil::FindOp(op_name, cnode);
|
||||||
if (op_info == nullptr || !op_info->is_ref()) {
|
if (op_info == nullptr || !op_info->is_ref()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,6 +65,9 @@ const AnfNodePtr ConcatFission::Process(const FuncGraphPtr &func_graph, const An
|
||||||
const EquivPtr &) const {
|
const EquivPtr &) const {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (AnfAlgo::IsDynamicShape(node)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
// The real input begins with index 1.
|
// The real input begins with index 1.
|
||||||
|
|
|
@ -86,6 +86,9 @@ const AnfNodePtr LayerNormGradSplit::Process(const FuncGraphPtr &graph, const An
|
||||||
const EquivPtr &) const {
|
const EquivPtr &) const {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (AnfAlgo::IsDynamicShape(node)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
if (cnode->inputs().size() != kLayerNormGradInputNum) {
|
if (cnode->inputs().size() != kLayerNormGradInputNum) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -72,6 +72,9 @@ const BaseRef PackFission::DefinePattern() const {
|
||||||
const AnfNodePtr PackFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
|
const AnfNodePtr PackFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (AnfAlgo::IsDynamicShape(node)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
// The real input begins with index 1.
|
// The real input begins with index 1.
|
||||||
|
|
|
@ -105,6 +105,9 @@ const AnfNodePtr ReduceMinFission::Process(const FuncGraphPtr &graph, const AnfN
|
||||||
if (graph == nullptr || node == nullptr) {
|
if (graph == nullptr || node == nullptr) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
if (AnfAlgo::IsDynamicShape(node)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
CheckCNodeInputSize(cnode, 2);
|
CheckCNodeInputSize(cnode, 2);
|
||||||
|
|
|
@ -174,6 +174,9 @@ const BaseRef SplitFission::DefinePattern() const {
|
||||||
|
|
||||||
const AnfNodePtr SplitFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
|
const AnfNodePtr SplitFission::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (AnfAlgo::IsDynamicShape(node)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
// Check output num
|
// Check output num
|
||||||
|
|
|
@ -127,6 +127,9 @@ const BaseRef TopKSplit::DefinePattern() const {
|
||||||
const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
|
const AnfNodePtr TopKSplit::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const {
|
||||||
MS_EXCEPTION_IF_NULL(func_graph);
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (AnfAlgo::IsDynamicShape(node)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
|
||||||
// set value node as topk's input
|
// set value node as topk's input
|
||||||
auto cnode = node->cast<CNodePtr>();
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
|
|
@ -86,7 +86,7 @@ const AnfNodePtr AddInputToOutput::Process(const FuncGraphPtr &func_graph, const
|
||||||
if (!InputToOutputRegistry::Instance().GetRegisterByOpName(op_name, ®)) {
|
if (!InputToOutputRegistry::Instance().GetRegisterByOpName(op_name, ®)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
int output_num = op_finder_->GetOpRegisteredOutputNum(op_name);
|
int output_num = op_finder_->GetOpRegisteredOutputNum(op_name, cnode);
|
||||||
// No need add output when it is not a tbe op.
|
// No need add output when it is not a tbe op.
|
||||||
if (output_num == -1) {
|
if (output_num == -1) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -84,6 +84,9 @@ bool QuitFusion(const FuncGraphPtr &graph, const AnfNodePtr &mul0_anf, const Anf
|
||||||
MS_LOG(INFO) << "mul's second input is not addn";
|
MS_LOG(INFO) << "mul's second input is not addn";
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (AnfAlgo::IsDynamicShape(addn)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(addn, 0);
|
std::vector<size_t> shape = AnfAlgo::GetOutputInferShape(addn, 0);
|
||||||
if (shape.size() != 2 || !(shape[1] == 1024 || shape[1] == 768)) {
|
if (shape.size() != 2 || !(shape[1] == 1024 || shape[1] == 768)) {
|
||||||
MS_LOG(INFO) << "Addn's infer shape is not equal [x,1024] or [x,768]";
|
MS_LOG(INFO) << "Addn's infer shape is not equal [x,1024] or [x,768]";
|
||||||
|
|
|
@ -53,6 +53,9 @@ CNodePtr CreateReluV2(const FuncGraphPtr &graph, const CNodePtr &relu) {
|
||||||
|
|
||||||
// ReluV2's 2rd output is mask whose data type is uint8
|
// ReluV2's 2rd output is mask whose data type is uint8
|
||||||
TypeId mask_dtype = kNumberTypeUInt8;
|
TypeId mask_dtype = kNumberTypeUInt8;
|
||||||
|
if (AnfAlgo::IsDynamicShape(relu)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
std::vector<size_t> mask_shape = AnfAlgo::GetOutputInferShape(relu, 0);
|
std::vector<size_t> mask_shape = AnfAlgo::GetOutputInferShape(relu, 0);
|
||||||
if (mask_shape.size() != 4) {
|
if (mask_shape.size() != 4) {
|
||||||
MS_LOG(DEBUG) << "relu's infer shape size not equal 4";
|
MS_LOG(DEBUG) << "relu's infer shape size not equal 4";
|
||||||
|
|
|
@ -29,6 +29,9 @@ bool CheckValueNodeInputOfMul(const AnfNodePtr &node) {
|
||||||
if (!node->isa<ValueNode>()) {
|
if (!node->isa<ValueNode>()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
if (AnfAlgo::IsDynamicShape(node)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
std::vector<size_t> mul_input_shape = AnfAlgo::GetOutputInferShape(node, 0);
|
std::vector<size_t> mul_input_shape = AnfAlgo::GetOutputInferShape(node, 0);
|
||||||
return mul_input_shape.empty() || (mul_input_shape.size() == 1 && mul_input_shape[0] == 1);
|
return mul_input_shape.empty() || (mul_input_shape.size() == 1 && mul_input_shape[0] == 1);
|
||||||
}
|
}
|
||||||
|
|
|
@ -85,6 +85,9 @@ const AnfNodePtr MulAddNFusion::Process(const FuncGraphPtr &graph, const AnfNode
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (AnfAlgo::IsDynamicShape(mul->input(lossscale_input_index))) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
auto constant_shape = AnfAlgo::GetOutputInferShape(mul->input(lossscale_input_index), 0);
|
auto constant_shape = AnfAlgo::GetOutputInferShape(mul->input(lossscale_input_index), 0);
|
||||||
if (!(constant_shape.size() == 0 || (constant_shape.size() == 1 && constant_shape[0] == 1))) {
|
if (!(constant_shape.size() == 0 || (constant_shape.size() == 1 && constant_shape[0] == 1))) {
|
||||||
MS_LOG(DEBUG) << "The const input of Mul node must be scalar or shape=(1,), but shape size is "
|
MS_LOG(DEBUG) << "The const input of Mul node must be scalar or shape=(1,), but shape size is "
|
||||||
|
|
|
@ -45,6 +45,10 @@ const AnfNodePtr RemoveReshapePair::Process(const FuncGraphPtr &func_graph, cons
|
||||||
if (IsUsedByOthers(func_graph, in_reshape)) {
|
if (IsUsedByOthers(func_graph, in_reshape)) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (AnfAlgo::IsDynamicShape(out_reshape) || AnfAlgo::IsDynamicShape(in_reshape)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
auto output_shape = AnfAlgo::GetOutputDeviceShape(out_reshape, 0);
|
auto output_shape = AnfAlgo::GetOutputDeviceShape(out_reshape, 0);
|
||||||
auto input_shape = AnfAlgo::GetInputDeviceShape(in_reshape, 0);
|
auto input_shape = AnfAlgo::GetInputDeviceShape(in_reshape, 0);
|
||||||
if (kernel::IsSameShape(input_shape, output_shape)) {
|
if (kernel::IsSameShape(input_shape, output_shape)) {
|
||||||
|
|
|
@ -50,6 +50,9 @@ const AnfNodePtr ReshapeTransposeFusion::Process(const FuncGraphPtr &func_graph,
|
||||||
MS_EXCEPTION_IF_NULL(transpose_cnode);
|
MS_EXCEPTION_IF_NULL(transpose_cnode);
|
||||||
auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(transpose_cnode->input(1), kBackendReshapeInputNum);
|
auto reshape_cnode = CheckAnfNodeIfCNodeAndInputSize(transpose_cnode->input(1), kBackendReshapeInputNum);
|
||||||
MS_EXCEPTION_IF_NULL(reshape_cnode);
|
MS_EXCEPTION_IF_NULL(reshape_cnode);
|
||||||
|
if (AnfAlgo::IsDynamicShape(transpose_cnode) || AnfAlgo::IsDynamicShape(reshape_cnode)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
std::vector<size_t> reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0);
|
std::vector<size_t> reshape_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(reshape_cnode, 0);
|
||||||
std::vector<size_t> transpose_output0_shape = AnfAlgo::GetOutputInferShape(transpose_cnode, 0);
|
std::vector<size_t> transpose_output0_shape = AnfAlgo::GetOutputInferShape(transpose_cnode, 0);
|
||||||
if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_output0_shape)) {
|
if (!CheckShapeDimInfo(reshape_input0_shape) || !CheckShapeDimInfo(transpose_output0_shape)) {
|
||||||
|
|
|
@ -50,6 +50,9 @@ const AnfNodePtr TransposeReshapeFusion::Process(const FuncGraphPtr &func_graph,
|
||||||
MS_EXCEPTION_IF_NULL(reshape_cnode);
|
MS_EXCEPTION_IF_NULL(reshape_cnode);
|
||||||
auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputNum);
|
auto transpose_cnode = CheckAnfNodeIfCNodeAndInputSize(reshape_cnode->input(1), kBackendReshapeInputNum);
|
||||||
MS_EXCEPTION_IF_NULL(transpose_cnode);
|
MS_EXCEPTION_IF_NULL(transpose_cnode);
|
||||||
|
if (AnfAlgo::IsDynamicShape(transpose_cnode) || AnfAlgo::IsDynamicShape(reshape_cnode)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
std::vector<size_t> reshape_output0_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0);
|
std::vector<size_t> reshape_output0_shape = AnfAlgo::GetOutputInferShape(reshape_cnode, 0);
|
||||||
std::vector<size_t> transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0);
|
std::vector<size_t> transpose_input0_shape = AnfAlgo::GetPrevNodeOutputInferShape(transpose_cnode, 0);
|
||||||
if (!CheckShapeDimInfo(reshape_output0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) {
|
if (!CheckShapeDimInfo(reshape_output0_shape) || !CheckShapeDimInfo(transpose_input0_shape)) {
|
||||||
|
|
|
@ -26,6 +26,8 @@
|
||||||
#include "base/base_ref.h"
|
#include "base/base_ref.h"
|
||||||
#include "backend/session/anf_runtime_algorithm.h"
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
#include "base/core_ops.h"
|
#include "base/core_ops.h"
|
||||||
|
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
|
||||||
|
#include "frontend/operator/ops.h"
|
||||||
#include "utils/ms_utils.h"
|
#include "utils/ms_utils.h"
|
||||||
#include "runtime/device/kernel_info.h"
|
#include "runtime/device/kernel_info.h"
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
|
@ -394,6 +396,7 @@ bool IsNopNode(const AnfNodePtr &node) {
|
||||||
context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
|
context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET) != kGPUDevice) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::unordered_set<std::string> nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName,
|
static std::unordered_set<std::string> nop_nodes = {prim::kPrimReshape->name(), kExpandDimsOpName,
|
||||||
prim::kPrimSqueeze->name(), prim::kPrimFlatten->name(),
|
prim::kPrimSqueeze->name(), prim::kPrimFlatten->name(),
|
||||||
kFlattenGradOpName};
|
kFlattenGradOpName};
|
||||||
|
|
|
@ -55,6 +55,10 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (AnfAlgo::IsDynamicShape(cnode)) {
|
||||||
|
MS_LOG(INFO) << "current node is dynamic shape " << cnode->fullname_with_scope();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
ConstInputToAttr(cnode, reg.GetConstInputAttrInfo());
|
ConstInputToAttr(cnode, reg.GetConstInputAttrInfo());
|
||||||
}
|
}
|
||||||
return node;
|
return node;
|
||||||
|
|
|
@ -28,6 +28,7 @@
|
||||||
#include "backend/kernel_compiler/kernel.h"
|
#include "backend/kernel_compiler/kernel.h"
|
||||||
#include "backend/kernel_compiler/kernel_build_info.h"
|
#include "backend/kernel_compiler/kernel_build_info.h"
|
||||||
#include "common/trans.h"
|
#include "common/trans.h"
|
||||||
|
#include "abstract/param_validator.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace session {
|
namespace session {
|
||||||
|
@ -42,12 +43,27 @@ namespace {
|
||||||
constexpr size_t kNopNodeInputSize = 2;
|
constexpr size_t kNopNodeInputSize = 2;
|
||||||
constexpr size_t kNopNodeRealInputIndex = 1;
|
constexpr size_t kNopNodeRealInputIndex = 1;
|
||||||
|
|
||||||
|
bool IsShapeDynamic(const abstract::ShapePtr &shape) {
|
||||||
|
MS_EXCEPTION_IF_NULL(shape);
|
||||||
|
return std::any_of(shape->shape().begin(), shape->shape().end(), [](int s) { return s < 0; });
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
|
std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
|
||||||
MS_EXCEPTION_IF_NULL(shape);
|
MS_EXCEPTION_IF_NULL(shape);
|
||||||
std::vector<size_t> shape_size_t;
|
std::vector<size_t> shape_size_t;
|
||||||
|
if (IsShapeDynamic(shape)) {
|
||||||
|
if (std::all_of(shape->max_shape().begin(), shape->max_shape().end(), [](int s) { return s >= 0; })) {
|
||||||
|
std::transform(shape->max_shape().begin(), shape->max_shape().end(), std::back_inserter(shape_size_t), IntToSize);
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid Max Shape";
|
||||||
|
}
|
||||||
|
} else {
|
||||||
std::transform(shape->shape().begin(), shape->shape().end(), std::back_inserter(shape_size_t), IntToSize);
|
std::transform(shape->shape().begin(), shape->shape().end(), std::back_inserter(shape_size_t), IntToSize);
|
||||||
|
}
|
||||||
return shape_size_t;
|
return shape_size_t;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum ShapeType { kMaxShape, kMinShape };
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
AnfNodePtr AnfRuntimeAlgorithm::GetTupleGetItemRealInput(const CNodePtr &tuple_get_item) {
|
AnfNodePtr AnfRuntimeAlgorithm::GetTupleGetItemRealInput(const CNodePtr &tuple_get_item) {
|
||||||
|
@ -1206,19 +1222,6 @@ TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputPrecision(const AnfNodePtr &node, s
|
||||||
return GetCNodeOutputPrecision(kernel_with_index.first);
|
return GetCNodeOutputPrecision(kernel_with_index.first);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool AnfRuntimeAlgorithm::IsDynamicShape(const AnfNodePtr &node) {
|
|
||||||
if (!node->isa<CNode>()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
auto cnode = node->cast<CNodePtr>();
|
|
||||||
MS_EXCEPTION_IF_NULL(cnode);
|
|
||||||
auto has_attr = AnfAlgo::HasNodeAttr(kAttrIsDynamicShape, cnode);
|
|
||||||
if (!has_attr) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return AnfAlgo::GetNodeAttr<bool>(node, kAttrIsDynamicShape);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool AnfRuntimeAlgorithm::IsCondControlKernel(const CNodePtr &node) {
|
bool AnfRuntimeAlgorithm::IsCondControlKernel(const CNodePtr &node) {
|
||||||
MS_EXCEPTION_IF_NULL(node);
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
if (node->inputs().empty()) {
|
if (node->inputs().empty()) {
|
||||||
|
@ -1252,5 +1255,96 @@ bool AnfRuntimeAlgorithm::IsIndependentNode(const CNodePtr &node) {
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool AnfRuntimeAlgorithm::GetBooleanAttr(const AnfNodePtr &node, const std::string &attr) {
|
||||||
|
MS_EXCEPTION_IF_NULL(node);
|
||||||
|
if (!node->isa<CNode>()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto cnode = node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
auto has_attr = AnfAlgo::HasNodeAttr(attr, cnode);
|
||||||
|
if (!has_attr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return AnfAlgo::GetNodeAttr<bool>(node, attr);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AnfRuntimeAlgorithm::IsDynamicShape(const AnfNodePtr &node) {
|
||||||
|
return GetBooleanAttr(node, kAttrInputIsDynamicShape) || GetBooleanAttr(node, kAttrOutputIsDynamicShape);
|
||||||
|
}
|
||||||
|
|
||||||
|
void AnfRuntimeAlgorithm::GetRealDynamicShape(const std::vector<size_t> &shape,
|
||||||
|
NotNull<std::vector<int64_t> *> dynamic_shape) {
|
||||||
|
for (auto size : shape) {
|
||||||
|
if (size == SIZE_MAX) {
|
||||||
|
dynamic_shape->push_back(-1);
|
||||||
|
} else {
|
||||||
|
dynamic_shape->push_back(SizeToLong(size));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> GetShapeFromSequeueShape(const abstract::SequeueShapePtr &sequeue_shape_ptr, size_t index,
|
||||||
|
ShapeType type) {
|
||||||
|
MS_EXCEPTION_IF_NULL(sequeue_shape_ptr);
|
||||||
|
auto shape_list = sequeue_shape_ptr->shape();
|
||||||
|
if (index >= shape_list.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Output Index:" << index << " >= " << shape_list.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto shape = shape_list[index];
|
||||||
|
MS_EXCEPTION_IF_NULL(shape);
|
||||||
|
if (shape->isa<abstract::Shape>()) {
|
||||||
|
auto shape_ptr = shape->cast<abstract::ShapePtr>();
|
||||||
|
if (type == kMaxShape) {
|
||||||
|
return shape_ptr->max_shape().empty() ? shape_ptr->shape() : shape_ptr->max_shape();
|
||||||
|
} else {
|
||||||
|
return shape_ptr->min_shape().empty() ? shape_ptr->shape() : shape_ptr->min_shape();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid Shape Type In Shape List";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> AnfRuntimeAlgorithm::GetInputMaxShape(const AnfNodePtr &anf_node, size_t index) {
|
||||||
|
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, index);
|
||||||
|
return GetOutputMaxShape(input_node_with_index.first, input_node_with_index.second);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> AnfRuntimeAlgorithm::GetInputMinShape(const AnfNodePtr &anf_node, size_t index) {
|
||||||
|
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, index);
|
||||||
|
return GetOutputMinShape(input_node_with_index.first, input_node_with_index.second);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> AnfRuntimeAlgorithm::GetOutputMaxShape(const AnfNodePtr &anf_node, size_t index) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
auto shape = anf_node->Shape();
|
||||||
|
MS_EXCEPTION_IF_NULL(shape);
|
||||||
|
if (shape->isa<abstract::Shape>()) {
|
||||||
|
auto shape_ptr = shape->cast<abstract::ShapePtr>();
|
||||||
|
return shape_ptr->max_shape().empty() ? shape_ptr->shape() : shape_ptr->max_shape();
|
||||||
|
} else if (shape->isa<abstract::SequeueShape>()) {
|
||||||
|
auto shape_ptr = shape->cast<abstract::SequeueShapePtr>();
|
||||||
|
return GetShapeFromSequeueShape(shape_ptr, index, kMaxShape);
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid Shape Type";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int> AnfRuntimeAlgorithm::GetOutputMinShape(const AnfNodePtr &anf_node, size_t index) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
auto shape = anf_node->Shape();
|
||||||
|
MS_EXCEPTION_IF_NULL(shape);
|
||||||
|
if (shape->isa<abstract::Shape>()) {
|
||||||
|
auto shape_ptr = shape->cast<abstract::ShapePtr>();
|
||||||
|
return shape_ptr->min_shape().empty() ? shape_ptr->shape() : shape_ptr->min_shape();
|
||||||
|
} else if (shape->isa<abstract::SequeueShape>()) {
|
||||||
|
auto shape_ptr = shape->cast<abstract::SequeueShapePtr>();
|
||||||
|
return GetShapeFromSequeueShape(shape_ptr, index, kMinShape);
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid Shape Type";
|
||||||
|
}
|
||||||
|
}
|
||||||
} // namespace session
|
} // namespace session
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -221,6 +221,12 @@ class AnfRuntimeAlgorithm {
|
||||||
static bool IsDynamicShape(const AnfNodePtr &node);
|
static bool IsDynamicShape(const AnfNodePtr &node);
|
||||||
static bool IsCondControlKernel(const CNodePtr &node);
|
static bool IsCondControlKernel(const CNodePtr &node);
|
||||||
static bool IsIndependentNode(const CNodePtr &node);
|
static bool IsIndependentNode(const CNodePtr &node);
|
||||||
|
static bool GetBooleanAttr(const AnfNodePtr &node, const std::string &attr);
|
||||||
|
static void GetRealDynamicShape(const std::vector<size_t> &shape, NotNull<std::vector<int64_t> *> dynamic_shape);
|
||||||
|
static std::vector<int> GetInputMaxShape(const AnfNodePtr &anf_node, size_t index);
|
||||||
|
static std::vector<int> GetInputMinShape(const AnfNodePtr &anf_node, size_t index);
|
||||||
|
static std::vector<int> GetOutputMaxShape(const AnfNodePtr &anf_node, size_t index);
|
||||||
|
static std::vector<int> GetOutputMinShape(const AnfNodePtr &anf_node, size_t index);
|
||||||
};
|
};
|
||||||
} // namespace session
|
} // namespace session
|
||||||
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
||||||
|
|
|
@ -127,6 +127,9 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
|
||||||
MS_LOG(INFO) << "Start";
|
MS_LOG(INFO) << "Start";
|
||||||
std::vector<KernelGraphPtr> all_graphs;
|
std::vector<KernelGraphPtr> all_graphs;
|
||||||
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs);
|
auto root_graph = ConstructKernelGraph(func_graph, &all_graphs);
|
||||||
|
// Update Graph Dynamic Shape Attr
|
||||||
|
UpdateGraphDynamicShapeAttr(NOT_NULL(root_graph));
|
||||||
|
root_graph->UpdateGraphDynamicAttr();
|
||||||
BackendOptimization(all_graphs);
|
BackendOptimization(all_graphs);
|
||||||
// empty graph dont entry to backend
|
// empty graph dont entry to backend
|
||||||
if (root_graph->execution_order().empty()) {
|
if (root_graph->execution_order().empty()) {
|
||||||
|
@ -136,6 +139,7 @@ GraphId AscendSession::CompileGraphImpl(NotNull<FuncGraphPtr> func_graph) {
|
||||||
InitRuntimeResource();
|
InitRuntimeResource();
|
||||||
return root_graph->graph_id();
|
return root_graph->graph_id();
|
||||||
}
|
}
|
||||||
|
|
||||||
// create parameter for multiple branch
|
// create parameter for multiple branch
|
||||||
std::set<KernelGraphPtr> memo;
|
std::set<KernelGraphPtr> memo;
|
||||||
CreateMultiBranchOutput(NOT_NULL(root_graph), NOT_NULL(&memo));
|
CreateMultiBranchOutput(NOT_NULL(root_graph), NOT_NULL(&memo));
|
||||||
|
|
|
@ -1201,6 +1201,17 @@ void KernelGraph::RemoveNodeFromGraph(const AnfNodePtr &node) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void KernelGraph::UpdateGraphDynamicAttr() {
|
||||||
|
for (const auto &cnode : execution_order_) {
|
||||||
|
if (AnfAlgo::IsDynamicShape(cnode)) {
|
||||||
|
MS_LOG(INFO) << "Update Graph Dynamic Attr";
|
||||||
|
is_dynamic_shape_ = true;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
is_dynamic_shape_ = false;
|
||||||
|
}
|
||||||
|
|
||||||
std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); }
|
std::string KernelGraph::ToString() const { return std::string("kernel_graph_").append(std::to_string(graph_id_)); }
|
||||||
|
|
||||||
KernelGraph::~KernelGraph() {
|
KernelGraph::~KernelGraph() {
|
||||||
|
|
|
@ -37,7 +37,13 @@ namespace session {
|
||||||
using AnfWithOutIndex = std::pair<AnfNodePtr, size_t>;
|
using AnfWithOutIndex = std::pair<AnfNodePtr, size_t>;
|
||||||
class KernelGraph : public FuncGraph {
|
class KernelGraph : public FuncGraph {
|
||||||
public:
|
public:
|
||||||
KernelGraph() : graph_id_(0), start_label_(nullptr), end_goto_(nullptr), null_output_(false), current_epoch_(0) {
|
KernelGraph()
|
||||||
|
: graph_id_(0),
|
||||||
|
start_label_(nullptr),
|
||||||
|
end_goto_(nullptr),
|
||||||
|
null_output_(false),
|
||||||
|
current_epoch_(0),
|
||||||
|
is_dynamic_shape_(false) {
|
||||||
inputs_ = std::make_shared<std::vector<AnfNodePtr>>();
|
inputs_ = std::make_shared<std::vector<AnfNodePtr>>();
|
||||||
execution_order_ = {};
|
execution_order_ = {};
|
||||||
executable_ = true;
|
executable_ = true;
|
||||||
|
@ -161,6 +167,7 @@ class KernelGraph : public FuncGraph {
|
||||||
void set_child_graph_result(const std::vector<AnfNodePtr> &child_graph_result) {
|
void set_child_graph_result(const std::vector<AnfNodePtr> &child_graph_result) {
|
||||||
child_graph_result_ = child_graph_result;
|
child_graph_result_ = child_graph_result;
|
||||||
}
|
}
|
||||||
|
|
||||||
void InsertTupleParameterToMakeTupleMap(const AnfNodePtr ¶m, const AnfNodePtr &make_tuple) {
|
void InsertTupleParameterToMakeTupleMap(const AnfNodePtr ¶m, const AnfNodePtr &make_tuple) {
|
||||||
if (tuple_parameter_to_make_tuple_map_.find(param) != tuple_parameter_to_make_tuple_map_.end()) {
|
if (tuple_parameter_to_make_tuple_map_.find(param) != tuple_parameter_to_make_tuple_map_.end()) {
|
||||||
return;
|
return;
|
||||||
|
@ -176,6 +183,9 @@ class KernelGraph : public FuncGraph {
|
||||||
}
|
}
|
||||||
void RemoveNodeFromGraph(const AnfNodePtr &node);
|
void RemoveNodeFromGraph(const AnfNodePtr &node);
|
||||||
|
|
||||||
|
void UpdateGraphDynamicAttr();
|
||||||
|
bool is_dynamic_shape() const { return is_dynamic_shape_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// remove value node form graph
|
// remove value node form graph
|
||||||
bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
|
bool RemoveValueNodeFromGraph(const ValueNodePtr &value_node);
|
||||||
|
@ -247,10 +257,10 @@ class KernelGraph : public FuncGraph {
|
||||||
std::unordered_map<AnfNodePtr, std::unordered_map<int, tensor::TensorPtr>> internal_outputs_tensor_map_;
|
std::unordered_map<AnfNodePtr, std::unordered_map<int, tensor::TensorPtr>> internal_outputs_tensor_map_;
|
||||||
uint32_t current_epoch_;
|
uint32_t current_epoch_;
|
||||||
std::unordered_map<AnfNodePtr, AnfNodePtr> tuple_parameter_to_make_tuple_map_;
|
std::unordered_map<AnfNodePtr, AnfNodePtr> tuple_parameter_to_make_tuple_map_;
|
||||||
|
|
||||||
std::set<AnfNodePtr> visited_nodes_;
|
std::set<AnfNodePtr> visited_nodes_;
|
||||||
std::map<AnfNodePtr, AnfNodePtr> edge_to_;
|
std::map<AnfNodePtr, AnfNodePtr> edge_to_;
|
||||||
std::stack<AnfNodePtr> loop_nodes_;
|
std::stack<AnfNodePtr> loop_nodes_;
|
||||||
|
bool is_dynamic_shape_;
|
||||||
};
|
};
|
||||||
} // namespace session
|
} // namespace session
|
||||||
using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;
|
using KernelGraphPtr = std::shared_ptr<session::KernelGraph>;
|
||||||
|
|
|
@ -35,6 +35,7 @@
|
||||||
#include "ir/dtype.h"
|
#include "ir/dtype.h"
|
||||||
#include "ir/anf.h"
|
#include "ir/anf.h"
|
||||||
#include "ir/func_graph_cloner.h"
|
#include "ir/func_graph_cloner.h"
|
||||||
|
#include "utils/utils.h"
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||||
#include "ps/worker.h"
|
#include "ps/worker.h"
|
||||||
#include "ps/common.h"
|
#include "ps/common.h"
|
||||||
|
@ -1405,6 +1406,97 @@ void SessionBasic::RunGraphAsync(const GraphId &graph_id, const std::vector<tens
|
||||||
executor_->RunGraphAsync(shared_from_this(), graph_id, inputs, outputs);
|
executor_->RunGraphAsync(shared_from_this(), graph_id, inputs, outputs);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool IsDynamicShape(const NotNull<abstract::ShapePtr> &shape) {
|
||||||
|
return !std::all_of(shape->shape().begin(), shape->shape().end(), [](int s) { return s > 0; });
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsNodeOutputDynamicShape(const CNodePtr &anf_node_ptr) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node_ptr);
|
||||||
|
auto base_shape = anf_node_ptr->Shape();
|
||||||
|
if (base_shape == nullptr) {
|
||||||
|
MS_LOG(INFO) << "Invalid bash shape ptr, node:" << anf_node_ptr->fullname_with_scope();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (base_shape->isa<abstract::Shape>()) {
|
||||||
|
if (IsDynamicShape(NOT_NULL(base_shape->cast<abstract::ShapePtr>()))) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} else if (base_shape->isa<abstract::TupleShape>()) {
|
||||||
|
auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(tuple_shape);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < tuple_shape->size(); ++i) {
|
||||||
|
auto b_shp = (*tuple_shape)[i];
|
||||||
|
if (!b_shp->isa<abstract::Shape>()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (IsDynamicShape(NOT_NULL(b_shp->cast<abstract::ShapePtr>()))) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsNodeInputDynamicShape(const CNodePtr &anf_node_ptr) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node_ptr);
|
||||||
|
auto input_num = AnfAlgo::GetInputTensorNum(anf_node_ptr);
|
||||||
|
for (size_t i = 0; i < input_num; ++i) {
|
||||||
|
auto input_with_index = AnfAlgo::GetPrevNodeOutput(anf_node_ptr, i);
|
||||||
|
auto input = input_with_index.first;
|
||||||
|
auto index = input_with_index.second;
|
||||||
|
MS_EXCEPTION_IF_NULL(input);
|
||||||
|
|
||||||
|
auto base_shape = input->Shape();
|
||||||
|
if (base_shape == nullptr) {
|
||||||
|
MS_LOG(INFO) << "Invalid shape ptr, node:" << input->fullname_with_scope();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (base_shape->isa<abstract::Shape>()) {
|
||||||
|
if (IsDynamicShape(NOT_NULL(base_shape->cast<abstract::ShapePtr>()))) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} else if (base_shape->isa<abstract::TupleShape>()) {
|
||||||
|
auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(tuple_shape);
|
||||||
|
|
||||||
|
if (index >= tuple_shape->size()) {
|
||||||
|
MS_LOG(INFO) << "Node:" << anf_node_ptr->fullname_with_scope() << "Invalid index:" << index
|
||||||
|
<< " and tuple_shape size:" << tuple_shape->size();
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto b_shp = (*tuple_shape)[index];
|
||||||
|
if (!b_shp->isa<abstract::Shape>()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if (IsDynamicShape(NOT_NULL(b_shp->cast<abstract::ShapePtr>()))) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
void SessionBasic::UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &root_graph) {
|
||||||
|
for (const auto &cnode : root_graph->execution_order()) {
|
||||||
|
auto output_dynamic = IsNodeOutputDynamicShape(NOT_NULL(cnode));
|
||||||
|
auto input_dynamic = IsNodeInputDynamicShape(NOT_NULL(cnode));
|
||||||
|
if (output_dynamic || input_dynamic) {
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrIsDynamicShape, MakeValue(true), cnode);
|
||||||
|
MS_LOG(INFO) << "Set Dynamic Shape Attr to Node:" << cnode->fullname_with_scope();
|
||||||
|
}
|
||||||
|
if (output_dynamic) {
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrOutputIsDynamicShape, MakeValue(true), cnode);
|
||||||
|
MS_LOG(INFO) << "Set Output Dynamic Shape Attr to Node:" << cnode->fullname_with_scope();
|
||||||
|
}
|
||||||
|
if (input_dynamic) {
|
||||||
|
AnfAlgo::SetNodeAttr(kAttrInputIsDynamicShape, MakeValue(true), cnode);
|
||||||
|
MS_LOG(INFO) << "Set Input Dynamic Shape Attr to Node:" << cnode->fullname_with_scope();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
#if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU))
|
||||||
void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) {
|
void SessionBasic::AssignParamKey(const KernelGraphPtr &kernel_graph) {
|
||||||
if (!ps::Util::IsRoleOfWorker()) {
|
if (!ps::Util::IsRoleOfWorker()) {
|
||||||
|
|
|
@ -172,6 +172,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
||||||
void AddParameterToGraphInputs(const std::vector<AnfNodePtr> ¶meters, KernelGraph *graph);
|
void AddParameterToGraphInputs(const std::vector<AnfNodePtr> ¶meters, KernelGraph *graph);
|
||||||
void InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter);
|
void InitInternalOutputParameter(const AnfNodePtr &out_node, const AnfNodePtr ¶meter);
|
||||||
AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list);
|
AnfNodePtr FindPullNode(const AnfNodePtr &push_node, const std::vector<AnfNodePtr> &node_list);
|
||||||
|
void UpdateGraphDynamicShapeAttr(const NotNull<KernelGraphPtr> &root_graph);
|
||||||
|
|
||||||
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
|
std::unordered_map<GraphId, std::shared_ptr<KernelGraph>> graphs_;
|
||||||
std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;
|
std::unordered_map<GraphInfo, std::shared_ptr<KernelGraph>> run_op_graphs_;
|
||||||
|
|
|
@ -713,5 +713,16 @@ EvalResultPtr EvalOnePrim(const PrimitivePtr &primitive, const AbstractBasePtrLi
|
||||||
auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs);
|
auto eval_result = trivial_evaluator->EvalPrim(nullptr, arg_specs);
|
||||||
return eval_result;
|
return eval_result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list) {
|
||||||
|
MS_EXCEPTION_IF_NULL(prim);
|
||||||
|
auto &prim_eval_implement_map = GetPrimitiveToEvalImplMap();
|
||||||
|
auto ret = prim_eval_implement_map.find(prim);
|
||||||
|
if (ret == prim_eval_implement_map.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Get infer shape function failed, primitive name:" << prim->name()
|
||||||
|
<< " primitive type:" << prim->type_name();
|
||||||
|
}
|
||||||
|
return ret->second.impl_(nullptr, prim, args_spec_list);
|
||||||
|
}
|
||||||
} // namespace abstract
|
} // namespace abstract
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -302,6 +302,8 @@ AbstractBasePtr FromValue(const T &value, bool broaden = false) {
|
||||||
}
|
}
|
||||||
|
|
||||||
EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs);
|
EvalResultPtr EvalOnePrim(const PrimitivePtr &p, const AbstractBasePtrList &arg_specs);
|
||||||
|
|
||||||
|
AbstractBasePtr CppInferShape(const PrimitivePtr &prim, const AbstractBasePtrList &args_spec_list);
|
||||||
} // namespace abstract
|
} // namespace abstract
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*.cc"
|
file(GLOB_RECURSE DEVICE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} "common/*.cc"
|
||||||
"kernel_info.cc" "kernel_runtime.cc" "memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc"
|
"kernel_info.cc" "executor/dynamic_kernel.cc" "kernel_runtime.cc" "memory_manager.cc" "kernel_runtime_manager.cc" "convert_tensor_utils.cc"
|
||||||
)
|
)
|
||||||
|
|
||||||
if (ENABLE_GPU)
|
if (ENABLE_GPU)
|
||||||
|
|
|
@ -372,7 +372,7 @@ kernel::KernelModPtr AscendDeviceAddress::CompileTransDataAndObtainKernelMod(con
|
||||||
// get size
|
// get size
|
||||||
std::vector<size_t> input_size_list;
|
std::vector<size_t> input_size_list;
|
||||||
std::vector<size_t> output_size_list;
|
std::vector<size_t> output_size_list;
|
||||||
(void)kernel::TbeKernelBuild::GetIOSize(kernel_json, &input_size_list, &output_size_list);
|
(void)kernel::TbeKernelBuild::GetIOSize(kernel_json, &input_size_list, &output_size_list, nullptr);
|
||||||
std::string json_name = kernel_json[op_info_str][kernel_name_str];
|
std::string json_name = kernel_json[op_info_str][kernel_name_str];
|
||||||
// op build
|
// op build
|
||||||
if (constructed_kernel.find(json_name) == constructed_kernel.end()) {
|
if (constructed_kernel.find(json_name) == constructed_kernel.end()) {
|
||||||
|
@ -382,15 +382,15 @@ kernel::KernelModPtr AscendDeviceAddress::CompileTransDataAndObtainKernelMod(con
|
||||||
while (!build_manager->IsAllTaskFinish()) {
|
while (!build_manager->IsAllTaskFinish()) {
|
||||||
int task_id = -1;
|
int task_id = -1;
|
||||||
std::string task_result;
|
std::string task_result;
|
||||||
std::string pre_build_result;
|
std::string build_result;
|
||||||
auto ret = build_manager->WaitOne(&task_id, &task_result, &pre_build_result);
|
auto ret = build_manager->WaitOne(&task_id, &task_result, &build_result);
|
||||||
if (!ret) {
|
if (!ret) {
|
||||||
MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id;
|
MS_EXCEPTION(ArgumentError) << "Build Failed. wait one ret:" << ret << ", task id:" << task_id;
|
||||||
}
|
}
|
||||||
if (task_result != "Success") {
|
if (task_result != "Success") {
|
||||||
MS_EXCEPTION(ArgumentError) << "task compile Failed, task id:" << task_id << ", cause:" << task_result;
|
MS_EXCEPTION(ArgumentError) << "task compile Failed, task id:" << task_id << ", cause:" << task_result;
|
||||||
}
|
}
|
||||||
(void)build_manager->TaskFinishProcess(task_id, false);
|
(void)build_manager->TaskFinishProcess(task_id, build_result, false);
|
||||||
}
|
}
|
||||||
constructed_kernel.insert(json_name);
|
constructed_kernel.insert(json_name);
|
||||||
// search cache
|
// search cache
|
||||||
|
|
|
@ -46,12 +46,22 @@
|
||||||
#ifdef MEM_REUSE_DEBUG
|
#ifdef MEM_REUSE_DEBUG
|
||||||
#include "backend/optimizer/mem_reuse/mem_reuse_checker.h"
|
#include "backend/optimizer/mem_reuse/mem_reuse_checker.h"
|
||||||
#endif
|
#endif
|
||||||
|
#include "runtime/device/ascend/executor/tiling/op_tiling_calculater.h"
|
||||||
|
#include "runtime/device/ascend/executor/executor_callback.h"
|
||||||
|
#include "runtime/device/ascend/executor/hccl_dynamic_kernel.h"
|
||||||
|
#include "profiler/device/ascend/ascend_profiling.h"
|
||||||
|
#include "profiler/device/ascend/profiling_context.h"
|
||||||
|
#include "profiler/device/ascend/rt_callback_manager.h"
|
||||||
|
|
||||||
using ge::model_runner::ModelRunner;
|
using ge::model_runner::ModelRunner;
|
||||||
using mindspore::device::ascend::ProfilingManager;
|
using mindspore::device::ascend::ProfilingManager;
|
||||||
using mindspore::device::ascend::ProfilingUtils;
|
using mindspore::device::ascend::ProfilingUtils;
|
||||||
using mindspore::device::ascend::tasksink::TaskGenerator;
|
using mindspore::device::ascend::tasksink::TaskGenerator;
|
||||||
using mindspore::kernel::tbe::TbeUtils;
|
using mindspore::kernel::tbe::TbeUtils;
|
||||||
|
using mindspore::profiler::ascend::AscendProfiler;
|
||||||
|
using mindspore::profiler::ascend::CallbackManager;
|
||||||
|
using mindspore::profiler::ascend::GetTid;
|
||||||
|
using mindspore::profiler::ascend::kCallback;
|
||||||
using std::vector;
|
using std::vector;
|
||||||
|
|
||||||
constexpr uint32_t kTupleTaskId = 0;
|
constexpr uint32_t kTupleTaskId = 0;
|
||||||
|
@ -135,6 +145,8 @@ void AscendKernelRuntime::ClearGraphModelMap() {
|
||||||
// tell users which dump kernel name not used
|
// tell users which dump kernel name not used
|
||||||
DumpJsonParser::GetInstance().PrintUnusedKernel();
|
DumpJsonParser::GetInstance().PrintUnusedKernel();
|
||||||
|
|
||||||
|
graph_dynamic_kernel_map_.clear();
|
||||||
|
|
||||||
for (auto &iter : graph_model_map_) {
|
for (auto &iter : graph_model_map_) {
|
||||||
MS_LOG(INFO) << "Ge UnloadModel " << iter.first;
|
MS_LOG(INFO) << "Ge UnloadModel " << iter.first;
|
||||||
auto ret = ModelRunner::Instance().UnloadModel(iter.first);
|
auto ret = ModelRunner::Instance().UnloadModel(iter.first);
|
||||||
|
@ -160,6 +172,13 @@ void AscendKernelRuntime::ClearGraphRuntimeResource(uint32_t graph_id, const std
|
||||||
MS_LOG(DEBUG) << "GraphId:" << graph_id << " not found";
|
MS_LOG(DEBUG) << "GraphId:" << graph_id << " not found";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
MS_LOG(DEBUG) << "Clear graph:" << graph_id << " dynamic kernels";
|
||||||
|
if (auto dynamic_kernel_iter = graph_dynamic_kernel_map_.find(graph_id);
|
||||||
|
dynamic_kernel_iter != graph_dynamic_kernel_map_.end()) {
|
||||||
|
MS_LOG(DEBUG) << "Start Clear graph:" << graph_id << " dynamic kernel";
|
||||||
|
graph_dynamic_kernel_map_.erase(dynamic_kernel_iter);
|
||||||
|
}
|
||||||
|
|
||||||
MS_LOG(DEBUG) << "Clear graph:" << graph_id << " runtime resource";
|
MS_LOG(DEBUG) << "Clear graph:" << graph_id << " runtime resource";
|
||||||
if (auto model_iter = graph_model_map_.find(graph_id); model_iter != graph_model_map_.end()) {
|
if (auto model_iter = graph_model_map_.find(graph_id); model_iter != graph_model_map_.end()) {
|
||||||
MS_LOG(DEBUG) << "Ge UnloadModel " << graph_id;
|
MS_LOG(DEBUG) << "Ge UnloadModel " << graph_id;
|
||||||
|
@ -233,6 +252,7 @@ bool AscendKernelRuntime::Init() {
|
||||||
InnerSetContext();
|
InnerSetContext();
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
OpTilingCalculater::GetInstance().Init();
|
||||||
// Start up profiling before rtSetDevice
|
// Start up profiling before rtSetDevice
|
||||||
bool ret = ProfilingManager::GetInstance().StartupProfiling(device_id_);
|
bool ret = ProfilingManager::GetInstance().StartupProfiling(device_id_);
|
||||||
if (!ret) {
|
if (!ret) {
|
||||||
|
@ -342,6 +362,11 @@ bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) {
|
||||||
if (!is_task_sink) {
|
if (!is_task_sink) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
// Do HcomExecutorInitialize
|
||||||
|
if (graph->is_dynamic_shape() && !HcclExecutorManager::GetInstance().Initialize()) {
|
||||||
|
MS_LOG(ERROR) << "Init Hccl Executor Failed";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
if (!GenTask(graph)) {
|
if (!GenTask(graph)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -351,8 +376,35 @@ bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool AscendKernelRuntime::GenDynamicKernel(const session::KernelGraph *graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
MS_LOG(INFO) << "GenDynamicKernel start";
|
||||||
|
auto cnode_list = graph->execution_order();
|
||||||
|
std::vector<DynamicKernelPtr> dynamic_kernels;
|
||||||
|
for (const auto &cnode : cnode_list) {
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
MS_LOG(INFO) << "Generate node:" << cnode->fullname_with_scope() << " dynamic kernel";
|
||||||
|
auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
|
||||||
|
auto dynamic_kernel = kernel_mod->GenDynamicKernel(cnode, stream_);
|
||||||
|
MS_EXCEPTION_IF_NULL(dynamic_kernel);
|
||||||
|
dynamic_kernel->Initialize();
|
||||||
|
dynamic_kernels.emplace_back(dynamic_kernel);
|
||||||
|
}
|
||||||
|
auto ret = graph_dynamic_kernel_map_.try_emplace(graph->graph_id(), dynamic_kernels);
|
||||||
|
if (!ret.second) {
|
||||||
|
MS_LOG(ERROR) << "Graph:" << graph->graph_id() << " already generator executor";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "GenDynamicKernel end";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
|
bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
|
||||||
InnerSetContext();
|
InnerSetContext();
|
||||||
|
if (graph->is_dynamic_shape()) {
|
||||||
|
MS_LOG(INFO) << "Dynamic Shape Graph Generate Dynamic kernel";
|
||||||
|
return GenDynamicKernel(graph);
|
||||||
|
}
|
||||||
if (graph == nullptr) {
|
if (graph == nullptr) {
|
||||||
MS_EXCEPTION(NotExistsError) << "session::KernelGraph is NULL!";
|
MS_EXCEPTION(NotExistsError) << "session::KernelGraph is NULL!";
|
||||||
}
|
}
|
||||||
|
@ -407,6 +459,11 @@ bool AscendKernelRuntime::GenTask(const session::KernelGraph *graph) {
|
||||||
|
|
||||||
bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
|
bool AscendKernelRuntime::LoadTask(const session::KernelGraph *graph) {
|
||||||
InnerSetContext();
|
InnerSetContext();
|
||||||
|
if (graph->is_dynamic_shape()) {
|
||||||
|
MS_LOG(INFO) << "Dynamic Shape Graph Skip Load Task Step";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
if (graph == nullptr) {
|
if (graph == nullptr) {
|
||||||
MS_EXCEPTION(NotExistsError) << "Null pointer graph, LoadTask failed. ";
|
MS_EXCEPTION(NotExistsError) << "Null pointer graph, LoadTask failed. ";
|
||||||
}
|
}
|
||||||
|
@ -520,9 +577,70 @@ bool AscendKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink, De
|
||||||
return ret;
|
return ret;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool AscendKernelRuntime::RunDynamicKernelAsync(const session::KernelGraph *graph) {
|
||||||
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
MS_LOG(INFO) << "RunExecutorAsync start. GraphId:" << graph->graph_id();
|
||||||
|
|
||||||
|
auto iter = graph_dynamic_kernel_map_.find(graph->graph_id());
|
||||||
|
if (iter == graph_dynamic_kernel_map_.end()) {
|
||||||
|
MS_LOG(ERROR) << "GraphId:" << graph->graph_id() << " Not Found! Please generator executor first";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Profiling Init
|
||||||
|
auto &async_profiler = AscendProfiler::GetInstance();
|
||||||
|
auto &rt_callback = CallbackManager::GetInstance(stream_);
|
||||||
|
rt_callback.Init();
|
||||||
|
|
||||||
|
auto dynamic_kernels = iter->second;
|
||||||
|
for (const auto &dynamic_kernel : dynamic_kernels) {
|
||||||
|
if (dynamic_kernel->have_depends()) {
|
||||||
|
MS_LOG(INFO) << "Match Dynamic Kernel, Start SyncStream";
|
||||||
|
if (!SyncStream()) {
|
||||||
|
MS_LOG(ERROR) << "SyncStream failed";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (dynamic_kernel->is_dynamic_shape()) {
|
||||||
|
ExecutorCallback::GetInstance().Consume();
|
||||||
|
dynamic_kernel->InferShape();
|
||||||
|
dynamic_kernel->UpdateArgs();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable profiling trace point start
|
||||||
|
rt_callback.RegisterCallback(
|
||||||
|
[&]() { RECORD_CALLBACK_EVENT(&async_profiler, dynamic_kernel->GetKernelName().c_str(), "[Callback] start"); });
|
||||||
|
|
||||||
|
dynamic_kernel->Execute();
|
||||||
|
|
||||||
|
// Enable profiling trace point end
|
||||||
|
rt_callback.RegisterCallback(
|
||||||
|
[&]() { RECORD_CALLBACK_EVENT(&async_profiler, dynamic_kernel->GetKernelName().c_str(), "[Callback] end"); });
|
||||||
|
|
||||||
|
ExecutorCallback::GetInstance().RegistCallback([&dynamic_kernel] { dynamic_kernel->PostExecute(); });
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!SyncStream()) {
|
||||||
|
MS_LOG(ERROR) << "SyncStream failed";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
ExecutorCallback::GetInstance().Consume();
|
||||||
|
|
||||||
|
rt_callback.Destroy();
|
||||||
|
async_profiler.Dump(std::cout);
|
||||||
|
async_profiler.Reset();
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
|
bool AscendKernelRuntime::RunTask(const session::KernelGraph *graph) {
|
||||||
InnerSetContext();
|
InnerSetContext();
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
if (graph->is_dynamic_shape()) {
|
||||||
|
MS_LOG(INFO) << "Dynamic Shape Graph Run Task Async";
|
||||||
|
return RunDynamicKernelAsync(graph);
|
||||||
|
}
|
||||||
|
|
||||||
MS_LOG(INFO) << "RunTask start. GraphId:" << graph->graph_id();
|
MS_LOG(INFO) << "RunTask start. GraphId:" << graph->graph_id();
|
||||||
|
|
||||||
auto context_ptr = MsContext::GetInstance();
|
auto context_ptr = MsContext::GetInstance();
|
||||||
|
@ -657,7 +775,12 @@ bool AscendKernelRuntime::DestroyHccl() {
|
||||||
MS_LOG(INFO) << "Hccl is not enable, no need to close.";
|
MS_LOG(INFO) << "Hccl is not enable, no need to close.";
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
// Dynamic Shape Hccl Finalize
|
||||||
|
if (!HcclExecutorManager::GetInstance().Finalize()) {
|
||||||
|
MS_LOG(ERROR) << "Dynamic Shape Hccl Finalize Failed";
|
||||||
|
}
|
||||||
HcclResult res = hcom_destroy();
|
HcclResult res = hcom_destroy();
|
||||||
|
|
||||||
if (res != HCCL_SUCCESS) {
|
if (res != HCCL_SUCCESS) {
|
||||||
MS_LOG(ERROR) << "Hccl destroy failed";
|
MS_LOG(ERROR) << "Hccl destroy failed";
|
||||||
return false;
|
return false;
|
||||||
|
|
|
@ -40,6 +40,8 @@ class AscendKernelRuntime : public KernelRuntime {
|
||||||
bool Init() override;
|
bool Init() override;
|
||||||
bool LoadData(session::KernelGraph *graph, Debugger *debugger) override;
|
bool LoadData(session::KernelGraph *graph, Debugger *debugger) override;
|
||||||
bool GenTask(const session::KernelGraph *graph);
|
bool GenTask(const session::KernelGraph *graph);
|
||||||
|
bool GenDynamicKernel(const session::KernelGraph *graph) override;
|
||||||
|
bool RunDynamicKernelAsync(const session::KernelGraph *graph) override;
|
||||||
bool LoadTask(const session::KernelGraph *graph);
|
bool LoadTask(const session::KernelGraph *graph);
|
||||||
bool RunTask(const session::KernelGraph *graph);
|
bool RunTask(const session::KernelGraph *graph);
|
||||||
bool Load(session::KernelGraph *graph, bool is_task_sink) override;
|
bool Load(session::KernelGraph *graph, bool is_task_sink) override;
|
||||||
|
|
|
@ -34,7 +34,7 @@ const uint32_t kHcomMaxTask = 5;
|
||||||
const uint32_t kCommonMaxTask = 350;
|
const uint32_t kCommonMaxTask = 350;
|
||||||
|
|
||||||
void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) {
|
void AscendStreamAssign::AssignStream(const NotNull<KernelGraphPtr> &graph_ptr) {
|
||||||
if (IsTaskSink()) {
|
if (IsTaskSink() && !graph_ptr->is_dynamic_shape()) {
|
||||||
Reset();
|
Reset();
|
||||||
SetLoopSink();
|
SetLoopSink();
|
||||||
ReorderIndependentOrders(graph_ptr);
|
ReorderIndependentOrders(graph_ptr);
|
||||||
|
|
|
@ -24,7 +24,7 @@
|
||||||
#include "runtime/mem.h"
|
#include "runtime/mem.h"
|
||||||
#include "runtime/kernel.h"
|
#include "runtime/kernel.h"
|
||||||
#include "runtime/rt_model.h"
|
#include "runtime/rt_model.h"
|
||||||
#include "runtime/device/ascend/dump/ge_dump.h"
|
#include "runtime/device/ascend/ge_types_convert.h"
|
||||||
#include "proto/op_mapping_info.pb.h"
|
#include "proto/op_mapping_info.pb.h"
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
#include "debug/data_dump/dump_json_parser.h"
|
#include "debug/data_dump/dump_json_parser.h"
|
||||||
|
@ -369,13 +369,13 @@ void DataDumper::DumpKernelOutput(const CNodePtr &kernel, void *args, NotNull<ai
|
||||||
auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel, i);
|
auto output_shape = AnfAlgo::GetOutputDeviceShape(kernel, i);
|
||||||
|
|
||||||
aicpu::dump::Output output;
|
aicpu::dump::Output output;
|
||||||
output.set_data_type(GetGeDataType(data_type));
|
output.set_data_type(GeTypesConvert::GetGeDataType(data_type));
|
||||||
output.set_format(GetGeFormat(output_format, output_shape.size()));
|
output.set_format(GeTypesConvert::GetGeFormat(output_format, output_shape.size()));
|
||||||
MS_EXCEPTION_IF_NULL(output.mutable_shape());
|
MS_EXCEPTION_IF_NULL(output.mutable_shape());
|
||||||
for (auto dim : output_shape) {
|
for (auto dim : output_shape) {
|
||||||
output.mutable_shape()->add_dim(dim);
|
output.mutable_shape()->add_dim(dim);
|
||||||
}
|
}
|
||||||
output.set_original_output_format(GetGeFormat(output_format, output_shape.size()));
|
output.set_original_output_format(GeTypesConvert::GetGeFormat(output_format, output_shape.size()));
|
||||||
output.set_address(static_cast<uint64_t>(reinterpret_cast<uintptr_t>(args)) + offset);
|
output.set_address(static_cast<uint64_t>(reinterpret_cast<uintptr_t>(args)) + offset);
|
||||||
// device address data size
|
// device address data size
|
||||||
auto address = AnfAlgo::GetOutputAddr(kernel, i);
|
auto address = AnfAlgo::GetOutputAddr(kernel, i);
|
||||||
|
@ -409,8 +409,8 @@ void DataDumper::DumpKernelInput(const CNodePtr &kernel, void *args, NotNull<aic
|
||||||
}
|
}
|
||||||
auto output_shape = AnfAlgo::GetOutputDeviceShape(input_node, input_index);
|
auto output_shape = AnfAlgo::GetOutputDeviceShape(input_node, input_index);
|
||||||
|
|
||||||
input.set_data_type(GetGeDataType(output_type));
|
input.set_data_type(GeTypesConvert::GetGeDataType(output_type));
|
||||||
input.set_format(GetGeFormat(output_format, output_shape.size()));
|
input.set_format(GeTypesConvert::GetGeFormat(output_format, output_shape.size()));
|
||||||
MS_EXCEPTION_IF_NULL(input.mutable_shape());
|
MS_EXCEPTION_IF_NULL(input.mutable_shape());
|
||||||
for (auto dim : output_shape) {
|
for (auto dim : output_shape) {
|
||||||
input.mutable_shape()->add_dim(dim);
|
input.mutable_shape()->add_dim(dim);
|
||||||
|
|
|
@ -0,0 +1,182 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "runtime/device/ascend/executor/ai_core_dynamic_kernel.h"
|
||||||
|
|
||||||
|
#include <regex>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <memory>
|
||||||
|
#include "framework/common/debug/log.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
#include "runtime/device/ascend/executor/tiling/op_tiling_calculater.h"
|
||||||
|
#include "register/op_tiling.h"
|
||||||
|
#include "utils/convert_utils_base.h"
|
||||||
|
#include "utils/ms_context.h"
|
||||||
|
#include "runtime/device/kernel_runtime_manager.h"
|
||||||
|
#include "pipeline/jit/static_analysis/static_analysis.h"
|
||||||
|
#include "common/trans.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
namespace ascend {
|
||||||
|
AiCoreDynamicKernel::~AiCoreDynamicKernel() {
|
||||||
|
if (tiling_data_ptr_ != nullptr) {
|
||||||
|
auto ret = rtFree(tiling_data_ptr_);
|
||||||
|
if (ret != RT_ERROR_NONE) {
|
||||||
|
MS_LOG(ERROR) << "rtFree tiling_data_ptr_ failed";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AiCoreDynamicKernel::Execute() {
|
||||||
|
if (stream_ == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "stream_ptr should not be nullptr.";
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Start Execute node:" << cnode_ptr_->fullname_with_scope();
|
||||||
|
rtL2Ctrl_t *l2ctrl = nullptr;
|
||||||
|
auto args_size = static_cast<uint32_t>(UlongToUint(sizeof(void *)) * runtime_args_.size());
|
||||||
|
if (RT_ERROR_NONE != rtKernelLaunch(stub_func_, block_dim_, runtime_args_.data(), args_size, l2ctrl, stream_)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Call runtime rtKernelLaunch error.";
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "End Execute node:" << cnode_ptr_->fullname_with_scope();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string ReplaceInvalidJsonStr(const std::string &str) {
|
||||||
|
auto ret = std::regex_replace(str, std::regex("100000000"), R"("100000000")");
|
||||||
|
ret = std::regex_replace(ret, std::regex("100000001"), R"("100000001")");
|
||||||
|
ret = std::regex_replace(ret, std::regex("100000002"), R"("100000002")");
|
||||||
|
ret = std::regex_replace(ret, std::regex("True"), R"(true)");
|
||||||
|
ret = std::regex_replace(ret, std::regex("False"), R"(false)");
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AiCoreDynamicKernel::ParseCompileJson() {
|
||||||
|
if (!AnfAlgo::IsDynamicShape(cnode_ptr_)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (!AnfAlgo::HasNodeAttr(kAttrCompileInfo, cnode_ptr_)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Get compile_info failed";
|
||||||
|
}
|
||||||
|
auto compile_info_attr = AnfAlgo::GetNodeAttr<std::string>(cnode_ptr_, kAttrCompileInfo);
|
||||||
|
std::replace(compile_info_attr.begin(), compile_info_attr.end(), '\'', '\"');
|
||||||
|
compile_info_attr = ReplaceInvalidJsonStr(compile_info_attr);
|
||||||
|
MS_LOG(INFO) << "Get compile_info:" << compile_info_attr;
|
||||||
|
|
||||||
|
try {
|
||||||
|
compile_info_json_ = std::make_shared<nlohmann::json>(nlohmann::json::parse(compile_info_attr));
|
||||||
|
} catch (nlohmann::json::parse_error &e) {
|
||||||
|
MS_LOG(EXCEPTION) << "parse json failed, error:" << e.what();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (AnfAlgo::HasNodeAttr(kAttrFusionType, cnode_ptr_)) {
|
||||||
|
auto fusion_type = AnfAlgo::GetNodeAttr<std::string>(cnode_ptr_, kAttrFusionType);
|
||||||
|
MS_LOG(INFO) << "Get fusion_type:" << fusion_type;
|
||||||
|
(*compile_info_json_)["_pattern"] = fusion_type;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AiCoreDynamicKernel::Initialize() {
|
||||||
|
DynamicKernel::Initialize();
|
||||||
|
ParseCompileJson();
|
||||||
|
}
|
||||||
|
|
||||||
|
void AiCoreDynamicKernel::UpdateArgs() {
|
||||||
|
ComputeTiling();
|
||||||
|
|
||||||
|
if (!CopyTilingToDevice()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Copy tiling to device failed";
|
||||||
|
}
|
||||||
|
|
||||||
|
AllocateWorkspace();
|
||||||
|
|
||||||
|
auto kernel_mod = AnfAlgo::GetKernelMod(cnode_ptr_);
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||||
|
|
||||||
|
AddressPtrList kernel_inputs;
|
||||||
|
AddressPtrList kernel_workspaces;
|
||||||
|
AddressPtrList kernel_outputs;
|
||||||
|
KernelRuntime::GenLaunchArgs(*kernel_mod, cnode_ptr_, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||||
|
|
||||||
|
runtime_args_.clear();
|
||||||
|
|
||||||
|
(void)std::transform(std::begin(kernel_inputs), std::end(kernel_inputs), std::back_inserter(runtime_args_),
|
||||||
|
[](const AddressPtr &input) -> void * { return input->addr; });
|
||||||
|
(void)std::transform(std::begin(kernel_outputs), std::end(kernel_outputs), std::back_inserter(runtime_args_),
|
||||||
|
[](const AddressPtr &output) -> void * { return output->addr; });
|
||||||
|
// Update workspace
|
||||||
|
if (!workspace_addr_.empty()) {
|
||||||
|
(void)std::transform(std::begin(workspace_addr_), std::end(workspace_addr_), std::back_inserter(runtime_args_),
|
||||||
|
[](const DeviceAddressPtr &address_ptr) -> void * { return address_ptr->GetMutablePtr(); });
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_dynamic_shape_ && !tiling_data_.empty() && tiling_data_ptr_ != nullptr) {
|
||||||
|
runtime_args_.push_back(tiling_data_ptr_);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AiCoreDynamicKernel::ComputeTiling() {
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode_ptr_);
|
||||||
|
MS_LOG(INFO) << "Start compute tiling of:" << cnode_ptr_->fullname_with_scope();
|
||||||
|
optiling::OpRunInfo op_run_info;
|
||||||
|
|
||||||
|
OpTilingCalculater::GetInstance().CalculateTiling(NOT_NULL(cnode_ptr_), NOT_NULL(compile_info_json_),
|
||||||
|
depend_tensor_map_, NOT_NULL(&op_run_info));
|
||||||
|
block_dim_ = op_run_info.block_dim;
|
||||||
|
workspaces_size_ = op_run_info.workspaces;
|
||||||
|
tiling_data_ = op_run_info.tiling_data.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
void AiCoreDynamicKernel::AllocateWorkspace() {
|
||||||
|
auto ms_context = MsContext::GetInstance();
|
||||||
|
MS_EXCEPTION_IF_NULL(ms_context);
|
||||||
|
auto device_id = ms_context->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||||
|
auto runtime_instance = KernelRuntimeManager::Instance().GetSingleKernelRuntime(kAscendDevice, device_id);
|
||||||
|
MS_EXCEPTION_IF_NULL(runtime_instance);
|
||||||
|
|
||||||
|
workspace_addr_.clear();
|
||||||
|
for (auto size : workspaces_size_) {
|
||||||
|
auto device_address_ptr = std::make_shared<AscendDeviceAddress>(nullptr, size);
|
||||||
|
auto device_ptr = runtime_instance->MallocMem(MemType::kDynamicMem, size, device_address_ptr);
|
||||||
|
if (device_ptr == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "MallocMem from memory pool failed";
|
||||||
|
}
|
||||||
|
workspace_addr_.emplace_back(device_address_ptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AiCoreDynamicKernel::CopyTilingToDevice() {
|
||||||
|
if (tiling_data_.size() > op_para_size_) {
|
||||||
|
MS_LOG(EXCEPTION) << "compute tiling size:" << tiling_data_.size()
|
||||||
|
<< " larger than tbe build op_para_size:" << op_para_size_;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tiling_data_.empty() || tiling_data_ptr_ == nullptr) {
|
||||||
|
MS_LOG(INFO) << "tiling size is 0, skip rtMemcpyAsync";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ret = rtMemcpyAsync(tiling_data_ptr_, tiling_data_.size(), tiling_data_.c_str(), tiling_data_.size(),
|
||||||
|
RT_MEMCPY_HOST_TO_DEVICE_EX, stream_);
|
||||||
|
if (ret != RT_ERROR_NONE) {
|
||||||
|
MS_LOG(EXCEPTION) << "tiling rtMemcpyAsync failed, ret:" << ret;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AiCoreDynamicKernel::PostExecute() {}
|
||||||
|
} // namespace ascend
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,70 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AI_CORE_DYNAMIC_KERNEL_H_
|
||||||
|
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AI_CORE_DYNAMIC_KERNEL_H_
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
#include <map>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include "nlohmann/json.hpp"
|
||||||
|
#include "ir/tensor.h"
|
||||||
|
#include "runtime/device/device_address.h"
|
||||||
|
#include "mindspore/ccsrc/runtime/device/executor/dynamic_kernel.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
namespace ascend {
|
||||||
|
class AiCoreDynamicKernel : public DynamicKernel {
|
||||||
|
public:
|
||||||
|
AiCoreDynamicKernel(const void *stub_fubc, uint32_t block_dim, void *tiling_data_ptr, uint32_t op_para_size,
|
||||||
|
void *stream, const CNodePtr &cnode_ptr, const std::vector<void *> &runtime_args)
|
||||||
|
: DynamicKernel(stream, cnode_ptr),
|
||||||
|
stub_func_(stub_fubc),
|
||||||
|
block_dim_(block_dim),
|
||||||
|
tiling_data_ptr_(tiling_data_ptr),
|
||||||
|
op_para_size_(op_para_size),
|
||||||
|
runtime_args_(runtime_args) {}
|
||||||
|
~AiCoreDynamicKernel() override;
|
||||||
|
|
||||||
|
void Execute() override;
|
||||||
|
void UpdateArgs() override;
|
||||||
|
void Initialize() override;
|
||||||
|
void PostExecute() override;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void AllocateWorkspace();
|
||||||
|
void ParseCompileJson();
|
||||||
|
|
||||||
|
private:
|
||||||
|
const void *stub_func_;
|
||||||
|
uint32_t block_dim_;
|
||||||
|
void *tiling_data_ptr_; // device ptr
|
||||||
|
uint32_t op_para_size_; // size of tiling_data_ptr_
|
||||||
|
std::vector<void *> runtime_args_;
|
||||||
|
std::string tiling_data_;
|
||||||
|
std::vector<int64_t> workspaces_size_;
|
||||||
|
std::vector<DeviceAddressPtr> workspace_addr_;
|
||||||
|
std::shared_ptr<nlohmann::json> compile_info_json_;
|
||||||
|
|
||||||
|
void ComputeTiling();
|
||||||
|
bool CopyTilingToDevice();
|
||||||
|
};
|
||||||
|
} // namespace ascend
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AI_CORE_DYNAMIC_KERNEL_H_
|
|
@ -0,0 +1,204 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "runtime/device/ascend/executor/ai_cpu_dynamic_kernel.h"
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <algorithm>
|
||||||
|
#include "runtime/mem.h"
|
||||||
|
#include "runtime/kernel.h"
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
#include "backend/kernel_compiler/aicpu/aicpu_util.h"
|
||||||
|
#include "runtime/device/ascend/executor/executor_callback.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
namespace ascend {
|
||||||
|
AiCpuDynamicKernel::~AiCpuDynamicKernel() {
|
||||||
|
// free dev ptr
|
||||||
|
if (ext_info_addr_dev_ == nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto ret = rtFree(ext_info_addr_dev_);
|
||||||
|
if (ret != RT_ERROR_NONE) {
|
||||||
|
MS_LOG(ERROR) << "rtFree failed";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AiCpuDynamicKernel::UpdateArgs() {
|
||||||
|
if (!UpdateInputOutputAddr()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Update input output failed";
|
||||||
|
}
|
||||||
|
|
||||||
|
if (is_dynamic_shape_ && !UpdateExtInfo()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Update ExtInfo failed";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AiCpuDynamicKernel::Execute() {
|
||||||
|
MS_LOG(INFO) << "Execute AiCpuDynamicKerenl Start";
|
||||||
|
auto ret = rtCpuKernelLaunchWithFlag(
|
||||||
|
reinterpret_cast<const void *>(so_name_.c_str()), reinterpret_cast<const void *>(kernel_name_.c_str()), 1,
|
||||||
|
reinterpret_cast<const void *>(args_.data()), args_.size(), nullptr, stream_, RT_KERNEL_DEFAULT);
|
||||||
|
if (ret != RT_ERROR_NONE) {
|
||||||
|
MS_LOG(EXCEPTION) << "Call rtCpuKernelLaunchWithFlag Failed";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void AiCpuDynamicKernel::Initialize() {
|
||||||
|
// is dynamic
|
||||||
|
MS_LOG(INFO) << "Initialize node:" << cnode_ptr_->fullname_with_scope();
|
||||||
|
DynamicKernel::Initialize();
|
||||||
|
|
||||||
|
input_num_ = AnfAlgo::GetInputTensorNum(cnode_ptr_);
|
||||||
|
output_num_ = AnfAlgo::GetOutputTensorNum(cnode_ptr_);
|
||||||
|
|
||||||
|
// Parse aicpu ext info
|
||||||
|
if (is_dynamic_shape_) {
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode_ptr_);
|
||||||
|
ext_info_handler_ =
|
||||||
|
std::make_shared<AicpuExtInfoHandler>(cnode_ptr_->fullname_with_scope(), input_num_, output_num_, DEPEND_COMPUTE);
|
||||||
|
ext_info_handler_->Parse(ext_info_data_);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ext_info_data_.empty()) {
|
||||||
|
MS_LOG(INFO) << "No need to copy to device, ext_info_data_ is empty. ";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Allocate ext info addr in device
|
||||||
|
auto ret = rtMalloc(&ext_info_addr_dev_, ext_info_data_.size(), RT_MEMORY_HBM);
|
||||||
|
if (ret != RT_ERROR_NONE) {
|
||||||
|
MS_LOG(EXCEPTION) << "Call rtMalloc ext_info_addr_dev_ failed";
|
||||||
|
}
|
||||||
|
ext_info_size_ = ext_info_data_.size();
|
||||||
|
|
||||||
|
ret = rtMemcpy(ext_info_addr_dev_, ext_info_size_, ext_info_data_.data(), ext_info_data_.size(),
|
||||||
|
RT_MEMCPY_HOST_TO_DEVICE);
|
||||||
|
if (ret != RT_ERROR_NONE) {
|
||||||
|
MS_LOG(EXCEPTION) << "Call rtMemcpy ext_info_addr_dev_ failed";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto aicpu_param_head = reinterpret_cast<kernel::AicpuParamHead *>(args_.data());
|
||||||
|
aicpu_param_head->extInfoLength = ext_info_size_;
|
||||||
|
aicpu_param_head->extInfoAddr = reinterpret_cast<uint64_t>(ext_info_addr_dev_);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AiCpuDynamicKernel::UpdateInputOutputAddr() {
|
||||||
|
std::vector<uint64_t> io_addrs;
|
||||||
|
io_addrs.reserve(input_num_ + output_num_);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < input_num_; ++i) {
|
||||||
|
auto input_addr = AnfAlgo::GetPrevNodeOutputAddr(cnode_ptr_, i);
|
||||||
|
io_addrs.emplace_back(reinterpret_cast<uintptr_t>(input_addr->GetMutablePtr()));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < output_num_; ++i) {
|
||||||
|
auto output_addr = AnfAlgo::GetOutputAddr(cnode_ptr_, i);
|
||||||
|
io_addrs.emplace_back(reinterpret_cast<uintptr_t>(output_addr->GetMutablePtr()));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (args_.empty()) {
|
||||||
|
MS_LOG(ERROR) << "args_ is empty";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto io_ptr = args_.data() + sizeof(kernel::AicpuParamHead);
|
||||||
|
auto ret =
|
||||||
|
memcpy_s(io_ptr, args_.size() - sizeof(kernel::AicpuParamHead), &io_addrs[0], sizeof(uint64_t) * io_addrs.size());
|
||||||
|
if (ret != 0) {
|
||||||
|
MS_LOG(EXCEPTION) << "Memcpy input output addr failed";
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AiCpuDynamicKernel::UpdateExtInfo() {
|
||||||
|
MS_LOG(INFO) << "UpdateExtInfo of " << cnode_ptr_->fullname_with_scope() << " start";
|
||||||
|
if (input_num_ == 0 && output_num_ == 0) {
|
||||||
|
MS_LOG(INFO) << "Node:" << cnode_ptr_->fullname_with_scope() << " no need to update output shape";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < input_num_; ++i) {
|
||||||
|
ext_info_handler_->UpdateInputShapeAndType(i, NOT_NULL(cnode_ptr_));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (unknow_type_ != DEPEND_COMPUTE) {
|
||||||
|
for (size_t i = 0; i < output_num_; ++i) {
|
||||||
|
ext_info_handler_->UpdateOutputShapeAndType(i, NOT_NULL(cnode_ptr_));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ret = rtMemcpy(ext_info_addr_dev_, ext_info_size_, ext_info_handler_->GetExtInfo(),
|
||||||
|
ext_info_handler_->GetExtInfoLen(), RT_MEMCPY_HOST_TO_DEVICE);
|
||||||
|
if (ret != RT_ERROR_NONE) {
|
||||||
|
MS_LOG(ERROR) << "UpdateExtInfo rtMemcpy failed";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "UpdateExtInfo of " << cnode_ptr_->fullname_with_scope() << " end";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AiCpuDynamicKernel::UpdateOutputShapeFromExtInfo() {
|
||||||
|
if (input_num_ == 0) {
|
||||||
|
MS_LOG(WARNING) << "input num is 0";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "UpdateOutputShapeFromExtInfo start";
|
||||||
|
auto ret = rtMemcpy(ext_info_handler_->GetExtInfo(), ext_info_handler_->GetExtInfoLen(), ext_info_addr_dev_,
|
||||||
|
ext_info_size_, RT_MEMCPY_DEVICE_TO_HOST);
|
||||||
|
if (ret != RT_ERROR_NONE) {
|
||||||
|
MS_LOG(ERROR) << "rtMemcpy output shape failed";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "rtMemcpy from device to host success";
|
||||||
|
|
||||||
|
std::vector<TypeId> type_ids;
|
||||||
|
std::vector<std::vector<size_t>> shapes;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < output_num_; ++i) {
|
||||||
|
MS_LOG(INFO) << "Get output:" << output_num_ << " Shape";
|
||||||
|
std::vector<int64_t> shape;
|
||||||
|
TypeId type_id;
|
||||||
|
ext_info_handler_->GetOutputShapeAndType(i, NOT_NULL(&shape), NOT_NULL(&type_id));
|
||||||
|
|
||||||
|
for (auto x : shape) {
|
||||||
|
MS_LOG(INFO) << "Update output:" << i << " shape:" << x;
|
||||||
|
}
|
||||||
|
|
||||||
|
type_ids.emplace_back(type_id);
|
||||||
|
std::vector<size_t> size_t_shape;
|
||||||
|
std::transform(shape.begin(), shape.end(), std::back_inserter(size_t_shape), LongToSize);
|
||||||
|
shapes.emplace_back(size_t_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
AnfAlgo::SetOutputInferTypeAndShape(type_ids, shapes, cnode_ptr_.get());
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AiCpuDynamicKernel::PostExecute() {
|
||||||
|
MS_LOG(INFO) << "Aicpu " << cnode_ptr_->fullname_with_scope() << " PostExecute";
|
||||||
|
if (AnfAlgo::IsDynamicShape(cnode_ptr_) && unknow_type_ == DEPEND_COMPUTE) {
|
||||||
|
MS_LOG(INFO) << "Update aicpu kernel output shape from ext_info";
|
||||||
|
UpdateOutputShapeFromExtInfo();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace ascend
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,76 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AI_CPU_DYNAMIC_KERNEL_H_
|
||||||
|
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AI_CPU_DYNAMIC_KERNEL_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include "runtime/device/executor/dynamic_kernel.h"
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "runtime/device/ascend/executor/aicpu_ext_info_handle.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
namespace ascend {
|
||||||
|
class AiCpuDynamicKernel : public DynamicKernel {
|
||||||
|
public:
|
||||||
|
AiCpuDynamicKernel(void *stream, const CNodePtr &cnode_ptr, const std::string &args, const std::string &ext_info_data,
|
||||||
|
const std::string &so_name, const std::string &kernel_name)
|
||||||
|
: DynamicKernel(stream, cnode_ptr),
|
||||||
|
args_(args),
|
||||||
|
ext_info_data_(ext_info_data),
|
||||||
|
so_name_(so_name),
|
||||||
|
kernel_name_(kernel_name),
|
||||||
|
ext_info_handler_(nullptr),
|
||||||
|
ext_info_addr_dev_(nullptr),
|
||||||
|
ext_info_size_(0),
|
||||||
|
input_num_(0),
|
||||||
|
output_num_(0),
|
||||||
|
unknow_type_(DEPEND_COMPUTE) {}
|
||||||
|
|
||||||
|
~AiCpuDynamicKernel() override;
|
||||||
|
|
||||||
|
void UpdateArgs() override;
|
||||||
|
void Execute() override;
|
||||||
|
void Initialize() override;
|
||||||
|
void PostExecute() override;
|
||||||
|
|
||||||
|
// Get Compute Shape from ExtInfo
|
||||||
|
bool UpdateOutputShapeFromExtInfo();
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string args_;
|
||||||
|
std::string ext_info_data_;
|
||||||
|
std::string so_name_;
|
||||||
|
std::string kernel_name_;
|
||||||
|
|
||||||
|
std::shared_ptr<AicpuExtInfoHandler> ext_info_handler_;
|
||||||
|
void *ext_info_addr_dev_;
|
||||||
|
size_t ext_info_size_;
|
||||||
|
|
||||||
|
size_t input_num_;
|
||||||
|
size_t output_num_;
|
||||||
|
|
||||||
|
UnknowShapeOpType unknow_type_;
|
||||||
|
|
||||||
|
bool UpdateInputOutputAddr();
|
||||||
|
bool UpdateExtInfo();
|
||||||
|
};
|
||||||
|
} // namespace ascend
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AI_CPU_DYNAMIC_KERNEL_H_
|
|
@ -0,0 +1,218 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "runtime/device/ascend/executor/aicpu_ext_info_handle.h"
|
||||||
|
#include <algorithm>
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
#include "backend/kernel_compiler/aicpu/aicpu_util.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
namespace ascend {
|
||||||
|
namespace {
|
||||||
|
// if dim count is not reach kMaxShapeDims(8), use INT64_MIN to mark dim end.
|
||||||
|
constexpr int64_t kDimEndFlag = INT64_MIN;
|
||||||
|
} // namespace
|
||||||
|
bool AicpuExtInfoHandler::Parse(const std::string &ext_info) {
|
||||||
|
MS_LOG(INFO) << "Parse Node:" << node_name_ << " start";
|
||||||
|
if (ext_info.empty()) {
|
||||||
|
MS_LOG(ERROR) << "Node:" << node_name_ << " ext_info is empty";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
ext_info_len_ = ext_info.size();
|
||||||
|
ext_info_.reset(new (std::nothrow) uint8_t[ext_info_len_]);
|
||||||
|
MS_EXCEPTION_IF_NULL(ext_info_);
|
||||||
|
|
||||||
|
(void)memcpy_s(ext_info_.get(), ext_info_len_, ext_info.c_str(), ext_info.size());
|
||||||
|
|
||||||
|
input_shape_and_type_.clear();
|
||||||
|
output_shape_and_type_.clear();
|
||||||
|
|
||||||
|
auto ext_info_data = ext_info_.get();
|
||||||
|
size_t offset = 0;
|
||||||
|
while (offset + sizeof(AicpuExtInfo) <= ext_info_len_) {
|
||||||
|
auto aicpu_ext_info = reinterpret_cast<AicpuExtInfo *>(ext_info_data + offset);
|
||||||
|
MS_EXCEPTION_IF_NULL(aicpu_ext_info);
|
||||||
|
switch (aicpu_ext_info->infoType) {
|
||||||
|
case kernel::FWK_ADPT_EXT_SHAPE_TYPE:
|
||||||
|
if (!ParseExtShapeType(aicpu_ext_info)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Parse ext shape type failed.";
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case kernel::FWK_ADPT_EXT_INPUT_SHAPE:
|
||||||
|
if (!ParseExtInputShape(aicpu_ext_info)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Parse ext input shape failed.";
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
case kernel::FWK_ADPT_EXT_OUTPUT_SHAPE:
|
||||||
|
if (!ParseExtOutputShape(aicpu_ext_info)) {
|
||||||
|
MS_LOG(EXCEPTION) << "Parse ext output shape failed.";
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
MS_LOG(INFO) << "Ignore Node:" << node_name_ << " infoType:" << aicpu_ext_info->infoType
|
||||||
|
<< " infoLen:" << aicpu_ext_info->infoLen;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
offset += sizeof(AicpuExtInfo);
|
||||||
|
offset += aicpu_ext_info->infoLen;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (offset != ext_info_len_) {
|
||||||
|
MS_LOG(EXCEPTION) << "Node:" << node_name_ << " ext_info format error, parse not reach end, offset=" << offset
|
||||||
|
<< ", ext_info_len" << ext_info_len_;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Node:" << node_name_ << " parse ext info end.";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AicpuExtInfoHandler::ParseExtShapeType(AicpuExtInfo *aicpu_ext_info) {
|
||||||
|
if (aicpu_ext_info->infoLen != sizeof(int32_t)) {
|
||||||
|
MS_LOG(ERROR) << "Node:" << node_name_ << " parse ext shape type failed as infoLen must be " << sizeof(int32_t)
|
||||||
|
<< " but got:" << aicpu_ext_info->infoLen;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto type = reinterpret_cast<const int32_t *>(aicpu_ext_info->infoMsg);
|
||||||
|
|
||||||
|
if (*type != unknown_type_) {
|
||||||
|
MS_LOG(ERROR) << "Node:" << node_name_ << " parse ext shape type failed as need:" << unknown_type_
|
||||||
|
<< " but got:" << *type;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Node:" << node_name_ << "parse ext shape type success infoLen=" << aicpu_ext_info->infoLen;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AicpuExtInfoHandler::ParseExtInputShape(AicpuExtInfo *aicpu_ext_info) {
|
||||||
|
auto need_len = input_num_ * sizeof(AicpuShapeAndType);
|
||||||
|
|
||||||
|
if (aicpu_ext_info->infoLen != need_len) {
|
||||||
|
MS_LOG(ERROR) << "Node:" << node_name_
|
||||||
|
<< " parse ext input shape failed as aicpu_ext_info->infoLen:" << aicpu_ext_info->infoLen
|
||||||
|
<< " and need_len:" << need_len;
|
||||||
|
}
|
||||||
|
auto input = reinterpret_cast<AicpuShapeAndType *>(aicpu_ext_info->infoMsg);
|
||||||
|
|
||||||
|
for (uint32_t index = 0; index < input_num_; ++index) {
|
||||||
|
input_shape_and_type_.emplace_back(&input[index]);
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Node:" << node_name_.c_str() << " parse ext input shape success infoLen=" << aicpu_ext_info->infoLen;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AicpuExtInfoHandler::ParseExtOutputShape(AicpuExtInfo *aicpu_ext_info) {
|
||||||
|
auto need_len = output_num_ * sizeof(AicpuShapeAndType);
|
||||||
|
if (aicpu_ext_info->infoLen != need_len) {
|
||||||
|
MS_LOG(INFO) << "Node:" << node_name_
|
||||||
|
<< " parse ext output shape failed, aicpu_ext_info->infoLen:" << aicpu_ext_info->infoLen
|
||||||
|
<< " need_len:" << need_len;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto output = reinterpret_cast<AicpuShapeAndType *>(aicpu_ext_info->infoMsg);
|
||||||
|
for (uint32_t index = 0; index < output_num_; ++index) {
|
||||||
|
output_shape_and_type_.emplace_back(&output[index]);
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Node:" << node_name_ << " parse ext output shape success infoLen=" << aicpu_ext_info->infoLen;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AicpuExtInfoHandler::UpdateInputShapeAndType(uint32_t input_index, const NotNull<AnfNodePtr> &anf_node) {
|
||||||
|
if (input_index >= input_num_) {
|
||||||
|
MS_LOG(ERROR) << "input_index=" << input_index << " >= input_num_:" << input_num_;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto input_shape = AnfAlgo::GetInputDeviceShape(anf_node, input_index);
|
||||||
|
auto data_type = AnfAlgo::GetInputDeviceDataType(anf_node, input_index);
|
||||||
|
std::vector<int64_t> tmp_shape;
|
||||||
|
std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(tmp_shape), SizeToLong);
|
||||||
|
return UpdateShapeAndType(tmp_shape, data_type, NOT_NULL(input_shape_and_type_[input_index]));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AicpuExtInfoHandler::UpdateOutputShapeAndType(uint32_t output_index, const NotNull<AnfNodePtr> &anf_node) {
|
||||||
|
if (output_index >= output_num_) {
|
||||||
|
MS_LOG(ERROR) << "output_index:" << output_index << " >= output_num_:" << output_num_;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto shape = AnfAlgo::GetOutputDeviceShape(anf_node, output_index);
|
||||||
|
auto max_shape = AnfAlgo::GetOutputMaxShape(anf_node, output_index);
|
||||||
|
if (shape.size() != max_shape.size()) {
|
||||||
|
MS_LOG(ERROR) << "shape size != max_shape size";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (size_t i = 0; i < shape.size(); ++i) {
|
||||||
|
if (i < max_shape.size() && shape[i] == SIZE_MAX) {
|
||||||
|
MS_LOG(INFO) << "Node:" << node_name_ << " update shape from SIZE_MAX to " << max_shape[i];
|
||||||
|
shape[i] = max_shape[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> tmp_shape;
|
||||||
|
std::transform(shape.begin(), shape.end(), std::back_inserter(tmp_shape), SizeToLong);
|
||||||
|
return UpdateShapeAndType(tmp_shape, AnfAlgo::GetOutputDeviceDataType(anf_node, output_index),
|
||||||
|
NOT_NULL(output_shape_and_type_[output_index]));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AicpuExtInfoHandler::GetOutputShapeAndType(uint32_t output_index, NotNull<std::vector<int64_t> *> shape,
|
||||||
|
NotNull<TypeId *> data_type) {
|
||||||
|
MS_LOG(INFO) << "Get " << node_name_ << " Output:" << output_index << " Shape And Type";
|
||||||
|
GetShapeAndType(NOT_NULL(output_shape_and_type_[output_index]), shape, data_type);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool AicpuExtInfoHandler::UpdateShapeAndType(const std::vector<int64_t> &shape, TypeId data_type,
|
||||||
|
NotNull<AicpuShapeAndType *> shape_and_type) {
|
||||||
|
if (shape.empty() || shape.size() > kernel::kMaxShapeDims) {
|
||||||
|
MS_LOG(ERROR) << "Invalid shape:" << shape.size();
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t index = 0;
|
||||||
|
for (; index < shape.size(); ++index) {
|
||||||
|
shape_and_type->dims[index] = shape[index];
|
||||||
|
}
|
||||||
|
if (index < kernel::kMaxShapeDims) {
|
||||||
|
shape_and_type->dims[index] = kDimEndFlag;
|
||||||
|
}
|
||||||
|
|
||||||
|
// now only support update shape, type is not support
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void AicpuExtInfoHandler::GetShapeAndType(NotNull<const AicpuShapeAndType *> shape_and_type,
|
||||||
|
NotNull<std::vector<int64_t> *> shape, NotNull<TypeId *> data_type) {
|
||||||
|
for (int64_t tmpDim : shape_and_type->dims) {
|
||||||
|
if (tmpDim == kDimEndFlag) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
shape->emplace_back(tmpDim);
|
||||||
|
MS_LOG(INFO) << "Debug tmpDim:" << tmpDim;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto ms_type = kernel::AicpuOpUtil::ProtoTypeToMsType(shape_and_type->type);
|
||||||
|
if (ms_type == -1) {
|
||||||
|
MS_LOG(EXCEPTION) << "Unspport Proto Type:" << shape_and_type->type;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Debug ms_type:" << ms_type;
|
||||||
|
*data_type = static_cast<TypeId>(ms_type);
|
||||||
|
}
|
||||||
|
} // namespace ascend
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,88 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AICPU_EXT_INFO_HANDLE_H_
|
||||||
|
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AICPU_EXT_INFO_HANDLE_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <utility>
|
||||||
|
#include <memory>
|
||||||
|
#include "backend/kernel_compiler/aicpu/aicpu_util.h"
|
||||||
|
#include "utils/contract.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
namespace ascend {
|
||||||
|
// for unknown shape op type
|
||||||
|
enum UnknowShapeOpType {
|
||||||
|
DEPEND_IN_SHAPE = 1, // op out shape get by input shape
|
||||||
|
DEPEND_CONST_VALUE = 2, // op out shape get by const op value
|
||||||
|
DEPEND_SHAPE_RANGE = 3, // op out shape get by range
|
||||||
|
DEPEND_COMPUTE = 4 // op out shape get by totally computing
|
||||||
|
};
|
||||||
|
|
||||||
|
using AicpuShapeAndType = kernel::ShapeAndType;
|
||||||
|
using AicpuExtInfo = kernel::ExtInfo;
|
||||||
|
|
||||||
|
class AicpuExtInfoHandler {
|
||||||
|
public:
|
||||||
|
AicpuExtInfoHandler(std::string node_name, uint32_t input_num, uint32_t output_num, UnknowShapeOpType unknown_type)
|
||||||
|
: node_name_(std::move(node_name)),
|
||||||
|
input_num_(input_num),
|
||||||
|
output_num_(output_num),
|
||||||
|
unknown_type_(unknown_type),
|
||||||
|
ext_info_len_(0) {}
|
||||||
|
|
||||||
|
~AicpuExtInfoHandler() = default;
|
||||||
|
|
||||||
|
uint8_t *GetExtInfo() const { return ext_info_.get(); }
|
||||||
|
size_t GetExtInfoLen() const { return ext_info_len_; }
|
||||||
|
|
||||||
|
bool Parse(const std::string &ext_info);
|
||||||
|
|
||||||
|
bool UpdateInputShapeAndType(uint32_t input_index, const NotNull<AnfNodePtr> &anf_node);
|
||||||
|
|
||||||
|
bool UpdateOutputShapeAndType(uint32_t output_index, const NotNull<AnfNodePtr> &anf_node);
|
||||||
|
|
||||||
|
bool GetOutputShapeAndType(uint32_t output_index, NotNull<std::vector<int64_t> *> shape, NotNull<TypeId *> data_type);
|
||||||
|
|
||||||
|
private:
|
||||||
|
bool ParseExtShapeType(AicpuExtInfo *aicpu_ext_info);
|
||||||
|
bool ParseExtInputShape(AicpuExtInfo *aicpu_ext_info);
|
||||||
|
bool ParseExtOutputShape(AicpuExtInfo *aicpu_ext_info);
|
||||||
|
|
||||||
|
static bool UpdateShapeAndType(const std::vector<int64_t> &shape, TypeId data_type,
|
||||||
|
NotNull<AicpuShapeAndType *> shape_and_type);
|
||||||
|
|
||||||
|
static void GetShapeAndType(NotNull<const AicpuShapeAndType *> shape_and_type, NotNull<std::vector<int64_t> *> shape,
|
||||||
|
NotNull<TypeId *> data_type);
|
||||||
|
|
||||||
|
private:
|
||||||
|
const std::string node_name_;
|
||||||
|
const uint32_t input_num_;
|
||||||
|
const uint32_t output_num_;
|
||||||
|
UnknowShapeOpType unknown_type_;
|
||||||
|
size_t ext_info_len_;
|
||||||
|
|
||||||
|
std::unique_ptr<uint8_t[]> ext_info_;
|
||||||
|
std::vector<AicpuShapeAndType *> input_shape_and_type_;
|
||||||
|
std::vector<AicpuShapeAndType *> output_shape_and_type_;
|
||||||
|
};
|
||||||
|
} // namespace ascend
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_AICPU_EXT_INFO_HANDLE_H_
|
|
@ -0,0 +1,41 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "runtime/device/ascend/executor/executor_callback.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
namespace ascend {
|
||||||
|
void ExecutorCallback::RegistCallback(const std::function<void()> &callback) {
|
||||||
|
std::lock_guard<std::mutex> guard(lock_);
|
||||||
|
callback_queue_.push(callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExecutorCallback::Consume() {
|
||||||
|
std::lock_guard<std::mutex> guard(lock_);
|
||||||
|
while (!callback_queue_.empty()) {
|
||||||
|
auto callback_func = callback_queue_.front();
|
||||||
|
callback_queue_.pop();
|
||||||
|
if (!callback_func) {
|
||||||
|
MS_LOG(EXCEPTION) << "callback_func is empty";
|
||||||
|
}
|
||||||
|
callback_func();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace ascend
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,49 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_EXECUTOR_CALLBACK_H_
|
||||||
|
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_EXECUTOR_CALLBACK_H_
|
||||||
|
|
||||||
|
#include <queue>
|
||||||
|
#include <mutex>
|
||||||
|
#include <functional>
|
||||||
|
#include "utils/ms_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
namespace ascend {
|
||||||
|
class ExecutorCallback {
|
||||||
|
public:
|
||||||
|
static ExecutorCallback &GetInstance() {
|
||||||
|
static ExecutorCallback instance;
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
void RegistCallback(const std::function<void()> &callback);
|
||||||
|
void Consume();
|
||||||
|
|
||||||
|
private:
|
||||||
|
ExecutorCallback() = default;
|
||||||
|
~ExecutorCallback() = default;
|
||||||
|
DISABLE_COPY_AND_ASSIGN(ExecutorCallback);
|
||||||
|
|
||||||
|
std::queue<std::function<void()>> callback_queue_;
|
||||||
|
std::mutex lock_;
|
||||||
|
};
|
||||||
|
} // namespace ascend
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_EXECUTOR_CALLBACK_H_
|
|
@ -0,0 +1,187 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "runtime/device/ascend/executor/hccl_dynamic_kernel.h"
|
||||||
|
|
||||||
|
#include <dlfcn.h>
|
||||||
|
#include <vector>
|
||||||
|
#include "hccl/hcom.h"
|
||||||
|
#include "common/opskernel/ge_task_info.h"
|
||||||
|
#include "utils/log_adapter.h"
|
||||||
|
#include "runtime/device/kernel_runtime.h"
|
||||||
|
#include "backend/kernel_compiler/hccl/hcom_util.h"
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Find so in RPATH or LD_LIBRARY_PATH (/usr/local/Ascend/fwkacllib/lib64/)
|
||||||
|
constexpr auto kHcomGraphAdaptorPath = "libhcom_graph_adaptor.so";
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
namespace ascend {
|
||||||
|
void HcclDynamicKernel::UpdateArgs() {
|
||||||
|
if (!is_dynamic_shape_) {
|
||||||
|
MS_LOG(INFO) << "Not Dynamic Shape";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Start to UpdateArgs";
|
||||||
|
auto kernel_mod = AnfAlgo::GetKernelMod(cnode_ptr_);
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||||
|
// Update input, output, count
|
||||||
|
AddressPtrList kernel_inputs;
|
||||||
|
AddressPtrList kernel_workspaces;
|
||||||
|
AddressPtrList kernel_outputs;
|
||||||
|
KernelRuntime::GenLaunchArgs(*kernel_mod, cnode_ptr_, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||||
|
if (kernel_inputs.empty() || kernel_outputs.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Inputs or outputs is empty";
|
||||||
|
}
|
||||||
|
auto input0 = kernel_inputs.at(0);
|
||||||
|
auto output0 = kernel_outputs.at(0);
|
||||||
|
MS_EXCEPTION_IF_NULL(input0);
|
||||||
|
MS_EXCEPTION_IF_NULL(output0);
|
||||||
|
|
||||||
|
// Update Hccl input and output
|
||||||
|
input_ptr_ = input0->addr;
|
||||||
|
output_ptr_ = output0->addr;
|
||||||
|
|
||||||
|
std::vector<std::vector<size_t>> hccl_kernel_input_shape_list;
|
||||||
|
if (!HcomUtil::GetKernelInputShape(cnode_ptr_, &hccl_kernel_input_shape_list)) {
|
||||||
|
MS_LOG(EXCEPTION) << "GetKernelInputShape fail!";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<HcclDataType> hccl_data_type_list;
|
||||||
|
if (!HcomUtil::GetHcomDataType(cnode_ptr_, &hccl_data_type_list)) {
|
||||||
|
MS_LOG(EXCEPTION) << "GetHcomDataType fail!";
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update Hccl count
|
||||||
|
if (!HcomUtil::GetHcomCount(cnode_ptr_, hccl_data_type_list, hccl_kernel_input_shape_list, &count_)) {
|
||||||
|
MS_LOG(EXCEPTION) << "GetHcomCount fail!";
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Update Hccl count:" << count_;
|
||||||
|
}
|
||||||
|
|
||||||
|
void HcclDynamicKernel::StaticShapeExecute() {
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode_ptr_);
|
||||||
|
auto kernel_mod = AnfAlgo::GetKernelMod(cnode_ptr_);
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||||
|
AddressPtrList kernel_inputs;
|
||||||
|
AddressPtrList kernel_workspaces;
|
||||||
|
AddressPtrList kernel_outputs;
|
||||||
|
KernelRuntime::GenLaunchArgs(*kernel_mod, cnode_ptr_, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||||
|
kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, stream_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void HcclDynamicKernel::Execute() {
|
||||||
|
MS_LOG(INFO) << "Start Execute";
|
||||||
|
if (!is_dynamic_shape_) {
|
||||||
|
MS_LOG(INFO) << "Not Dynamic, call hcom api";
|
||||||
|
StaticShapeExecute();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto handle = HcclExecutorManager::GetInstance().handle();
|
||||||
|
auto EnqueueHcomOperation =
|
||||||
|
(HcclResult(*)(ge::HcomOpertion, std::function<void(HcclResult status)>))dlsym(handle, "EnqueueHcomOpertion");
|
||||||
|
if (EnqueueHcomOperation == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Failed to get EnqueueHcomOperation function";
|
||||||
|
if (dlclose(handle) != 0) {
|
||||||
|
MS_LOG(WARNING) << "Failed to close hcom handle";
|
||||||
|
}
|
||||||
|
MS_LOG(EXCEPTION) << "Hccl dynamic kernel execute failed";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
ge::HcomOpertion op_info;
|
||||||
|
op_info.hcclType = hccl_type_;
|
||||||
|
op_info.inputPtr = input_ptr_;
|
||||||
|
op_info.outputPtr = output_ptr_;
|
||||||
|
op_info.dataType = data_type_;
|
||||||
|
op_info.opType = op_type_;
|
||||||
|
op_info.root = root_;
|
||||||
|
op_info.count = count_;
|
||||||
|
|
||||||
|
auto callback = [this](HcclResult status) {
|
||||||
|
if (status != HCCL_SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "HcomExcutorInitialize failed, ret:" << status;
|
||||||
|
}
|
||||||
|
std::lock_guard<std::mutex> lock(this->hccl_mutex_);
|
||||||
|
this->cond_.notify_all();
|
||||||
|
MS_LOG(INFO) << "hccl callback success.";
|
||||||
|
};
|
||||||
|
|
||||||
|
auto hccl_ret = EnqueueHcomOperation(op_info, callback);
|
||||||
|
if (hccl_ret != HCCL_SUCCESS) {
|
||||||
|
MS_LOG(EXCEPTION) << "Call EnqueueHcomOperation failed";
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unique_lock<std::mutex> ulock(hccl_mutex_);
|
||||||
|
cond_.wait(ulock);
|
||||||
|
MS_LOG(INFO) << "Execute success";
|
||||||
|
}
|
||||||
|
|
||||||
|
void HcclDynamicKernel::PostExecute() {}
|
||||||
|
|
||||||
|
bool HcclExecutorManager::Initialize() {
|
||||||
|
if (initialized_) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
initialized_ = true;
|
||||||
|
MS_LOG(INFO) << "Start Initialize Hccl DynamicKernel";
|
||||||
|
handle_ = dlopen(kHcomGraphAdaptorPath, RTLD_NOW | RTLD_GLOBAL);
|
||||||
|
if (handle_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "dlopen failed, path:" << kHcomGraphAdaptorPath;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto HcomExecutorInitialize = (HcclResult(*)())dlsym(handle_, "HcomExcutorInitialize");
|
||||||
|
if (HcomExecutorInitialize == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "dlsym HcomExecutorInitialize failed";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
HcclResult hccl_ret = HcomExecutorInitialize();
|
||||||
|
if (hccl_ret == HCCL_E_PTR) {
|
||||||
|
MS_LOG(WARNING) << "Hccl comm is null, hcom executor initialize is not required";
|
||||||
|
} else if (hccl_ret == HCCL_SUCCESS) {
|
||||||
|
MS_LOG(INFO) << "Hcom DynamicKernel Initialize success";
|
||||||
|
} else {
|
||||||
|
MS_LOG(ERROR) << "Hcom DynamicKernel Initialize failed";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool HcclExecutorManager::Finalize() {
|
||||||
|
auto HcomExecutorFinalize = (HcclResult(*)())dlsym(handle_, "HcomExcutorFinalize");
|
||||||
|
if (HcomExecutorFinalize == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Faile to dlsym HcomExecutorFinalize";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
HcclResult hccl_ret = HcomExecutorFinalize();
|
||||||
|
if (hccl_ret != HCCL_SUCCESS) {
|
||||||
|
MS_LOG(ERROR) << "Hcom DynamicKernel Finalize failed";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (dlclose(handle_) != 0) {
|
||||||
|
MS_LOG(ERROR) << "Failed to close hcom handle";
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Hccl DynamicKernel Finalize failed";
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} // namespace ascend
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,82 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_HCCL_DYNAMIC_KERNEL_H_
|
||||||
|
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_HCCL_DYNAMIC_KERNEL_H_
|
||||||
|
|
||||||
|
#include <condition_variable>
|
||||||
|
#include <string>
|
||||||
|
#include "runtime/device/executor/dynamic_kernel.h"
|
||||||
|
|
||||||
|
#include "utils/ms_utils.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
namespace ascend {
|
||||||
|
class HcclDynamicKernel : public DynamicKernel {
|
||||||
|
public:
|
||||||
|
HcclDynamicKernel(const std::string &hccl_type, void *input_ptr, void *output_ptr, uint64_t count, int32_t data_type,
|
||||||
|
int32_t op_type, int32_t root, void *stream, const CNodePtr &cnode_ptr)
|
||||||
|
: DynamicKernel(stream, cnode_ptr),
|
||||||
|
hccl_type_(hccl_type),
|
||||||
|
input_ptr_(input_ptr),
|
||||||
|
output_ptr_(output_ptr),
|
||||||
|
count_(count),
|
||||||
|
data_type_(data_type),
|
||||||
|
op_type_(op_type),
|
||||||
|
root_(root) {}
|
||||||
|
~HcclDynamicKernel() override = default;
|
||||||
|
void UpdateArgs() override;
|
||||||
|
void Execute() override;
|
||||||
|
void PostExecute() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string hccl_type_;
|
||||||
|
void *input_ptr_;
|
||||||
|
void *output_ptr_;
|
||||||
|
uint64_t count_{0};
|
||||||
|
int32_t data_type_{0};
|
||||||
|
int32_t op_type_{0};
|
||||||
|
int32_t root_{0};
|
||||||
|
std::mutex hccl_mutex_;
|
||||||
|
std::condition_variable cond_;
|
||||||
|
|
||||||
|
void StaticShapeExecute();
|
||||||
|
};
|
||||||
|
|
||||||
|
class HcclExecutorManager {
|
||||||
|
public:
|
||||||
|
static HcclExecutorManager &GetInstance() {
|
||||||
|
static HcclExecutorManager instance;
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool Initialize();
|
||||||
|
bool Finalize();
|
||||||
|
void *handle() { return handle_; }
|
||||||
|
|
||||||
|
private:
|
||||||
|
HcclExecutorManager() = default;
|
||||||
|
~HcclExecutorManager() = default;
|
||||||
|
DISABLE_COPY_AND_ASSIGN(HcclExecutorManager);
|
||||||
|
|
||||||
|
void *handle_{nullptr};
|
||||||
|
bool initialized_{false};
|
||||||
|
};
|
||||||
|
} // namespace ascend
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_HCCL_DYNAMIC_KERNEL_H_
|
|
@ -0,0 +1,36 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_HOST_DYNAMIC_KERNEL_H_
|
||||||
|
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_HOST_DYNAMIC_KERNEL_H_
|
||||||
|
|
||||||
|
#include "runtime/device/executor/dynamic_kernel.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
namespace ascend {
|
||||||
|
class HostDynamicKernel : public DynamicKernel {
|
||||||
|
public:
|
||||||
|
HostDynamicKernel(void *stream, const CNodePtr &cnode_ptr) : DynamicKernel(stream, cnode_ptr) {}
|
||||||
|
~HostDynamicKernel() override = default;
|
||||||
|
void UpdateArgs() override {}
|
||||||
|
void Execute() override = 0;
|
||||||
|
void PostExecute() override {}
|
||||||
|
};
|
||||||
|
} // namespace ascend
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_HOST_DYNAMIC_KERNEL_H_
|
|
@ -0,0 +1,32 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "runtime/device/ascend/executor/rts/memcpy_rts_dynamic_kernel.h"
|
||||||
|
|
||||||
|
#include "runtime/mem.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
namespace ascend {
|
||||||
|
void MemcpyRtsDynamicKernel::Execute() {
|
||||||
|
auto status = rtMemcpyAsync(dst_, dest_max_, src_, count_, RT_MEMCPY_DEVICE_TO_DEVICE, stream_);
|
||||||
|
if (status != RT_ERROR_NONE) {
|
||||||
|
MS_LOG(EXCEPTION) << "MemCpyAsync op rtMemcpyAsync failed!";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace ascend
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,45 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_RTS_MEMCPY_RTS_DYNAMIC_KERNEL_H_
|
||||||
|
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_RTS_MEMCPY_RTS_DYNAMIC_KERNEL_H_
|
||||||
|
|
||||||
|
#include "runtime/device/executor/dynamic_kernel.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
namespace ascend {
|
||||||
|
class MemcpyRtsDynamicKernel : public DynamicKernel {
|
||||||
|
public:
|
||||||
|
MemcpyRtsDynamicKernel(void *stream, const CNodePtr &cnode_ptr, void *dst, uint32_t dest_max, void *src,
|
||||||
|
uint32_t count)
|
||||||
|
: DynamicKernel(stream, cnode_ptr), dst_(dst), dest_max_(dest_max), src_(src), count_(count) {}
|
||||||
|
~MemcpyRtsDynamicKernel() override = default;
|
||||||
|
|
||||||
|
void UpdateArgs() override {}
|
||||||
|
void Execute() override;
|
||||||
|
void PostExecute() override {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
void *dst_;
|
||||||
|
uint32_t dest_max_;
|
||||||
|
void *src_;
|
||||||
|
uint32_t count_;
|
||||||
|
};
|
||||||
|
} // namespace ascend
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_RTS_MEMCPY_RTS_DYNAMIC_KERNEL_H_
|
|
@ -0,0 +1,32 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "runtime/device/ascend/executor/rts/profiling_rts_dynamic_kernel.h"
|
||||||
|
|
||||||
|
#include "runtime/base.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
namespace ascend {
|
||||||
|
void ProfilingRtsDynamicKernel::Execute() {
|
||||||
|
auto rt_ret = rtProfilerTrace(log_id_, notify_, flags_, stream_);
|
||||||
|
if (rt_ret != RT_ERROR_NONE) {
|
||||||
|
MS_LOG(EXCEPTION) << "Call rtProfilerTrace failed";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace ascend
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,43 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_RTS_PROFILING_RTS_DYNAMIC_KERNEL_H_
|
||||||
|
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_RTS_PROFILING_RTS_DYNAMIC_KERNEL_H_
|
||||||
|
|
||||||
|
#include "runtime/device/executor/dynamic_kernel.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
namespace ascend {
|
||||||
|
class ProfilingRtsDynamicKernel : public DynamicKernel {
|
||||||
|
public:
|
||||||
|
ProfilingRtsDynamicKernel(void *stream, const CNodePtr &cnode_ptr, uint64_t log_id, bool notify, uint32_t flags)
|
||||||
|
: DynamicKernel(stream, cnode_ptr), log_id_(log_id), notify_(notify), flags_(flags) {}
|
||||||
|
~ProfilingRtsDynamicKernel() override = default;
|
||||||
|
|
||||||
|
void UpdateArgs() override {}
|
||||||
|
void Execute() override;
|
||||||
|
void PostExecute() override {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
uint64_t log_id_;
|
||||||
|
bool notify_;
|
||||||
|
uint32_t flags_;
|
||||||
|
};
|
||||||
|
} // namespace ascend
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_RTS_PROFILING_RTS_DYNAMIC_KERNEL_H_
|
|
@ -0,0 +1,188 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "runtime/device/ascend/executor/tiling/op_tiling_calculater.h"
|
||||||
|
#include <dlfcn.h>
|
||||||
|
#include <map>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
#include "runtime/device/ascend/ge_types_convert.h"
|
||||||
|
#include "utils/utils.h"
|
||||||
|
#include "external/graph/tensor.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
namespace ascend {
|
||||||
|
ge::Tensor MakeTempGeTensor(TypeId type_id) {
|
||||||
|
auto ge_type = GeTypesConvert::TransTypeIdToGeDataType(type_id);
|
||||||
|
ge::TensorDesc tensor_desc;
|
||||||
|
tensor_desc.SetDataType(ge_type);
|
||||||
|
ge::Tensor ge_tensor;
|
||||||
|
ge_tensor.SetTensorDesc(tensor_desc);
|
||||||
|
return ge_tensor;
|
||||||
|
}
|
||||||
|
|
||||||
|
void FeedTeOpTensorInputArg(const NotNull<CNodePtr> &cnode,
|
||||||
|
NotNull<std::vector<optiling::TeOpTensorArg> *> tensor_arg_list) {
|
||||||
|
MS_LOG(INFO) << "FeedTeOpTensorInputArg start, node:" << cnode->fullname_with_scope();
|
||||||
|
auto input_size = AnfAlgo::GetInputTensorNum(cnode.get());
|
||||||
|
|
||||||
|
// Skip Dynamic Shape Depend Input
|
||||||
|
|
||||||
|
for (size_t i = 0; i < input_size; ++i) {
|
||||||
|
auto input_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode.get(), i);
|
||||||
|
auto input_node = input_node_with_index.first;
|
||||||
|
auto input_index = input_node_with_index.second;
|
||||||
|
auto output_shape = AnfAlgo::GetOutputDeviceShape(input_node, input_index);
|
||||||
|
auto output_format = AnfAlgo::GetOutputFormat(input_node, input_index);
|
||||||
|
auto output_dtype = AnfAlgo::GetOutputDeviceDataType(input_node, input_index);
|
||||||
|
auto iter = type_name_map.find(output_dtype);
|
||||||
|
if (iter == type_name_map.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Cannot found typeId:" << output_dtype;
|
||||||
|
}
|
||||||
|
auto ge_output_dtype = iter->second;
|
||||||
|
|
||||||
|
optiling::TeOpTensorArg tensor_arg;
|
||||||
|
optiling::TeOpTensor tensor;
|
||||||
|
tensor_arg.arg_type = optiling::TA_SINGLE;
|
||||||
|
tensor.dtype = ge_output_dtype;
|
||||||
|
tensor.shape.insert(tensor.shape.end(), output_shape.begin(), output_shape.end());
|
||||||
|
|
||||||
|
tensor.format = GeTypesConvert::GetGeTilingFormat(GeTypesConvert::GetGeFormat(output_format, output_shape.size()));
|
||||||
|
MS_LOG(INFO) << "Tiling Format:" << tensor.format;
|
||||||
|
tensor_arg.tensor.emplace_back(tensor);
|
||||||
|
tensor_arg_list->emplace_back(tensor_arg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void FeedTeOpTensorOutputArg(const NotNull<CNodePtr> &cnode,
|
||||||
|
NotNull<std::vector<optiling::TeOpTensorArg> *> tensor_arg_list) {
|
||||||
|
MS_LOG(INFO) << "FeedTeOpTensorOutputArg start, node:" << cnode->fullname_with_scope();
|
||||||
|
auto output_size = AnfAlgo::GetOutputTensorNum(cnode.get());
|
||||||
|
for (size_t i = 0; i < output_size; ++i) {
|
||||||
|
auto output_shape = AnfAlgo::GetOutputDeviceShape(cnode.get(), i);
|
||||||
|
auto output_format = AnfAlgo::GetOutputFormat(cnode.get(), i);
|
||||||
|
auto data_type = AnfAlgo::GetOutputDeviceDataType(cnode.get(), i);
|
||||||
|
auto iter = type_name_map.find(data_type);
|
||||||
|
if (iter == type_name_map.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Cannot found typeId:" << data_type;
|
||||||
|
}
|
||||||
|
|
||||||
|
optiling::TeOpTensorArg tensor_arg;
|
||||||
|
optiling::TeOpTensor tensor;
|
||||||
|
tensor_arg.arg_type = optiling::TA_SINGLE;
|
||||||
|
tensor.dtype = iter->second;
|
||||||
|
tensor.shape.insert(tensor.shape.end(), output_shape.begin(), output_shape.end());
|
||||||
|
tensor.format = GeTypesConvert::GetGeTilingFormat(GeTypesConvert::GetGeFormat(output_format, output_shape.size()));
|
||||||
|
MS_LOG(INFO) << "Tiling Format:" << tensor.format;
|
||||||
|
tensor_arg.tensor.emplace_back(tensor);
|
||||||
|
tensor_arg_list->emplace_back(tensor_arg);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void FeedTeOpConstTensor(const NotNull<CNodePtr> &cnode, const std::map<uint32_t, tensor::TensorPtr> &depend_tensor_map,
|
||||||
|
NotNull<std::map<std::string, optiling::TeConstTensorData> *> const_inputs) {
|
||||||
|
MS_LOG(INFO) << "FeedTeOpConstTensor start, node:" << cnode->fullname_with_scope();
|
||||||
|
if (!AnfAlgo::HasNodeAttr(kDynamicShapeDepends, cnode.get())) {
|
||||||
|
MS_LOG(INFO) << "No input depend found, " << cnode->fullname_with_scope();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto depends_list = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode.get(), kDynamicShapeDepends);
|
||||||
|
for (auto index : depends_list) {
|
||||||
|
auto iter = depend_tensor_map.find(IntToSize(index));
|
||||||
|
if (iter == depend_tensor_map.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Index not found in depend_tensor_map";
|
||||||
|
}
|
||||||
|
|
||||||
|
auto const_tensor = iter->second;
|
||||||
|
|
||||||
|
auto have_input_names_attr = AnfAlgo::HasNodeAttr("input_names", cnode);
|
||||||
|
if (!have_input_names_attr) {
|
||||||
|
MS_LOG(EXCEPTION) << "cnode:" << cnode->fullname_with_scope() << " no input_names attr";
|
||||||
|
}
|
||||||
|
auto input_names_attr = AnfAlgo::GetNodeAttr<std::vector<std::string>>(cnode.get(), "input_names");
|
||||||
|
if (IntToSize(index) >= input_names_attr.size()) {
|
||||||
|
MS_LOG(EXCEPTION) << "input index" << index << " >= input_name_attr.size:" << input_names_attr.size();
|
||||||
|
}
|
||||||
|
auto input_name = input_names_attr[index];
|
||||||
|
MS_LOG(INFO) << "input_name is " << input_name;
|
||||||
|
auto type_id = AnfAlgo::GetPrevNodeOutputDeviceDataType(cnode.get(), index);
|
||||||
|
const_inputs->try_emplace(
|
||||||
|
input_name, optiling::TeConstTensorData{static_cast<const uint8_t *>(const_tensor->data_c()),
|
||||||
|
IntToSize(const_tensor->DataSize()), MakeTempGeTensor(type_id)});
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "FeedTeOpConstTensor end";
|
||||||
|
}
|
||||||
|
|
||||||
|
void OpTilingCalculater::Init() {
|
||||||
|
MS_LOG(INFO) << "Start init OpTilingCalculater";
|
||||||
|
tiling_func_map_ = optiling::OpTilingInterf::RegisteredOpInterf();
|
||||||
|
MS_LOG(INFO) << "tiling_func_map_ size:" << tiling_func_map_.size();
|
||||||
|
for (const auto &iter : tiling_func_map_) {
|
||||||
|
MS_LOG(INFO) << "Regist tiling func:" << iter.first;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string GetRealOpType(const std::string &op_type) {
|
||||||
|
static const std::map<std::string, std::string> kOpTypeMap = {
|
||||||
|
{"SparseApplyFtrl", "SparseApplyFtrlD"},
|
||||||
|
};
|
||||||
|
auto iter = kOpTypeMap.find(op_type);
|
||||||
|
if (iter == kOpTypeMap.end()) {
|
||||||
|
return op_type;
|
||||||
|
}
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
void OpTilingCalculater::CalculateTiling(const NotNull<CNodePtr> &cnode,
|
||||||
|
const NotNull<std::shared_ptr<nlohmann::json>> &compile_info_json,
|
||||||
|
const std::map<uint32_t, tensor::TensorPtr> &depend_tensor_map,
|
||||||
|
NotNull<optiling::OpRunInfo *> op_run_info) {
|
||||||
|
optiling::TeOpParas op_param;
|
||||||
|
std::string op_type = AnfAlgo::GetCNodeName(cnode.get());
|
||||||
|
MS_LOG(INFO) << "[DynamicShape] calculate tiling, op_type:" << op_type;
|
||||||
|
|
||||||
|
FeedTeOpTensorInputArg(cnode, NOT_NULL(&op_param.inputs));
|
||||||
|
FeedTeOpTensorOutputArg(cnode, NOT_NULL(&op_param.outputs));
|
||||||
|
FeedTeOpConstTensor(cnode, depend_tensor_map, NOT_NULL(&op_param.const_inputs));
|
||||||
|
|
||||||
|
op_type = GetRealOpType(op_type);
|
||||||
|
auto iter = tiling_func_map_.find(op_type);
|
||||||
|
if (iter == tiling_func_map_.end()) {
|
||||||
|
iter = tiling_func_map_.find("AutoTiling");
|
||||||
|
if (iter == tiling_func_map_.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "AutoTiling Func Not Found";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(INFO) << "Get tiling func:" << iter->first;
|
||||||
|
|
||||||
|
if (iter != tiling_func_map_.end()) {
|
||||||
|
bool ret = (iter->second)(op_type, op_param, *compile_info_json.get(), *op_run_info);
|
||||||
|
if (!ret) {
|
||||||
|
MS_LOG(EXCEPTION) << "Calculate tiling failed";
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
MS_LOG(EXCEPTION) << "Tiling func not found";
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "CalculateTiling success";
|
||||||
|
}
|
||||||
|
} // namespace ascend
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,55 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_TILING_OP_TILING_CALCULATE_H_
|
||||||
|
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_TILING_OP_TILING_CALCULATE_H_
|
||||||
|
|
||||||
|
#include <map>
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include "utils/ms_utils.h"
|
||||||
|
#include "utils/contract.h"
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "ir/tensor.h"
|
||||||
|
#include "register/op_tiling.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
namespace ascend {
|
||||||
|
class OpTilingCalculater {
|
||||||
|
public:
|
||||||
|
static OpTilingCalculater &GetInstance() {
|
||||||
|
static OpTilingCalculater instance;
|
||||||
|
return instance;
|
||||||
|
}
|
||||||
|
|
||||||
|
void Init();
|
||||||
|
void CalculateTiling(const NotNull<CNodePtr> &cnode,
|
||||||
|
const NotNull<std::shared_ptr<nlohmann::json>> &compile_info_json,
|
||||||
|
const std::map<uint32_t, tensor::TensorPtr> &depend_tensor_map,
|
||||||
|
NotNull<optiling::OpRunInfo *> op_run_info);
|
||||||
|
|
||||||
|
private:
|
||||||
|
OpTilingCalculater() = default;
|
||||||
|
~OpTilingCalculater() = default;
|
||||||
|
DISABLE_COPY_AND_ASSIGN(OpTilingCalculater);
|
||||||
|
|
||||||
|
std::map<std::string, optiling::OpTilingFunc> tiling_func_map_;
|
||||||
|
};
|
||||||
|
} // namespace ascend
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_TILING_OP_TILING_CALCULATE_H_
|
|
@ -0,0 +1,137 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "runtime/device/ascend/ge_types_convert.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
namespace ascend {
|
||||||
|
ge::proto::DataType GeTypesConvert::GetGeDataType(TypeId type_id) {
|
||||||
|
static const std::map<TypeId, ge::proto::DataType> data_type_map = {
|
||||||
|
{TypeId::kTypeUnknown, ge::proto::DT_UNDEFINED}, {TypeId::kNumberTypeFloat32, ge::proto::DT_FLOAT},
|
||||||
|
{TypeId::kNumberTypeFloat16, ge::proto::DT_FLOAT16}, {TypeId::kNumberTypeInt8, ge::proto::DT_INT8},
|
||||||
|
{TypeId::kNumberTypeUInt8, ge::proto::DT_UINT8}, {TypeId::kNumberTypeInt16, ge::proto::DT_INT16},
|
||||||
|
{TypeId::kNumberTypeUInt16, ge::proto::DT_UINT16}, {TypeId::kNumberTypeInt32, ge::proto::DT_INT32},
|
||||||
|
{TypeId::kNumberTypeInt64, ge::proto::DT_INT64}, {TypeId::kNumberTypeUInt32, ge::proto::DT_UINT32},
|
||||||
|
{TypeId::kNumberTypeUInt64, ge::proto::DT_UINT64}, {TypeId::kNumberTypeBool, ge::proto::DT_BOOL},
|
||||||
|
{TypeId::kNumberTypeFloat64, ge::proto::DT_DOUBLE},
|
||||||
|
};
|
||||||
|
MS_LOG(INFO) << "Vm origin type_id:" << type_id;
|
||||||
|
auto iter = data_type_map.find(type_id);
|
||||||
|
if (iter == data_type_map.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid data type:" << type_id;
|
||||||
|
}
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
ge::DataType GeTypesConvert::TransTypeIdToGeDataType(TypeId type_id) {
|
||||||
|
static const std::map<TypeId, ge::DataType> data_type_map = {
|
||||||
|
{TypeId::kNumberTypeFloat, ge::DataType::DT_FLOAT}, {TypeId::kNumberTypeFloat32, ge::DataType::DT_FLOAT},
|
||||||
|
{TypeId::kNumberTypeFloat16, ge::DataType::DT_FLOAT16}, {TypeId::kNumberTypeInt8, ge::DataType::DT_INT8},
|
||||||
|
{TypeId::kNumberTypeInt16, ge::DataType::DT_INT16}, {TypeId::kNumberTypeUInt16, ge::DataType::DT_UINT16},
|
||||||
|
{TypeId::kNumberTypeUInt8, ge::DataType::DT_UINT8}, {TypeId::kNumberTypeInt32, ge::DataType::DT_INT32},
|
||||||
|
{TypeId::kNumberTypeInt, ge::DataType::DT_INT32}, {TypeId::kNumberTypeInt64, ge::DataType::DT_INT64},
|
||||||
|
{TypeId::kNumberTypeUInt32, ge::DataType::DT_UINT32}, {TypeId::kNumberTypeUInt, ge::DataType::DT_UINT32},
|
||||||
|
{TypeId::kNumberTypeUInt64, ge::DataType::DT_UINT64}, {TypeId::kNumberTypeBool, ge::DataType::DT_BOOL},
|
||||||
|
{TypeId::kNumberTypeInt64, ge::DataType::DT_DOUBLE}, {TypeId::kTypeUnknown, ge::DataType::DT_UNDEFINED}};
|
||||||
|
auto iter = data_type_map.find(type_id);
|
||||||
|
if (iter == data_type_map.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid data type:" << type_id;
|
||||||
|
}
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
GeFormat GeTypesConvert::GetGeFormat(const std::string &format, size_t shape_size) {
|
||||||
|
static const std::map<std::string, GeFormat> format_map = {
|
||||||
|
// default format: nchw, fractal_nz?
|
||||||
|
{kOpFormat_DEFAULT, kFormat_NCHW},
|
||||||
|
{kOpFormat_NC1KHKWHWC0, kFormat_NC1KHKWHWC0},
|
||||||
|
{kOpFormat_ND, kFormat_ND},
|
||||||
|
{kOpFormat_NCHW, kFormat_NCHW},
|
||||||
|
{kOpFormat_NHWC, kFormat_NHWC},
|
||||||
|
{kOpFormat_HWCN, kFormat_HWCN},
|
||||||
|
{kOpFormat_NC1HWC0, kFormat_NC1HWC0},
|
||||||
|
{kOpFormat_FRAC_Z, kFormat_FRACTAL_Z},
|
||||||
|
{kOpFormat_FRAC_NZ, kFormat_FRACTAL_NZ},
|
||||||
|
{kOpFormat_C1HWNCoC0, kFormat_C1HWNCoC0},
|
||||||
|
{kOpFormat_NC1HWC0_C04, kFormat_NC1HWC0_C04},
|
||||||
|
{kOpFormat_FRACTAL_Z_C04, kFormat_FRACTAL_Z_C04},
|
||||||
|
{kOpFormat_NDHWC, kFormat_NDHWC},
|
||||||
|
};
|
||||||
|
MS_LOG(INFO) << "GetGeFormat format:" << format << " shape_size:" << shape_size;
|
||||||
|
if (format == kOpFormat_DEFAULT) {
|
||||||
|
return shape_size == 4 ? kFormat_NCHW : kFormat_ND;
|
||||||
|
}
|
||||||
|
auto iter = format_map.find(format);
|
||||||
|
if (iter == format_map.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid format:" << format;
|
||||||
|
}
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string GeTypesConvert::GetGeTilingFormat(GeFormat ge_format) {
|
||||||
|
static const std::map<GeFormat, std::string> kFormatToStringMap = {
|
||||||
|
{kFormat_NCHW, "NCHW"},
|
||||||
|
{kFormat_NHWC, "NHWC"},
|
||||||
|
{kFormat_ND, "ND"},
|
||||||
|
{kFormat_NC1HWC0, "NC1HWC0"},
|
||||||
|
{kFormat_FRACTAL_Z, "FRACTAL_Z"},
|
||||||
|
{kFormat_NC1C0HWPAD, "NC1C0HWPAD"},
|
||||||
|
{kFormat_NHWC1C0, "NHWC1C0"},
|
||||||
|
{kFormat_FSR_NCHW, "FSR_NCHW"},
|
||||||
|
{kFormat_FRACTAL_DECONV, "FRACTAL_DECONV"},
|
||||||
|
{kFormat_C1HWNC0, "C1HWNC0"},
|
||||||
|
{kFormat_FRACTAL_DECONV_TRANSPOSE, "FRACTAL_DECONV_TRANSPOSE"},
|
||||||
|
{kFormat_FRACTAL_DECONV_SP_STRIDE_TRANS, "FRACTAL_DECONV_SP_STRIDE_TRANS"},
|
||||||
|
{kFormat_NC1HWC0_C04, "NC1HWC0_C04"},
|
||||||
|
{kFormat_FRACTAL_Z_C04, "FRACTAL_Z_C04"},
|
||||||
|
{kFormat_CHWN, "CHWN"},
|
||||||
|
{kFormat_FRACTAL_DECONV_SP_STRIDE8_TRANS, "DECONV_SP_STRIDE8_TRANS"},
|
||||||
|
{kFormat_NC1KHKWHWC0, "NC1KHKWHWC0"},
|
||||||
|
{kFormat_BN_WEIGHT, "BN_WEIGHT"},
|
||||||
|
{kFormat_FILTER_HWCK, "FILTER_HWCK"},
|
||||||
|
{kFormat_HWCN, "HWCN"},
|
||||||
|
{kFormat_HASHTABLE_LOOKUP_LOOKUPS, "LOOKUP_LOOKUPS"},
|
||||||
|
{kFormat_HASHTABLE_LOOKUP_KEYS, "LOOKUP_KEYS"},
|
||||||
|
{kFormat_HASHTABLE_LOOKUP_VALUE, "LOOKUP_VALUE"},
|
||||||
|
{kFormat_HASHTABLE_LOOKUP_OUTPUT, "LOOKUP_OUTPUT"},
|
||||||
|
{kFormat_HASHTABLE_LOOKUP_HITS, "LOOKUP_HITS"},
|
||||||
|
{kFormat_MD, "MD"},
|
||||||
|
{kFormat_NDHWC, "NDHWC"},
|
||||||
|
{kFormat_NCDHW, "NCDHW"},
|
||||||
|
{kFormat_DHWCN, "DHWCN"},
|
||||||
|
{kFormat_DHWNC, "DHWNC"},
|
||||||
|
{kFormat_NDC1HWC0, "NDC1HWC0"},
|
||||||
|
{kFormat_FRACTAL_Z_3D, "FRACTAL_Z_3D"},
|
||||||
|
{kFormat_FRACTAL_Z_3D_TRANSPOSE, "FRACTAL_Z_3D_TRANSPOSE"},
|
||||||
|
{kFormat_C1HWNCoC0, "C1HWNCoC0"},
|
||||||
|
{kFormat_FRACTAL_NZ, "FRACTAL_NZ"},
|
||||||
|
{kFormat_CN, "CN"},
|
||||||
|
{kFormat_NC, "NC"},
|
||||||
|
{kFormat_FRACTAL_ZN_LSTM, "FRACTAL_ZN_LSTM"},
|
||||||
|
{kFormat_FRACTAL_Z_G, "FRACTAL_Z_G"},
|
||||||
|
{kFormat_RESERVED, "FORMAT_RESERVED"},
|
||||||
|
{kFormat_ALL, "ALL"}};
|
||||||
|
|
||||||
|
auto iter = kFormatToStringMap.find(ge_format);
|
||||||
|
if (iter == kFormatToStringMap.end()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid ge_format:" << ge_format;
|
||||||
|
}
|
||||||
|
return iter->second;
|
||||||
|
}
|
||||||
|
} // namespace ascend
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
|
@ -22,28 +22,11 @@
|
||||||
#include "proto/ge_dtype.pb.h"
|
#include "proto/ge_dtype.pb.h"
|
||||||
#include "ir/dtype/type_id.h"
|
#include "ir/dtype/type_id.h"
|
||||||
#include "utils/utils.h"
|
#include "utils/utils.h"
|
||||||
|
#include "external/graph/types.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace device {
|
namespace device {
|
||||||
namespace ascend {
|
namespace ascend {
|
||||||
static ge::proto::DataType GetGeDataType(TypeId type_id) {
|
|
||||||
static const std::map<TypeId, ge::proto::DataType> data_type_map = {
|
|
||||||
{TypeId::kTypeUnknown, ge::proto::DT_UNDEFINED}, {TypeId::kNumberTypeFloat32, ge::proto::DT_FLOAT},
|
|
||||||
{TypeId::kNumberTypeFloat16, ge::proto::DT_FLOAT16}, {TypeId::kNumberTypeInt8, ge::proto::DT_INT8},
|
|
||||||
{TypeId::kNumberTypeUInt8, ge::proto::DT_UINT8}, {TypeId::kNumberTypeInt16, ge::proto::DT_INT16},
|
|
||||||
{TypeId::kNumberTypeUInt16, ge::proto::DT_UINT16}, {TypeId::kNumberTypeInt32, ge::proto::DT_INT32},
|
|
||||||
{TypeId::kNumberTypeInt64, ge::proto::DT_INT64}, {TypeId::kNumberTypeUInt32, ge::proto::DT_UINT32},
|
|
||||||
{TypeId::kNumberTypeUInt64, ge::proto::DT_UINT64}, {TypeId::kNumberTypeBool, ge::proto::DT_BOOL},
|
|
||||||
{TypeId::kNumberTypeFloat64, ge::proto::DT_DOUBLE},
|
|
||||||
};
|
|
||||||
MS_LOG(INFO) << "Vm origin type_id:" << type_id;
|
|
||||||
auto iter = data_type_map.find(type_id);
|
|
||||||
if (iter == data_type_map.end()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Invalid data type:" << type_id;
|
|
||||||
}
|
|
||||||
return iter->second;
|
|
||||||
}
|
|
||||||
|
|
||||||
enum GeFormat {
|
enum GeFormat {
|
||||||
kFormat_NCHW = 0, // NCHW
|
kFormat_NCHW = 0, // NCHW
|
||||||
kFormat_NHWC, // NHWC
|
kFormat_NHWC, // NHWC
|
||||||
|
@ -83,37 +66,21 @@ enum GeFormat {
|
||||||
kFormat_NC,
|
kFormat_NC,
|
||||||
kFormat_DHWNC,
|
kFormat_DHWNC,
|
||||||
kFormat_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format
|
kFormat_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format
|
||||||
|
kFormat_FRACTAL_ZN_LSTM,
|
||||||
|
kFormat_FRACTAL_Z_G,
|
||||||
kFormat_RESERVED,
|
kFormat_RESERVED,
|
||||||
kFormat_ALL
|
kFormat_ALL
|
||||||
};
|
};
|
||||||
|
|
||||||
static GeFormat GetGeFormat(const std::string &format, size_t shape_size) {
|
class GeTypesConvert {
|
||||||
static const std::map<std::string, GeFormat> format_map = {
|
public:
|
||||||
// default format: nchw, fractal_nz?
|
GeTypesConvert() = default;
|
||||||
{kOpFormat_DEFAULT, kFormat_NCHW},
|
~GeTypesConvert() = default;
|
||||||
{kOpFormat_NC1KHKWHWC0, kFormat_NC1KHKWHWC0},
|
static ge::proto::DataType GetGeDataType(TypeId type_id);
|
||||||
{kOpFormat_ND, kFormat_ND},
|
static GeFormat GetGeFormat(const std::string &format, size_t shape_size);
|
||||||
{kOpFormat_NCHW, kFormat_NCHW},
|
static std::string GetGeTilingFormat(GeFormat ge_format);
|
||||||
{kOpFormat_NHWC, kFormat_NHWC},
|
static ge::DataType TransTypeIdToGeDataType(TypeId type_id);
|
||||||
{kOpFormat_HWCN, kFormat_HWCN},
|
};
|
||||||
{kOpFormat_NC1HWC0, kFormat_NC1HWC0},
|
|
||||||
{kOpFormat_FRAC_Z, kFormat_FRACTAL_Z},
|
|
||||||
{kOpFormat_FRAC_NZ, kFormat_FRACTAL_NZ},
|
|
||||||
{kOpFormat_C1HWNCoC0, kFormat_C1HWNCoC0},
|
|
||||||
{kOpFormat_NC1HWC0_C04, kFormat_NC1HWC0_C04},
|
|
||||||
{kOpFormat_FRACTAL_Z_C04, kFormat_FRACTAL_Z_C04},
|
|
||||||
{kOpFormat_NDHWC, kFormat_NDHWC},
|
|
||||||
};
|
|
||||||
MS_LOG(INFO) << "GetGeFormat format:" << format << " shape_size:" << shape_size;
|
|
||||||
if (format == kOpFormat_DEFAULT) {
|
|
||||||
return shape_size == 4 ? kFormat_NCHW : kFormat_ND;
|
|
||||||
}
|
|
||||||
auto iter = format_map.find(format);
|
|
||||||
if (iter == format_map.end()) {
|
|
||||||
MS_LOG(EXCEPTION) << "Invalid format:" << format;
|
|
||||||
}
|
|
||||||
return iter->second;
|
|
||||||
}
|
|
||||||
} // namespace ascend
|
} // namespace ascend
|
||||||
} // namespace device
|
} // namespace device
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
|
@ -27,6 +27,7 @@
|
||||||
#include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h"
|
#include "backend/kernel_compiler/tbe/tbe_kernel_parallel_build.h"
|
||||||
#include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h"
|
#include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.h"
|
||||||
#include "backend/kernel_compiler/aicpu/aicpu_kernel_build.h"
|
#include "backend/kernel_compiler/aicpu/aicpu_kernel_build.h"
|
||||||
|
#include "backend/kernel_compiler/host/host_kernel_build.h"
|
||||||
#include "backend/kernel_compiler/hccl/hccl_kernel_build.h"
|
#include "backend/kernel_compiler/hccl/hccl_kernel_build.h"
|
||||||
#include "backend/kernel_compiler/rts/rt_kernel_build.h"
|
#include "backend/kernel_compiler/rts/rt_kernel_build.h"
|
||||||
#include "backend/kernel_compiler/tbe/tbe_utils.h"
|
#include "backend/kernel_compiler/tbe/tbe_utils.h"
|
||||||
|
@ -47,6 +48,10 @@ static kernel::KernelModPtr SerialCompileImpl(const AnfNodePtr &anf_node) {
|
||||||
kernel_mod_ptr = kernel::AicpuOpBuild(anf_node);
|
kernel_mod_ptr = kernel::AicpuOpBuild(anf_node);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case KernelType::HOST_KERNEL: {
|
||||||
|
kernel_mod_ptr = kernel::HostOpBuild(anf_node);
|
||||||
|
break;
|
||||||
|
}
|
||||||
case KernelType::RT_KERNEL: {
|
case KernelType::RT_KERNEL: {
|
||||||
kernel_mod_ptr = kernel::RtOpBuild(anf_node);
|
kernel_mod_ptr = kernel::RtOpBuild(anf_node);
|
||||||
break;
|
break;
|
||||||
|
|
|
@ -22,6 +22,10 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <map>
|
#include <map>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
#include "utils/ms_utils.h"
|
||||||
|
#include "backend/kernel_compiler/tbe/tbe_dynaminc_shape_util.h"
|
||||||
#include "debug/anf_ir_dump.h"
|
#include "debug/anf_ir_dump.h"
|
||||||
#include "frontend/operator/ops.h"
|
#include "frontend/operator/ops.h"
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
|
@ -493,7 +497,7 @@ void SetTensorDeviceInfo(const kernel::KernelBuildInfo &selected_kernel_info, co
|
||||||
}
|
}
|
||||||
// we set special device info of a input tensor.
|
// we set special device info of a input tensor.
|
||||||
bool is_ref = false;
|
bool is_ref = false;
|
||||||
auto op_info = kernel::OpLib::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel::kTBE);
|
auto op_info = kernel::tbe::TbeDynamicShapeUtil::FindOp(AnfAlgo::GetCNodeName(kernel_node), kernel_node);
|
||||||
if (op_info != nullptr) {
|
if (op_info != nullptr) {
|
||||||
is_ref = op_info->is_ref();
|
is_ref = op_info->is_ref();
|
||||||
}
|
}
|
||||||
|
|
|
@ -44,6 +44,8 @@ class CPUKernelRuntime : public KernelRuntime {
|
||||||
VectorRef *outputs);
|
VectorRef *outputs);
|
||||||
void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs);
|
void IncreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs);
|
||||||
void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs);
|
void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs);
|
||||||
|
bool GenDynamicKernel(const session::KernelGraph *graph) override { return true; }
|
||||||
|
bool RunDynamicKernelAsync(const session::KernelGraph *graph) override { return true; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
bool SyncStream() override { return true; };
|
bool SyncStream() override { return true; };
|
||||||
|
|
|
@ -0,0 +1,128 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "runtime/device/executor/dynamic_kernel.h"
|
||||||
|
#include <vector>
|
||||||
|
#include "backend/session/anf_runtime_algorithm.h"
|
||||||
|
#include "common/trans.h"
|
||||||
|
#include "pipeline/jit/static_analysis/static_analysis.h"
|
||||||
|
#include "abstract/dshape.h"
|
||||||
|
#include "abstract/param_validator.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
void DynamicKernel::Initialize() {
|
||||||
|
MS_LOG(INFO) << "Init Start";
|
||||||
|
is_dynamic_shape_ = AnfAlgo::IsDynamicShape(cnode_ptr_);
|
||||||
|
if (!is_dynamic_shape_) {
|
||||||
|
MS_LOG(INFO) << "cnode is not dynamic shape:" << cnode_ptr_->fullname_with_scope();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
is_input_dynamic_shape_ = AnfAlgo::GetBooleanAttr(cnode_ptr_, kAttrInputIsDynamicShape);
|
||||||
|
is_output_dynamic_shape_ = AnfAlgo::GetBooleanAttr(cnode_ptr_, kAttrOutputIsDynamicShape);
|
||||||
|
|
||||||
|
auto have_depends = AnfAlgo::HasNodeAttr(kDynamicShapeDepends, cnode_ptr_);
|
||||||
|
if (!have_depends) {
|
||||||
|
MS_LOG(WARNING) << "No dynamic_shape_depends found";
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Have depends";
|
||||||
|
auto depends_list = AnfAlgo::GetNodeAttr<std::vector<int>>(cnode_ptr_, kDynamicShapeDepends);
|
||||||
|
// Save depend input tensor. Sync data in InferShape.
|
||||||
|
for (auto depend : depends_list) {
|
||||||
|
auto pre_node_with_index = AnfAlgo::GetPrevNodeOutput(cnode_ptr_, depend);
|
||||||
|
auto output_addr = AnfAlgo::GetPrevNodeMutableOutputAddr(cnode_ptr_, depend);
|
||||||
|
std::vector<int> shapes = trans::GetRuntimePaddingShape(pre_node_with_index.first, pre_node_with_index.second);
|
||||||
|
auto host_type = AnfAlgo::GetOutputInferDataType(pre_node_with_index.first, pre_node_with_index.second);
|
||||||
|
auto out_tensor = std::make_shared<tensor::Tensor>(host_type, shapes);
|
||||||
|
out_tensor->set_device_address(output_addr);
|
||||||
|
|
||||||
|
auto ret = depend_tensor_map_.try_emplace(depend, out_tensor);
|
||||||
|
if (!ret.second) {
|
||||||
|
MS_LOG(EXCEPTION) << "Insert map failed";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(INFO) << "Init End";
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsTupleGetItem(const AnfNodePtr &anf_node) {
|
||||||
|
MS_EXCEPTION_IF_NULL(anf_node);
|
||||||
|
if (!anf_node->isa<CNode>()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
auto cnode = anf_node->cast<CNodePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode);
|
||||||
|
auto input0 = cnode->input(0);
|
||||||
|
return IsPrimitive(input0, prim::kPrimTupleGetItem);
|
||||||
|
}
|
||||||
|
|
||||||
|
void DynamicKernel::InferShape() {
|
||||||
|
if (!is_input_dynamic_shape_ && is_output_dynamic_shape_ && !have_depends()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode_ptr_);
|
||||||
|
MS_LOG(INFO) << "InferShape start, node:" << cnode_ptr_->fullname_with_scope();
|
||||||
|
|
||||||
|
auto inputs = cnode_ptr_->inputs();
|
||||||
|
if (inputs.empty()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid inputs";
|
||||||
|
}
|
||||||
|
AbstractBasePtrList args_spec_list;
|
||||||
|
auto primitive = GetValueNode<PrimitivePtr>(inputs[0]);
|
||||||
|
|
||||||
|
auto input_size = AnfAlgo::GetInputTensorNum(cnode_ptr_);
|
||||||
|
for (size_t i = 0; i < input_size; ++i) {
|
||||||
|
auto input_with_index = AnfAlgo::GetPrevNodeOutput(cnode_ptr_, i);
|
||||||
|
auto real_input = input_with_index.first;
|
||||||
|
|
||||||
|
MS_EXCEPTION_IF_NULL(real_input);
|
||||||
|
auto ret = depend_tensor_map_.find(i);
|
||||||
|
if (ret != depend_tensor_map_.end()) {
|
||||||
|
auto tensor_ptr = ret->second;
|
||||||
|
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
||||||
|
// sync data from device to host
|
||||||
|
tensor_ptr->data_sync();
|
||||||
|
real_input->abstract()->set_value(tensor_ptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto cnode_input = cnode_ptr_->input(i + 1);
|
||||||
|
MS_EXCEPTION_IF_NULL(cnode_input);
|
||||||
|
if (IsTupleGetItem(cnode_input)) {
|
||||||
|
auto base_shape = real_input->Shape();
|
||||||
|
if (!base_shape->isa<abstract::TupleShape>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Node:" << cnode_ptr_->fullname_with_scope()
|
||||||
|
<< " input is a tuple_get_item but real input node shape is not a TupleShape";
|
||||||
|
}
|
||||||
|
auto tuple_ptr = base_shape->cast<abstract::TupleShapePtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(tuple_ptr);
|
||||||
|
auto tuple_get_item_index = AnfAlgo::GetTupleGetItemOutIndex(cnode_input->cast<CNodePtr>());
|
||||||
|
auto real_shape = tuple_ptr->shape().at(tuple_get_item_index);
|
||||||
|
auto abstract_tensor = cnode_input->abstract()->cast<abstract::AbstractTensorPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(abstract_tensor);
|
||||||
|
args_spec_list.emplace_back(std::make_shared<abstract::AbstractTensor>(abstract_tensor->element(), real_shape));
|
||||||
|
} else if (cnode_input->isa<CNode>() && AnfAlgo::GetCNodeName(cnode_input) == prim::kPrimReshape->name()) {
|
||||||
|
args_spec_list.emplace_back(cnode_input->abstract());
|
||||||
|
} else {
|
||||||
|
args_spec_list.emplace_back(real_input->abstract());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
auto eval_result = abstract::CppInferShape(primitive, args_spec_list);
|
||||||
|
cnode_ptr_->set_abstract(eval_result);
|
||||||
|
}
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,62 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_EXECUTOR_H_
|
||||||
|
#define MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_EXECUTOR_H_
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <map>
|
||||||
|
#include "ir/anf.h"
|
||||||
|
#include "ir/tensor.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace device {
|
||||||
|
|
||||||
|
constexpr auto kDynamicShapeDepends = "dynamic_shape_depends";
|
||||||
|
|
||||||
|
class DynamicKernel {
|
||||||
|
public:
|
||||||
|
DynamicKernel(void *stream, const CNodePtr &cnode_ptr)
|
||||||
|
: stream_(stream),
|
||||||
|
cnode_ptr_(cnode_ptr),
|
||||||
|
is_dynamic_shape_(false),
|
||||||
|
is_input_dynamic_shape_(false),
|
||||||
|
is_output_dynamic_shape_(false) {}
|
||||||
|
virtual ~DynamicKernel() = default;
|
||||||
|
virtual void InferShape();
|
||||||
|
virtual void UpdateArgs() = 0;
|
||||||
|
virtual void Execute() = 0;
|
||||||
|
virtual void PostExecute() = 0;
|
||||||
|
bool is_dynamic_shape() const { return is_dynamic_shape_; }
|
||||||
|
bool is_input_dynamic_shape() const { return is_input_dynamic_shape_; }
|
||||||
|
bool is_output_dynamic_shape() const { return is_output_dynamic_shape_; }
|
||||||
|
bool have_depends() const { return !depend_tensor_map_.empty(); }
|
||||||
|
virtual void Initialize();
|
||||||
|
std::string GetKernelName() { return cnode_ptr_->fullname_with_scope(); }
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void *stream_;
|
||||||
|
const CNodePtr cnode_ptr_;
|
||||||
|
bool is_dynamic_shape_;
|
||||||
|
bool is_input_dynamic_shape_;
|
||||||
|
bool is_output_dynamic_shape_;
|
||||||
|
std::map<uint32_t, tensor::TensorPtr> depend_tensor_map_;
|
||||||
|
};
|
||||||
|
using DynamicKernelPtr = std::shared_ptr<DynamicKernel>;
|
||||||
|
} // namespace device
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_EXECUTOR_EXECUTOR_H_
|
|
@ -43,6 +43,8 @@ class GPUKernelRuntime : public KernelRuntime {
|
||||||
const std::vector<CNodePtr> &execution_order) override;
|
const std::vector<CNodePtr> &execution_order) override;
|
||||||
void AssignMemory(session::KernelGraph *graph) override;
|
void AssignMemory(session::KernelGraph *graph) override;
|
||||||
bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) override;
|
bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) override;
|
||||||
|
bool GenDynamicKernel(const session::KernelGraph *graph) override { return true; }
|
||||||
|
bool RunDynamicKernelAsync(const session::KernelGraph *graph) override { return true; }
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue