fix bug in op arithmetic

This commit is contained in:
chenzupeng 2020-10-10 17:03:52 +08:00
parent a4e46c8e75
commit 00a1933f34
3 changed files with 25 additions and 27 deletions

View File

@ -50,6 +50,12 @@ std::vector<size_t> ArithmeticOpenCLKernel::InitGlobalSize() const {
void ArithmeticOpenCLKernel::Image2dGetWorkGroupSize() {
local_size_ = {16, 16};
if (out_tensors_[0]->shape().size() == 2) {
size_t H = out_tensors_[0]->shape()[0];
size_t W = UP_DIV(out_tensors_[0]->shape()[1], C4NUM);
global_size_ = {W, H};
return;
}
if (out_tensors_[0]->GetFormat() == schema::Format_NC4HW4) {
size_t H = out_tensors_[0]->Batch() * out_tensors_[0]->Height() * UP_DIV(out_tensors_[0]->Channel(), C4NUM);
size_t W = out_tensors_[0]->Width();
@ -74,18 +80,23 @@ void ArithmeticOpenCLKernel::BufferGetWorkGroupSize() {
int ArithmeticOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
size_t im_dst_x, im_dst_y;
if (out_tensors_[0]->GetFormat() == schema::Format_NC4HW4) {
im_dst_x = out_tensors_[0]->Width();
im_dst_y = out_tensors_[0]->Batch() * out_tensors_[0]->Height() * UP_DIV(out_tensors_[0]->Channel(), C4NUM);
} else if (out_tensors_[0]->GetFormat() == schema::Format_NHWC4) {
im_dst_x = out_tensors_[0]->Width() * UP_DIV(out_tensors_[0]->Channel(), C4NUM);
im_dst_y = out_tensors_[0]->Batch() * out_tensors_[0]->Height();
} else if (out_tensors_[0]->GetFormat() == schema::Format_NC4) {
im_dst_y = out_tensors_[0]->Batch();
im_dst_x = UP_DIV(out_tensors_[0]->Channel(), C4NUM);
if (out_tensors_[0]->shape().size() == 2) {
im_dst_x = UP_DIV(out_tensors_[0]->shape()[1], C4NUM);
im_dst_y = out_tensors_[0]->shape()[0];
} else {
MS_LOG(ERROR) << "Unsupport data format " << out_tensors_[0]->GetFormat();
return RET_ERROR;
if (out_tensors_[0]->GetFormat() == schema::Format_NC4HW4) {
im_dst_x = out_tensors_[0]->Width();
im_dst_y = out_tensors_[0]->Batch() * out_tensors_[0]->Height() * UP_DIV(out_tensors_[0]->Channel(), C4NUM);
} else if (out_tensors_[0]->GetFormat() == schema::Format_NHWC4) {
im_dst_x = out_tensors_[0]->Width() * UP_DIV(out_tensors_[0]->Channel(), C4NUM);
im_dst_y = out_tensors_[0]->Batch() * out_tensors_[0]->Height();
} else if (out_tensors_[0]->GetFormat() == schema::Format_NC4) {
im_dst_y = out_tensors_[0]->Batch();
im_dst_x = UP_DIV(out_tensors_[0]->Channel(), C4NUM);
} else {
MS_LOG(ERROR) << "Unsupport data format " << out_tensors_[0]->GetFormat();
return RET_ERROR;
}
}
size_t img_dtype = CL_FLOAT;
@ -335,22 +346,7 @@ int ArithmeticOpenCLKernel::Run() {
}
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->data_c());
int H = 0;
int W = 0;
if (out_tensors_[0]->GetFormat() == schema::Format_NC4HW4) {
H = out_tensors_[0]->Batch() * out_tensors_[0]->Height() * UP_DIV(out_tensors_[0]->Channel(), C4NUM);
W = out_tensors_[0]->Width();
} else if (out_tensors_[0]->GetFormat() == schema::Format_NHWC4) {
H = out_tensors_[0]->Batch() * out_tensors_[0]->Height();
W = out_tensors_[0]->Width() * UP_DIV(out_tensors_[0]->Channel(), C4NUM);
} else if (out_tensors_[0]->GetFormat() == schema::Format_NC4) {
H = out_tensors_[0]->Batch();
W = UP_DIV(out_tensors_[0]->Channel(), C4NUM);
} else {
MS_LOG(ERROR) << "Error output type " << out_tensors_[0]->GetFormat();
return RET_ERROR;
}
cl_int2 output_shape{W, H};
cl_int2 output_shape{static_cast<int>(global_size_[0]), static_cast<int>(global_size_[1])};
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, output_shape);
ocl_runtime_->RunKernel(kernel_, global_size_, local_size_, nullptr);
return RET_OK;

View File

@ -162,4 +162,5 @@ kernel::LiteKernel *OpenCLBiasAddKernelCreator(const std::vector<lite::Tensor *>
}
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_BiasAdd, OpenCLBiasAddKernelCreator)
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_BiasAdd, OpenCLBiasAddKernelCreator)
} // namespace mindspore::kernel

View File

@ -172,4 +172,5 @@ kernel::LiteKernel *OpenCLPReluKernelCreator(const std::vector<lite::Tensor *> &
}
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_PReLU, OpenCLPReluKernelCreator)
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_PReLU, OpenCLPReluKernelCreator)
} // namespace mindspore::kernel