!27756 fix consie distance

Merge pull request !27756 from zhaodezan/master
This commit is contained in:
i-robot 2021-12-16 11:58:06 +00:00 committed by Gitee
commit 2823059f6e
3 changed files with 9 additions and 11 deletions

View File

@ -134,7 +134,7 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
AddFlag(&BenchmarkFlags::benchmark_data_type_, "benchmarkDataType",
"Benchmark data type. FLOAT | INT32 | INT8 | UINT8", "FLOAT");
AddFlag(&BenchmarkFlags::accuracy_threshold_, "accuracyThreshold", "Threshold of accuracy", 0.5);
AddFlag(&BenchmarkFlags::cosine_distance_threshold_, "cosineDistanceThreshold", "cosine distance threshold", 0.99);
AddFlag(&BenchmarkFlags::cosine_distance_threshold_, "cosineDistanceThreshold", "cosine distance threshold", -1.1);
AddFlag(&BenchmarkFlags::resize_dims_in_, "inputShapes",
"Shape of input data, the format should be NHWC. e.g. 1,32,32,32:1,1,32,32,1", "");
#ifdef ENABLE_OPENGL_TEXTURE
@ -171,7 +171,7 @@ class MS_API BenchmarkFlags : public virtual FlagParser {
std::string benchmark_data_file_;
std::string benchmark_data_type_ = "FLOAT";
float accuracy_threshold_ = 0.5;
float cosine_distance_threshold_ = 0.5;
float cosine_distance_threshold_ = -1.1;
// Resize
std::string resize_dims_in_;
std::vector<std::vector<int>> resize_dims_;

View File

@ -714,14 +714,14 @@ int BenchmarkUnifiedApi::MarkAccuracy() {
std::cerr << "Compare output error " << status << std::endl;
return status;
}
#ifdef SUPPORT_34XX
status = CompareOutputByCosineDistance(this->flags_->cosine_distance_threshold_);
if (status != RET_OK) {
MS_LOG(ERROR) << "Compare output error by consine distance " << status;
std::cerr << "Compare output error by consine distance" << status << std::endl;
return status;
if (this->flags_->cosine_distance_threshold_ >= -1) {
status = CompareOutputByCosineDistance(this->flags_->cosine_distance_threshold_);
if (status != RET_OK) {
MS_LOG(ERROR) << "Compare output error by consine distance " << status;
std::cerr << "Compare output error by consine distance" << status << std::endl;
return status;
}
}
#endif
return RET_OK;
}

View File

@ -117,8 +117,6 @@ class MS_API BenchmarkUnifiedApi : public BenchmarkBase {
MSKernelCallBack ms_before_call_back_ = nullptr;
MSKernelCallBack ms_after_call_back_ = nullptr;
float cosine_distance_threshold_ = 0.99;
};
} // namespace mindspore::lite