fix some problems for power

This commit is contained in:
Pengyongrong 2020-12-13 20:25:54 -08:00
parent 521b059608
commit d97bd37f57
1 changed files with 3 additions and 7 deletions

View File

@ -32,16 +32,10 @@ using mindspore::schema::PrimitiveType_Power;
namespace mindspore::kernel {
int PowerOpenCLKernel::CheckSpecs() {
auto param = reinterpret_cast<PowerParameter *>(this->op_parameter_);
broadcast_ = param->broadcast_;
if ((in_tensors_.size() != 1 && in_tensors_.size() != 2) || out_tensors_.size() != 1) {
MS_LOG(ERROR) << "in size: " << in_tensors_.size() << "out size: " << out_tensors_.size();
return RET_ERROR;
}
if (in_tensors_.size() == 1 && !broadcast_) {
MS_LOG(ERROR) << "broadcast is supported when in_tensors_.size() == 1 ";
return RET_ERROR;
}
if (in_tensors_.size() == 2 && in_tensors_.at(0)->shape().size() != in_tensors_.at(1)->shape().size()) {
MS_LOG(ERROR) << "Unsupported input->shape.size " << in_tensors_.at(0)->shape().size()
<< "!=" << in_tensors_.at(1)->shape().size();
@ -143,12 +137,14 @@ void PowerOpenCLKernel::SetGlobalLocal() {
}
int PowerOpenCLKernel::Prepare() {
if (in_tensors_.size() == 1) {
broadcast_ = true;
}
use_fp16_enable_ = ocl_runtime_->GetFp16Enable();
auto param = reinterpret_cast<PowerParameter *>(this->op_parameter_);
std::string kernel_name = "power";
std::string source = power_source;
std::string program_name = "power";
broadcast_ = param->broadcast_;
if (broadcast_ && in_tensors_.size() == 1) {
power_ = param->power_;
kernel_name += "_broadcast";