add new ops named layer_norm for gpu

This commit is contained in:
Pengyongrong 2020-12-23 02:06:24 -08:00
parent 98a3f318c7
commit f8847e427e
10 changed files with 484 additions and 2 deletions

View File

@ -1,9 +1,9 @@
if (ENABLE_GITEE)
set(REQ_URL "https://gitee.com/mirrors/flatbuffers/repository/archive/v2020.06.16.tar.gz")
set(REQ_URL "https://gitee.com/mirrors/OpenCL-Headers/repository/archive/v2020.06.16.tar.gz")
set(MD5 "fc7627b5a8a95ecbe3d5df43bc88aa44")
set(PKG_GIT_TAG "")
__download_pkg_with_git(OpenCL-Headers ${REQ_URL} ${PKG_GIT_TAG} ${MD5})
set(REQ_URL "https://gitee.com/mirrors/flatbuffers/repository/archive/v2.0.12.tar.gz")
set(REQ_URL "https://gitee.com/mirrors/OpenCL-CLHPP/repository/archive/v2.0.12.tar.gz")
set(MD5 "bd00fca8f861b3b65660d719f00a58dd")
set(PKG_GIT_TAG "")
__download_pkg_with_git(OpenCL-CLHPP ${REQ_URL} ${PKG_GIT_TAG} ${MD5})

View File

@ -0,0 +1,103 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
#define UP_DIV(x, y) (((x) + (y) - (1)) / (y))
#define C4NUM 4
__kernel void ComputeMeanVarDim1NHWC4(__read_only image2d_t src_data, __global FLT *mean_, __global FLT *variance_,
int4 in_shape, int normalized_shape_size) {
int X = get_global_id(0); // n*h
int Y = get_global_id(1); // w
if (X > in_shape.x * in_shape.y || Y > in_shape.z || in_shape.y == 0) {
return;
}
int n = X / in_shape.y;
int h = X % in_shape.y;
int w = Y;
int ci4 = UP_DIV(in_shape.w, C4NUM);
int remainder = in_shape.w % C4NUM;
FLT4 mean_temp = {0.0f, 0.0f, 0.0f, 0.0f};
FLT4 var_temp = {0.0f, 0.0f, 0.0f, 0.0f};
FLT mean = 0.0f;
FLT var = 0.0f;
// compute mean
for (int i = 0; i < ci4; ++i) {
FLT4 result_temp = READ_IMAGE(src_data, smp_none, (int2)(w * ci4 + i, n * in_shape.y + h));
mean_temp += result_temp;
}
mean = (mean_temp.x + mean_temp.y + mean_temp.z + mean_temp.w) / normalized_shape_size;
mean_temp.x = mean_temp.y = mean_temp.z = mean_temp.w = mean;
// compute var
for (int i = 0; i < ci4; ++i) {
FLT4 result_temp = READ_IMAGE(src_data, smp_none, (int2)(w * ci4 + i, n * in_shape.y + h));
if ((i + 1) * C4NUM <= in_shape.w) {
var_temp += (result_temp - mean_temp) * (result_temp - mean_temp);
} else {
if (remainder == 1) {
mean_temp.x = mean;
mean_temp.y = mean_temp.z = mean_temp.w = 0.0f;
} else if (remainder == 2) {
mean_temp.x = mean_temp.y = mean;
mean_temp.z = mean_temp.w = 0.0f;
} else {
mean_temp.x = mean_temp.y = mean_temp.z = mean;
mean_temp.w = 0.0f;
}
var_temp += (result_temp - mean_temp) * (result_temp - mean_temp);
}
}
var = (var_temp.x + var_temp.y + var_temp.z + var_temp.w) / normalized_shape_size;
// write result to dst
int postion = (n * in_shape.y + h) * in_shape.z + w;
mean_[postion] = mean;
variance_[postion] = var;
}
__kernel void LayerNormalization_NHWC4(__read_only image2d_t src_data, __write_only image2d_t dst_data,
__global FLT *mean_, __global FLT *variance_, __global FLT *gamma_,
__global FLT *beta_, int4 in_shape, float epsilon_, int normalized_dims_,
int elementwise_affine_) {
int X = get_global_id(0); // n*h
int Y = get_global_id(1); // w
int Z = get_global_id(2); // c4
if (X >= in_shape.x * in_shape.y || Y >= in_shape.z || Z >= in_shape.w || in_shape.y == 0) {
return;
}
int n = X / in_shape.y;
int h = X % in_shape.y;
int w = Y;
int c = Z;
int ci4 = UP_DIV(in_shape.w, C4NUM);
int postion_mv = 0;
int postion_gb = 0;
if (normalized_dims_ == 1) {
postion_mv = (n * in_shape.y + h) * in_shape.z + w;
postion_gb = c * C4NUM;
} else if (normalized_dims_ == 2) {
postion_mv = n * in_shape.y + h;
postion_gb = w * ci4 * C4NUM + c * C4NUM;
} else if (normalized_dims_ == 3) {
postion_mv = n;
postion_gb = (h * in_shape.z + w) * ci4 * C4NUM + c * C4NUM;
}
FLT4 result = {0.0f, 0.0f, 0.0f, 0.0f};
FLT4 result_in = READ_IMAGE(src_data, smp_none, (int2)(w * ci4 + c, n * in_shape.y + h));
if (elementwise_affine_) {
result.x = ((result_in.x - mean_[postion_mv]) / sqrt(variance_[postion_mv] + epsilon_)) * gamma_[postion_gb] +
beta_[postion_gb];
result.y = ((result_in.y - mean_[postion_mv]) / sqrt(variance_[postion_mv] + epsilon_)) * gamma_[postion_gb + 1] +
beta_[postion_gb + 1];
result.z = ((result_in.z - mean_[postion_mv]) / sqrt(variance_[postion_mv] + epsilon_)) * gamma_[postion_gb + 2] +
beta_[postion_gb + 2];
result.w = ((result_in.w - mean_[postion_mv]) / sqrt(variance_[postion_mv] + epsilon_)) * gamma_[postion_gb + 3] +
beta_[postion_gb + 3];
} else {
result.x = ((result_in.x - mean_[postion_mv]) / sqrt(variance_[postion_mv] + epsilon_));
result.y = ((result_in.y - mean_[postion_mv]) / sqrt(variance_[postion_mv] + epsilon_));
result.z = ((result_in.z - mean_[postion_mv]) / sqrt(variance_[postion_mv] + epsilon_));
result.w = ((result_in.w - mean_[postion_mv]) / sqrt(variance_[postion_mv] + epsilon_));
}
WRITE_IMAGE(dst_data, (int2)((w * ci4 + c), (n * in_shape.y + h)), result);
}

View File

@ -0,0 +1,250 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cstring>
#include <algorithm>
#include <set>
#include <string>
#include "src/kernel_registry.h"
#include "src/runtime/kernel/opencl/kernel/layer_norm.h"
#include "nnacl/layer_norm_parameter.h"
#include "src/runtime/kernel/opencl/utils.h"
#include "src/runtime/kernel/opencl/cl/layer_norm.cl.inc"
using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_LayerNorm;
namespace mindspore::kernel {
int LayerNormOpenCLKernel::CheckSpecs() {
auto param = reinterpret_cast<LayerNormParameter *>(this->op_parameter_);
if (param->elementwise_mode_ == ELEMENTWISE_PER_CHANNEL) {
if (in_tensors_.size() != 3) {
MS_LOG(ERROR) << " invalid in_tensors_ size" << in_tensors_.size() << std::endl;
return RET_ERROR;
}
if (param->normalized_dims_ > in_tensors_.at(0)->shape().size()) {
MS_LOG(ERROR) << " invalid normalized_shape_ size" << param->normalized_dims_ << std::endl;
return RET_ERROR;
}
} else if (param->elementwise_mode_ == ELEMENTWISE_NOT) {
if (in_tensors_.size() != 1) {
MS_LOG(ERROR) << " invalid in_tensors_ size" << in_tensors_.size() << std::endl;
return RET_ERROR;
}
} else {
MS_LOG(ERROR) << "Unsupported elementwise_mode_" << param->elementwise_mode_;
return RET_ERROR;
}
if (in_tensors_.at(0)->shape().size() != 4 || out_tensors_.size() != 1) {
MS_LOG(ERROR) << "UnSupported in_tensors_.shape.size: " << in_tensors_.at(0)->shape().size()
<< " out_tensors_.size(): " << out_tensors_.size();
return RET_ERROR;
}
if (param->normalized_dims_ != 1) {
MS_LOG(ERROR) << "UnSupported normalized_shape_ size: " << param->normalized_dims_;
return RET_ERROR;
}
return RET_OK;
}
void LayerNormGetWorkGroup(const std::vector<size_t> &global, std::vector<size_t> *local, int max_size) {
const int max_divider = 8;
const int max_x = 4, max_y = 8;
int x = std::min(GetMaxDivisorStrategy1(global[0], max_divider), max_x);
int yz = max_size / x;
int y = std::min(std::min(GetMaxDivisorStrategy1(global[1], max_divider), yz), max_y);
int z = std::min(yz / y, static_cast<int>(UP_DIV(global[2], 2)));
local->clear();
local->push_back(x);
local->push_back(y);
local->push_back(z);
}
void LayerNormOpenCLKernel::SetConstArgs() {
int arg_cn = 6;
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_shape_);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, epsilon_);
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, normalized_dims_);
if (elementwise_affine_) {
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, 1);
} else {
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, 0);
}
ocl_runtime_->SetKernelArg(kernel_mean_var_, 3, in_shape_);
ocl_runtime_->SetKernelArg(kernel_mean_var_, 4, normalized_shape_size_);
}
void AlignMeanVarGlobalLocal(const std::vector<int> &global, const std::vector<int> &local, cl::NDRange *global_range,
cl::NDRange *local_range) {
*local_range = cl::NDRange(local[0], local[1], local[2]);
*global_range =
cl::NDRange(UP_ROUND(global[0], local[0]), UP_ROUND(global[1], local[1]), UP_ROUND(global[2], local[2]));
}
void LayerNormOpenCLKernel::SetGlobalLocal() {
size_t OH = 1, OW = 1, OC = 1;
OH = in_shape_.s[0] * in_shape_.s[1];
OW = in_shape_.s[2];
OC = UP_DIV(in_shape_.s[3], C4NUM);
local_size_ = {1, 1, 1}; // init local
global_size_ = {OH, OW, OC};
const std::vector<size_t> &max_global = ocl_runtime_->GetWorkItemSize();
LayerNormGetWorkGroup(global_size_, &local_size_, max_global[0]);
OpenCLKernel::AlignGlobalLocal(global_size_, local_size_);
if (normalized_dims_ != in_tensors_.at(0)->shape().size()) {
if (normalized_dims_ == 1) {
OH = in_shape_.s[0] * in_shape_.s[1];
OW = in_shape_.s[2];
OC = 1;
} else if (normalized_dims_ == 2) {
OH = in_shape_.s[0] * in_shape_.s[1];
OW = 1;
OC = 1;
} else {
OH = in_shape_.s[0];
OW = 1;
OC = 1;
}
} else {
OH = 1;
OW = 1;
OC = 1;
}
AlignMeanVarGlobalLocal({static_cast<int>(OH), static_cast<int>(OW), static_cast<int>(OC)}, {1, 1, 1},
&global_mean_var_, &local_mean_var_);
}
int LayerNormOpenCLKernel::Initweight() {
auto allocator = ocl_runtime_->GetAllocator();
GpuTensorInfo img_info(in_tensors_.at(1)); // gamma
auto weight_tensor = in_tensors_.at(1);
size_t weight_size = img_info.Image2DSize;
// allocated memory for weight and init value
gamma_ = allocator->Malloc(weight_size);
beta_ = allocator->Malloc(weight_size);
allocator->MapBuffer(gamma_, CL_MAP_WRITE, nullptr, true);
allocator->MapBuffer(beta_, CL_MAP_WRITE, nullptr, true);
memset(gamma_, 0x01, weight_size);
memset(beta_, 0x00, weight_size);
if (weight_tensor->data_type() == kNumberTypeFloat16) {
if (use_fp16_enable_) {
memcpy(gamma_, in_tensors_.at(1)->data_c(), weight_size);
memcpy(beta_, in_tensors_.at(2)->data_c(), weight_size);
} else {
auto gamma_fp32 = reinterpret_cast<float *>(gamma_);
auto beta_fp32 = reinterpret_cast<float *>(beta_);
auto origin_gamma_fp16 = reinterpret_cast<float16_t *>(in_tensors_.at(1)->data_c());
auto origin_beta_fp16 = reinterpret_cast<float16_t *>(in_tensors_.at(2)->data_c());
for (int i = 0; i < img_info.ElementsNum; ++i) {
gamma_fp32[i] = static_cast<float>(origin_gamma_fp16[i]);
beta_fp32[i] = static_cast<float>(origin_beta_fp16[i]);
}
}
} else {
if (use_fp16_enable_) {
auto gamma_fp16 = reinterpret_cast<float16_t *>(gamma_);
auto beta_fp16 = reinterpret_cast<float16_t *>(beta_);
auto origin_gamma_fp32 = reinterpret_cast<float *>(in_tensors_.at(1)->data_c());
auto origin_beta_fp32 = reinterpret_cast<float *>(in_tensors_.at(2)->data_c());
for (int i = 0; i < img_info.ElementsNum; ++i) {
gamma_fp16[i] = static_cast<float16_t>(origin_gamma_fp32[i]);
beta_fp16[i] = static_cast<float16_t>(origin_beta_fp32[i]);
}
} else {
memcpy(gamma_, in_tensors_.at(1)->data_c(), weight_size);
memcpy(beta_, in_tensors_.at(2)->data_c(), weight_size);
}
}
allocator->UnmapBuffer(gamma_);
allocator->UnmapBuffer(beta_);
return RET_OK;
}
int LayerNormOpenCLKernel::Prepare() {
use_fp16_enable_ = ocl_runtime_->GetFp16Enable();
auto param = reinterpret_cast<LayerNormParameter *>(this->op_parameter_);
elementwise_affine_ = param->elementwise_mode_;
normalized_dims_ = param->normalized_dims_;
epsilon_ = param->epsilon_;
if (elementwise_affine_) {
int ret = Initweight();
if (ret) {
MS_LOG(ERROR) << "Initweight failed ";
return RET_ERROR;
}
}
auto allocator = ocl_runtime_->GetAllocator();
size_t mean_size = 1;
size_t size = in_tensors_.at(0)->shape().size() - normalized_dims_;
for (int i = 0; i < size; ++i) {
mean_size *= in_tensors_.at(0)->shape()[i];
}
size_t size_dtype = use_fp16_enable_ ? sizeof(float16_t) : sizeof(float);
mean_size *= size_dtype;
mean_ = allocator->Malloc(mean_size);
var_ = allocator->Malloc(mean_size);
GpuTensorInfo img_info(in_tensors_.at(0));
in_shape_.s[0] = img_info.N, in_shape_.s[1] = img_info.H, in_shape_.s[2] = img_info.W, in_shape_.s[3] = img_info.C;
for (int i = 0; i < normalized_dims_; ++i) {
normalized_shape_size_ *= param->normalized_shape_[i];
}
std::string kernel_name = "LayerNormalization_NHWC4";
std::string kernel_name_mean_var = "ComputeMeanVar";
std::set<std::string> build_options;
std::string source = layer_norm_source;
std::string program_name = "LayerNormalization";
ocl_runtime_->LoadSource(program_name, source);
ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name, build_options);
kernel_name_mean_var += "Dim" + std::to_string(normalized_dims_) + "NHWC4";
ocl_runtime_->BuildKernel(kernel_mean_var_, program_name, kernel_name_mean_var, build_options);
MS_LOG(DEBUG) << kernel_name << " Init Done!";
SetConstArgs();
SetGlobalLocal();
return RET_OK;
}
int LayerNormOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->name() << " Running! ";
int arg1_cn = 0;
ocl_runtime_->SetKernelArg(kernel_mean_var_, arg1_cn++, in_tensors_.at(0)->data_c()); // input tensor
ocl_runtime_->SetKernelArg(kernel_mean_var_, arg1_cn++, mean_, lite::opencl::MemType::BUF); // mean_
ocl_runtime_->SetKernelArg(kernel_mean_var_, arg1_cn++, var_, lite::opencl::MemType::BUF); // var_ return RET_OK;
ocl_runtime_->RunKernel(kernel_mean_var_, global_mean_var_, local_mean_var_, nullptr, &event_);
int arg_cn = 0;
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, in_tensors_.at(0)->data_c()); // input tensor
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, out_tensors_.at(0)->data_c()); // out tensor
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, mean_, lite::opencl::MemType::BUF); // mean_
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, var_, lite::opencl::MemType::BUF); // var_
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, gamma_, lite::opencl::MemType::BUF); // gamma_
ocl_runtime_->SetKernelArg(kernel_, arg_cn++, beta_, lite::opencl::MemType::BUF); // beta_
ocl_runtime_->RunKernel(kernel_, global_range_, local_range_, nullptr, &event_);
return RET_OK;
}
REG_KERNEL(kGPU, kNumberTypeFloat32, PrimitiveType_LayerNorm, OpenCLKernelCreator<LayerNormOpenCLKernel>)
REG_KERNEL(kGPU, kNumberTypeFloat16, PrimitiveType_LayerNorm, OpenCLKernelCreator<LayerNormOpenCLKernel>)
} // namespace mindspore::kernel

View File

@ -0,0 +1,61 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_LAYER_NORM_H_
#define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_LAYER_NORM_H_
#include <vector>
#include "src/runtime/kernel/opencl/opencl_kernel.h"
namespace mindspore::kernel {
class LayerNormOpenCLKernel : public OpenCLKernel {
public:
LayerNormOpenCLKernel(OpParameter *parameter, const std::vector<lite::Tensor *> &inputs,
const std::vector<lite::Tensor *> &outputs)
: OpenCLKernel(parameter, inputs, outputs) {}
~LayerNormOpenCLKernel() override = default;
int Run() override;
int Prepare() override;
int CheckSpecs() override;
void SetConstArgs() override;
void SetGlobalLocal() override;
private:
int Initweight();
void GetMeanVar();
private:
cl::Kernel kernel_mean_var_;
cl::NDRange global_mean_var_, local_mean_var_;
bool use_fp16_enable_{false};
void *gamma_{nullptr};
void *mean_{nullptr};
void *var_{nullptr};
void *beta_{nullptr};
cl_int4 in_shape_{};
int elementwise_affine_;
int32_t normalized_dims_{1};
int normalized_shape_size_{1};
float epsilon_{0.0f};
cl::Kernel kernel_;
};
} // namespace mindspore::kernel
#endif

View File

@ -104,6 +104,10 @@ void TestMain(const std::vector<ArgsTupleWithDtype> &input_infos, std::tuple<std
return;
}
// call sub_graph->Init() after construct subgraph like scheduler.cc
MS_LOG(DEBUG) << "call sub_graph->Init()";
EXPECT_TRUE(sub_graph->Init() == RET_OK);
// simulating benchmark: session_->CompileGraph() -> PrepareKernels() -> OpenCLSubGraph.Prepare()
MS_LOG(DEBUG) << "call sub_graph->Prepare()";
EXPECT_TRUE(sub_graph->Prepare() == RET_OK); // will set Tensor's allocator be OpenCLAllocator

View File

@ -0,0 +1,62 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ut/src/runtime/kernel/opencl/common.h"
#include "nnacl/layer_norm_parameter.h"
namespace mindspore::lite::opencl::test {
class TestOpenCL_LayerNorm : public CommonTest {};
namespace {
// PrimitiveType_Stack: src/ops/populate/stack_populate.cc
OpParameter *CreateParameter(float epsilon, int normalized_dims_, std::vector<int> normalizedShape) {
auto *param = test::CreateParameter<LayerNormParameter>(schema::PrimitiveType_LayerNorm);
param->elementwise_mode_ = ELEMENTWISE_PER_CHANNEL;
param->epsilon_ = epsilon;
param->normalized_dims_ = normalized_dims_;
for (int i = 0; i < normalizedShape.size() && i < normalized_dims_; ++i) {
param->normalized_shape_[i] = normalizedShape[i];
}
return reinterpret_cast<OpParameter *>(param);
}
} // namespace
TEST_F(TestOpenCL_LayerNorm, test1) {
float epsilon = 1e-5;
int normalized_dims_ = 1;
std::vector<int> normalizedShape = {5};
std::vector<int> input_shape = {2, 3, 4, 5};
std::vector<int> gamma_shape = {1, 1, 1, 5};
std::vector<int> beta_shape = {1, 1, 1, 5};
std::vector<int> output_shape = {2, 3, 4, 5};
size_t input_size, gamma_size, beta_size, output_size;
std::string inputPpath = "./test_data/layernormfp32_input.bin";
std::string gammaPpath = "./test_data/gammafp32_input.bin";
std::string betaPpath = "./test_data/betafp32_input.bin";
std::string correctOutputPath = "./test_data/layernormfp32_output.bin";
auto input_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(inputPpath.c_str(), &input_size));
auto gamma_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(gammaPpath.c_str(), &gamma_size));
auto beta_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(betaPpath.c_str(), &beta_size));
auto output_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(correctOutputPath.c_str(), &output_size));
for (auto fp16_enable : {false}) {
auto *param = CreateParameter(epsilon, normalized_dims_, normalizedShape);
TestMain(
{{input_shape, input_data, VAR}, {gamma_shape, gamma_data, CONST_TENSOR}, {beta_shape, beta_data, CONST_TENSOR}},
{output_shape, output_data}, param, fp16_enable, fp16_enable ? 1e-3 : 1e-6);
}
}
} // namespace mindspore::lite::opencl::test

View File

@ -0,0 +1,2 @@
Sv■╬Ba≥╬Ъ9<=∙╒∙©cЪ©Н╨ц?E⌠y©°┤O> │╬©яб=▐ьз©┼Dж=╘0к>v╕©╟+д>пF?⌠═╬лЁ'>╧К█©М8;▐Ы8©q>╚?#░H©└█юsEg>хА@с?ycю©Я>7▌©┌ф$©cЗ╛©KQ?▓Т1©3╖C╬Т▐╧?╞┌?нЙ©╪c@ыХ╬Ы^р?©oБ?,$2?1e"?]О╝©Ю╗╬╖9v╬Ц?NTюхl ?║д┌>Х 2?}▌©Lё┤>Z~┘?ИK©цо
╬²│═?╬;╘|╦?╡▀б©▓▒~?≈Шё©~t░>#╓h©≥Ыl?Ш  ?─r©=СО╨©{Юq>∙▐?с8╨>}I?мU©X©~сy╬лx╦?P│>Б(©лbЭ╫Ь╓▐©╜хБ©r╖n?cф│╫ q|?хА>0┤©Щ╧┘©Ю|?ТЫ©еo?хp??╫Ф©?Pю╪[·@²©эЖ╝>0╩©2Vt©╣╛ч=ужk>0╗А╬╤hн>╓6╧=XЩ┘╬'яG=щg>°╝^╫zи┴? ь">8|Б>Ljг╩▀M▓©ш▌I?o>Щ>√z┤>▓/}©с%И╬╨и░©&[▌?