forked from mindspore-Ecosystem/mindspore
fix some problems for power
This commit is contained in:
parent
521b059608
commit
d97bd37f57
|
@ -32,16 +32,10 @@ using mindspore::schema::PrimitiveType_Power;
|
|||
namespace mindspore::kernel {
|
||||
|
||||
int PowerOpenCLKernel::CheckSpecs() {
|
||||
auto param = reinterpret_cast<PowerParameter *>(this->op_parameter_);
|
||||
broadcast_ = param->broadcast_;
|
||||
if ((in_tensors_.size() != 1 && in_tensors_.size() != 2) || out_tensors_.size() != 1) {
|
||||
MS_LOG(ERROR) << "in size: " << in_tensors_.size() << "out size: " << out_tensors_.size();
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (in_tensors_.size() == 1 && !broadcast_) {
|
||||
MS_LOG(ERROR) << "broadcast is supported when in_tensors_.size() == 1 ";
|
||||
return RET_ERROR;
|
||||
}
|
||||
if (in_tensors_.size() == 2 && in_tensors_.at(0)->shape().size() != in_tensors_.at(1)->shape().size()) {
|
||||
MS_LOG(ERROR) << "Unsupported input->shape.size " << in_tensors_.at(0)->shape().size()
|
||||
<< "!=" << in_tensors_.at(1)->shape().size();
|
||||
|
@ -143,12 +137,14 @@ void PowerOpenCLKernel::SetGlobalLocal() {
|
|||
}
|
||||
|
||||
int PowerOpenCLKernel::Prepare() {
|
||||
if (in_tensors_.size() == 1) {
|
||||
broadcast_ = true;
|
||||
}
|
||||
use_fp16_enable_ = ocl_runtime_->GetFp16Enable();
|
||||
auto param = reinterpret_cast<PowerParameter *>(this->op_parameter_);
|
||||
std::string kernel_name = "power";
|
||||
std::string source = power_source;
|
||||
std::string program_name = "power";
|
||||
broadcast_ = param->broadcast_;
|
||||
if (broadcast_ && in_tensors_.size() == 1) {
|
||||
power_ = param->power_;
|
||||
kernel_name += "_broadcast";
|
||||
|
|
Loading…
Reference in New Issue