move AkgCallBack to the first input of akg cpu kernel and add an extend_data value
This commit is contained in:
parent
383c1406dc
commit
09e14c690e
2
akg
2
akg
|
@ -1 +1 @@
|
||||||
Subproject commit a9cbf642063fb1086a93e8bc6be6feb145689817
|
Subproject commit e3f2411858e34499fce13ec00ea35e1292d441b1
|
|
@ -45,6 +45,7 @@ struct AkgCallBack {
|
||||||
void *parallel_launch_func;
|
void *parallel_launch_func;
|
||||||
void *(*malloc_func)(size_t);
|
void *(*malloc_func)(size_t);
|
||||||
void (*free_func)(void *);
|
void (*free_func)(void *);
|
||||||
|
void *extend_data = nullptr;
|
||||||
|
|
||||||
AkgCallBack() {
|
AkgCallBack() {
|
||||||
parallel_launch_func = reinterpret_cast<void *>(&AkgParallelLaunch::AkgLaunchFunc);
|
parallel_launch_func = reinterpret_cast<void *>(&AkgParallelLaunch::AkgLaunchFunc);
|
||||||
|
@ -126,13 +127,14 @@ bool AkgCpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::v
|
||||||
MS_LOG(ERROR) << "GetFunction failed. kernel: " << kernel_name_;
|
MS_LOG(ERROR) << "GetFunction failed. kernel: " << kernel_name_;
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
static AkgCallBack akg_callback = AkgCallBack();
|
||||||
std::vector<void *> runtimeargs;
|
std::vector<void *> runtimeargs;
|
||||||
|
runtimeargs.reserve(inputs.size() + outputs.size() + 1);
|
||||||
|
(void)runtimeargs.emplace_back(reinterpret_cast<void *>(&akg_callback));
|
||||||
(void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(runtimeargs),
|
(void)std::transform(std::begin(inputs), std::end(inputs), std::back_inserter(runtimeargs),
|
||||||
[](const AddressPtr &input) { return input->addr; });
|
[](const AddressPtr &input) { return input->addr; });
|
||||||
(void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtimeargs),
|
(void)std::transform(std::begin(outputs), std::end(outputs), std::back_inserter(runtimeargs),
|
||||||
[](const AddressPtr &output) { return output->addr; });
|
[](const AddressPtr &output) { return output->addr; });
|
||||||
static AkgCallBack akg_callback = AkgCallBack();
|
|
||||||
(void)runtimeargs.emplace_back(reinterpret_cast<void *>(&akg_callback));
|
|
||||||
using AkgCpuKernelFunction = void (*)(void *);
|
using AkgCpuKernelFunction = void (*)(void *);
|
||||||
reinterpret_cast<AkgCpuKernelFunction>(launch_func_)(reinterpret_cast<void *>(runtimeargs.data()));
|
reinterpret_cast<AkgCpuKernelFunction>(launch_func_)(reinterpret_cast<void *>(runtimeargs.data()));
|
||||||
return true;
|
return true;
|
||||||
|
|
Loading…
Reference in New Issue