move AkgCallBack to the first input of akg cpu kernel and add an extend_data value

This commit is contained in:
dayschan 2022-03-01 20:42:38 +08:00
parent 383c1406dc
commit 09e14c690e
2 changed files with 5 additions and 3 deletions

2
akg

@ -1 +1 @@
Subproject commit a9cbf642063fb1086a93e8bc6be6feb145689817
Subproject commit e3f2411858e34499fce13ec00ea35e1292d441b1

View File

@ -45,6 +45,7 @@ struct AkgCallBack {
void *parallel_launch_func;
void *(*malloc_func)(size_t);
void (*free_func)(void *);
void *extend_data = nullptr;
AkgCallBack() {
parallel_launch_func = reinterpret_cast<void *>(&AkgParallelLaunch::AkgLaunchFunc);
@ -126,13 +127,14 @@ bool AkgCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::v
MS_LOG(ERROR) << "GetFunction failed. kernel: " << kernel_name_;
return false;
}
static AkgCallBack akg_callback = AkgCallBack();
std::vector<void *> runtimeargs;
runtimeargs.reserve(inputs.size() + outputs.size() + 1);
(void)runtimeargs.emplace_back(reinterpret_cast<void *>(&akg_callback));
(void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(runtimeargs),
[](const AddressPtr &input) { return input->addr; });
(void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtimeargs),
[](const AddressPtr &output) { return output->addr; });
static AkgCallBack akg_callback = AkgCallBack();
(void)runtimeargs.emplace_back(reinterpret_cast<void *>(&akg_callback));
using AkgCpuKernelFunction = void (*)(void *);
reinterpret_cast<AkgCpuKernelFunction>(launch_func_)(reinterpret_cast<void *>(runtimeargs.data()));
return true;