forked from mindspore-Ecosystem/mindspore
fix equal
This commit is contained in:
parent
3e8504a20e
commit
dd287634a9
|
@ -30,7 +30,7 @@ namespace mindspore {
|
|||
namespace ops {
|
||||
namespace {
|
||||
template <typename T>
|
||||
void EqualImpl(void *x1, void *x2, void *result, size_t size) {
|
||||
void EqualImpl(void *x1, void *x2, void *result, size_t size, bool need_broad_cast) {
|
||||
MS_EXCEPTION_IF_NULL(x1);
|
||||
MS_EXCEPTION_IF_NULL(x2);
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
|
@ -38,7 +38,11 @@ void EqualImpl(void *x1, void *x2, void *result, size_t size) {
|
|||
T *x2_data = static_cast<T *>(x2);
|
||||
auto result_data = static_cast<bool *>(result);
|
||||
for (size_t i = 0; i < size; ++i) {
|
||||
result_data[i] = x1_data[i] == x2_data[i];
|
||||
if (need_broad_cast) {
|
||||
result_data[i] = x1_data[i] == x2_data[0];
|
||||
} else {
|
||||
result_data[i] = x1_data[i] == x2_data[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -91,34 +95,38 @@ ValuePtr EqualInferValue(const PrimitivePtr &prim, const std::vector<AbstractBas
|
|||
MS_EXCEPTION_IF_NULL(x2_tensor);
|
||||
auto type_id = x1_tensor->data_type();
|
||||
auto data_size = x1_tensor->DataSize();
|
||||
bool need_broad_cast = false;
|
||||
if (x1_tensor->DataSize() != x2_tensor->DataSize() && x2_tensor->DataSize() == 1) {
|
||||
need_broad_cast = true;
|
||||
}
|
||||
auto result_tensor = std::make_shared<tensor::Tensor>(kNumberTypeBool, result_shape->shape());
|
||||
switch (type_id) {
|
||||
case kNumberTypeBool: {
|
||||
EqualImpl<bool>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size);
|
||||
EqualImpl<bool>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size, need_broad_cast);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeInt: {
|
||||
EqualImpl<int>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size);
|
||||
EqualImpl<int>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size, need_broad_cast);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeInt8: {
|
||||
EqualImpl<int8_t>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size);
|
||||
EqualImpl<int8_t>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size, need_broad_cast);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeInt16: {
|
||||
EqualImpl<int16_t>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size);
|
||||
EqualImpl<int16_t>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size, need_broad_cast);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeInt32: {
|
||||
EqualImpl<int32_t>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size);
|
||||
EqualImpl<int32_t>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size, need_broad_cast);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeInt64: {
|
||||
EqualImpl<int64_t>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size);
|
||||
EqualImpl<int64_t>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size, need_broad_cast);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeUInt8: {
|
||||
EqualImpl<uint8_t>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size);
|
||||
EqualImpl<uint8_t>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size, need_broad_cast);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeFloat: {
|
||||
|
@ -126,7 +134,7 @@ ValuePtr EqualInferValue(const PrimitivePtr &prim, const std::vector<AbstractBas
|
|||
break;
|
||||
}
|
||||
case kNumberTypeFloat16: {
|
||||
EqualImpl<float16>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size);
|
||||
EqualImpl<float16>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size, need_broad_cast);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeFloat32: {
|
||||
|
@ -138,11 +146,13 @@ ValuePtr EqualInferValue(const PrimitivePtr &prim, const std::vector<AbstractBas
|
|||
break;
|
||||
}
|
||||
case kNumberTypeComplex64: {
|
||||
EqualImpl<std::complex<float>>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size);
|
||||
EqualImpl<std::complex<float>>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size,
|
||||
need_broad_cast);
|
||||
break;
|
||||
}
|
||||
case kNumberTypeComplex128: {
|
||||
EqualImpl<std::complex<double>>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size);
|
||||
EqualImpl<std::complex<double>>(x1_tensor->data_c(), x2_tensor->data_c(), result_tensor->data_c(), data_size,
|
||||
need_broad_cast);
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
|
|
|
@ -147,7 +147,7 @@ Status ModelWorker::Init(const char *model_buf, size_t size) {
|
|||
MS_LOG(INFO) << "ms model init done.";
|
||||
origin_worker_inputs_ = model_->GetInputs();
|
||||
origin_worker_outputs_ = model_->GetOutputs();
|
||||
if (origin_worker_outputs_.empty() || origin_worker_outputs_.empty()) {
|
||||
if (origin_worker_outputs_.empty() || origin_worker_inputs_.empty()) {
|
||||
MS_LOG(ERROR) << "model worker get empty input/output.";
|
||||
delete model_;
|
||||
model_ = nullptr;
|
||||
|
|
Loading…
Reference in New Issue