move AkgCallBack to the first input of akg cpu kernel and add an extend_data value
This commit is contained in:
parent
383c1406dc
commit
09e14c690e
2
akg
2
akg
|
@ -1 +1 @@
|
|||
Subproject commit a9cbf642063fb1086a93e8bc6be6feb145689817
|
||||
Subproject commit e3f2411858e34499fce13ec00ea35e1292d441b1
|
|
@ -45,6 +45,7 @@ struct AkgCallBack {
|
|||
void *parallel_launch_func;
|
||||
void *(*malloc_func)(size_t);
|
||||
void (*free_func)(void *);
|
||||
void *extend_data = nullptr;
|
||||
|
||||
AkgCallBack() {
|
||||
parallel_launch_func = reinterpret_cast<void *>(&AkgParallelLaunch::AkgLaunchFunc);
|
||||
|
@ -126,13 +127,14 @@ bool AkgCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::v
|
|||
MS_LOG(ERROR) << "GetFunction failed. kernel: " << kernel_name_;
|
||||
return false;
|
||||
}
|
||||
static AkgCallBack akg_callback = AkgCallBack();
|
||||
std::vector<void *> runtimeargs;
|
||||
runtimeargs.reserve(inputs.size() + outputs.size() + 1);
|
||||
(void)runtimeargs.emplace_back(reinterpret_cast<void *>(&akg_callback));
|
||||
(void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(runtimeargs),
|
||||
[](const AddressPtr &input) { return input->addr; });
|
||||
(void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtimeargs),
|
||||
[](const AddressPtr &output) { return output->addr; });
|
||||
static AkgCallBack akg_callback = AkgCallBack();
|
||||
(void)runtimeargs.emplace_back(reinterpret_cast<void *>(&akg_callback));
|
||||
using AkgCpuKernelFunction = void (*)(void *);
|
||||
reinterpret_cast<AkgCpuKernelFunction>(launch_func_)(reinterpret_cast<void *>(runtimeargs.data()));
|
||||
return true;
|
||||
|
|
Loading…
Reference in New Issue