diff --git a/mindspore/lite/include/context.h b/mindspore/lite/include/context.h index f7e54727048..7df42a180a1 100644 --- a/mindspore/lite/include/context.h +++ b/mindspore/lite/include/context.h @@ -65,7 +65,7 @@ class MS_API Context { virtual ~Context(); public: - bool float16_priority = true; /**< allow priority select float16 kernel */ + bool float16_priority = false; /**< allow priority select float16 kernel */ DeviceContext device_ctx_{DT_CPU}; int thread_num_ = 2; /**< thread number config for thread pool */ std::shared_ptr allocator = nullptr; diff --git a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc index cb487411bf0..2237a1cbbcc 100644 --- a/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc +++ b/mindspore/lite/src/runtime/kernel/arm/fp16/arithmetic_fp16.cc @@ -446,21 +446,21 @@ kernel::LiteKernel *CpuArithmeticFp16KernelCreator(const std::vectorcpu_bind_mode_ = NO_BIND; } context->thread_num_ = _flags->numThreads; + context->float16_priority = _flags->fp16Priority; session = session::LiteSession::CreateSession(context); delete (context); if (session == nullptr) { @@ -503,6 +504,7 @@ int Benchmark::Init() { MS_LOG(INFO) << "AccuracyThreshold = " << this->_flags->accuracyThreshold; MS_LOG(INFO) << "WarmUpLoopCount = " << this->_flags->warmUpLoopCount; MS_LOG(INFO) << "NumThreads = " << this->_flags->numThreads; + MS_LOG(INFO) << "Fp16Priority = " << this->_flags->fp16Priority; MS_LOG(INFO) << "calibDataPath = " << this->_flags->calibDataPath; if (this->_flags->loopCount < 1) { diff --git a/mindspore/lite/tools/benchmark/benchmark.h b/mindspore/lite/tools/benchmark/benchmark.h index 555ae0ca31c..d39cacc82cc 100644 --- a/mindspore/lite/tools/benchmark/benchmark.h +++ b/mindspore/lite/tools/benchmark/benchmark.h @@ -63,6 +63,7 @@ class MS_API BenchmarkFlags : public virtual FlagParser { // MarkPerformance AddFlag(&BenchmarkFlags::loopCount, "loopCount", "Run loop count", 10); AddFlag(&BenchmarkFlags::numThreads, "numThreads", "Run threads number", 2); + AddFlag(&BenchmarkFlags::fp16Priority, "fp16Priority", "Priority float16", false); AddFlag(&BenchmarkFlags::warmUpLoopCount, "warmUpLoopCount", "Run warm up loop", 3); // MarkAccuracy AddFlag(&BenchmarkFlags::calibDataPath, "calibDataPath", "Calibration data file path", ""); @@ -88,6 +89,7 @@ class MS_API BenchmarkFlags : public virtual FlagParser { // MarkPerformance int loopCount; int numThreads; + bool fp16Priority; int warmUpLoopCount; // MarkAccuracy std::string calibDataPath;