forked from mindspore-Ecosystem/mindspore
commit
4141404213
|
@ -33,15 +33,15 @@
|
|||
namespace mindspore {
|
||||
namespace custom_gpu_demo {
|
||||
|
||||
class CustomAddKernel : public kernel::Kernel {
|
||||
class CustomAddKernelGpu : public kernel::Kernel {
|
||||
public:
|
||||
CustomAddKernel(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs,
|
||||
const schema::Primitive *primitive, const mindspore::Context *ctx, const std::string &build_options,
|
||||
bool fp16_enable)
|
||||
CustomAddKernelGpu(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs,
|
||||
const schema::Primitive *primitive, const mindspore::Context *ctx,
|
||||
const std::string &build_options, bool fp16_enable)
|
||||
: Kernel(inputs, outputs, primitive, ctx), build_options_(build_options), fp16_enable_(fp16_enable) {
|
||||
opencl_runtime_ = new registry::opencl::OpenCLRuntimeWrapper();
|
||||
}
|
||||
~CustomAddKernel() override { FreeWeight(); }
|
||||
~CustomAddKernelGpu() override { FreeWeight(); }
|
||||
// Prepare will be called during graph compilation
|
||||
int Prepare() override {
|
||||
const std::string kernel_name_ = "ElementAdd";
|
||||
|
@ -195,7 +195,7 @@ class CustomAddKernel : public kernel::Kernel {
|
|||
registry::opencl::OpenCLRuntimeWrapper *opencl_runtime_;
|
||||
|
||||
int PreProcess() {
|
||||
int ret;
|
||||
int ret = 0;
|
||||
ret = ReSize();
|
||||
if (ret != lite::RET_OK) {
|
||||
return ret;
|
||||
|
@ -240,7 +240,7 @@ std::shared_ptr<kernel::Kernel> CustomAddCreator(const std::vector<MSTensor> &in
|
|||
bool fp16_enable = false;
|
||||
|
||||
std::cout << "using fp32 add.\n" << std::endl;
|
||||
return std::make_shared<CustomAddKernel>(inputs, outputs, primitive, ctx, build_options, fp16_enable);
|
||||
return std::make_shared<CustomAddKernelGpu>(inputs, outputs, primitive, ctx, build_options, fp16_enable);
|
||||
}
|
||||
|
||||
std::shared_ptr<kernel::Kernel> CustomAddFP16Creator(const std::vector<MSTensor> &inputs,
|
||||
|
@ -251,7 +251,7 @@ std::shared_ptr<kernel::Kernel> CustomAddFP16Creator(const std::vector<MSTensor>
|
|||
bool fp16_enable = true;
|
||||
|
||||
std::cout << "using fp16 add." << std::endl;
|
||||
return std::make_shared<CustomAddKernel>(inputs, outputs, primitive, ctx, build_options, fp16_enable);
|
||||
return std::make_shared<CustomAddKernelGpu>(inputs, outputs, primitive, ctx, build_options, fp16_enable);
|
||||
}
|
||||
|
||||
} // namespace custom_gpu_demo
|
||||
|
|
|
@ -71,6 +71,5 @@ void PackNHWCToNHWC4(void *src, void *dst, bool src_is_fp16, bool dst_is_fp16, c
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace custom_common
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -31,7 +31,6 @@
|
|||
using mindspore::kernel::CLErrorCode;
|
||||
|
||||
namespace mindspore::registry::opencl {
|
||||
|
||||
Status OpenCLRuntimeWrapper::LoadSource(const std::string &program_name, const std::string &source) {
|
||||
lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap;
|
||||
lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance();
|
||||
|
|
Loading…
Reference in New Issue