forked from mindspore-Ecosystem/mindspore
335 lines
14 KiB
Diff
335 lines
14 KiB
Diff
diff --git a/src/cpu/nchw_pooling.cpp b/src/cpu/nchw_pooling.cpp
|
|
index b678200a1..09736ccae 100644
|
|
--- a/src/cpu/nchw_pooling.cpp
|
|
+++ b/src/cpu/nchw_pooling.cpp
|
|
@@ -609,10 +609,12 @@ status_t nchw_pooling_bwd_t<data_type::bf16>::execute_backward(
|
|
int od_end = min(OD, 1 + (padF + ID - 1) / SD);
|
|
|
|
dim_t c_blk = pd()->channel_block_size_;
|
|
- int c_blk_tail = C % c_blk;
|
|
+ dim_t c_blk_tail = C % c_blk;
|
|
+ const int nthr = pd()->nthr_;
|
|
+
|
|
if (alg == alg_kind::pooling_max) {
|
|
- parallel_nd_ext(0, MB, utils::div_up(C, c_blk),
|
|
- [&](int ithr, int, int mb, int cb) {
|
|
+ parallel_nd_ext(nthr, MB, utils::div_up(C, c_blk),
|
|
+ [&](int ithr, int, dim_t mb, dim_t cb) {
|
|
bool is_last_c_block
|
|
= c_blk_tail > 0 && (cb + 1) * c_blk > C;
|
|
int curr_c_block = is_last_c_block ? c_blk_tail : c_blk;
|
|
@@ -649,8 +651,8 @@ status_t nchw_pooling_bwd_t<data_type::bf16>::execute_backward(
|
|
diff_src_fp32, src_sp_size * curr_c_block);
|
|
});
|
|
} else {
|
|
- parallel_nd_ext(0, MB, utils::div_up(C, c_blk),
|
|
- [&](int ithr, int, int mb, int cb) {
|
|
+ parallel_nd_ext(nthr, MB, utils::div_up(C, c_blk),
|
|
+ [&](int ithr, int, dim_t mb, dim_t cb) {
|
|
bool is_last_c_block
|
|
= c_blk_tail > 0 && (cb + 1) * c_blk > C;
|
|
int curr_c_block = is_last_c_block ? c_blk_tail : c_blk;
|
|
diff --git a/src/cpu/nchw_pooling.hpp b/src/cpu/nchw_pooling.hpp
|
|
index 9d649f3f5..2a73f6ae6 100644
|
|
--- a/src/cpu/nchw_pooling.hpp
|
|
+++ b/src/cpu/nchw_pooling.hpp
|
|
@@ -139,6 +139,7 @@ struct nchw_pooling_bwd_t : public primitive_t {
|
|
ws_md_ = *hint_fwd_pd_->workspace_md();
|
|
}
|
|
|
|
+ nthr_ = dnnl_get_max_threads();
|
|
calculate_channel_block_size();
|
|
init_scratchpad();
|
|
|
|
@@ -146,6 +147,7 @@ struct nchw_pooling_bwd_t : public primitive_t {
|
|
}
|
|
|
|
dim_t channel_block_size_;
|
|
+ int nthr_; // To not exceed the limit in execute used for set up.
|
|
|
|
private:
|
|
void init_scratchpad() {
|
|
@@ -153,13 +155,12 @@ struct nchw_pooling_bwd_t : public primitive_t {
|
|
if (diff_dst_md()->data_type == data_type::bf16) {
|
|
size_t dst_sz_ = OD() * OH() * OW();
|
|
size_t src_sz_ = ID() * IH() * IW();
|
|
- size_t nthrs = dnnl_get_max_threads();
|
|
auto scratchpad = scratchpad_registry().registrar();
|
|
|
|
scratchpad.template book<float>(key_pool_src_bf16cvt,
|
|
- src_sz_ * nthrs * channel_block_size_);
|
|
+ src_sz_ * nthr_ * channel_block_size_);
|
|
scratchpad.template book<float>(key_pool_dst_bf16cvt,
|
|
- dst_sz_ * nthrs * channel_block_size_);
|
|
+ dst_sz_ * nthr_ * channel_block_size_);
|
|
}
|
|
}
|
|
|
|
@@ -169,8 +170,7 @@ struct nchw_pooling_bwd_t : public primitive_t {
|
|
// spatial
|
|
dim_t dst_sz_ = OD() * OH() * OW();
|
|
dim_t src_sz_ = ID() * IH() * IW();
|
|
- dim_t nthrs = dnnl_get_max_threads();
|
|
- dim_t C_per_thr = nstl::min(MB() * C() / nthrs, C());
|
|
+ dim_t C_per_thr = nstl::min(MB() * C() / nthr_, C());
|
|
const dim_t max_block_size
|
|
= platform::get_per_core_cache_size(1) / 2;
|
|
dim_t data_size_per_ch = (dst_sz_ + src_sz_) * 6; // f32 + bf16
|
|
diff --git a/src/cpu/nhwc_pooling.cpp b/src/cpu/nhwc_pooling.cpp
|
|
index 48d9e1240..efe3083f7 100644
|
|
--- a/src/cpu/nhwc_pooling.cpp
|
|
+++ b/src/cpu/nhwc_pooling.cpp
|
|
@@ -378,8 +378,9 @@ status_t nhwc_pooling_fwd_t<data_type::bf16>::execute_forward(
|
|
return OSP * OC * mb + OSP * oc + SP * od + OW * oh + ow;
|
|
};
|
|
const bool are_postops_set = !(pd()->attr()->post_ops_.entry_.empty());
|
|
+ const int nthr = pd()->nthr_;
|
|
|
|
- parallel_nd_ext(0, MB, OD, OH, OW,
|
|
+ parallel_nd_ext(nthr, MB, OD, OH, OW,
|
|
[&](int ithr, int, int mb, int od, int oh, int ow) {
|
|
const size_t dst_offset_init = strided_offset(mb, dst_n_stride,
|
|
od, dst_d_stride, oh, dst_h_stride, ow, dst_w_stride);
|
|
@@ -682,8 +683,9 @@ status_t nhwc_pooling_bwd_t<data_type::bf16>::execute_backward(
|
|
auto apply_offset = [=](int index, int offset) {
|
|
return (index > offset) ? index - offset : 0;
|
|
};
|
|
+ const int nthr = pd()->nthr_;
|
|
|
|
- parallel_nd_ext(0, MB, ID, IH, IW,
|
|
+ parallel_nd_ext(nthr, MB, ID, IH, IW,
|
|
[&](int ithr, int, int mb, int id, int ih, int iw) {
|
|
size_t src_offset_init = strided_offset(mb, diff_src_n_stride,
|
|
id, diff_src_d_stride, ih, diff_src_h_stride, iw,
|
|
diff --git a/src/cpu/nhwc_pooling.hpp b/src/cpu/nhwc_pooling.hpp
|
|
index c65196a94..c16e840a2 100644
|
|
--- a/src/cpu/nhwc_pooling.hpp
|
|
+++ b/src/cpu/nhwc_pooling.hpp
|
|
@@ -73,16 +73,19 @@ struct nhwc_pooling_fwd_t : public primitive_t {
|
|
init_default_ws();
|
|
}
|
|
|
|
+ nthr_ = dnnl_get_max_threads();
|
|
init_scratchpad();
|
|
|
|
return status::success;
|
|
}
|
|
|
|
+ int nthr_; // To not exceed the limit in execute used for set up.
|
|
+
|
|
private:
|
|
void init_scratchpad() {
|
|
using namespace memory_tracking::names;
|
|
if (src_md()->data_type == data_type::bf16) {
|
|
- const size_t bf16cvt_sz_ = C() * dnnl_get_max_threads();
|
|
+ const size_t bf16cvt_sz_ = C() * nthr_;
|
|
auto scratchpad = scratchpad_registry().registrar();
|
|
scratchpad.template book<float>(
|
|
key_pool_src_bf16cvt, bf16cvt_sz_);
|
|
@@ -148,16 +151,19 @@ struct nhwc_pooling_bwd_t : public primitive_t {
|
|
if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
|
|
}
|
|
|
|
+ nthr_ = dnnl_get_max_threads();
|
|
init_scratchpad();
|
|
|
|
return status::success;
|
|
}
|
|
|
|
+ int nthr_; // To not exceed the limit in execute used for set up.
|
|
+
|
|
private:
|
|
void init_scratchpad() {
|
|
using namespace memory_tracking::names;
|
|
if (diff_src_md()->data_type == data_type::bf16) {
|
|
- size_t bf16cvt_sz_ = C() * dnnl_get_max_threads();
|
|
+ size_t bf16cvt_sz_ = C() * nthr_;
|
|
auto scratchpad = scratchpad_registry().registrar();
|
|
scratchpad.template book<float>(
|
|
key_pool_src_bf16cvt, bf16cvt_sz_);
|
|
diff --git a/src/cpu/x64/jit_primitive_conf.hpp b/src/cpu/x64/jit_primitive_conf.hpp
|
|
index a2a181cfa..5befb81ac 100644
|
|
--- a/src/cpu/x64/jit_primitive_conf.hpp
|
|
+++ b/src/cpu/x64/jit_primitive_conf.hpp
|
|
@@ -672,6 +672,7 @@ struct jit_pool_conf_t {
|
|
bool with_postops;
|
|
bool with_eltwise;
|
|
bool with_binary;
|
|
+ int nthr;
|
|
};
|
|
|
|
struct jit_pool_call_s {
|
|
diff --git a/src/cpu/x64/jit_uni_pool_kernel.cpp b/src/cpu/x64/jit_uni_pool_kernel.cpp
|
|
index 36d129e6d..ebd4f3af1 100644
|
|
--- a/src/cpu/x64/jit_uni_pool_kernel.cpp
|
|
+++ b/src/cpu/x64/jit_uni_pool_kernel.cpp
|
|
@@ -76,8 +76,7 @@ jit_uni_pool_kernel<isa>::jit_uni_pool_kernel(
|
|
|
|
template <cpu_isa_t isa>
|
|
status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp,
|
|
- memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd,
|
|
- int nthreads) {
|
|
+ memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd) {
|
|
|
|
const auto &pd = *ppd->desc();
|
|
const memory_desc_wrapper src_d(
|
|
@@ -87,6 +86,7 @@ status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp,
|
|
|
|
const int ndims = src_d.ndims();
|
|
|
|
+ jpp.nthr = dnnl_get_max_threads();
|
|
jpp.is_training = pd.prop_kind == prop_kind::forward_training;
|
|
jpp.is_backward = pd.prop_kind == prop_kind::backward_data;
|
|
|
|
@@ -248,7 +248,7 @@ status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp,
|
|
? (ndims == 5 && jpp.simple_alg ? jpp.od : 1)
|
|
: (ndims == 5 ? jpp.od : jpp.oh);
|
|
work *= jpp.mb * nb2_c;
|
|
- auto eff = (float)work / utils::rnd_up(work, nthreads);
|
|
+ auto eff = (float)work / utils::rnd_up(work, jpp.nthr);
|
|
if (eff > best_eff) {
|
|
|
|
best_eff = eff;
|
|
diff --git a/src/cpu/x64/jit_uni_pool_kernel.hpp b/src/cpu/x64/jit_uni_pool_kernel.hpp
|
|
index d5d5f25a2..57ce6f43d 100644
|
|
--- a/src/cpu/x64/jit_uni_pool_kernel.hpp
|
|
+++ b/src/cpu/x64/jit_uni_pool_kernel.hpp
|
|
@@ -46,8 +46,7 @@ struct jit_uni_pool_kernel : public jit_generator {
|
|
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_pool_kernel)
|
|
|
|
static status_t init_conf(jit_pool_conf_t &jbp,
|
|
- memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd,
|
|
- int nthreads);
|
|
+ memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd);
|
|
|
|
private:
|
|
using Xmm = Xbyak::Xmm;
|
|
diff --git a/src/cpu/x64/jit_uni_pooling.cpp b/src/cpu/x64/jit_uni_pooling.cpp
|
|
index b2055f2a9..29987f70c 100644
|
|
--- a/src/cpu/x64/jit_uni_pooling.cpp
|
|
+++ b/src/cpu/x64/jit_uni_pooling.cpp
|
|
@@ -612,6 +612,8 @@ void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward(const data_t *src,
|
|
(*kernel_)(&arg);
|
|
};
|
|
|
|
+ const int nthr = jpp.nthr;
|
|
+
|
|
if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) {
|
|
const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
|
|
parallel_nd(jpp.mb, jpp.oh, nb2_c, [&](int n, int oh, int b2_c) {
|
|
@@ -622,7 +624,7 @@ void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward(const data_t *src,
|
|
} else {
|
|
if (trans_src || trans_dst) {
|
|
// ncsp format
|
|
- parallel_nd_ext(0, jpp.mb, jpp.nb_c,
|
|
+ parallel_nd_ext(nthr, jpp.mb, jpp.nb_c,
|
|
[&](int ithr, int nthr, int n, int b_c) {
|
|
if (trans_src)
|
|
transpose_facade.execute_transpose_input(
|
|
@@ -635,7 +637,7 @@ void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward(const data_t *src,
|
|
});
|
|
} else {
|
|
// nChw16c, nChw8c format
|
|
- parallel(0, [&](std::size_t ithr, std::size_t nthr) {
|
|
+ parallel(nthr, [&](int ithr, int nthr) {
|
|
const std::size_t work_amount
|
|
= static_cast<std::size_t>(jpp.mb) * jpp.nb_c * jpp.oh;
|
|
if (ithr >= work_amount) return;
|
|
@@ -739,6 +741,8 @@ void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward_3d(const data_t *src,
|
|
(*kernel_)(&arg);
|
|
};
|
|
|
|
+ const int nthr = jpp.nthr;
|
|
+
|
|
if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) {
|
|
const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
|
|
parallel_nd(jpp.mb, jpp.od, nb2_c, [&](int n, int od, int b2_c) {
|
|
@@ -757,7 +761,7 @@ void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward_3d(const data_t *src,
|
|
});
|
|
} else {
|
|
if (trans_src || trans_dst) {
|
|
- parallel_nd_ext(0, jpp.mb, jpp.nb_c,
|
|
+ parallel_nd_ext(nthr, jpp.mb, jpp.nb_c,
|
|
[&](int ithr, int nthr, int n, int b_c) {
|
|
if (trans_src)
|
|
transpose_facade.execute_transpose_input(
|
|
@@ -948,7 +952,9 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward(
|
|
transpose_facade.execute_transpose_output(ithr, n, b_c);
|
|
};
|
|
|
|
- parallel(0, [&](int ithr, int nthr) {
|
|
+ const int nthr = jpp.nthr;
|
|
+
|
|
+ parallel(nthr, [&](int ithr, int nthr) {
|
|
const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
|
|
const std::size_t work_amount
|
|
= static_cast<std::size_t>(jpp.mb) * nb2_c;
|
|
@@ -1098,6 +1104,8 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward_3d(
|
|
}
|
|
};
|
|
|
|
+ const int nthr = jpp.nthr;
|
|
+
|
|
if (jpp.simple_alg) {
|
|
if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) {
|
|
const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
|
|
@@ -1109,7 +1117,7 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward_3d(
|
|
} else {
|
|
assert(jpp.ur_bc == 1);
|
|
if (trans_src || trans_dst) {
|
|
- parallel_nd_ext(0, jpp.mb, jpp.nb_c,
|
|
+ parallel_nd_ext(nthr, jpp.mb, jpp.nb_c,
|
|
[&](int ithr, int nthr, int n, int b_c) {
|
|
if (trans_src)
|
|
transpose_facade.execute_transpose_input(
|
|
@@ -1142,7 +1150,7 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward_3d(
|
|
if (!trans_src) {
|
|
const size_t chunk_size
|
|
= (size_t)jpp.id * jpp.ih * jpp.iw * jpp.c_block;
|
|
- parallel_nd_ext(0, jpp.mb, jpp.nb_c,
|
|
+ parallel_nd_ext(nthr, jpp.mb, jpp.nb_c,
|
|
[&](int ithr, int nthr, int n, int b_c) {
|
|
const size_t offset
|
|
= ((size_t)n * jpp.nb_c + b_c) * chunk_size;
|
|
@@ -1155,8 +1163,8 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward_3d(
|
|
|
|
const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
|
|
if (trans_src || trans_dst) {
|
|
- parallel_nd_ext(
|
|
- 0, jpp.mb, nb2_c, [&](int ithr, int nthr, int n, int b2_c) {
|
|
+ parallel_nd_ext(nthr, jpp.mb, nb2_c,
|
|
+ [&](int ithr, int nthr, int n, int b2_c) {
|
|
const auto b_c = b2_c * jpp.ur_bc;
|
|
|
|
if (trans_dst) {
|
|
diff --git a/src/cpu/x64/jit_uni_pooling.hpp b/src/cpu/x64/jit_uni_pooling.hpp
|
|
index ec4b04a2b..e25d9ce05 100644
|
|
--- a/src/cpu/x64/jit_uni_pooling.hpp
|
|
+++ b/src/cpu/x64/jit_uni_pooling.hpp
|
|
@@ -66,8 +66,9 @@ struct jit_uni_pooling_fwd_t : public primitive_t {
|
|
init_default_ws();
|
|
|
|
auto scratchpad = scratchpad_registry().registrar();
|
|
- return jit_uni_pool_kernel<isa>::init_conf(
|
|
- jpp_, scratchpad, this, dnnl_get_max_threads());
|
|
+ CHECK(jit_uni_pool_kernel<isa>::init_conf(jpp_, scratchpad, this));
|
|
+
|
|
+ return status::success;
|
|
}
|
|
|
|
jit_pool_conf_t jpp_;
|
|
@@ -130,9 +131,11 @@ struct jit_uni_pooling_bwd_t : public primitive_t {
|
|
init_default_ws();
|
|
if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
|
|
}
|
|
+
|
|
auto scratchpad = scratchpad_registry().registrar();
|
|
- return jit_uni_pool_kernel<isa>::init_conf(
|
|
- jpp_, scratchpad, this, dnnl_get_max_threads());
|
|
+ CHECK(jit_uni_pool_kernel<isa>::init_conf(jpp_, scratchpad, this));
|
|
+
|
|
+ return status::success;
|
|
}
|
|
|
|
jit_pool_conf_t jpp_;
|