!23860 statistic fix

Merge pull request !23860 from QianliMa/statis
This commit is contained in:
i-robot 2021-09-22 08:49:16 +00:00 committed by Gitee
commit 4141404213
3 changed files with 8 additions and 10 deletions

View File

@ -33,15 +33,15 @@
namespace mindspore {
namespace custom_gpu_demo {
class CustomAddKernel : public kernel::Kernel {
class CustomAddKernelGpu : public kernel::Kernel {
public:
CustomAddKernel(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs,
const schema::Primitive *primitive, const mindspore::Context *ctx, const std::string &build_options,
bool fp16_enable)
CustomAddKernelGpu(const std::vector<MSTensor> &inputs, const std::vector<MSTensor> &outputs,
const schema::Primitive *primitive, const mindspore::Context *ctx,
const std::string &build_options, bool fp16_enable)
: Kernel(inputs, outputs, primitive, ctx), build_options_(build_options), fp16_enable_(fp16_enable) {
opencl_runtime_ = new registry::opencl::OpenCLRuntimeWrapper();
}
~CustomAddKernel() override { FreeWeight(); }
~CustomAddKernelGpu() override { FreeWeight(); }
// Prepare will be called during graph compilation
int Prepare() override {
const std::string kernel_name_ = "ElementAdd";
@ -195,7 +195,7 @@ class CustomAddKernel : public kernel::Kernel {
registry::opencl::OpenCLRuntimeWrapper *opencl_runtime_;
int PreProcess() {
int ret;
int ret = 0;
ret = ReSize();
if (ret != lite::RET_OK) {
return ret;
@ -240,7 +240,7 @@ std::shared_ptr<kernel::Kernel> CustomAddCreator(const std::vector<MSTensor> &in
bool fp16_enable = false;
std::cout << "using fp32 add.\n" << std::endl;
return std::make_shared<CustomAddKernel>(inputs, outputs, primitive, ctx, build_options, fp16_enable);
return std::make_shared<CustomAddKernelGpu>(inputs, outputs, primitive, ctx, build_options, fp16_enable);
}
std::shared_ptr<kernel::Kernel> CustomAddFP16Creator(const std::vector<MSTensor> &inputs,
@ -251,7 +251,7 @@ std::shared_ptr<kernel::Kernel> CustomAddFP16Creator(const std::vector<MSTensor>
bool fp16_enable = true;
std::cout << "using fp16 add." << std::endl;
return std::make_shared<CustomAddKernel>(inputs, outputs, primitive, ctx, build_options, fp16_enable);
return std::make_shared<CustomAddKernelGpu>(inputs, outputs, primitive, ctx, build_options, fp16_enable);
}
} // namespace custom_gpu_demo

View File

@ -71,6 +71,5 @@ void PackNHWCToNHWC4(void *src, void *dst, bool src_is_fp16, bool dst_is_fp16, c
}
}
}
} // namespace custom_common
} // namespace mindspore

View File

@ -31,7 +31,6 @@
using mindspore::kernel::CLErrorCode;
namespace mindspore::registry::opencl {
Status OpenCLRuntimeWrapper::LoadSource(const std::string &program_name, const std::string &source) {
lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap;
lite::opencl::OpenCLRuntime *ocl_runtime = ocl_runtime_wrap.GetInstance();