!5823 [MSLITE] Support exponent tensor broadcast for power op
Merge pull request !5823 from zhanyuan/dev
This commit is contained in:
commit
023c93277a
|
@ -64,7 +64,9 @@ int Power::InferShape(std::vector<tensor::Tensor *> inputs, std::vector<tensor::
|
|||
return RET_OK;
|
||||
}
|
||||
if (exp_tensor != nullptr) {
|
||||
if (exp_tensor->shape() != x_tensor->shape() || exp_tensor->data_type() != x_tensor->data_type()) {
|
||||
if ((exp_tensor->shape().size() > 1 && exp_tensor->shape() != x_tensor->shape()) ||
|
||||
(exp_tensor->shape().size() == 1 && exp_tensor->shape()[0] != 1) ||
|
||||
exp_tensor->data_type() != x_tensor->data_type()) {
|
||||
MS_LOG(ERROR) << "Power inputs shape or type is not equal!";
|
||||
return RET_INPUT_TENSOR_ERROR;
|
||||
}
|
||||
|
|
|
@ -64,11 +64,11 @@ int PowerCPUKernel::RunImpl(int task_id) {
|
|||
bool broadcast = true;
|
||||
if (in_tensors_.size() == 2) {
|
||||
exp_addr = reinterpret_cast<float *>(in_tensors_[1]->Data());
|
||||
broadcast = false;
|
||||
broadcast = in_tensors_[0]->shape() == in_tensors_[1]->shape() ? false : true;
|
||||
}
|
||||
float *cur_exp = nullptr;
|
||||
if (broadcast) {
|
||||
cur_exp = &power_;
|
||||
cur_exp = in_tensors_.size() == 2 ? exp_addr : &power_;
|
||||
} else {
|
||||
cur_exp = exp_addr + stride * task_id;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue