dynamic shape bug fix

This commit is contained in:
liubuyu 2021-03-19 17:57:31 +08:00
parent 83b56cac85
commit a03452ff2c
10 changed files with 87 additions and 28 deletions

View File

@ -18,7 +18,7 @@ import os
import sys
from te.platform.cce_conf import te_set_version
from te.platform.fusion_util import fusion_op
import te
import tbe.common.context.op_info as operator_info
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
# pylint: disable=wrong-import-position
from tbe_common import check_kernel_info, get_args, get_built_in_impl_path
@ -68,6 +68,7 @@ def build_op(build_type, json_str, tune_mode=None):
check_kernel_info(kernel_info)
te_set_version(kernel_info["op_info"]["socVersion"])
op_name = kernel_info['op_info']['name']
op_type = kernel_info['op_info']['Type']
try:
custom_flag = False
@ -117,10 +118,13 @@ def build_op(build_type, json_str, tune_mode=None):
# with te.op.dynamic():
import tbe.common.context.op_context as op_context
with op_context.OpContext("dynamic"):
op_info = operator_info.OpInfo(op_type, op_type)
op_context.get_context().add_op_info(op_info)
op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
compile_info = op_context.get_context().get_compile_info()
if tune_mode is not None:
return (te.op.get_compile_info()), (inputs_args, outputs_args, attrs_args), op_module_name
return te.op.get_compile_info()
return compile_info, (inputs_args, outputs_args, attrs_args), op_module_name
return compile_info
else:
res = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
if tune_mode is not None:

View File

@ -113,19 +113,12 @@ device::DynamicKernelPtr TbeKernelMod::GenDynamicKernel(const CNodePtr &cnode_pt
AddressPtrList kernel_workspaces;
AddressPtrList kernel_outputs;
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
auto dynamic_flag = AnfAlgo::IsDynamicShape(cnode_ptr);
// Get para_size from json
auto kernel_json_info = kernel_pack_->kernel_json_info();
auto op_para_size = kernel_json_info.op_para_size;
// Get stub_function
uint32_t block_dim = 1; // default blockdim equal to 1.
auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim);
if (func_stub == 0) {
MS_LOG(EXCEPTION) << "GenFuncStub failed.";
}
const void *stub_func_ptr = reinterpret_cast<void *>(func_stub);
// Generate args
std::vector<void *> runtime_args;
(void)std::transform(std::begin(kernel_inputs), std::end(kernel_inputs), std::back_inserter(runtime_args),
@ -146,8 +139,26 @@ device::DynamicKernelPtr TbeKernelMod::GenDynamicKernel(const CNodePtr &cnode_pt
runtime_args.push_back(tiling_data_ptr);
}
auto executor = std::make_shared<device::ascend::AiCoreDynamicKernel>(
stub_func_ptr, block_dim, tiling_data_ptr, op_para_size, stream_ptr, cnode_ptr, runtime_args);
// Get stub_function
uint32_t block_dim = 1; // default blockdim equal to 1.
device::DynamicKernelPtr executor = nullptr;
std::string origin_key;
void *handle = nullptr;
auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim, dynamic_flag, &handle, &origin_key);
if (dynamic_flag) {
if (func_stub != 1) {
MS_LOG(EXCEPTION) << "GenFuncStub failed.";
}
executor = std::make_shared<device::ascend::AiCoreDynamicKernel>(handle, block_dim, tiling_data_ptr, op_para_size,
stream_ptr, cnode_ptr, runtime_args, origin_key);
} else {
if (func_stub == 0) {
MS_LOG(EXCEPTION) << "GenFuncStub failed.";
}
const void *stub_func_ptr = reinterpret_cast<void *>(func_stub);
executor = std::make_shared<device::ascend::AiCoreDynamicKernel>(stub_func_ptr, block_dim, tiling_data_ptr,
op_para_size, stream_ptr, cnode_ptr, runtime_args);
}
return executor;
}

View File

@ -116,8 +116,8 @@ KernelPackPtr TbeUtils::InsertCache(const std::string &kernel_name, const std::s
return SearchCache(kernel_name, processor);
}
int KernelManager::BinaryRegister(const mindspore::kernel::FlexArray &kernel_buffer, void **module,
const string &magic) {
int KernelManager::BinaryRegister(const mindspore::kernel::FlexArray &kernel_buffer, void **module, const string &magic,
const bool dynamic_flag) {
static std::map<string, uint32_t> magic_maps = {{"RT_DEV_BINARY_MAGIC_ELF", RT_DEV_BINARY_MAGIC_ELF},
{"RT_DEV_BINARY_MAGIC_PLAIN", RT_DEV_BINARY_MAGIC_PLAIN},
{"RT_DEV_BINARY_MAGIC_PLAIN_AICPU", RT_DEV_BINARY_MAGIC_PLAIN_AICPU},
@ -132,8 +132,9 @@ int KernelManager::BinaryRegister(const mindspore::kernel::FlexArray &kernel_buf
}
dev_bin.magic = iter->second;
dev_bin.length = kernel_buffer.len;
dev_bin.version = 2;
if (RT_ERROR_NONE != rtDevBinaryRegister(&dev_bin, module)) {
dev_bin.version = 0;
auto ret = dynamic_flag ? rtRegisterAllKernel(&dev_bin, module) : rtDevBinaryRegister(&dev_bin, module);
if (RT_ERROR_NONE != ret) {
MS_LOG(INFO) << "Call runtime rtDevBinaryRegister error.";
return -1;
}
@ -141,7 +142,8 @@ int KernelManager::BinaryRegister(const mindspore::kernel::FlexArray &kernel_buf
}
uintptr_t KernelManager::GenFuncStub(const mindspore::kernel::KernelPack &kernel_pack, bool force_reload,
uint32_t *block_dim) {
uint32_t *block_dim, const bool dynamic_flag, void **handle,
std::string *origin_key) {
auto kernel = kernel_pack.GetKernel();
if (kernel == nullptr) {
MS_LOG(EXCEPTION) << "Invalid kernel pack, json or kernel is nullptr.";
@ -162,14 +164,24 @@ uintptr_t KernelManager::GenFuncStub(const mindspore::kernel::KernelPack &kernel
if (iter != info_table_.end()) {
auto kernelmeta = iter->second;
*block_dim = kernelmeta->block_dim_;
return kernelmeta->func_stub_;
if (!dynamic_flag) {
return kernelmeta->func_stub_;
}
}
}
void *module = nullptr;
if (BinaryRegister((*kernel_pack.GetKernel()), &module, magic) != 0) {
if (BinaryRegister((*kernel_pack.GetKernel()), &module, magic, dynamic_flag) != 0) {
MS_LOG(INFO) << "Call runtime BinaryRegister error.";
if (module != nullptr) {
(void)rtDevBinaryUnRegister(module);
}
return 0;
}
if (dynamic_flag) {
*handle = module;
*origin_key = func_name;
return 1;
}
// to diff different funcs.
uintptr_t func_stub = ++kernel_stub_gen_;
if (RT_ERROR_NONE !=

View File

@ -59,13 +59,16 @@ using KernelMetaPtr = std::shared_ptr<KernelMetaInfo>;
class KernelManager {
public:
static uintptr_t GenFuncStub(const KernelPack &kernel_pack, bool force_reload, uint32_t *block_dim);
static uintptr_t GenFuncStub(const KernelPack &kernel_pack, bool force_reload, uint32_t *block_dim,
const bool dynamic_flag = false, void **handle = nullptr,
std::string *origin_key = nullptr);
static std::string GetStubFuncName(const KernelPackPtr &kernel_pack);
private:
KernelManager() = default;
~KernelManager() = default;
static int BinaryRegister(const FlexArray &kernel_buffer, void **module, const string &magic);
static int BinaryRegister(const FlexArray &kernel_buffer, void **module, const string &magic,
const bool dynamic_flag);
static std::unordered_map<string, KernelMetaPtr> info_table_;
static uintptr_t kernel_stub_gen_;
};

View File

@ -45,11 +45,23 @@ void AiCoreDynamicKernel::Execute() {
}
auto cnode = cnode_ptr_.lock();
MS_EXCEPTION_IF_NULL(cnode);
MS_LOG(INFO) << "Start Execute node:" << cnode->fullname_with_scope();
auto node_info = cnode->fullname_with_scope();
MS_LOG(INFO) << "Start Execute node:" << node_info;
rtL2Ctrl_t *l2ctrl = nullptr;
auto args_size = static_cast<uint32_t>(UlongToUint(sizeof(void *)) * runtime_args_.size());
if (RT_ERROR_NONE != rtKernelLaunch(stub_func_, block_dim_, runtime_args_.data(), args_size, l2ctrl, stream_)) {
MS_LOG(EXCEPTION) << "Call runtime rtKernelLaunch error.";
if (handle_ != nullptr) {
const auto dev_func =
origin_key_.find("kernel0") != origin_key_.npos ? origin_key_ : origin_key_ + "_" + std::to_string(tiling_key_);
const auto kernel_info = node_info + "/" + std::to_string(tiling_key_);
if (RT_ERROR_NONE != rtKernelLaunchWithHandle(handle_, dev_func.c_str(), block_dim_, runtime_args_.data(),
args_size, l2ctrl, stream_, kernel_info.c_str())) {
MS_LOG(EXCEPTION) << "Call runtime rtKernelLaunchWithHandle error.";
}
} else {
if (RT_ERROR_NONE != rtKernelLaunch(stub_func_, block_dim_, runtime_args_.data(), args_size, l2ctrl, stream_)) {
MS_LOG(EXCEPTION) << "Call runtime rtKernelLaunch error.";
}
}
MS_LOG(INFO) << "End Execute node:" << cnode->fullname_with_scope();
}
@ -127,6 +139,7 @@ void AiCoreDynamicKernel::ComputeTiling() {
block_dim_ = op_run_info.block_dim;
workspaces_size_ = op_run_info.workspaces;
tiling_data_ = op_run_info.tiling_data.str();
tiling_key_ = op_run_info.tiling_key;
}
void AiCoreDynamicKernel::AllocateWorkspace() {

View File

@ -40,6 +40,15 @@ class AiCoreDynamicKernel : public DynamicKernel {
tiling_data_ptr_(tiling_data_ptr),
op_para_size_(op_para_size),
runtime_args_(runtime_args) {}
AiCoreDynamicKernel(void *handle, uint32_t block_dim, void *tiling_data_ptr, uint32_t op_para_size, void *stream,
const CNodePtr &cnode_ptr, const std::vector<void *> &runtime_args, const std::string &ori_key)
: DynamicKernel(stream, cnode_ptr),
handle_(handle),
block_dim_(block_dim),
tiling_data_ptr_(tiling_data_ptr),
op_para_size_(op_para_size),
runtime_args_(runtime_args),
origin_key_(ori_key) {}
~AiCoreDynamicKernel() override;
void Execute() override;
@ -53,6 +62,7 @@ class AiCoreDynamicKernel : public DynamicKernel {
private:
const void *stub_func_;
void *handle_{nullptr};
uint32_t block_dim_;
void *tiling_data_ptr_; // device ptr
uint32_t op_para_size_; // size of tiling_data_ptr_
@ -62,6 +72,8 @@ class AiCoreDynamicKernel : public DynamicKernel {
std::vector<DeviceAddressPtr> workspace_addr_;
std::shared_ptr<nlohmann::json> compile_info_json_;
optiling::OpCompileInfo op_compile_info_{};
uint32_t tiling_key_{0};
const std::string origin_key_{""};
void ComputeTiling();
bool CopyTilingToDevice();

View File

@ -58,7 +58,7 @@ def test_unique_ascend():
assert (output[1].asnumpy() == expect2).all()
@pytest.mark.level2
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard

View File

@ -36,7 +36,7 @@ class NetWithEmbeddingLookUp(nn.Cell):
return out
@pytest.mark.level2
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.platform_x86_gpu_training

View File

@ -56,7 +56,7 @@ def test_ftrl_net():
[[0.6821311, 0.6821311]],
[[0.6821311, 0.6821311]]]))
@pytest.mark.level2
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard

View File

@ -161,6 +161,10 @@ RTS_API rtError_t rtRegTaskFailCallbackByModule(const char *moduleName, rtTaskFa
return RT_ERROR_NONE;
}
RTS_API rtError_t rtRegisterAllKernel(const rtDevBinary_t *bin, void **module) { return RT_ERROR_NONE; }
RTS_API rtError_t rtDevBinaryUnRegister(void *handle) { return RT_ERROR_NONE; }
RTS_API rtError_t rtMemsetAsync(void *ptr, uint64_t destMax, uint32_t value, uint64_t count, rtStream_t stream) {
return RT_ERROR_NONE;
}