forked from mindspore-Ecosystem/mindspore
fix bug in op arithmetic
This commit is contained in:
parent
a4e46c8e75
commit
00a1933f34
|
@ -50,6 +50,12 @@ std::vector<size_t> ArithmeticOpenCLKernel::InitGlobalSize() const {
|
|||
|
||||
void ArithmeticOpenCLKernel::Image2dGetWorkGroupSize() {
|
||||
local_size_ = {16, 16};
|
||||
if (out_tensors_[0]->shape().size() == 2) {
|
||||
size_t H = out_tensors_[0]->shape()[0];
|
||||
size_t W = UP_DIV(out_tensors_[0]->shape()[1], C4NUM);
|
||||
global_size_ = {W, H};
|
||||
return;
|
||||
}
|
||||
if (out_tensors_[0]->GetFormat() == schema::Format_NC4HW4) {
|
||||
size_t H = out_tensors_[0]->Batch() * out_tensors_[0]->Height() * UP_DIV(out_tensors_[0]->Channel(), C4NUM);
|
||||
size_t W = out_tensors_[0]->Width();
|
||||
|
@ -74,18 +80,23 @@ void ArithmeticOpenCLKernel::BufferGetWorkGroupSize() {
|
|||
|
||||
int ArithmeticOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
|
||||
size_t im_dst_x, im_dst_y;
|
||||
if (out_tensors_[0]->GetFormat() == schema::Format_NC4HW4) {
|
||||
im_dst_x = out_tensors_[0]->Width();
|
||||
im_dst_y = out_tensors_[0]->Batch() * out_tensors_[0]->Height() * UP_DIV(out_tensors_[0]->Channel(), C4NUM);
|
||||
} else if (out_tensors_[0]->GetFormat() == schema::Format_NHWC4) {
|
||||
im_dst_x = out_tensors_[0]->Width() * UP_DIV(out_tensors_[0]->Channel(), C4NUM);
|
||||
im_dst_y = out_tensors_[0]->Batch() * out_tensors_[0]->Height();
|
||||
} else if (out_tensors_[0]->GetFormat() == schema::Format_NC4) {
|
||||
im_dst_y = out_tensors_[0]->Batch();
|
||||
im_dst_x = UP_DIV(out_tensors_[0]->Channel(), C4NUM);
|
||||
if (out_tensors_[0]->shape().size() == 2) {
|
||||
im_dst_x = UP_DIV(out_tensors_[0]->shape()[1], C4NUM);
|
||||
im_dst_y = out_tensors_[0]->shape()[0];
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupport data format " << out_tensors_[0]->GetFormat();
|
||||
return RET_ERROR;
|
||||
if (out_tensors_[0]->GetFormat() == schema::Format_NC4HW4) {
|
||||
im_dst_x = out_tensors_[0]->Width();
|
||||
im_dst_y = out_tensors_[0]->Batch() * out_tensors_[0]->Height() * UP_DIV(out_tensors_[0]->Channel(), C4NUM);
|
||||
} else if (out_tensors_[0]->GetFormat() == schema::Format_NHWC4) {
|
||||
im_dst_x = out_tensors_[0]->Width() * UP_DIV(out_tensors_[0]->Channel(), C4NUM);
|
||||
im_dst_y = out_tensors_[0]->Batch() * out_tensors_[0]->Height();
|
||||
} else if (out_tensors_[0]->GetFormat() == schema::Format_NC4) {
|
||||
im_dst_y = out_tensors_[0]->Batch();
|
||||
im_dst_x = UP_DIV(out_tensors_[0]->Channel(), C4NUM);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupport data format " << out_tensors_[0]->GetFormat();
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
||||
size_t img_dtype = CL_FLOAT;
|
||||
|
@ -335,22 +346,7 @@ int ArithmeticOpenCLKernel::Run() {
|
|||
}
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->data_c());
|
||||
|
||||
int H = 0;
|
||||
int W = 0;
|
||||
if (out_tensors_[0]->GetFormat() == schema::Format_NC4HW4) {
|
||||
H = out_tensors_[0]->Batch() * out_tensors_[0]->Height() * UP_DIV(out_tensors_[0]->Channel(), C4NUM);
|
||||
W = out_tensors_[0]->Width();
|
||||
} else if (out_tensors_[0]->GetFormat() == schema::Format_NHWC4) {
|
||||
H = out_tensors_[0]->Batch() * out_tensors_[0]->Height();
|
||||
W = out_tensors_[0]->Width() * UP_DIV(out_tensors_[0]->Channel(), C4NUM);
|
||||
} else if (out_tensors_[0]->GetFormat() == schema::Format_NC4) {
|
||||
H = out_tensors_[0]->Batch();
|
||||
W = UP_DIV(out_tensors_[0]->Channel(), C4NUM);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Error output type " << out_tensors_[0]->GetFormat();
|
||||
return RET_ERROR;
|
||||
}
|
||||
cl_int2 output_shape{W, H};
|
||||
cl_int2 output_shape{static_cast<int>(global_size_[0]), static_cast<int>(global_size_[1])};
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_idx++, output_shape);
|
||||
ocl_runtime_->RunKernel(kernel_, global_size_, local_size_, nullptr);
|
||||
return RET_OK;
|
||||
|
|
|
@ -162,4 +162,5 @@ kernel::LiteKernel *OpenCLBiasAddKernelCreator(const std::vector<lite::Tensor *>
|
|||
}
|
||||
|
||||
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_BiasAdd, OpenCLBiasAddKernelCreator)
|
||||
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_BiasAdd, OpenCLBiasAddKernelCreator)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -172,4 +172,5 @@ kernel::LiteKernel *OpenCLPReluKernelCreator(const std::vector<lite::Tensor *> &
|
|||
}
|
||||
|
||||
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_PReLU, OpenCLPReluKernelCreator)
|
||||
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_PReLU, OpenCLPReluKernelCreator)
|
||||
} // namespace mindspore::kernel
|
||||
|
|
Loading…
Reference in New Issue