add int data support for opencl
This commit is contained in:
parent
ffa92acdef
commit
5bcf605b45
|
@ -11,23 +11,28 @@
|
|||
__constant sampler_t smp_zero = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
|
||||
__kernel void argminmax(__global FLT *src_data, __global FLT *dst_data, __global FLT *buf, __global int *ids,
|
||||
int4 shape, int4 src_size, int4 cus_size, int4 strides, int4 flags) {
|
||||
int X = get_global_id(0); // reduce len
|
||||
int X = get_global_id(0); // lower reduce stride
|
||||
int Y = get_global_id(1); // upper axis accumulation
|
||||
if (X >= src_size.x || Y >= src_size.y) {
|
||||
return;
|
||||
}
|
||||
int offset = X + Y * src_size.z;
|
||||
int align_c4 = (flags.z != 3) ? (X / shape.w) * (shape.x) : 0;
|
||||
int align_c4 = (flags.z != 3) ? (X / shape.w) * (C4NUM - shape.w & 0x00000003) : 0;
|
||||
int align_in = 0;
|
||||
int align_out = 0;
|
||||
bool keep_dims = cus_size.y;
|
||||
int width = shape.z * shape.w;
|
||||
if (flags.z == 3) {
|
||||
align_in = (Y / shape.z) * cus_size.z;
|
||||
align_out = (Y / shape.z) * cus_size.w;
|
||||
}
|
||||
if (flags.z == 0) {
|
||||
align_in = X / (shape.y) * cus_size.z;
|
||||
align_in = X / (width)*cus_size.z;
|
||||
align_out = align_in;
|
||||
}
|
||||
if (flags.z == 2 && !keep_dims) {
|
||||
align_out = (Y / shape.y) * cus_size.w;
|
||||
}
|
||||
for (int k = 0; k < src_size.w; ++k) {
|
||||
int idx0 = (X + k * strides.x) + Y * strides.y + (align_c4 + align_in);
|
||||
int idx1 = offset + k * src_size.x;
|
||||
|
|
|
@ -61,8 +61,6 @@ void ArgMinMaxOpenCLKernel::SetConstArgs() {
|
|||
auto param = reinterpret_cast<ArgMinMaxParameter *>(op_parameter_);
|
||||
cl_int4 in_shape{static_cast<int>(im_in_.N), static_cast<int>(im_in_.H), static_cast<int>(im_in_.W),
|
||||
static_cast<int>(im_in_.C)};
|
||||
in_shape.s[0] = UP_ROUND(im_in_.C, C4NUM) - im_in_.C;
|
||||
in_shape.s[1] = im_in_.W * im_in_.C;
|
||||
cl_int4 flags = {param->out_value_, param->get_max_, param->axis_, param->topk_};
|
||||
int arg_cnt = 2;
|
||||
ocl_runtime_->SetKernelArg(kernel_, arg_cnt++, buff_, lite::opencl::MemType::BUF);
|
||||
|
@ -77,17 +75,20 @@ void ArgMinMaxOpenCLKernel::SetConstArgs() {
|
|||
void ArgMinMaxOpenCLKernel::SetGlobalLocal() {
|
||||
auto param = reinterpret_cast<ArgMinMaxParameter *>(op_parameter_);
|
||||
im_in_ = GpuTensorInfo(in_tensors_[0]);
|
||||
im_out_ = GpuTensorInfo(out_tensors_[0]);
|
||||
std::vector<size_t> in_shape = {im_in_.N, im_in_.H, im_in_.W, im_in_.C};
|
||||
auto in_shape_align = in_shape;
|
||||
in_shape_align[3] = UP_ROUND(in_shape[3], C4NUM);
|
||||
auto out_shape_align = in_shape_align;
|
||||
out_shape_align.at(param->axis_) = param->axis_ == 3 ? UP_ROUND(param->topk_, C4NUM) : param->topk_;
|
||||
std::vector<size_t> out_shape = {im_out_.N, im_out_.H, im_out_.W, im_out_.C};
|
||||
auto out_shape_align = out_shape;
|
||||
out_shape_align[3] = UP_ROUND(out_shape[3], C4NUM);
|
||||
int reduce_len = GetUpPow2(in_shape.at(param->axis_));
|
||||
int dtype_size = in_tensors_[0]->data_type() == kNumberTypeFloat16 ? sizeof(int16_t) : sizeof(float);
|
||||
cus_size_ = {reduce_len, static_cast<int>(im_in_.RowPitch() / dtype_size), 1, 1};
|
||||
cus_size_.s[2] = UP_ROUND(im_in_.width * C4NUM, cus_size_.s[1]) - im_in_.width * C4NUM;
|
||||
cus_size_.s[3] = im_in_.W * UP_ROUND(param->topk_, C4NUM);
|
||||
cus_size_.s[3] = UP_ROUND(cus_size_.s[3], cus_size_.s[1]) - cus_size_.s[3];
|
||||
int in_pitch = im_in_.RowPitch() / dtype_size;
|
||||
int out_pitch = im_out_.RowPitch() / dtype_size;
|
||||
cus_size_ = {reduce_len, param->keep_dims_, 1, 1};
|
||||
cus_size_.s[2] = in_pitch - im_in_.width * C4NUM;
|
||||
cus_size_.s[3] = out_pitch - im_out_.width * C4NUM;
|
||||
src_size_ = {std::accumulate(in_shape.begin() + param->axis_ + 1, in_shape.end(), 1, std::multiplies<int>()),
|
||||
std::accumulate(in_shape.begin(), in_shape.begin() + param->axis_, 1, std::multiplies<int>()),
|
||||
std::accumulate(in_shape.begin() + param->axis_, in_shape.end(), 1, std::multiplies<int>()),
|
||||
|
@ -100,22 +101,25 @@ void ArgMinMaxOpenCLKernel::SetGlobalLocal() {
|
|||
};
|
||||
switch (param->axis_) {
|
||||
case 0:
|
||||
strides_.s[0] = UP_ROUND(strides_.s[0] / im_in_.H, cus_size_.s[1]) * im_in_.H;
|
||||
strides_.s[0] = UP_ROUND(strides_.s[0] / im_in_.H, in_pitch) * im_in_.H;
|
||||
strides_.s[1] = strides_.s[0] * im_in_.N;
|
||||
strides_.s[2] = UP_ROUND(strides_.s[2] / im_in_.H, cus_size_.s[1]) * im_in_.H;
|
||||
strides_.s[2] = UP_ROUND(strides_.s[2] / im_in_.H, out_pitch) * im_in_.H;
|
||||
strides_.s[3] = strides_.s[2] * param->topk_;
|
||||
break;
|
||||
case 1:
|
||||
strides_.s[0] = UP_ROUND(strides_.s[0], cus_size_.s[1]);
|
||||
strides_.s[1] = UP_ROUND(strides_.s[1] / im_in_.H, cus_size_.s[1]) * im_in_.H;
|
||||
strides_.s[2] = UP_ROUND(strides_.s[2], cus_size_.s[1]);
|
||||
strides_.s[3] = UP_ROUND(strides_.s[3] / param->topk_, cus_size_.s[1]) * param->topk_;
|
||||
strides_.s[0] = UP_ROUND(strides_.s[0], in_pitch);
|
||||
strides_.s[1] = UP_ROUND(strides_.s[1] / im_in_.H, in_pitch) * im_in_.H;
|
||||
// org dim(4,3) org axis(1,0)
|
||||
strides_.s[2] = UP_ROUND(strides_.s[2], out_pitch);
|
||||
strides_.s[3] = UP_ROUND(strides_.s[3] / param->topk_, out_pitch) * param->topk_;
|
||||
break;
|
||||
case 2:
|
||||
strides_.s[1] = UP_ROUND(strides_.s[1], cus_size_.s[1]);
|
||||
strides_.s[3] = UP_ROUND(strides_.s[3], cus_size_.s[1]);
|
||||
strides_.s[1] = UP_ROUND(strides_.s[1], in_pitch);
|
||||
// org dim(4,3,2) org axis(2,1,0)
|
||||
strides_.s[3] = param->keep_dims_ ? UP_ROUND(strides_.s[3], out_pitch) : strides_.s[2];
|
||||
break;
|
||||
default: // 3
|
||||
// org dim(4,3,2,1) org axis(3,2,1,0)
|
||||
break;
|
||||
}
|
||||
local_size_ = {1, 1, 1};
|
||||
|
@ -147,8 +151,10 @@ int ArgMinMaxOpenCLKernel::Prepare() {
|
|||
auto *param = reinterpret_cast<ArgMinMaxParameter *>(this->op_parameter_);
|
||||
param->dims_size_ = in_tensors_[0]->shape().size();
|
||||
param->axis_ = (param->axis_ + param->dims_size_) % param->dims_size_;
|
||||
param->axis_ = (4 - param->dims_size_) + param->axis_;
|
||||
param->axis_ = GetBroadcastGpuAxis(param->dims_size_, param->axis_);
|
||||
param->get_max_ = (Type() == PrimitiveType_ArgMax);
|
||||
param->keep_dims_ =
|
||||
param->keep_dims_ || param->topk_ > 1 || in_tensors_[0]->shape().size() == out_tensors_[0]->shape().size();
|
||||
|
||||
InitWeights();
|
||||
SetGlobalLocal();
|
||||
|
|
|
@ -44,6 +44,7 @@ class ArgMinMaxOpenCLKernel : public OpenCLKernel {
|
|||
void *buff_{nullptr};
|
||||
void *ids_{nullptr};
|
||||
GpuTensorInfo im_in_{GpuTensorInfo(nullptr)};
|
||||
GpuTensorInfo im_out_{GpuTensorInfo(nullptr)};
|
||||
cl_int4 src_size_;
|
||||
cl_int4 cus_size_;
|
||||
cl_int4 strides_;
|
||||
|
|
|
@ -105,6 +105,7 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() {
|
|||
auto allocator = ocl_runtime_->GetAllocator();
|
||||
bool is_fp16 = ocl_runtime_->GetFp16Enable();
|
||||
|
||||
size_t dtype_size = is_fp16 ? sizeof(int16_t) : sizeof(float);
|
||||
auto out_info = GpuTensorInfo(out_tensors_[0]);
|
||||
// weight: o, h, w, i; o == group, i == 1
|
||||
void *origin_weight = in_tensors_.at(kWeightIndex)->data_c();
|
||||
|
@ -121,7 +122,7 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() {
|
|||
size_t img_dtype = ocl_runtime_->GetFp16Enable() ? CL_HALF_FLOAT : CL_FLOAT;
|
||||
img_size = {(size_t)plane_out / C4NUM, (size_t)out_info.N * CO4, img_dtype};
|
||||
}
|
||||
pack_weight_size = is_fp16 ? pack_weight_size * sizeof(int16_t) : pack_weight_size * sizeof(float);
|
||||
pack_weight_size = pack_weight_size * dtype_size;
|
||||
auto ConvertFilter = [](void *src, void *dst, TypeId src_type, TypeId dst_type, size_t plane_in, size_t plane_out,
|
||||
size_t channel) {
|
||||
if (dst_type == kNumberTypeFloat16) {
|
||||
|
@ -173,18 +174,14 @@ int DepthwiseConv2dOpenCLKernel::InitWeights() {
|
|||
memcpy(dst, src, size * dtype_size);
|
||||
}
|
||||
};
|
||||
size_t dtype_size = sizeof(float);
|
||||
if (is_fp16 && in_tensors_.at(kBiasIndex)->data_type() == kNumberTypeFloat16) {
|
||||
dtype_size = sizeof(int16_t);
|
||||
}
|
||||
std::vector<char> temp_bias(pack_weight_size, 0);
|
||||
size_t bias_size = C4NUM * CO4 * dtype_size;
|
||||
std::vector<char> temp_bias(bias_size, 0);
|
||||
if (in_tensors_.size() == 3) {
|
||||
src_type = in_tensors_.at(kBiasIndex)->data_type();
|
||||
dst_type = is_fp16 ? kNumberTypeFloat16 : kNumberTypeFloat32;
|
||||
auto element_size = in_tensors_.at(kBiasIndex)->ElementsNum();
|
||||
ConvertBias(in_tensors_.at(kBiasIndex)->data_c(), temp_bias.data(), element_size, dtype_size, src_type, dst_type);
|
||||
}
|
||||
size_t bias_size = C4NUM * CO4 * dtype_size;
|
||||
bias_data_ = allocator->Malloc(bias_size, {}, temp_bias.data());
|
||||
if (bias_data_ == nullptr) {
|
||||
return RET_ERROR;
|
||||
|
|
|
@ -538,7 +538,7 @@ int TryMergeEltwiseEltwise(LiteKernel *node, std::vector<LiteKernel *> *nodes, s
|
|||
|
||||
} // namespace
|
||||
|
||||
void OpenCLSubGraph::Fusion() {
|
||||
int OpenCLSubGraph::FusionPass() {
|
||||
MS_LOG(DEBUG) << "start Fusion";
|
||||
|
||||
std::vector<LiteKernel *> input_nodes;
|
||||
|
@ -657,6 +657,7 @@ void OpenCLSubGraph::Fusion() {
|
|||
std::remove_if(nodes_.begin(), nodes_.end(), [&](LiteKernel *node) { return AIsInB(node, &removed_set); }),
|
||||
nodes_.end());
|
||||
MS_LOG(DEBUG) << "number of kernels(after fusion) : " << nodes_.size();
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -16,6 +16,8 @@
|
|||
|
||||
#include "src/runtime/kernel/opencl/opencl_subgraph.h"
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include "src/runtime/opencl/opencl_executor.h"
|
||||
#include "src/runtime/kernel/opencl/utils.h"
|
||||
#include "include/errorcode.h"
|
||||
|
@ -189,19 +191,7 @@ int OpenCLSubGraph::GenToFormatOp(const std::vector<lite::Tensor *> &in_tensors,
|
|||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
int OpenCLSubGraph::Init() {
|
||||
allocator_ = ocl_runtime_->GetAllocator();
|
||||
MS_LOG(DEBUG) << "input num=" << in_tensors_.size() << ", output num=" << out_tensors_.size();
|
||||
for (const auto tensor : in_tensors_) {
|
||||
MS_ASSERT(tensor);
|
||||
tensor->set_allocator(allocator_);
|
||||
}
|
||||
for (const auto tensor : out_tensors_) {
|
||||
MS_ASSERT(tensor);
|
||||
tensor->set_allocator(allocator_);
|
||||
}
|
||||
|
||||
int OpenCLSubGraph::InsertOpsPass() {
|
||||
GetInOutNodes();
|
||||
|
||||
std::vector<std::vector<kernel::LiteKernel *>> from_kernels_;
|
||||
|
@ -222,12 +212,34 @@ int OpenCLSubGraph::Init() {
|
|||
}
|
||||
nodes_.insert(nodes_.end(), out_convert_ops_.begin(), out_convert_ops_.end());
|
||||
GetInOutNodes();
|
||||
UpdateTensorDataType();
|
||||
Fusion();
|
||||
return RET_OK;
|
||||
}
|
||||
int OpenCLSubGraph::Init() {
|
||||
allocator_ = ocl_runtime_->GetAllocator();
|
||||
MS_LOG(DEBUG) << "input num=" << in_tensors_.size() << ", output num=" << out_tensors_.size();
|
||||
for (const auto tensor : in_tensors_) {
|
||||
MS_ASSERT(tensor);
|
||||
tensor->set_allocator(allocator_);
|
||||
}
|
||||
for (const auto tensor : out_tensors_) {
|
||||
MS_ASSERT(tensor);
|
||||
tensor->set_allocator(allocator_);
|
||||
}
|
||||
std::map<std::string, std::function<int(void)>> pass_manager{
|
||||
{"InsertOpsPass", std::bind(&OpenCLSubGraph::InsertOpsPass, this)},
|
||||
{"UpdateTensorDataTypePass", std::bind(&OpenCLSubGraph::UpdateTensorDataTypePass, this)},
|
||||
{"FusionPass", std::bind(&OpenCLSubGraph::FusionPass, this)}};
|
||||
for (auto iv : pass_manager) {
|
||||
auto ret = iv.second();
|
||||
if (ret != RET_OK) {
|
||||
MS_LOG(ERROR) << "Run Pass: " << iv.first << " failed.";
|
||||
return RET_ERROR;
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void OpenCLSubGraph::UpdateTensorDataType() {
|
||||
int OpenCLSubGraph::UpdateTensorDataTypePass() {
|
||||
bool is_fp16 = ocl_runtime_->GetFp16Enable();
|
||||
MS_ASSERT(in_tensors_[0]);
|
||||
if (is_fp16 && (in_tensors_[0]->data_type() == kNumberTypeFloat32)) {
|
||||
|
@ -245,6 +257,7 @@ void OpenCLSubGraph::UpdateTensorDataType() {
|
|||
}
|
||||
}
|
||||
}
|
||||
return RET_OK;
|
||||
}
|
||||
|
||||
void OpenCLSubGraph::GetKernelFromToTensor(const std::vector<lite::Tensor *> &in_tensors,
|
||||
|
|
|
@ -46,10 +46,11 @@ class OpenCLSubGraph : public SubGraphKernel {
|
|||
int ReSize() override;
|
||||
int Run() override;
|
||||
int Run(const KernelCallBack &before, const KernelCallBack &after) override { return this->Run(); };
|
||||
int InsertOpsPass();
|
||||
|
||||
private:
|
||||
void UnInit();
|
||||
void UpdateTensorDataType();
|
||||
int UpdateTensorDataTypePass();
|
||||
void ReplaceOutTensorAndKernelToNull(const std::vector<lite::Tensor *> &in_tensors,
|
||||
const std::vector<std::vector<kernel::LiteKernel *>> &in_kernels,
|
||||
lite::opencl::MemType mem_type);
|
||||
|
@ -64,7 +65,10 @@ class OpenCLSubGraph : public SubGraphKernel {
|
|||
void GetKernelFromToTensor(const std::vector<lite::Tensor *> &in_tensors,
|
||||
const std::vector<kernel::LiteKernel *> &in_kernels,
|
||||
std::vector<std::vector<kernel::LiteKernel *>> *out_kernels, bool is_from);
|
||||
void Fusion();
|
||||
int FusionPass();
|
||||
|
||||
public:
|
||||
using PassFunc = int (OpenCLSubGraph::*)(void);
|
||||
|
||||
private:
|
||||
lite::opencl::OpenCLAllocator *allocator_{nullptr};
|
||||
|
|
|
@ -330,4 +330,22 @@ std::vector<size_t> GetImage2dShapeFromNHWC(const std::vector<int> &tensor_shape
|
|||
}
|
||||
return {image_x, image_y};
|
||||
}
|
||||
int GetBroadcastGpuAxis(int ndim, int ori_axis) {
|
||||
if (ori_axis >= ndim) {
|
||||
return ndim - 1;
|
||||
}
|
||||
int axis = 0;
|
||||
if (ndim == 1) {
|
||||
axis = 3;
|
||||
} else if (ndim == 2) {
|
||||
axis = ori_axis == 0 ? 0 : 3;
|
||||
} else if (ndim == 3) {
|
||||
axis = ori_axis == 0 ? 0 : ori_axis == 1 ? 2 : 3;
|
||||
} else if (ndim == 4) {
|
||||
axis = ori_axis;
|
||||
} else if (ndim > 4) {
|
||||
MS_LOG(ERROR) << "GPU doesn't support ndim>=" << ndim;
|
||||
}
|
||||
return axis;
|
||||
}
|
||||
} // namespace mindspore::kernel
|
||||
|
|
|
@ -61,6 +61,8 @@ std::vector<int> GetNHWCShape(const std::vector<int> &tensor_shape);
|
|||
|
||||
std::vector<size_t> GetImage2dShapeFromNHWC(const std::vector<int> &tensor_shape, schema::Format format);
|
||||
|
||||
int GetBroadcastGpuAxis(int ndim, int ori_axis);
|
||||
|
||||
template <class T1, class T2>
|
||||
void PackNCHWToNC4HW4(void *src, void *dst, int batch, int plane_in, int plane_out, int channel,
|
||||
const std::function<T2(T1)> &to_dtype) {
|
||||
|
|
|
@ -185,7 +185,7 @@ TEST_F(TestOpenCL_ArgMinMax, axis3topk2value) {
|
|||
TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable);
|
||||
}
|
||||
}
|
||||
TEST_F(TestOpenCL_ArgMinMax, axis1topk1index) {
|
||||
TEST_F(TestOpenCL_ArgMinMax, dim32axis1topk1index) {
|
||||
schema::PrimitiveType type = schema::PrimitiveType_ArgMax;
|
||||
int axis = 1;
|
||||
int topk = 1;
|
||||
|
@ -200,4 +200,52 @@ TEST_F(TestOpenCL_ArgMinMax, axis1topk1index) {
|
|||
TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable, 1e-1, 1e-1, true);
|
||||
}
|
||||
}
|
||||
TEST_F(TestOpenCL_ArgMinMax, dim43axis2topk1index) {
|
||||
schema::PrimitiveType type = schema::PrimitiveType_ArgMax;
|
||||
int axis = 2;
|
||||
int topk = 1;
|
||||
bool out_value = false;
|
||||
std::vector<int> input_shape = {2, 2, 2, 14};
|
||||
std::vector<int> output_shape = {2, 2, 14};
|
||||
float input_data[] = {10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15,
|
||||
1, 50, 30, 45, 25, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50, 30, 10, 20, 30,
|
||||
40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25,
|
||||
50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 10, 20, 30, 40, 90, 20, 11, 15,
|
||||
1, 50, 30, 45, 25, 50, 30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25};
|
||||
float output_data[] = {1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0,
|
||||
1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0};
|
||||
for (auto fp16_enable : {false, true}) {
|
||||
auto *param = CreateParameter(type, axis, topk, out_value);
|
||||
TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable, 1e-1, 1e-1, true);
|
||||
}
|
||||
}
|
||||
TEST_F(TestOpenCL_ArgMinMax, dim21axis2topk1index) {
|
||||
schema::PrimitiveType type = schema::PrimitiveType_ArgMax;
|
||||
int axis = 0;
|
||||
int topk = 1;
|
||||
bool out_value = false;
|
||||
std::vector<int> input_shape = {2, 14};
|
||||
std::vector<int> output_shape = {14};
|
||||
float input_data[] = {10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50,
|
||||
30, 10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25};
|
||||
float output_data[] = {1, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0};
|
||||
for (auto fp16_enable : {false, true}) {
|
||||
auto *param = CreateParameter(type, axis, topk, out_value);
|
||||
TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable, 1e-1, 1e-1, true);
|
||||
}
|
||||
}
|
||||
TEST_F(TestOpenCL_ArgMinMax, dim10axis2topk1index) {
|
||||
schema::PrimitiveType type = schema::PrimitiveType_ArgMax;
|
||||
int axis = 0;
|
||||
int topk = 1;
|
||||
bool out_value = false;
|
||||
std::vector<int> input_shape = {14};
|
||||
std::vector<int> output_shape = {1};
|
||||
float input_data[] = {10, 20, 30, 40, 90, 20, 11, 15, 1, 50, 30, 45, 25, 50};
|
||||
float output_data[] = {4};
|
||||
for (auto fp16_enable : {false, true}) {
|
||||
auto *param = CreateParameter(type, axis, topk, out_value);
|
||||
TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable, 1e-1, 1e-1, true);
|
||||
}
|
||||
}
|
||||
} // namespace mindspore::lite::opencl::test
|
||||
|
|
Loading…
Reference in New Issue