!11542 【MS】【LITE】【GPU】 fix opencl fusion bug

From: @wangdongxu6
Reviewed-by: @ddwsky,@zhanghaibo5
Signed-off-by: @ddwsky
This commit is contained in:
mindspore-ci-bot 2021-01-22 20:22:19 +08:00 committed by Gitee
commit 1082c91158
2 changed files with 28 additions and 28 deletions

View File

@ -18,6 +18,7 @@
#include "src/runtime/kernel/opencl/utils.h"
#include "include/errorcode.h"
#include "nnacl/fp32/activation_fp32.h"
#include "nnacl/scale.h"
using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK;
@ -60,12 +61,18 @@ std::pair<bool, FusionEltwiseParameter *> CheckSupportOrCreateParam(
param = reinterpret_cast<FusionEltwiseParameter *>(eltwise->GetParameter());
eltwise->ClearParameter();
}
} else if (IsArithmetic(node_type)) {
auto act_type =
static_cast<ActivationType>(reinterpret_cast<ArithmeticParameter *>(op_parameter)->activation_type_);
} else if (IsArithmetic(node_type) || node_type == schema::PrimitiveType_Scale) {
auto *arith_param = reinterpret_cast<ArithmeticParameter *>(op_parameter);
auto *scale_param = reinterpret_cast<ScaleParameter *>(op_parameter);
auto act_type = static_cast<ActivationType>(
node_type == schema::PrimitiveType_Scale ? scale_param->activation_type_ : arith_param->activation_type_);
EltwiseOperator act_operator = Activation2Operator(act_type);
support =
node->in_tensors().size() == 2 && SupportedOperators.count(operator_) && SupportedOperators.count(act_operator);
support = SupportedOperators.count(operator_) && SupportedOperators.count(act_operator);
if (node_type == schema::PrimitiveType_Scale) {
support = support && node->in_tensors().size() == 3 && scale_param->axis_ == -1;
} else {
support = support && (node->in_tensors().size() == 2);
}
if (create_param) {
param = new (std::nothrow) FusionEltwiseParameter(operator_, node->name(), node->in_tensors(), replace_map);
MS_ASSERT(param);
@ -83,12 +90,6 @@ std::pair<bool, FusionEltwiseParameter *> CheckSupportOrCreateParam(
param = new (std::nothrow) FusionEltwiseParameter(operator_, node->name(), node->in_tensors(), replace_map);
MS_ASSERT(param);
}
} else if (node_type == schema::PrimitiveType_Scale) {
support = node->in_tensors().size() == 3 && SupportedOperators.count(operator_);
if (create_param) {
param = new (std::nothrow) FusionEltwiseParameter(operator_, node->name(), node->in_tensors(), replace_map);
MS_ASSERT(param);
}
} else if (node_type == schema::PrimitiveType_Activation) {
auto act_type = static_cast<ActivationType>(reinterpret_cast<ActivationParameter *>(op_parameter)->type_);
EltwiseOperator act_operator = Activation2Operator(act_type);
@ -141,15 +142,11 @@ bool IsEltwiseAndOperatorSupported(LiteKernel *node) {
}
int FusionEltwiseOpenCLKernel::Prepare() {
static std::set<std::string> code_map;
std::string source = Codegen();
code_map.insert(source);
std::string program_name = "FusionEltwise" + std::to_string(code_map.size());
std::string program_name = "FusionEltwise\n" + source;
std::string kernel_name = "FusionEltwise";
ocl_runtime_->LoadSource(program_name, source);
ocl_runtime_->BuildKernel(kernel_, program_name, kernel_name);
InitWeights();
SetGlobalLocal();
SetConstArgs();
@ -390,20 +387,22 @@ std::string FusionEltwiseOpenCLKernel::CodegenCore(FusionEltwiseParameter *param
std::string FusionEltwiseOpenCLKernel::GetFormatVarName(std::string name) {
if (var_names_.count(name)) {
return name;
}
if (name.empty()) {
name = "_var_" + std::to_string(var_names_.size());
return simplify_var_name_ ? var_names_[name] : name;
} else {
char c = name.front();
if (c != '_' && !std::isalpha(c)) {
name = '_' + name;
if (name.empty()) {
name = "_var_" + std::to_string(var_names_.size());
} else {
char c = name.front();
if (c != '_' && !std::isalpha(c)) {
name = '_' + name;
}
std::replace_if(
name.begin(), name.end(), [](char c) { return !std::isalnum(c); }, '_');
}
std::replace_if(
name.begin(), name.end(), [](char c) { return !std::isalnum(c); }, '_');
auto new_name = "tmp" + std::to_string(var_names_.size());
var_names_.emplace(name, new_name);
return simplify_var_name_ ? new_name : name;
}
var_names_.insert(name);
return name;
}
int FusionEltwiseOpenCLKernel::GetTensorIdx(lite::Tensor *in_tensor) {

View File

@ -180,7 +180,8 @@ class FusionEltwiseOpenCLKernel : public OpenCLKernel {
return shape.empty() || (shape.size() == 1 && shape.front() == 1);
}
std::set<std::string> var_names_;
std::map<std::string, std::string> var_names_; // origin name -> simplified name
const bool simplify_var_name_{true};
std::vector<float> scalar_weights_;
std::vector<void *> buffer_weights_;
};