forked from mindspore-Ecosystem/mindspore
dynamic shape bug fix
This commit is contained in:
parent
83b56cac85
commit
a03452ff2c
|
@ -18,7 +18,7 @@ import os
|
|||
import sys
|
||||
from te.platform.cce_conf import te_set_version
|
||||
from te.platform.fusion_util import fusion_op
|
||||
import te
|
||||
import tbe.common.context.op_info as operator_info
|
||||
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
|
||||
# pylint: disable=wrong-import-position
|
||||
from tbe_common import check_kernel_info, get_args, get_built_in_impl_path
|
||||
|
@ -68,6 +68,7 @@ def build_op(build_type, json_str, tune_mode=None):
|
|||
check_kernel_info(kernel_info)
|
||||
te_set_version(kernel_info["op_info"]["socVersion"])
|
||||
op_name = kernel_info['op_info']['name']
|
||||
op_type = kernel_info['op_info']['Type']
|
||||
|
||||
try:
|
||||
custom_flag = False
|
||||
|
@ -117,10 +118,13 @@ def build_op(build_type, json_str, tune_mode=None):
|
|||
# with te.op.dynamic():
|
||||
import tbe.common.context.op_context as op_context
|
||||
with op_context.OpContext("dynamic"):
|
||||
op_info = operator_info.OpInfo(op_type, op_type)
|
||||
op_context.get_context().add_op_info(op_info)
|
||||
op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
|
||||
compile_info = op_context.get_context().get_compile_info()
|
||||
if tune_mode is not None:
|
||||
return (te.op.get_compile_info()), (inputs_args, outputs_args, attrs_args), op_module_name
|
||||
return te.op.get_compile_info()
|
||||
return compile_info, (inputs_args, outputs_args, attrs_args), op_module_name
|
||||
return compile_info
|
||||
else:
|
||||
res = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
|
||||
if tune_mode is not None:
|
||||
|
|
|
@ -113,19 +113,12 @@ device::DynamicKernelPtr TbeKernelMod::GenDynamicKernel(const CNodePtr &cnode_pt
|
|||
AddressPtrList kernel_workspaces;
|
||||
AddressPtrList kernel_outputs;
|
||||
device::KernelRuntime::GenLaunchArgs(*this, cnode_ptr, &kernel_inputs, &kernel_workspaces, &kernel_outputs);
|
||||
auto dynamic_flag = AnfAlgo::IsDynamicShape(cnode_ptr);
|
||||
|
||||
// Get para_size from json
|
||||
auto kernel_json_info = kernel_pack_->kernel_json_info();
|
||||
auto op_para_size = kernel_json_info.op_para_size;
|
||||
|
||||
// Get stub_function
|
||||
uint32_t block_dim = 1; // default blockdim equal to 1.
|
||||
auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim);
|
||||
if (func_stub == 0) {
|
||||
MS_LOG(EXCEPTION) << "GenFuncStub failed.";
|
||||
}
|
||||
const void *stub_func_ptr = reinterpret_cast<void *>(func_stub);
|
||||
|
||||
// Generate args
|
||||
std::vector<void *> runtime_args;
|
||||
(void)std::transform(std::begin(kernel_inputs), std::end(kernel_inputs), std::back_inserter(runtime_args),
|
||||
|
@ -146,8 +139,26 @@ device::DynamicKernelPtr TbeKernelMod::GenDynamicKernel(const CNodePtr &cnode_pt
|
|||
runtime_args.push_back(tiling_data_ptr);
|
||||
}
|
||||
|
||||
auto executor = std::make_shared<device::ascend::AiCoreDynamicKernel>(
|
||||
stub_func_ptr, block_dim, tiling_data_ptr, op_para_size, stream_ptr, cnode_ptr, runtime_args);
|
||||
// Get stub_function
|
||||
uint32_t block_dim = 1; // default blockdim equal to 1.
|
||||
device::DynamicKernelPtr executor = nullptr;
|
||||
std::string origin_key;
|
||||
void *handle = nullptr;
|
||||
auto func_stub = KernelManager::GenFuncStub(*kernel_pack_, false, &block_dim, dynamic_flag, &handle, &origin_key);
|
||||
if (dynamic_flag) {
|
||||
if (func_stub != 1) {
|
||||
MS_LOG(EXCEPTION) << "GenFuncStub failed.";
|
||||
}
|
||||
executor = std::make_shared<device::ascend::AiCoreDynamicKernel>(handle, block_dim, tiling_data_ptr, op_para_size,
|
||||
stream_ptr, cnode_ptr, runtime_args, origin_key);
|
||||
} else {
|
||||
if (func_stub == 0) {
|
||||
MS_LOG(EXCEPTION) << "GenFuncStub failed.";
|
||||
}
|
||||
const void *stub_func_ptr = reinterpret_cast<void *>(func_stub);
|
||||
executor = std::make_shared<device::ascend::AiCoreDynamicKernel>(stub_func_ptr, block_dim, tiling_data_ptr,
|
||||
op_para_size, stream_ptr, cnode_ptr, runtime_args);
|
||||
}
|
||||
return executor;
|
||||
}
|
||||
|
||||
|
|
|
@ -116,8 +116,8 @@ KernelPackPtr TbeUtils::InsertCache(const std::string &kernel_name, const std::s
|
|||
return SearchCache(kernel_name, processor);
|
||||
}
|
||||
|
||||
int KernelManager::BinaryRegister(const mindspore::kernel::FlexArray &kernel_buffer, void **module,
|
||||
const string &magic) {
|
||||
int KernelManager::BinaryRegister(const mindspore::kernel::FlexArray &kernel_buffer, void **module, const string &magic,
|
||||
const bool dynamic_flag) {
|
||||
static std::map<string, uint32_t> magic_maps = {{"RT_DEV_BINARY_MAGIC_ELF", RT_DEV_BINARY_MAGIC_ELF},
|
||||
{"RT_DEV_BINARY_MAGIC_PLAIN", RT_DEV_BINARY_MAGIC_PLAIN},
|
||||
{"RT_DEV_BINARY_MAGIC_PLAIN_AICPU", RT_DEV_BINARY_MAGIC_PLAIN_AICPU},
|
||||
|
@ -132,8 +132,9 @@ int KernelManager::BinaryRegister(const mindspore::kernel::FlexArray &kernel_buf
|
|||
}
|
||||
dev_bin.magic = iter->second;
|
||||
dev_bin.length = kernel_buffer.len;
|
||||
dev_bin.version = 2;
|
||||
if (RT_ERROR_NONE != rtDevBinaryRegister(&dev_bin, module)) {
|
||||
dev_bin.version = 0;
|
||||
auto ret = dynamic_flag ? rtRegisterAllKernel(&dev_bin, module) : rtDevBinaryRegister(&dev_bin, module);
|
||||
if (RT_ERROR_NONE != ret) {
|
||||
MS_LOG(INFO) << "Call runtime rtDevBinaryRegister error.";
|
||||
return -1;
|
||||
}
|
||||
|
@ -141,7 +142,8 @@ int KernelManager::BinaryRegister(const mindspore::kernel::FlexArray &kernel_buf
|
|||
}
|
||||
|
||||
uintptr_t KernelManager::GenFuncStub(const mindspore::kernel::KernelPack &kernel_pack, bool force_reload,
|
||||
uint32_t *block_dim) {
|
||||
uint32_t *block_dim, const bool dynamic_flag, void **handle,
|
||||
std::string *origin_key) {
|
||||
auto kernel = kernel_pack.GetKernel();
|
||||
if (kernel == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Invalid kernel pack, json or kernel is nullptr.";
|
||||
|
@ -162,14 +164,24 @@ uintptr_t KernelManager::GenFuncStub(const mindspore::kernel::KernelPack &kernel
|
|||
if (iter != info_table_.end()) {
|
||||
auto kernelmeta = iter->second;
|
||||
*block_dim = kernelmeta->block_dim_;
|
||||
return kernelmeta->func_stub_;
|
||||
if (!dynamic_flag) {
|
||||
return kernelmeta->func_stub_;
|
||||
}
|
||||
}
|
||||
}
|
||||
void *module = nullptr;
|
||||
if (BinaryRegister((*kernel_pack.GetKernel()), &module, magic) != 0) {
|
||||
if (BinaryRegister((*kernel_pack.GetKernel()), &module, magic, dynamic_flag) != 0) {
|
||||
MS_LOG(INFO) << "Call runtime BinaryRegister error.";
|
||||
if (module != nullptr) {
|
||||
(void)rtDevBinaryUnRegister(module);
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
if (dynamic_flag) {
|
||||
*handle = module;
|
||||
*origin_key = func_name;
|
||||
return 1;
|
||||
}
|
||||
// to diff different funcs.
|
||||
uintptr_t func_stub = ++kernel_stub_gen_;
|
||||
if (RT_ERROR_NONE !=
|
||||
|
|
|
@ -59,13 +59,16 @@ using KernelMetaPtr = std::shared_ptr<KernelMetaInfo>;
|
|||
|
||||
class KernelManager {
|
||||
public:
|
||||
static uintptr_t GenFuncStub(const KernelPack &kernel_pack, bool force_reload, uint32_t *block_dim);
|
||||
static uintptr_t GenFuncStub(const KernelPack &kernel_pack, bool force_reload, uint32_t *block_dim,
|
||||
const bool dynamic_flag = false, void **handle = nullptr,
|
||||
std::string *origin_key = nullptr);
|
||||
static std::string GetStubFuncName(const KernelPackPtr &kernel_pack);
|
||||
|
||||
private:
|
||||
KernelManager() = default;
|
||||
~KernelManager() = default;
|
||||
static int BinaryRegister(const FlexArray &kernel_buffer, void **module, const string &magic);
|
||||
static int BinaryRegister(const FlexArray &kernel_buffer, void **module, const string &magic,
|
||||
const bool dynamic_flag);
|
||||
static std::unordered_map<string, KernelMetaPtr> info_table_;
|
||||
static uintptr_t kernel_stub_gen_;
|
||||
};
|
||||
|
|
|
@ -45,11 +45,23 @@ void AiCoreDynamicKernel::Execute() {
|
|||
}
|
||||
auto cnode = cnode_ptr_.lock();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_LOG(INFO) << "Start Execute node:" << cnode->fullname_with_scope();
|
||||
auto node_info = cnode->fullname_with_scope();
|
||||
MS_LOG(INFO) << "Start Execute node:" << node_info;
|
||||
rtL2Ctrl_t *l2ctrl = nullptr;
|
||||
auto args_size = static_cast<uint32_t>(UlongToUint(sizeof(void *)) * runtime_args_.size());
|
||||
if (RT_ERROR_NONE != rtKernelLaunch(stub_func_, block_dim_, runtime_args_.data(), args_size, l2ctrl, stream_)) {
|
||||
MS_LOG(EXCEPTION) << "Call runtime rtKernelLaunch error.";
|
||||
if (handle_ != nullptr) {
|
||||
const auto dev_func =
|
||||
origin_key_.find("kernel0") != origin_key_.npos ? origin_key_ : origin_key_ + "_" + std::to_string(tiling_key_);
|
||||
const auto kernel_info = node_info + "/" + std::to_string(tiling_key_);
|
||||
if (RT_ERROR_NONE != rtKernelLaunchWithHandle(handle_, dev_func.c_str(), block_dim_, runtime_args_.data(),
|
||||
args_size, l2ctrl, stream_, kernel_info.c_str())) {
|
||||
MS_LOG(EXCEPTION) << "Call runtime rtKernelLaunchWithHandle error.";
|
||||
}
|
||||
|
||||
} else {
|
||||
if (RT_ERROR_NONE != rtKernelLaunch(stub_func_, block_dim_, runtime_args_.data(), args_size, l2ctrl, stream_)) {
|
||||
MS_LOG(EXCEPTION) << "Call runtime rtKernelLaunch error.";
|
||||
}
|
||||
}
|
||||
MS_LOG(INFO) << "End Execute node:" << cnode->fullname_with_scope();
|
||||
}
|
||||
|
@ -127,6 +139,7 @@ void AiCoreDynamicKernel::ComputeTiling() {
|
|||
block_dim_ = op_run_info.block_dim;
|
||||
workspaces_size_ = op_run_info.workspaces;
|
||||
tiling_data_ = op_run_info.tiling_data.str();
|
||||
tiling_key_ = op_run_info.tiling_key;
|
||||
}
|
||||
|
||||
void AiCoreDynamicKernel::AllocateWorkspace() {
|
||||
|
|
|
@ -40,6 +40,15 @@ class AiCoreDynamicKernel : public DynamicKernel {
|
|||
tiling_data_ptr_(tiling_data_ptr),
|
||||
op_para_size_(op_para_size),
|
||||
runtime_args_(runtime_args) {}
|
||||
AiCoreDynamicKernel(void *handle, uint32_t block_dim, void *tiling_data_ptr, uint32_t op_para_size, void *stream,
|
||||
const CNodePtr &cnode_ptr, const std::vector<void *> &runtime_args, const std::string &ori_key)
|
||||
: DynamicKernel(stream, cnode_ptr),
|
||||
handle_(handle),
|
||||
block_dim_(block_dim),
|
||||
tiling_data_ptr_(tiling_data_ptr),
|
||||
op_para_size_(op_para_size),
|
||||
runtime_args_(runtime_args),
|
||||
origin_key_(ori_key) {}
|
||||
~AiCoreDynamicKernel() override;
|
||||
|
||||
void Execute() override;
|
||||
|
@ -53,6 +62,7 @@ class AiCoreDynamicKernel : public DynamicKernel {
|
|||
|
||||
private:
|
||||
const void *stub_func_;
|
||||
void *handle_{nullptr};
|
||||
uint32_t block_dim_;
|
||||
void *tiling_data_ptr_; // device ptr
|
||||
uint32_t op_para_size_; // size of tiling_data_ptr_
|
||||
|
@ -62,6 +72,8 @@ class AiCoreDynamicKernel : public DynamicKernel {
|
|||
std::vector<DeviceAddressPtr> workspace_addr_;
|
||||
std::shared_ptr<nlohmann::json> compile_info_json_;
|
||||
optiling::OpCompileInfo op_compile_info_{};
|
||||
uint32_t tiling_key_{0};
|
||||
const std::string origin_key_{""};
|
||||
|
||||
void ComputeTiling();
|
||||
bool CopyTilingToDevice();
|
||||
|
|
|
@ -58,7 +58,7 @@ def test_unique_ascend():
|
|||
assert (output[1].asnumpy() == expect2).all()
|
||||
|
||||
|
||||
@pytest.mark.level2
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
|
|
|
@ -36,7 +36,7 @@ class NetWithEmbeddingLookUp(nn.Cell):
|
|||
return out
|
||||
|
||||
|
||||
@pytest.mark.level2
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
|
|
|
@ -56,7 +56,7 @@ def test_ftrl_net():
|
|||
[[0.6821311, 0.6821311]],
|
||||
[[0.6821311, 0.6821311]]]))
|
||||
|
||||
@pytest.mark.level2
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
|
|
|
@ -161,6 +161,10 @@ RTS_API rtError_t rtRegTaskFailCallbackByModule(const char *moduleName, rtTaskFa
|
|||
return RT_ERROR_NONE;
|
||||
}
|
||||
|
||||
RTS_API rtError_t rtRegisterAllKernel(const rtDevBinary_t *bin, void **module) { return RT_ERROR_NONE; }
|
||||
|
||||
RTS_API rtError_t rtDevBinaryUnRegister(void *handle) { return RT_ERROR_NONE; }
|
||||
|
||||
RTS_API rtError_t rtMemsetAsync(void *ptr, uint64_t destMax, uint32_t value, uint64_t count, rtStream_t stream) {
|
||||
return RT_ERROR_NONE;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue