fix equal

This commit is contained in:
yefeng 2023-02-10 11:07:29 +08:00
parent 3e8504a20e
commit dd287634a9
2 changed files with 23 additions and 13 deletions

View File

@ -30,7 +30,7 @@ namespace mindspore {
namespace ops {
namespace {
template <typename T>
void EqualImpl(void *x1, void *x2, void *result, size_t size) {
void EqualImpl(void *x1, void *x2, void *result, size_t size, bool need_broad_cast) {
MS_EXCEPTION_IF_NULL(x1);
MS_EXCEPTION_IF_NULL(x2);
MS_EXCEPTION_IF_NULL(result);
@ -38,7 +38,11 @@ void EqualImpl(void *x1, void *x2, void *result, size_t size) {
T *x2_data = static_cast<T *>(x2);
auto result_data = static_cast<bool *>(result);
for (size_t i = 0; i < size; ++i) {
result_data[i] = x1_data[i] == x2_data[i];
if (need_broad_cast) {
result_data[i] = x1_data[i] == x2_data[0];
} else {
result_data[i] = x1_data[i] == x2_data[i];
}
}
}
@ -91,34 +95,38 @@ ValuePtr EqualInferValue(const PrimitivePtr &prim, const std::vector<AbstractBas
MS_EXCEPTION_IF_NULL(x2_tensor);
auto type_id = x1_tensor->data_type();
auto data_size = x1_tensor->DataSize();
bool need_broad_cast = false;
if (x1_tensor->DataSize() != x2_tensor->DataSize() && x2_tensor->DataSize() == 1) {
need_broad_cast = true;
}
auto result_tensor = std::make_shared<tensor::Tensor>(kNumberTypeBool, result_shape->shape());
switch (type_id) {
case kNumberTypeBool: {
EqualImpl<bool>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size);
EqualImpl<bool>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size, need_broad_cast);
break;
}
case kNumberTypeInt: {
EqualImpl<int>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size);
EqualImpl<int>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size, need_broad_cast);
break;
}
case kNumberTypeInt8: {
EqualImpl<int8_t>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size);
EqualImpl<int8_t>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size, need_broad_cast);
break;
}
case kNumberTypeInt16: {
EqualImpl<int16_t>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size);
EqualImpl<int16_t>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size, need_broad_cast);
break;
}
case kNumberTypeInt32: {
EqualImpl<int32_t>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size);
EqualImpl<int32_t>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size, need_broad_cast);
break;
}
case kNumberTypeInt64: {
EqualImpl<int64_t>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size);
EqualImpl<int64_t>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size, need_broad_cast);
break;
}
case kNumberTypeUInt8: {
EqualImpl<uint8_t>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size);
EqualImpl<uint8_t>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size, need_broad_cast);
break;
}
case kNumberTypeFloat: {
@ -126,7 +134,7 @@ ValuePtr EqualInferValue(const PrimitivePtr &prim, const std::vector<AbstractBas
break;
}
case kNumberTypeFloat16: {
EqualImpl<float16>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size);
EqualImpl<float16>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size, need_broad_cast);
break;
}
case kNumberTypeFloat32: {
@ -138,11 +146,13 @@ ValuePtr EqualInferValue(const PrimitivePtr &prim, const std::vector<AbstractBas
break;
}
case kNumberTypeComplex64: {
EqualImpl<std::complex<float>>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size);
EqualImpl<std::complex<float>>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size,
need_broad_cast);
break;
}
case kNumberTypeComplex128: {
EqualImpl<std::complex<double>>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size);
EqualImpl<std::complex<double>>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size,
need_broad_cast);
break;
}
default: {

View File

@ -147,7 +147,7 @@ Status ModelWorker::Init(const char *model_buf, size_t size) {
MS_LOG(INFO) << "ms model init done.";
origin_worker_inputs_ = model_->GetInputs();
origin_worker_outputs_ = model_->GetOutputs();
if (origin_worker_outputs_.empty() || origin_worker_outputs_.empty()) {
if (origin_worker_outputs_.empty() || origin_worker_inputs_.empty()) {
MS_LOG(ERROR) << "model worker get empty input/output.";
delete model_;
model_ = nullptr;