forked from mindspore-Ecosystem/mindspore
!24348 imporve performance of CPU LSTMGrad
Merge pull request !24348 from limingqi107/bug_fix
This commit is contained in:
commit
bc5ccaa4ba
|
@ -30,6 +30,9 @@
|
|||
#include "ir/anf.h"
|
||||
#include "runtime/framework/graph_scheduler.h"
|
||||
#include "actor/actormgr.h"
|
||||
#if defined(__x86_64__) || defined(__amd64__) || defined(_M_IX86) || defined(_M_X64)
|
||||
#define PLATFORM_86
|
||||
#endif
|
||||
|
||||
using mindspore::kernel::Address;
|
||||
using mindspore::kernel::AddressPtr;
|
||||
|
|
|
@ -19,6 +19,9 @@
|
|||
#include "utils/ms_utils.h"
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_kernel_engine.h"
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
#ifdef PLATFORM_86
|
||||
#include <pmmintrin.h>
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
|
|
@ -17,13 +17,6 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LSTM_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_LSTM_CPU_KERNEL_H_
|
||||
|
||||
#if defined(__x86_64__) || defined(__amd64__) || defined(_M_IX86) || defined(_M_X64)
|
||||
#define PLATFORM_86
|
||||
#endif
|
||||
#ifdef PLATFORM_86
|
||||
#include <pmmintrin.h>
|
||||
#endif
|
||||
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include "backend/kernel_compiler/cpu/mkldnn/mkl_cpu_kernel.h"
|
||||
|
|
|
@ -33,12 +33,20 @@
|
|||
#ifndef ENABLE_SECURITY
|
||||
#include "debug/data_dump/dump_json_parser.h"
|
||||
#endif
|
||||
#ifdef PLATFORM_86
|
||||
#include <pmmintrin.h>
|
||||
#endif
|
||||
|
||||
namespace mindspore {
|
||||
namespace device {
|
||||
namespace cpu {
|
||||
using mindspore::kernel::KernelBuildInfo;
|
||||
|
||||
#ifdef PLATFORM_86
|
||||
// Whether need set the flush zero mode in the kernel launch.
|
||||
static bool flush_zero_mode_enable{false};
|
||||
#endif
|
||||
|
||||
void CPUDeviceContext::Initialize() {
|
||||
if (initialized_) {
|
||||
return;
|
||||
|
@ -176,6 +184,14 @@ void CPUDeviceContext::CreateKernel(const std::vector<CNodePtr> &nodes) const {
|
|||
MS_LOG(EXCEPTION) << "Build cpu operator[" << node->fullname_with_scope() << "] failed";
|
||||
}
|
||||
|
||||
#ifdef PLATFORM_86
|
||||
// Some CPU kernels need set the flush zero mode to improve performance.
|
||||
if (!flush_zero_mode_enable &&
|
||||
(kOpNeedSetFlushZeroModeList.find(kernel_name) != kOpNeedSetFlushZeroModeList.end())) {
|
||||
flush_zero_mode_enable = true;
|
||||
}
|
||||
#endif
|
||||
|
||||
cpu_kernel->Init(node);
|
||||
AnfAlgo::SetKernelMod(cpu_kernel, node.get());
|
||||
}
|
||||
|
@ -199,6 +215,14 @@ bool CPUDeviceContext::LaunchKernel(const CNodePtr &kernel, const std::vector<Ad
|
|||
auto cpu_kernel_mod = dynamic_cast<kernel::CPUKernel *>(kernel_mod);
|
||||
MS_EXCEPTION_IF_NULL(cpu_kernel_mod);
|
||||
|
||||
#ifdef PLATFORM_86
|
||||
// Some CPU kernels need set the flush zero mode to improve performance.
|
||||
if (flush_zero_mode_enable) {
|
||||
_MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);
|
||||
_MM_SET_DENORMALS_ZERO_MODE(_MM_DENORMALS_ZERO_ON);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Some CPU kernels can't initialize kernel and launch kernel in different thread, so reinitialize the kernels before
|
||||
// launch.
|
||||
if (kOpNotSupportMultiThreadExecList.find(AnfAlgo::GetCNodeName(kernel)) != kOpNotSupportMultiThreadExecList.end()) {
|
||||
|
|
|
@ -279,6 +279,8 @@ constexpr auto kBasicLSTMCellWeightGradOpName = "BasicLSTMCellWeightGrad";
|
|||
constexpr auto kBasicLSTMCellInputGradOpName = "BasicLSTMCellInputGrad";
|
||||
constexpr auto kBasicLSTMCellOpName = "BasicLSTMCell";
|
||||
constexpr auto kDynamicRNNOpName = "DynamicRNN";
|
||||
constexpr auto kLSTMOpName = "LSTM";
|
||||
constexpr auto kLSTMGradOpName = "LSTMGrad";
|
||||
constexpr auto kLSTMInputGradOpName = "LSTMInputGrad";
|
||||
constexpr auto kDynamicGRUV2OpName = "DynamicGRUV2";
|
||||
constexpr auto kGRUV2HiddenGradOpName = "GRUV2HiddenGrad";
|
||||
|
@ -657,6 +659,8 @@ const std::set<std::string> kPosteriorOperatorSet = {kPullOpName};
|
|||
const std::set<std::string> kOpCacheBlackList = {kUniformCandidateSamplerOpName, kInitDatasetQueueOpName,
|
||||
kGetNextOpName};
|
||||
|
||||
const std::set<std::string> kOpNeedSetFlushZeroModeList = {kLSTMOpName, kLSTMGradOpName};
|
||||
|
||||
const std::set<std::string> kOpNotSupportMultiThreadExecList = {kAvgPoolOpName, kAvgPoolGradOpName, kMaxPoolOpName,
|
||||
kBatchNorm, kBatchNormGradOpName};
|
||||
|
||||
|
|
Loading…
Reference in New Issue