fix argminmax bug when axis=3

This commit is contained in:
wandongdong 2021-02-26 01:18:42 -08:00
parent b730069942
commit 4f04337c7a
8 changed files with 168 additions and 19 deletions

View File

@ -16,15 +16,17 @@ __kernel void argminmax(__global FLT *src_data, __global FLT *dst_data, __global
if (X >= src_size.x || Y >= src_size.y) {
return;
}
int offset = X + Y * src_size.z;
int align_c4 = (flags.z != 3) ? (X / shape.w) * (C4NUM - shape.w & 0x00000003) : 0;
int align_in = 0;
int align_out = 0;
bool keep_dims = cus_size.y;
int width = shape.z * shape.w;
int offset = X + Y * src_size.z;
int align_c4_in = (flags.z != 3) ? (X / shape.w) * (C4NUM - shape.w & 0x00000003) : 0;
int align_c4_out =
(flags.z == 3 && flags.w == 1 && !keep_dims) ? (Y / shape.z) * (C4NUM - shape.z & 0x00000003) : align_c4_in;
int align_in = 0;
int align_out = 0;
if (flags.z == 3) {
align_in = (Y / shape.z) * cus_size.z;
align_out = (Y / shape.z) * cus_size.w;
align_out = (Y / ((flags.w > 1 || keep_dims) ? shape.z : shape.z * shape.y)) * cus_size.w;
}
if (flags.z == 0) {
align_in = X / (width)*cus_size.z;
@ -34,7 +36,7 @@ __kernel void argminmax(__global FLT *src_data, __global FLT *dst_data, __global
align_out = (Y / shape.y) * cus_size.w;
}
for (int k = 0; k < src_size.w; ++k) {
int idx0 = (X + k * strides.x) + Y * strides.y + (align_c4 + align_in);
int idx0 = (X + k * strides.x) + Y * strides.y + (align_c4_in + align_in);
int idx1 = offset + k * src_size.x;
ids[idx1] = k;
buf[idx1] = src_data[idx0];
@ -66,7 +68,7 @@ __kernel void argminmax(__global FLT *src_data, __global FLT *dst_data, __global
}
}
for (int k = 0; k < flags.w; ++k) {
int idx0 = (X + k * strides.z) + Y * strides.w + (align_c4 + align_out);
int idx0 = (X + k * strides.z) + Y * strides.w + (align_c4_out + align_out);
int idx1 = flags.y ? (offset + (src_size.w - k - 1) * src_size.x) : (offset + k * src_size.x);
if (flags.x) {
dst_data[idx0] = buf[idx1];

View File

@ -259,6 +259,6 @@ __kernel void transpose_general_NHWC4(__read_only image2d_t src_data, __write_on
result_tmp.c_array[i] = src_tmp.c_array[out_index[de_perm.w] % 4];
}
}
int CO4_SIZE = UP_DIV(in_shape.w, 4);
int CO4_SIZE = UP_DIV(out_shape.w, 4);
WRITE_IMAGE(dst_data, (int2)(Y * CO4_SIZE + Z, X), result_tmp.vector);
}

View File

@ -93,11 +93,13 @@ void ArgMinMaxOpenCLKernel::SetGlobalLocal() {
std::accumulate(in_shape.begin(), in_shape.begin() + param->axis_, 1, std::multiplies<int>()),
std::accumulate(in_shape.begin() + param->axis_, in_shape.end(), 1, std::multiplies<int>()),
static_cast<int>(in_shape.at(param->axis_))};
int out_axis = (param->axis_ == 3 && param->topk_ == 1 && !param->keep_dims_) ? 4 : param->axis_;
strides_ = {
std::accumulate(in_shape_align.begin() + param->axis_ + 1, in_shape_align.end(), 1, std::multiplies<int>()),
std::accumulate(in_shape_align.begin() + param->axis_, in_shape_align.end(), 1, std::multiplies<int>()),
std::accumulate(out_shape_align.begin() + param->axis_ + 1, out_shape_align.end(), 1, std::multiplies<int>()),
std::accumulate(out_shape_align.begin() + param->axis_, out_shape_align.end(), 1, std::multiplies<int>()),
std::accumulate(out_shape_align.begin() + std::min(out_axis + 1, 4), out_shape_align.end(), 1,
std::multiplies<int>()),
std::accumulate(out_shape_align.begin() + out_axis, out_shape_align.end(), 1, std::multiplies<int>()),
};
switch (param->axis_) {
case 0:

View File

@ -86,7 +86,7 @@ void Conv2dTransposeOpenCLKernel::SetGlobalLocal() {
int oh = out_tensors_[0]->shape()[1];
int ow = out_tensors_[0]->shape()[2];
local_size_ = {16, 1, 16};
global_size_ = {(size_t)UP_ROUND(oh / 2, stride_h), (size_t)UP_ROUND(ow / 2, stride_w), (size_t)co4};
global_size_ = {(size_t)UP_ROUND(UP_DIV(oh, 2), stride_h), (size_t)UP_ROUND(UP_DIV(ow, 2), stride_w), (size_t)co4};
AlignGlobalLocal(global_size_, local_size_);
}

View File

@ -58,9 +58,16 @@ int SplitOpenCLKernel::RunAxis0() {
}
int SplitOpenCLKernel::CheckSpecs() {
if (out_tensors_.size() != 2 || in_tensors_.size() != 1) {
MS_LOG(ERROR) << "in size: " << in_tensors_.size() << ", out size: " << out_tensors_.size();
return RET_ERROR;
auto param = reinterpret_cast<SplitParameter *>(this->op_parameter_);
if (param->split_dim_) {
if (out_tensors_.size() != 2 || in_tensors_.size() != 1) {
MS_LOG(ERROR) << "in size: " << in_tensors_.size() << ", out size: " << out_tensors_.size();
return RET_ERROR;
}
if (param->num_split_ != 2) {
MS_LOG(ERROR) << "num_split_(should be 2): " << param->num_split_;
return RET_ERROR;
}
}
if (in_tensors_.at(0)->IsConst()) {
MS_LOG(ERROR) << "in_tensors_ must be tensor";
@ -72,11 +79,6 @@ int SplitOpenCLKernel::CheckSpecs() {
return RET_ERROR;
}
}
auto param = reinterpret_cast<SplitParameter *>(this->op_parameter_);
if (param->num_split_ != 2 && param->num_split_ != 1) {
MS_LOG(ERROR) << "num_split_ only supported 1 or 2 yet";
return RET_ERROR;
}
if (param->split_dim_ < 0 || param->split_dim_ > 3) {
MS_LOG(ERROR) << "split_dim_ must between 0~3";
return RET_ERROR;

View File

@ -50,6 +50,7 @@ void OpenCLSubGraph::ReplaceOutTensorAndKernelToNull(const std::vector<lite::Ten
kernels.begin(), kernels.end(),
[this, &in_tensors, &i](kernel::LiteKernel *kv) {
MS_ASSERT(kv);
if (kv == nullptr) return false;
return std::find_if(kv->in_tensors().begin(), kv->in_tensors().end(),
[&in_tensors, &i](lite::Tensor *xv) { return xv == in_tensors.at(i); }) !=
kv->in_tensors().end() &&

View File

@ -248,4 +248,119 @@ TEST_F(TestOpenCL_ArgMinMax, dim10axis2topk1index) {
TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable, 1e-1, 1e-1, true);
}
}
TEST_F(TestOpenCL_ArgMinMax, dim43axis1topk1index) {
schema::PrimitiveType type = schema::PrimitiveType_ArgMaxFusion;
int axis = 1;
int topk = 1;
bool out_value = false;
std::vector<int> input_shape = {2, 2, 2, 14};
std::vector<int> output_shape = {1, 2, 2, 14};
float input_data[] = {
2732., 10799., 9845., 3264., 13123., 4859., 14019., 15719., 9225., 7891., 4373., 5874., 14116., 14935.,
15430., 15832., 6744., 3468., 14650., 705., 15846., 2599., 10327., 2222., 7768., 2897., 9893., 537.,
11085., 6216., 6921., 6036., 2163., 5072., 4851., 7877., 2046., 1871., 7599., 2496., 15186., 8291.,
10200., 15537., 755., 797., 659., 3219., 15246., 8615., 7456., 16321., 3337., 2745., 4735., 8736.,
6687., 714., 2292., 8343., 10915., 14846., 11723., 11122., 1207., 6172., 8994., 10368., 10368., 10148.,
7221., 6021., 3622., 3560., 8948., 12561., 14671., 12676., 1641., 11306., 13754., 14879., 4984., 4353.,
13633., 12263., 12201., 10297., 14627., 12134., 11383., 15115., 8622., 7250., 4187., 14208., 10638., 2659.,
9781., 2956., 10873., 16298., 12372., 2251., 4420., 13062., 7108., 1071., 12927., 14324., 5251., 13260.};
float output_data[] = {1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1,
1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1};
for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter(type, axis, topk, out_value);
TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable, 1e-1, 1e-1, true);
}
}
TEST_F(TestOpenCL_ArgMinMax, dim43axis3topk1index) {
schema::PrimitiveType type = schema::PrimitiveType_ArgMaxFusion;
int axis = 3;
int topk = 1;
bool out_value = false;
std::vector<int> input_shape = {1, 13, 13, 6};
std::vector<int> output_shape = {1, 13, 13};
float input_data[] = {
2732., 10799., 9845., 3264., 13123., 4859., 14019., 15719., 9225., 7891., 4373., 5874., 14116., 14935.,
15430., 15832., 6744., 3468., 14650., 705., 15846., 2599., 10327., 2222., 7768., 2897., 9893., 537.,
11085., 6216., 6921., 6036., 2163., 5072., 4851., 7877., 2046., 1871., 7599., 2496., 15186., 8291.,
10200., 15537., 755., 797., 659., 3219., 15246., 8615., 7456., 16321., 3337., 2745., 4735., 8736.,
6687., 714., 2292., 8343., 10915., 14846., 11723., 11122., 1207., 6172., 8994., 10368., 10368., 10148.,
7221., 6021., 3622., 3560., 8948., 12561., 14671., 12676., 1641., 11306., 13754., 14879., 4984., 4353.,
13633., 12263., 12201., 10297., 14627., 12134., 11383., 15115., 8622., 7250., 4187., 14208., 10638., 2659.,
9781., 2956., 10873., 16298., 12372., 2251., 4420., 13062., 7108., 1071., 12927., 14324., 5251., 13260.,
7012., 9396., 14312., 3918., 9359., 1684., 11491., 7098., 15127., 10959., 2957., 4469., 14165., 8752.,
13617., 9797., 14505., 5795., 1472., 7263., 7365., 11870., 8448., 6001., 3762., 13604., 10146., 9008.,
16221., 2435., 1634., 15914., 973., 4464., 10215., 11157., 8393., 10623., 13824., 14218., 2418., 12843.,
13242., 3455., 6167., 5819., 12418., 6521., 6242., 7742., 9123., 15070., 14459., 10179., 6738., 14254.,
2787., 7316., 4305., 2610., 5531., 6926., 15401., 15418., 15041., 7204., 6922., 4182., 15403., 13160.,
10251., 15106., 307., 13392., 14368., 5302., 1152., 6950., 8467., 5294., 13866., 13683., 1208., 2492.,
10728., 15949., 14622., 13592., 8829., 770., 15875., 8286., 11490., 5995., 15629., 14192., 2344., 13640.,
3091., 12895., 3912., 1434., 6594., 5368., 8372., 10563., 7148., 7997., 3854., 8032., 15620., 8131.,
4845., 10379., 5116., 10838., 3533., 2937., 9837., 4939., 15032., 9744., 3224., 5021., 10389., 1134.,
25., 9680., 956., 1913., 2934., 13429., 9661., 13907., 2721., 10088., 928., 15588., 5627., 11003.,
6265., 5446., 469., 10527., 8717., 1863., 1720., 5272., 591., 6185., 2322., 15912., 11702., 207.,
15115., 4262., 14447., 3421., 12281., 5249., 15071., 10102., 11052., 8408., 15741., 8216., 12355., 11218.,
5103., 7939., 2282., 1740., 6118., 12579., 5846., 13310., 15037., 3781., 2775., 2603., 14368., 7179.,
13928., 6356., 1162., 14006., 12267., 13733., 15997., 16028., 623., 15848., 8962., 13851., 4051., 1241.,
10903., 9013., 4403., 1198., 13460., 2997., 5661., 15939., 10787., 807., 11657., 2121., 12585., 15255.,
8067., 3886., 8922., 6066., 15212., 9987., 1823., 15113., 11658., 11803., 12717., 199., 1447., 5181.,
10581., 13153., 11308., 11042., 10146., 5208., 6177., 14981., 14568., 4863., 6180., 1792., 1483., 14626.,
8389., 894., 14261., 5374., 15440., 13758., 136., 10429., 6273., 10449., 9584., 14627., 13688., 3419.,
168., 6004., 2852., 12464., 9753., 4419., 8039., 8700., 15139., 3186., 5918., 5149., 1777., 3361.,
8338., 5393., 4317., 14676., 4605., 2562., 6213., 13669., 9100., 4652., 13173., 11005., 13634., 11887.,
6235., 11605., 423., 11815., 16331., 11926., 11678., 11921., 6854., 967., 4370., 9052., 6187., 5203.,
433., 13097., 6237., 13486., 1429., 12489., 12121., 2546., 14816., 14299., 329., 3612., 16363., 8401.,
6761., 14266., 3968., 8150., 15935., 1040., 6250., 8356., 8798., 7704., 6772., 5311., 9411., 9523.,
12424., 9144., 13147., 11357., 6011., 2798., 13399., 8352., 2195., 4680., 6599., 9303., 3085., 15674.,
5713., 5240., 10100., 11191., 15168., 10955., 732., 5028., 8473., 13088., 7594., 12046., 4566., 9500.,
7444., 16338., 3396., 13846., 5347., 7034., 595., 647., 12232., 573., 6797., 5637., 8448., 11400.,
11471., 14799., 16309., 5259., 9220., 6567., 4444., 2989., 13594., 586., 14132., 5102., 7601., 11483.,
11059., 739., 13161., 4882., 11637., 5410., 15923., 16030., 437., 3898., 12203., 1847., 9724., 1020.,
6930., 941., 11095., 8641., 11590., 5610., 11317., 9008., 15454., 2107., 14672., 9882., 13948., 4259.,
11834., 945., 13418., 8393., 7468., 1805., 15225., 1862., 8742., 3751., 9864., 15373., 2040., 903.,
14032., 15352., 14870., 8696., 8015., 14297., 5896., 12003., 7942., 7377., 9671., 14804., 5593., 16322.,
13884., 12688., 3128., 7026., 3821., 2711., 8472., 1028., 2660., 13292., 2353., 10583., 5662., 7734.,
8345., 12052., 7521., 10597., 10937., 12695., 15771., 1053., 2977., 5491., 3893., 2679., 11187., 4950.,
14838., 12295., 2665., 3057., 14473., 6838., 3968., 851., 9592., 5028., 3793., 7316., 8053., 7152.,
3331., 8318., 5930., 8769., 5652., 804., 5444., 3024., 112., 1967., 650., 4333., 1384., 13278.,
14171., 13867., 63., 3999., 3988., 2502., 13577., 3516., 12891., 2671., 13731., 2387., 10060., 5394.,
3441., 8010., 10466., 13537., 1963., 5763., 2956., 7396., 3898., 3969., 14705., 7296., 4903., 13336.,
8890., 292., 14691., 9029., 14470., 4099., 5346., 7033., 4776., 14780., 13729., 7452., 6980., 4122.,
736., 10488., 4461., 1971., 11465., 13749., 8389., 13217., 1671., 10877., 606., 2120., 12534., 6996.,
9351., 1731., 10453., 15835., 7788., 3395., 6246., 8020., 10567., 8787., 5343., 2304., 11909., 3419.,
1131., 15262., 14281., 2003., 11783., 11413., 10213., 7644., 13704., 1707., 9774., 8192., 7528., 691.,
13862., 13401., 11338., 2547., 10978., 2683., 8535., 15456., 6995., 12570., 6862., 6176., 11379., 6598.,
5985., 4524., 827., 10041., 6834., 10413., 14057., 3204., 11705., 93., 15707., 13713., 2467., 3778.,
404., 5037., 9401., 13263., 375., 16036., 3945., 10942., 15876., 497., 7666., 7373., 9630., 13677.,
14167., 8930., 4515., 6729., 3290., 10167., 1562., 13686., 13334., 8652., 15055., 15714., 11866., 3123.,
15334., 1838., 16080., 15933., 9660., 6959., 14330., 14440., 4736., 3466., 4043., 6029., 12615., 4702.,
5638., 7853., 15605., 5534., 13839., 14505., 6310., 13621., 2987., 4690., 11655., 3292., 2881., 5801.,
15170., 7282., 11100., 8526., 8933., 9435., 15606., 8292., 2463., 10461., 13490., 7676., 8366., 8797.,
7794., 3745., 4876., 3808., 9961., 9040., 9282., 5576., 13299., 2173., 9354., 4720., 6874., 1179.,
8888., 7288., 12609., 2496., 2757., 12120., 7458., 4047., 2051., 6844., 3310., 7845., 15531., 1747.,
11096., 15942., 7828., 9094., 3868., 4723., 4998., 4930., 604., 8156., 3686., 9061., 3451., 3781.,
13421., 9545., 100., 4790., 9037., 6037., 5627., 8863., 3665., 3107., 8429., 15603., 14586., 14728.,
6910., 9497., 21., 6573., 1253., 6102., 8592., 13465., 9198., 3191., 9893., 8063., 13697., 13701.,
1734., 6540., 3418., 8778., 15355., 5046., 7246., 9022., 9800., 14535., 16173., 3205., 15919., 12987.,
5290., 4547., 6282., 4850., 1337., 3547., 13657., 387., 5245., 10958., 3922., 1221., 14010., 1924.,
7185., 8901., 8639., 350., 8856., 3715., 12613., 8616., 4260., 7738., 9393., 3511., 10904., 673.,
1938., 8033., 12750., 8945., 7303., 1973., 13035., 12334., 15856., 15348., 10879., 15265., 8529., 5277.,
11788., 11894., 10030., 11126., 12576., 7970., 115., 5719., 10876., 12697., 11438., 10738., 8043., 15924.,
8169., 12910., 5696., 3404., 12150., 10825., 4242., 2820., 1799., 2691., 9264., 11340., 6437., 12404.,
9709., 9776., 6253., 10194., 10419., 10801., 11335., 16218., 11697., 4078., 5405., 4611., 8266., 15956.,
6634., 11585., 6007., 3604., 3280., 5162., 5618., 28., 1434., 2903., 3252., 6448., 14274., 9830.,
8969., 7426., 11636., 15212., 14057., 13145., 13692., 9077., 612., 4186., 9284., 8809., 9738., 4108.,
5736., 10465., 10661., 4263., 9120., 9594., 11553., 9114., 99., 7385., 2354., 10584., 15570., 5908.,
13022., 16028., 1608., 5394., 13721., 9112., 7719., 11498., 13947., 284., 13976., 11922., 16067., 14508.,
16083., 10541., 11062., 0., 10378., 4803.};
float output_data[] = {4, 1, 3, 2, 4, 5, 4, 1, 3, 1, 1, 1, 4, 3, 4, 1, 5, 3, 1, 0, 0, 2, 5, 2, 3, 1, 2, 1, 1,
1, 0, 0, 5, 4, 2, 1, 1, 0, 4, 2, 5, 3, 3, 5, 2, 2, 0, 5, 0, 3, 1, 2, 3, 3, 2, 2, 1, 1,
1, 0, 1, 1, 0, 3, 1, 0, 0, 5, 1, 4, 4, 2, 4, 2, 3, 2, 1, 1, 2, 4, 4, 0, 5, 2, 4, 2, 0,
2, 1, 0, 5, 0, 3, 3, 2, 4, 2, 0, 3, 0, 2, 2, 0, 1, 2, 2, 3, 3, 1, 2, 1, 4, 1, 2, 2, 3,
2, 4, 2, 5, 2, 2, 3, 1, 0, 4, 2, 1, 2, 2, 0, 2, 0, 2, 5, 3, 5, 4, 2, 3, 1, 1, 1, 0, 0,
4, 4, 1, 0, 4, 4, 1, 2, 5, 1, 5, 1, 3, 3, 0, 4, 3, 0, 4, 2, 5, 2, 4, 0};
for (auto fp16_enable : {false, true}) {
auto *param = CreateParameter(type, axis, topk, out_value);
TestMain({{input_shape, input_data, VAR}}, {output_shape, output_data}, param, fp16_enable, 1e-1, 1e-1, true);
}
}
} // namespace mindspore::lite::opencl::test

View File

@ -102,5 +102,32 @@ TEST_F(TestOpenCL_Conv2dTranspose, test1) {
{output_shape, output_data}, param, fp16_enable);
}
}
TEST_F(TestOpenCL_Conv2dTranspose, test2) {
int n = 1;
int h = 2;
int w = 2;
int oh = 5;
int ow = 5;
int ci = 2;
int co = 1;
int kh = 3;
int kw = 3;
std::vector<int> pad = {0, 0, 0, 0};
float input_data[] = {0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0};
float weight_data[] = {0.0, 2.0, 4.0, 6.0, 8.0, 10.0, 12.0, 14.0, 16.0,
1.0, 3.0, 5.0, 7.0, 9.0, 11.0, 13.0, 15.0, 17.0};
float bias_data[] = {0.5};
float output_data[] = {1.5, 3.5, 8.5, 13.5, 23.5, 7.5, 9.5, 44.5, 43.5, 53.5, 18.5, 38.5, 128.5,
106.5, 142.5, 59.5, 77.5, 180.5, 111.5, 137.5, 113.5, 131.5, 312.5, 189.5, 215.5};
for (auto fp16_enable : {false, true}) {
std::vector<int> input_shape, weight_shape, bias_shape, output_shape;
auto *param =
CreateParameter(n, h, w, ci, co, kh, kw, pad, oh, ow, &input_shape, &weight_shape, &bias_shape, &output_shape);
TestMain({{input_shape, input_data, VAR},
{weight_shape, weight_data, CONST_TENSOR},
{bias_shape, bias_data, CONST_TENSOR}},
{output_shape, output_data}, param, fp16_enable);
}
}
} // namespace mindspore::lite::opencl::test