From 53dbabf0267040e589974ee3ea233b5af4a125e2 Mon Sep 17 00:00:00 2001 From: zhaodezan Date: Thu, 16 Dec 2021 10:48:59 +0800 Subject: [PATCH] fix consie distance --- mindspore/lite/tools/benchmark/benchmark_base.h | 4 ++-- .../lite/tools/benchmark/benchmark_unified_api.cc | 14 +++++++------- .../lite/tools/benchmark/benchmark_unified_api.h | 2 -- 3 files changed, 9 insertions(+), 11 deletions(-) diff --git a/mindspore/lite/tools/benchmark/benchmark_base.h b/mindspore/lite/tools/benchmark/benchmark_base.h index 64faf2f9cbc..617ef55852f 100644 --- a/mindspore/lite/tools/benchmark/benchmark_base.h +++ b/mindspore/lite/tools/benchmark/benchmark_base.h @@ -134,7 +134,7 @@ class MS_API BenchmarkFlags : public virtual FlagParser { AddFlag(&BenchmarkFlags::benchmark_data_type_, "benchmarkDataType", "Benchmark data type. FLOAT | INT32 | INT8 | UINT8", "FLOAT"); AddFlag(&BenchmarkFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5); - AddFlag(&BenchmarkFlags::cosine_distance_threshold_, "cosineDistanceThreshold", "cosine distance threshold", 0.99); + AddFlag(&BenchmarkFlags::cosine_distance_threshold_, "cosineDistanceThreshold", "cosine distance threshold", -1.1); AddFlag(&BenchmarkFlags::resize_dims_in_, "inputShapes", "Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", ""); #ifdef ENABLE_OPENGL_TEXTURE @@ -171,7 +171,7 @@ class MS_API BenchmarkFlags : public virtual FlagParser { std::string benchmark_data_file_; std::string benchmark_data_type_ = "FLOAT"; float accuracy_threshold_ = 0.5; - float cosine_distance_threshold_ = 0.5; + float cosine_distance_threshold_ = -1.1; // Resize std::string resize_dims_in_; std::vector> resize_dims_; diff --git a/mindspore/lite/tools/benchmark/benchmark_unified_api.cc b/mindspore/lite/tools/benchmark/benchmark_unified_api.cc index 31dcd17c932..7156253a90d 100644 --- a/mindspore/lite/tools/benchmark/benchmark_unified_api.cc +++ b/mindspore/lite/tools/benchmark/benchmark_unified_api.cc @@ -714,14 +714,14 @@ int BenchmarkUnifiedApi::MarkAccuracy() { std::cerr << "Compare output error " << status << std::endl; return status; } -#ifdef SUPPORT_34XX - status = CompareOutputByCosineDistance(this->flags_->cosine_distance_threshold_); - if (status != RET_OK) { - MS_LOG(ERROR) << "Compare output error by consine distance " << status; - std::cerr << "Compare output error by consine distance" << status << std::endl; - return status; + if (this->flags_->cosine_distance_threshold_ >= -1) { + status = CompareOutputByCosineDistance(this->flags_->cosine_distance_threshold_); + if (status != RET_OK) { + MS_LOG(ERROR) << "Compare output error by consine distance " << status; + std::cerr << "Compare output error by consine distance" << status << std::endl; + return status; + } } -#endif return RET_OK; } diff --git a/mindspore/lite/tools/benchmark/benchmark_unified_api.h b/mindspore/lite/tools/benchmark/benchmark_unified_api.h index 78ae5784959..350ed32b01c 100644 --- a/mindspore/lite/tools/benchmark/benchmark_unified_api.h +++ b/mindspore/lite/tools/benchmark/benchmark_unified_api.h @@ -117,8 +117,6 @@ class MS_API BenchmarkUnifiedApi : public BenchmarkBase { MSKernelCallBack ms_before_call_back_ = nullptr; MSKernelCallBack ms_after_call_back_ = nullptr; - - float cosine_distance_threshold_ = 0.99; }; } // namespace mindspore::lite