From d97bd37f57867bfcb3162c0417448dc4dcabc08a Mon Sep 17 00:00:00 2001 From: Pengyongrong Date: Sun, 13 Dec 2020 20:25:54 -0800 Subject: [PATCH] fix some problems for power --- .../lite/src/runtime/kernel/opencl/kernel/power.cc | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/power.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/power.cc index 9de4afbf817..2bc1c9afc73 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/power.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/power.cc @@ -32,16 +32,10 @@ using mindspore::schema::PrimitiveType_Power; namespace mindspore::kernel { int PowerOpenCLKernel::CheckSpecs() { - auto param = reinterpret_cast(this->op_parameter_); - broadcast_ = param->broadcast_; if ((in_tensors_.size() != 1 && in_tensors_.size() != 2) || out_tensors_.size() != 1) { MS_LOG(ERROR) << "in size: " << in_tensors_.size() << "out size: " << out_tensors_.size(); return RET_ERROR; } - if (in_tensors_.size() == 1 && !broadcast_) { - MS_LOG(ERROR) << "broadcast is supported when in_tensors_.size() == 1 "; - return RET_ERROR; - } if (in_tensors_.size() == 2 && in_tensors_.at(0)->shape().size() != in_tensors_.at(1)->shape().size()) { MS_LOG(ERROR) << "Unsupported input->shape.size " << in_tensors_.at(0)->shape().size() << "!=" << in_tensors_.at(1)->shape().size(); @@ -143,12 +137,14 @@ void PowerOpenCLKernel::SetGlobalLocal() { } int PowerOpenCLKernel::Prepare() { + if (in_tensors_.size() == 1) { + broadcast_ = true; + } use_fp16_enable_ = ocl_runtime_->GetFp16Enable(); auto param = reinterpret_cast(this->op_parameter_); std::string kernel_name = "power"; std::string source = power_source; std::string program_name = "power"; - broadcast_ = param->broadcast_; if (broadcast_ && in_tensors_.size() == 1) { power_ = param->power_; kernel_name += "_broadcast";