forked from mindspore-Ecosystem/mindspore
[MS][LITE] opencl operator `softmax` support iamge2d
This commit is contained in:
parent
8ed3b423e0
commit
b132311556
|
@ -32,6 +32,10 @@ enum Format : int {
|
|||
CKHW,
|
||||
KHWC,
|
||||
CHWK,
|
||||
HW,
|
||||
HW4,
|
||||
NC,
|
||||
NC4,
|
||||
NC4HW4 = 100,
|
||||
NUM_OF_FORMAT
|
||||
}
|
||||
|
|
|
@ -104,7 +104,7 @@ int Executor::TransformTensorLayoutFp32(tensor::Tensor *tensor, schema::Format d
|
|||
allocator->Free(src_data);
|
||||
return RET_OK;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
|
||||
MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
|
||||
<< schema::EnumNameFormat(dst_format) << " in float32";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -116,7 +116,7 @@ int Executor::TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::Format
|
|||
MS_ASSERT(4 == tensor->shape().size());
|
||||
// auto src_format = tensor->GetFormat();
|
||||
// todo
|
||||
MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
|
||||
MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
|
||||
<< schema::EnumNameFormat(dst_format) << " in uint8";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
|
|
@ -104,8 +104,8 @@ bool Tensor::operator==(const Value &other) const {
|
|||
}
|
||||
|
||||
int32_t Tensor::Batch() const {
|
||||
if (this->shape_.size() != 4) {
|
||||
MS_LOG(ERROR) << "tensor should have 4 dim";
|
||||
if (this->shape_.size() != 4 && this->shape_.size() != 2) {
|
||||
MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size();
|
||||
return -1;
|
||||
}
|
||||
switch (this->format_) {
|
||||
|
@ -115,6 +115,8 @@ int32_t Tensor::Batch() const {
|
|||
case schema::Format_NC4HW4:
|
||||
case schema::Format_KCHW:
|
||||
case schema::Format_KHWC:
|
||||
case schema::Format_NC:
|
||||
case schema::Format_NC4:
|
||||
return this->shape_[0];
|
||||
case schema::Format_HWCK:
|
||||
case schema::Format_CHWK:
|
||||
|
@ -124,19 +126,21 @@ int32_t Tensor::Batch() const {
|
|||
case schema::Format_CKHW:
|
||||
return this->shape_[1];
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unsupport format: " << schema::EnumNameFormat(this->format_);
|
||||
MS_LOG(ERROR) << "Unsupported format: " << schema::EnumNameFormat(this->format_);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t Tensor::Channel() const {
|
||||
if (this->shape_.size() != 4) {
|
||||
MS_LOG(ERROR) << "tensor should have 4 dim";
|
||||
if (this->shape_.size() != 4 && this->shape_.size() != 2) {
|
||||
MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size();
|
||||
return -1;
|
||||
}
|
||||
switch (this->format_) {
|
||||
case schema::Format_NCHW:
|
||||
case schema::Format_KCHW:
|
||||
case schema::Format_NC:
|
||||
case schema::Format_NC4:
|
||||
return this->shape_[1];
|
||||
case schema::Format_HWCK:
|
||||
return this->shape_[2];
|
||||
|
@ -155,8 +159,8 @@ int32_t Tensor::Channel() const {
|
|||
}
|
||||
|
||||
int32_t Tensor::Height() const {
|
||||
if (this->shape_.size() != 4) {
|
||||
MS_LOG(ERROR) << "tensor should have 4 dim";
|
||||
if (this->shape_.size() != 4 && this->shape_.size() != 2) {
|
||||
MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size();
|
||||
return -1;
|
||||
}
|
||||
switch (this->format_) {
|
||||
|
@ -172,16 +176,18 @@ int32_t Tensor::Height() const {
|
|||
return this->shape_[1];
|
||||
case schema::Format_HWCK:
|
||||
case schema::Format_HWKC:
|
||||
case schema::Format_HW:
|
||||
case schema::Format_HW4:
|
||||
return this->shape_[0];
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unsupport format: " << schema::EnumNameFormat(this->format_);
|
||||
MS_LOG(ERROR) << "Unsupported format: " << schema::EnumNameFormat(this->format_);
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t Tensor::Width() const {
|
||||
if (this->shape_.size() != 4) {
|
||||
MS_LOG(ERROR) << "tensor should have 4 dim";
|
||||
if (this->shape_.size() != 4 && this->shape_.size() != 2) {
|
||||
MS_LOG(ERROR) << "Unsupported tensor shape: " << this->shape().size();
|
||||
return -1;
|
||||
}
|
||||
switch (this->format_) {
|
||||
|
@ -197,12 +203,24 @@ int32_t Tensor::Width() const {
|
|||
return this->shape_[2];
|
||||
case schema::Format_HWCK:
|
||||
case schema::Format_HWKC:
|
||||
case schema::Format_HW:
|
||||
case schema::Format_HW4:
|
||||
return this->shape_[1];
|
||||
default:
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t Tensor::ElementsC4Num() const {
|
||||
int32_t result = 0;
|
||||
if (this->shape_.size() == 4) {
|
||||
result = Batch() * Height() * Width() * ((Channel() + 3) / 4 * 4);
|
||||
} else if (this->shape_.size() == 2) {
|
||||
result = this->shape_[0] * ((this->shape_[1] + 3) / 4 * 4);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
std::string Tensor::ToString() const {
|
||||
std::ostringstream oss;
|
||||
oss << "Format: " << schema::EnumNameFormat(this->format_);
|
||||
|
@ -235,7 +253,7 @@ std::string Tensor::ToString() const {
|
|||
}
|
||||
} break;
|
||||
default:
|
||||
oss << "Unsupport data type to print";
|
||||
oss << "Unsupported data type to print";
|
||||
break;
|
||||
}
|
||||
return oss.str();
|
||||
|
|
|
@ -66,7 +66,7 @@ class Tensor : public mindspore::tensor::MetaTensor {
|
|||
|
||||
int32_t Width() const;
|
||||
|
||||
int32_t ElementsC4Num() const { return Batch() * Height() * Width() * ((Channel() + 3) / 4 * 4); }
|
||||
int32_t ElementsC4Num() const;
|
||||
|
||||
int DataSize() const { return this->ElementsNum(); }
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ int Cast::InferShape(std::vector<tensor::Tensor *> inputs_, std::vector<tensor::
|
|||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
if (kSupportDataType.find(input->data_type()) == kSupportDataType.end()) {
|
||||
MS_LOG(ERROR) << "Unsupport input data type " << input->data_type();
|
||||
MS_LOG(ERROR) << "Unsupported input data type " << input->data_type();
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
if (cast_prim->dstT() != kNumberTypeFloat && cast_prim->dstT() != kNumberTypeFloat32) {
|
||||
|
|
|
@ -74,7 +74,7 @@ int CastCPUKernel::DoCast(int thread_id) {
|
|||
Float32ToInt32(reinterpret_cast<float *>(input->Data()) + offset,
|
||||
reinterpret_cast<int32_t *>(output_data) + offset, data_num);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupport datatype from " << input_data_type << " to " << output_data_type;
|
||||
MS_LOG(ERROR) << "Unsupported datatype from " << input_data_type << " to " << output_data_type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
} else {
|
||||
|
@ -88,7 +88,7 @@ int CastCPUKernel::DoCast(int thread_id) {
|
|||
reinterpret_cast<float *>(output_data) + offset, data_num);
|
||||
break;
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unsupport input data type " << input_data_type;
|
||||
MS_LOG(ERROR) << "Unsupported input data type " << input_data_type;
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,21 +1,15 @@
|
|||
#define SLICES 4
|
||||
|
||||
int DivideRoundUp(int n, int div) {
|
||||
int q = n / div;
|
||||
return n % div == 0 ? q : q + 1;
|
||||
}
|
||||
|
||||
__kernel void SoftMax(__global float4 *input, __global float4 *output, const int4 input_shape) {
|
||||
int X = get_global_id(0); // width
|
||||
int Y = get_global_id(1); // height
|
||||
int H = input_shape.y;
|
||||
int W = input_shape.z;
|
||||
int C = input_shape.w;
|
||||
__kernel void SoftMax_BUF(__global float4 *input, __global float4 *output, const int4 input_shape) {
|
||||
int X = get_global_id(0);
|
||||
int Y = get_global_id(1);
|
||||
int H = input_shape.x;
|
||||
int W = input_shape.y;
|
||||
int C = input_shape.z;
|
||||
int S = input_shape.w;
|
||||
|
||||
if (X >= W || Y >= H) return;
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int d = 0; d < DivideRoundUp(C, SLICES); ++d) {
|
||||
for (int d = 0; d < S; ++d) {
|
||||
float4 t = input[(Y * W + X * H) * C + d];
|
||||
sum += exp(t.x);
|
||||
if (d * 4 + 1 < C) sum += exp(t.y);
|
||||
|
@ -23,10 +17,34 @@ __kernel void SoftMax(__global float4 *input, __global float4 *output, const int
|
|||
if (d * 4 + 3 < C) sum += exp(t.w);
|
||||
}
|
||||
|
||||
for (int d = 0; d < DivideRoundUp(C, SLICES); ++d) {
|
||||
for (int d = 0; d < S; ++d) {
|
||||
float4 t = input[(Y * W + X * H) * C + d];
|
||||
t = exp(t) / sum;
|
||||
float4 result = convert_float4(t);
|
||||
output[(Y * W + X * H) * C + d] = result;
|
||||
}
|
||||
}
|
||||
|
||||
__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;
|
||||
|
||||
__kernel void SoftMax_IMG(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape) {
|
||||
int X = get_global_id(0);
|
||||
int Y = get_global_id(1);
|
||||
if (X >= input_shape.x || Y >= input_shape.y) return;
|
||||
|
||||
float sum = 0.0f;
|
||||
for (int d = 0; d < input_shape.w; ++d) {
|
||||
float4 t = read_imagef(input, smp_none, (int2)(Y * input_shape.w + d, X));
|
||||
sum += exp(t.x);
|
||||
if (d * 4 + 1 < input_shape.z) sum += exp(t.y);
|
||||
if (d * 4 + 2 < input_shape.z) sum += exp(t.z);
|
||||
if (d * 4 + 3 < input_shape.z) sum += exp(t.w);
|
||||
}
|
||||
|
||||
for (int d = 0; d < input_shape.w; ++d) {
|
||||
float4 t = read_imagef(input, smp_none, (int2)(Y * input_shape.w + d, X));
|
||||
t = exp(t) / sum;
|
||||
float4 result = convert_float4(t);
|
||||
write_imagef(output, (int2)(Y * input_shape.w + d, X), result);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;
|
||||
|
||||
// what is mask and args.slices_x32
|
||||
__kernel void SoftMax1x1_IMG(__read_only image2d_t input, __write_only image2d_t output, const float4 mask,
|
||||
const int slices, const int slices_x32) {
|
||||
int tid = get_local_id(0);
|
||||
int slices_count = 0;
|
||||
int offset = 0;
|
||||
float sum = 0.0f;
|
||||
do {
|
||||
int z = offset + tid;
|
||||
if (z < slices) {
|
||||
float4 mask_temp = z == slices - 1 ? mask : (float4)(1.0f);
|
||||
float4 src = read_imagef(input, smp_none, (int2)(0, 0));
|
||||
sum += dot(mask_temp, exp(src));
|
||||
offset += 32;
|
||||
}
|
||||
slices_count++;
|
||||
} while (slices_count < slices_x32);
|
||||
|
||||
__local float4 tmp[8];
|
||||
__local float *tmpx1 = (__local float *)tmp;
|
||||
tmpx1[tid] = sum;
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
if (tid == 0) {
|
||||
sum = dot((float4)(1.0f), tmp[0]);
|
||||
sum += dot((float4)(1.0f), tmp[1]);
|
||||
sum += dot((float4)(1.0f), tmp[2]);
|
||||
sum += dot((float4)(1.0f), tmp[3]);
|
||||
sum += dot((float4)(1.0f), tmp[4]);
|
||||
sum += dot((float4)(1.0f), tmp[5]);
|
||||
sum += dot((float4)(1.0f), tmp[6]);
|
||||
sum += dot((float4)(1.0f), tmp[7]);
|
||||
tmpx1[0] = 1.0f / sum;
|
||||
}
|
||||
barrier(CLK_LOCAL_MEM_FENCE);
|
||||
sum = tmpx1[0];
|
||||
|
||||
offset = 0;
|
||||
slices_count = 0;
|
||||
do {
|
||||
int z = offset + tid;
|
||||
if (z < slices) {
|
||||
float4 res = convert_float4(exp(read_imagef(input, smp_none, (int2)(0, 0))) * sum);
|
||||
write_imagef(output, (int2)(0, 0), res);
|
||||
offset += 32;
|
||||
}
|
||||
slices_count++;
|
||||
} while (slices_count < slices_x32);
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -17,69 +17,143 @@
|
|||
#include "src/runtime/kernel/opencl/kernel/softmax.h"
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include "include/errorcode.h"
|
||||
#include "src/kernel_registry.h"
|
||||
#include "src/runtime/opencl/opencl_runtime.h"
|
||||
#include "src/runtime/kernel/opencl/utils.h"
|
||||
#ifndef PROGRAM_WITH_IL
|
||||
#include "src/runtime/kernel/opencl/cl/fp32/softmax.cl.inc"
|
||||
#include "src/runtime/kernel/opencl/cl/fp32/softmax1x1.cl.inc"
|
||||
#endif
|
||||
|
||||
using mindspore::kernel::KERNEL_ARCH::kGPU;
|
||||
using mindspore::lite::KernelRegistrar;
|
||||
using mindspore::schema::PrimitiveType_SoftMax;
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
int SoftmaxOpenCLKernel::Init() {
|
||||
std::string kernel_name = "SoftMax";
|
||||
if (parameter_->axis_ != -1 && parameter_->axis_ != 3) {
|
||||
MS_LOG(ERROR) << "Init `Softmax` kernel failed: Unsupported axis: " << parameter_->axis_;
|
||||
return -1;
|
||||
}
|
||||
namespace mindspore::kernel {
|
||||
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
#ifdef PROGRAM_WITH_IL
|
||||
ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name);
|
||||
#else
|
||||
std::set<std::string> build_options;
|
||||
std::string source = softmax_source_fp32;
|
||||
std::string program_name = "SoftMax";
|
||||
ocl_runtime->LoadSource(program_name, source);
|
||||
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
|
||||
#endif
|
||||
out_tensors_[0]->SetFormat(schema::Format_NHWC4);
|
||||
MS_LOG(DEBUG) << kernel_name << " Init Done!";
|
||||
return 0;
|
||||
std::vector<float> SoftmaxOpenCLKernel::GetMaskForLastChannel(int channels) {
|
||||
std::vector<float> mask{4, 0.0f};
|
||||
const int reminder = channels % 4 == 0 ? 4 : channels % 4;
|
||||
for (int i = 0; i < reminder; ++i) {
|
||||
mask[i] = 1.0f;
|
||||
}
|
||||
return mask;
|
||||
}
|
||||
|
||||
int SoftmaxOpenCLKernel::InitBuffer() { return 0; }
|
||||
int SoftmaxOpenCLKernel::ReSize() { return 0; }
|
||||
int SoftmaxOpenCLKernel::InitGlobalSize() {
|
||||
const size_t global_x = out_tensors_[0]->Height();
|
||||
const size_t global_y = out_tensors_[0]->Width();
|
||||
const size_t global_z = 1;
|
||||
global_size_ = {global_x, global_y, global_z};
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int SoftmaxOpenCLKernel::SetWorkGroupSize() {
|
||||
// set work group size
|
||||
InitGlobalSize();
|
||||
int max_work_group_size = runtime_->GetKernelMaxWorkGroupSize(kernel_(), (*runtime_->Device())());
|
||||
local_size_ = GetCommonLocalSize(global_size_, max_work_group_size);
|
||||
global_size_ = GetCommonGlobalSize(local_size_, global_size_);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int SoftmaxOpenCLKernel::SetWorkGroupSize1x1() {
|
||||
local_size_ = {32, 1, 1};
|
||||
global_size_ = {32, 1, 1};
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int SoftmaxOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
|
||||
size_t im_dst_x, im_dst_y;
|
||||
if (onexone_flag_) {
|
||||
im_dst_x = UP_DIV(in_tensors_[0]->shape()[1], C4NUM);
|
||||
im_dst_y = 1;
|
||||
} else {
|
||||
size_t CO4 = UP_DIV(out_tensors_[0]->Channel(), C4NUM);
|
||||
im_dst_x = out_tensors_[0]->Width() * CO4;
|
||||
im_dst_y = out_tensors_[0]->Height();
|
||||
}
|
||||
#ifdef ENABLE_FP16
|
||||
size_t img_dtype = CL_HALF_FLOAT;
|
||||
#else
|
||||
size_t img_dtype = CL_FLOAT;
|
||||
#endif
|
||||
img_size->clear();
|
||||
std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype};
|
||||
*img_size = vec;
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int SoftmaxOpenCLKernel::Init() {
|
||||
std::string kernel_name = "SoftMax";
|
||||
std::string program_name = "SoftMax";
|
||||
std::string source = softmax_source_fp32;
|
||||
runtime_ = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
|
||||
if (in_tensors_[0]->shape().size() == 4 && parameter_->axis_ == 3) {
|
||||
// support 4d tensor
|
||||
onexone_flag_ = false;
|
||||
} else if (in_tensors_[0]->shape().size() == 2 && parameter_->axis_ == 1) {
|
||||
// support 2d tensor
|
||||
kernel_name += "1x1";
|
||||
program_name += "1x1";
|
||||
source = softmax1x1_source_fp32;
|
||||
onexone_flag_ = true;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Init `Softmax` kernel failed: Unsupported axis: " << parameter_->axis_;
|
||||
}
|
||||
#ifdef PROGRAM_WITH_IL
|
||||
runtime_->CreateKernelFromIL(kernel_(), kernel_name);
|
||||
#else
|
||||
if (mem_type_ == MEM_TYPE::BUF) {
|
||||
kernel_name += "_BUF";
|
||||
program_name += "_BUF";
|
||||
} else {
|
||||
kernel_name += "_IMG";
|
||||
program_name += "_IMG";
|
||||
}
|
||||
std::set<std::string> build_options;
|
||||
runtime_->LoadSource(program_name, source);
|
||||
out_tensors_[0]->SetFormat(schema::Format_NHWC4);
|
||||
runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options);
|
||||
#endif
|
||||
MS_LOG(DEBUG) << kernel_name << " Init Done!";
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
int SoftmaxOpenCLKernel::Run() {
|
||||
MS_LOG(DEBUG) << this->name() << " Running!";
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
auto allocator = ocl_runtime->GetAllocator();
|
||||
std::cout << "run" << std::endl;
|
||||
|
||||
// global and local workers
|
||||
const uint32_t grid_x = in_tensors_[0]->shape()[2]; // W
|
||||
const uint32_t grid_y = in_tensors_[0]->shape()[1]; // H
|
||||
const uint32_t grid_z = 1;
|
||||
std::vector<size_t> global = {grid_x, grid_y, grid_z};
|
||||
std::vector<size_t> local = {1, 1, 1};
|
||||
|
||||
// input and output
|
||||
cl::Buffer *input = reinterpret_cast<cl::Buffer *>(allocator->GetDeviceBuffer(in_tensors_[0]->Data()));
|
||||
cl::Buffer *output = reinterpret_cast<cl::Buffer *>(allocator->GetDeviceBuffer(out_tensors_[0]->Data()));
|
||||
cl_int4 input_size = {in_tensors_[0]->shape()[0], in_tensors_[0]->shape()[1], in_tensors_[0]->shape()[2],
|
||||
in_tensors_[0]->shape()[3]};
|
||||
// attribute
|
||||
int arg_idx = 0;
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, *input);
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, *output);
|
||||
ocl_runtime->SetKernelArg(kernel_, arg_idx++, input_size);
|
||||
if (onexone_flag_) {
|
||||
int channel_size = in_tensors_[0]->shape()[1];
|
||||
int slices = UP_DIV(channel_size, C4NUM);
|
||||
cl_int slices_x32 = UP_DIV(slices, 32);
|
||||
auto mask_ = GetMaskForLastChannel(channel_size);
|
||||
cl_float4 mask = {mask_[0], mask_[1], mask_[2], mask_[3]};
|
||||
|
||||
runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data());
|
||||
runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data());
|
||||
runtime_->SetKernelArg(kernel_, arg_idx++, mask);
|
||||
runtime_->SetKernelArg(kernel_, arg_idx++, slices);
|
||||
runtime_->SetKernelArg(kernel_, arg_idx, slices_x32);
|
||||
SetWorkGroupSize1x1();
|
||||
} else {
|
||||
int slices = UP_DIV(out_tensors_[0]->Channel(), C4NUM);
|
||||
cl_int4 input_shape = {in_tensors_[0]->Height(), in_tensors_[0]->Width(), in_tensors_[0]->Channel(), slices};
|
||||
|
||||
runtime_->SetKernelArg(kernel_, arg_idx++, in_tensors_[0]->Data());
|
||||
runtime_->SetKernelArg(kernel_, arg_idx++, out_tensors_[0]->Data());
|
||||
runtime_->SetKernelArg(kernel_, arg_idx, input_shape);
|
||||
SetWorkGroupSize();
|
||||
}
|
||||
|
||||
// run opengl kernel
|
||||
ocl_runtime->RunKernel(kernel_, global, local, nullptr);
|
||||
|
||||
return 0;
|
||||
runtime_->RunKernel(kernel_, global_size_, local_size_, nullptr);
|
||||
return lite::RET_OK;
|
||||
}
|
||||
|
||||
kernel::LiteKernel *OpenCLSoftMaxKernelCreator(const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
|
@ -104,5 +178,4 @@ kernel::LiteKernel *OpenCLSoftMaxKernelCreator(const std::vector<lite::tensor::T
|
|||
}
|
||||
|
||||
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_SoftMax, OpenCLSoftMaxKernelCreator)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -23,29 +23,37 @@
|
|||
#include "src/runtime/kernel/arm/nnacl/fp32/softmax.h"
|
||||
#include "src/runtime/opencl/opencl_runtime.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class SoftmaxOpenCLKernel : public LiteKernel {
|
||||
namespace mindspore::kernel {
|
||||
|
||||
class SoftmaxOpenCLKernel : public OpenCLKernel {
|
||||
public:
|
||||
explicit SoftmaxOpenCLKernel(OpParameter *parameter,
|
||||
const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
explicit SoftmaxOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
|
||||
const std::vector<lite::tensor::Tensor *> &outputs)
|
||||
: LiteKernel(parameter, inputs, outputs, nullptr, nullptr) {
|
||||
: OpenCLKernel(parameter, inputs, outputs) {
|
||||
parameter_ = reinterpret_cast<SoftmaxParameter *>(parameter);
|
||||
}
|
||||
~SoftmaxOpenCLKernel() override{};
|
||||
|
||||
~SoftmaxOpenCLKernel() override{};
|
||||
int Init() override;
|
||||
int ReSize() override;
|
||||
int Run() override;
|
||||
int InitBuffer();
|
||||
int GetImageSize(size_t idx, std::vector<size_t> *img_size) override;
|
||||
|
||||
int InitGlobalSize();
|
||||
int SetWorkGroupSize1x1();
|
||||
int SetWorkGroupSize();
|
||||
std::vector<float> GetMaskForLastChannel(int channels);
|
||||
|
||||
private:
|
||||
SoftmaxParameter *parameter_;
|
||||
cl::Kernel kernel_;
|
||||
SoftmaxParameter *parameter_;
|
||||
lite::opencl::OpenCLRuntime *runtime_;
|
||||
enum class MEM_TYPE { BUF, IMG } mem_type_{MEM_TYPE::IMG};
|
||||
|
||||
bool onexone_flag_{false};
|
||||
std::vector<size_t> local_size_;
|
||||
std::vector<size_t> global_size_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_SOFTMAX_H_
|
||||
|
||||
|
|
|
@ -175,4 +175,3 @@ std::string CLErrorCode(cl_int error_code) {
|
|||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -85,4 +85,3 @@ std::string CLErrorCode(cl_int error_code);
|
|||
} // namespace mindspore::kernel
|
||||
|
||||
#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_UTILS_H_
|
||||
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include "utils/log_adapter.h"
|
||||
#include "src/runtime/opencl/opencl_runtime.h"
|
||||
#include "include/errorcode.h"
|
||||
#include "src/runtime/kernel/opencl/utils.h"
|
||||
|
||||
namespace mindspore::lite::opencl {
|
||||
|
||||
|
@ -128,7 +129,7 @@ void *OpenCLAllocator::Malloc(size_t size, const std::vector<size_t>& img_size)
|
|||
cl::Image2D *buffer = new cl::Image2D(*ocl_runtime->Context(), CL_MEM_READ_WRITE, image_format,
|
||||
img_size[0], img_size[1], 0, nullptr, &ret);
|
||||
if (ret != CL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "Create OpenCL Image2D failed! (ERROR CODE: " << ret << ")";
|
||||
MS_LOG(ERROR) << "Create OpenCL Image2D failed!" << kernel::CLErrorCode(ret);
|
||||
UnLock();
|
||||
delete buffer;
|
||||
return nullptr;
|
||||
|
@ -187,7 +188,7 @@ void *OpenCLAllocator::CreateImageFromHost(void *data, size_t size, const std::v
|
|||
cl::Image2D *buffer = new cl::Image2D(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR,
|
||||
image_format, img_size[0], img_size[1], 0, data, &ret);
|
||||
if (ret != CL_SUCCESS) {
|
||||
MS_LOG(ERROR) << "Create OpenCL Image2D failed! (ERROR CODE: " << ret << ")";
|
||||
MS_LOG(ERROR) << "Create OpenCL Image2D failed - " << kernel::CLErrorCode(ret);
|
||||
UnLock();
|
||||
delete buffer;
|
||||
return nullptr;
|
||||
|
|
|
@ -52,6 +52,7 @@ int OpenCLExecutor::Run(std::vector<tensor::Tensor *> &inputs, std::vector<tenso
|
|||
std::vector<size_t> img_size;
|
||||
op_kernel->GetImageSize(i, &img_size);
|
||||
auto data_ptr = op_allocator->Malloc(output->Size(), img_size);
|
||||
|
||||
output->SetData(data_ptr);
|
||||
} else {
|
||||
output->MallocData(allocator);
|
||||
|
@ -109,7 +110,7 @@ int OpenCLExecutor::TransformTensorLayout(tensor::Tensor *tensor, schema::Format
|
|||
case kNumberTypeFloat32:
|
||||
return TransformTensorLayoutFp32(tensor, src_format, dst_format, trans_dir);
|
||||
default:
|
||||
MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
|
||||
MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
|
||||
<< schema::EnumNameFormat(dst_format);
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -160,7 +161,7 @@ int OpenCLExecutor::TransformTensorLayoutToBuffer(tensor::Tensor *tensor, schema
|
|||
// TODO(wandongdong): add support !!
|
||||
return RET_OK;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
|
||||
MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
|
||||
<< schema::EnumNameFormat(dst_format) << " in float32";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -194,7 +195,7 @@ int OpenCLExecutor::TransformTensorLayoutToImage(tensor::Tensor *tensor, schema:
|
|||
allocator_->Free(src_data);
|
||||
return RET_OK;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
|
||||
MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
|
||||
<< schema::EnumNameFormat(dst_format) << " in float32";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -216,7 +217,7 @@ int OpenCLExecutor::TransformTensorLayoutFromImage(tensor::Tensor *tensor, schem
|
|||
allocator_->Free(src_data);
|
||||
return RET_OK;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
|
||||
MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
|
||||
<< schema::EnumNameFormat(dst_format) << " in float32";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
@ -228,7 +229,7 @@ int OpenCLExecutor::TransformTensorLayoutUint8(tensor::Tensor *tensor, schema::F
|
|||
MS_ASSERT(4 == tensor->shape().size());
|
||||
// auto src_format = tensor->GetFormat();
|
||||
// todo
|
||||
MS_LOG(ERROR) << "Unsupport layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
|
||||
MS_LOG(ERROR) << "Unsupported layout transform: " << schema::EnumNameFormat(tensor->GetFormat()) << " to "
|
||||
<< schema::EnumNameFormat(dst_format) << " in uint8";
|
||||
return RET_ERROR;
|
||||
}
|
||||
|
|
|
@ -17,76 +17,90 @@
|
|||
#include <memory>
|
||||
#include "mindspore/core/utils/log_adapter.h"
|
||||
#include "common/common_test.h"
|
||||
#include "mindspore/lite/src/common/file_utils.h"
|
||||
#include "mindspore/lite/src/runtime/opencl/opencl_runtime.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h"
|
||||
#include "mindspore/lite/src/runtime/kernel/opencl/kernel/softmax.h"
|
||||
#include "mindspore/lite/test/ut/src/runtime/kernel/opencl/utils_tests.h"
|
||||
|
||||
namespace mindspore {
|
||||
|
||||
class TestSoftmaxOpenCL : public mindspore::CommonTest {};
|
||||
|
||||
void InitSoftaxParam(SoftmaxParameter *param) { param->axis_ = -1; }
|
||||
|
||||
TEST_F(TestSoftmaxOpenCL, SoftmaxFp32) {
|
||||
std::cout << "======" << std::endl;
|
||||
MS_LOG(INFO) << "start TEST_F TestSoftmaxOpenCL";
|
||||
void RunTestCase(std::vector<int> input_shape, std::vector<int> output_shape, std::string input_file,
|
||||
std::string expect_file, SoftmaxParameter *param, schema::Format format) {
|
||||
std::cout << "runtime" << std::endl;
|
||||
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
|
||||
ocl_runtime->Init();
|
||||
auto allocator = ocl_runtime->GetAllocator();
|
||||
|
||||
MS_LOG(INFO) << "create SoftmaxParameter";
|
||||
auto param = new SoftmaxParameter();
|
||||
InitSoftaxParam(param);
|
||||
// define tensor
|
||||
MS_LOG(INFO) << "defineTensor";
|
||||
std::cout << "defineTensor" << std::endl;
|
||||
|
||||
MS_LOG(INFO) << "create Tensors";
|
||||
std::vector<int> shape_in = {1, 2, 2, 1};
|
||||
std::vector<int> shape_out = {1, 2, 2, 1};
|
||||
auto data_type = kNumberTypeFloat32;
|
||||
auto tensorType = schema::NodeType_ValueNode;
|
||||
lite::tensor::Tensor *tensor_in = new lite::tensor::Tensor(data_type, shape_in, schema::Format_NCHW, tensorType);
|
||||
lite::tensor::Tensor *tensor_out = new lite::tensor::Tensor(data_type, shape_out, schema::Format_NCHW, tensorType);
|
||||
std::vector<lite::tensor::Tensor *> inputs{tensor_in};
|
||||
std::vector<lite::tensor::Tensor *> outputs{tensor_out};
|
||||
auto input_tensor = new lite::tensor::Tensor(data_type, input_shape, format, tensorType);
|
||||
auto output_tensor = new lite::tensor::Tensor(data_type, output_shape, format, tensorType);
|
||||
std::vector<lite::tensor::Tensor *> inputs{input_tensor};
|
||||
std::vector<lite::tensor::Tensor *> outputs{output_tensor};
|
||||
|
||||
MS_LOG(INFO) << "create OpenCL Kernel";
|
||||
auto *Softmax_kernel = new kernel::SoftmaxOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
|
||||
Softmax_kernel->Init();
|
||||
std::vector<kernel::LiteKernel *> kernels{Softmax_kernel};
|
||||
// run
|
||||
MS_LOG(INFO) << "NewOpenCLKernel";
|
||||
std::cout << "NewOpenCLKernel" << std::endl;
|
||||
auto *kernel = new kernel::SoftmaxOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
|
||||
MS_LOG(INFO) << "KernelInit";
|
||||
std::cout << "KernelInit" << std::endl;
|
||||
kernel->Init();
|
||||
|
||||
MS_LOG(INFO) << "create SubGraphOpenCLKernel";
|
||||
std::cout << "LiteKernel" << std::endl;
|
||||
std::vector<kernel::LiteKernel *> kernels{kernel};
|
||||
inputs[0]->MallocData(allocator);
|
||||
std::cout << "SubGraphOpenCLKernel" << std::endl;
|
||||
auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
|
||||
MS_LOG(INFO) << "pGraphinit";
|
||||
pGraph->Init();
|
||||
|
||||
MS_LOG(INFO) << "initialize data";
|
||||
std::vector<lite::tensor::Tensor *> tensor_map = {tensor_in};
|
||||
for (auto &tensor_file : tensor_map) {
|
||||
auto tensor = tensor_file;
|
||||
size_t size = tensor->Size();
|
||||
const float data[4] = {std::log(1.0f), std::log(2.0f), std::log(3.0f), std::log(4.0f)};
|
||||
memcpy(tensor->Data(), data, size);
|
||||
}
|
||||
// load data
|
||||
MS_LOG(INFO) << "load data1";
|
||||
|
||||
MS_LOG(INFO) << "pGraph->Run()";
|
||||
LoadTestData(input_tensor->Data(), input_tensor->Size(), input_file);
|
||||
auto *input_data = reinterpret_cast<float *>(input_tensor->Data());
|
||||
printf("\ninput[0:10]:");
|
||||
for (int i = 0; i < 10; i++) {
|
||||
printf("[%d]:%.3f ", i, input_data[i]);
|
||||
}
|
||||
printf("\n\n");
|
||||
|
||||
MS_LOG(INFO) << "Run";
|
||||
pGraph->Run();
|
||||
|
||||
MS_LOG(INFO) << "==================output data=================";
|
||||
float *output_data = reinterpret_cast<float *>(tensor_out->Data());
|
||||
size_t output_size = tensor_out->Size();
|
||||
|
||||
printf("output:");
|
||||
for (int i = 0; i < 4; i++) {
|
||||
printf("%.3f ", output_data[i]);
|
||||
}
|
||||
printf("\n");
|
||||
float expect[4] = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
|
||||
for (int i = 0; i < tensor_out->ElementsNum(); ++i) {
|
||||
if (std::fabs(output_data[i] - expect[i]) > 1e-5) {
|
||||
printf("idx[%d] except=%.3f output=%.3f .", i, expect[i], output_data[i]);
|
||||
}
|
||||
}
|
||||
printf("\nTest all close OK for %zu!\n", output_size);
|
||||
lite::CompareOutputData(output_data, expect, 4);
|
||||
MS_LOG(INFO) << "compare result";
|
||||
std::cout << "compare result" << std::endl;
|
||||
CompareOutput(output_tensor, expect_file);
|
||||
}
|
||||
|
||||
TEST_F(TestSoftmaxOpenCL, Softmax_1) {
|
||||
std::vector<int> input_shape = {1, 2, 2, 8};
|
||||
std::vector<int> output_shape = {1, 2, 2, 8};
|
||||
std::string input_file = "softmax_in.bin";
|
||||
std::string expect_file = "softmax_out.bin";
|
||||
auto param = new SoftmaxParameter;
|
||||
param->axis_ = 3;
|
||||
schema::Format format = schema::Format_NHWC4;
|
||||
|
||||
RunTestCase(input_shape, output_shape, input_file, expect_file, param, format);
|
||||
}
|
||||
|
||||
// TEST_F(TestSoftmaxOpenCL, Softmax_1x1) {
|
||||
// std::vector<int> input_shape = {1, 100};
|
||||
// std::vector<int> output_shape = {1, 100};
|
||||
// std::string input_file = "softmax1x1_in.bin";
|
||||
// std::string expect_file = "softmax1x1_out.bin";
|
||||
// auto param = new SoftmaxParameter;
|
||||
// param->axis_ = 1;
|
||||
// schema::Format format = schema::Format_NHWC4;
|
||||
//
|
||||
// RunTestCase(input_shape, output_shape, input_file, expect_file, param, format);
|
||||
//}
|
||||
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -40,13 +40,13 @@ void CompareOutput(lite::tensor::Tensor *output_tensor, const std::string &file_
|
|||
size_t output_size = output_tensor->Size();
|
||||
float *expect_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &output_size));
|
||||
|
||||
printf("output[0:10]:");
|
||||
for (int i = 0; i < 10; i++) {
|
||||
printf("output[0:12]:");
|
||||
for (int i = 0; i < 12; i++) {
|
||||
printf("[%d]:%.3f ", i, output_data[i]);
|
||||
}
|
||||
printf("\n");
|
||||
printf("expect[0:10]:");
|
||||
for (int i = 0; i < 10; i++) {
|
||||
printf("expect[0:12]:");
|
||||
for (int i = 0; i < 12; i++) {
|
||||
printf("[%d]:%.3f ", i, expect_data[i]);
|
||||
}
|
||||
printf("\n");
|
||||
|
|
|
@ -157,7 +157,7 @@ int WeightFormatPass::ShapeFormatTrans(GraphNode *graphNode) {
|
|||
} else if (opType == schema::PrimitiveType_DeConv2D) {
|
||||
weightTensor->format = schema::Format_CHWK;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "unsupport format";
|
||||
MS_LOG(ERROR) << "Unsupported format";
|
||||
return -1;
|
||||
}
|
||||
} break;
|
||||
|
|
|
@ -184,7 +184,7 @@ size_t GetDataTypeSize(const TypeId &data_type) {
|
|||
return sizeof(int64_t);
|
||||
default:
|
||||
MS_LOG(ERROR) << data_type;
|
||||
MS_LOG(ERROR) << "unsupport datatype";
|
||||
MS_LOG(ERROR) << "Unsupported datatype";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue