mindspore/third_party/patch/onednn/0002-fix-pool-nthr-bug.patch

335 lines
14 KiB
Diff

diff --git a/src/cpu/nchw_pooling.cpp b/src/cpu/nchw_pooling.cpp
index b678200a1..09736ccae 100644
--- a/src/cpu/nchw_pooling.cpp
+++ b/src/cpu/nchw_pooling.cpp
@@ -609,10 +609,12 @@ status_t nchw_pooling_bwd_t<data_type::bf16>::execute_backward(
int od_end = min(OD, 1 + (padF + ID - 1) / SD);
dim_t c_blk = pd()->channel_block_size_;
- int c_blk_tail = C % c_blk;
+ dim_t c_blk_tail = C % c_blk;
+ const int nthr = pd()->nthr_;
+
if (alg == alg_kind::pooling_max) {
- parallel_nd_ext(0, MB, utils::div_up(C, c_blk),
- [&](int ithr, int, int mb, int cb) {
+ parallel_nd_ext(nthr, MB, utils::div_up(C, c_blk),
+ [&](int ithr, int, dim_t mb, dim_t cb) {
bool is_last_c_block
= c_blk_tail > 0 && (cb + 1) * c_blk > C;
int curr_c_block = is_last_c_block ? c_blk_tail : c_blk;
@@ -649,8 +651,8 @@ status_t nchw_pooling_bwd_t<data_type::bf16>::execute_backward(
diff_src_fp32, src_sp_size * curr_c_block);
});
} else {
- parallel_nd_ext(0, MB, utils::div_up(C, c_blk),
- [&](int ithr, int, int mb, int cb) {
+ parallel_nd_ext(nthr, MB, utils::div_up(C, c_blk),
+ [&](int ithr, int, dim_t mb, dim_t cb) {
bool is_last_c_block
= c_blk_tail > 0 && (cb + 1) * c_blk > C;
int curr_c_block = is_last_c_block ? c_blk_tail : c_blk;
diff --git a/src/cpu/nchw_pooling.hpp b/src/cpu/nchw_pooling.hpp
index 9d649f3f5..2a73f6ae6 100644
--- a/src/cpu/nchw_pooling.hpp
+++ b/src/cpu/nchw_pooling.hpp
@@ -139,6 +139,7 @@ struct nchw_pooling_bwd_t : public primitive_t {
ws_md_ = *hint_fwd_pd_->workspace_md();
}
+ nthr_ = dnnl_get_max_threads();
calculate_channel_block_size();
init_scratchpad();
@@ -146,6 +147,7 @@ struct nchw_pooling_bwd_t : public primitive_t {
}
dim_t channel_block_size_;
+ int nthr_; // To not exceed the limit in execute used for set up.
private:
void init_scratchpad() {
@@ -153,13 +155,12 @@ struct nchw_pooling_bwd_t : public primitive_t {
if (diff_dst_md()->data_type == data_type::bf16) {
size_t dst_sz_ = OD() * OH() * OW();
size_t src_sz_ = ID() * IH() * IW();
- size_t nthrs = dnnl_get_max_threads();
auto scratchpad = scratchpad_registry().registrar();
scratchpad.template book<float>(key_pool_src_bf16cvt,
- src_sz_ * nthrs * channel_block_size_);
+ src_sz_ * nthr_ * channel_block_size_);
scratchpad.template book<float>(key_pool_dst_bf16cvt,
- dst_sz_ * nthrs * channel_block_size_);
+ dst_sz_ * nthr_ * channel_block_size_);
}
}
@@ -169,8 +170,7 @@ struct nchw_pooling_bwd_t : public primitive_t {
// spatial
dim_t dst_sz_ = OD() * OH() * OW();
dim_t src_sz_ = ID() * IH() * IW();
- dim_t nthrs = dnnl_get_max_threads();
- dim_t C_per_thr = nstl::min(MB() * C() / nthrs, C());
+ dim_t C_per_thr = nstl::min(MB() * C() / nthr_, C());
const dim_t max_block_size
= platform::get_per_core_cache_size(1) / 2;
dim_t data_size_per_ch = (dst_sz_ + src_sz_) * 6; // f32 + bf16
diff --git a/src/cpu/nhwc_pooling.cpp b/src/cpu/nhwc_pooling.cpp
index 48d9e1240..efe3083f7 100644
--- a/src/cpu/nhwc_pooling.cpp
+++ b/src/cpu/nhwc_pooling.cpp
@@ -378,8 +378,9 @@ status_t nhwc_pooling_fwd_t<data_type::bf16>::execute_forward(
return OSP * OC * mb + OSP * oc + SP * od + OW * oh + ow;
};
const bool are_postops_set = !(pd()->attr()->post_ops_.entry_.empty());
+ const int nthr = pd()->nthr_;
- parallel_nd_ext(0, MB, OD, OH, OW,
+ parallel_nd_ext(nthr, MB, OD, OH, OW,
[&](int ithr, int, int mb, int od, int oh, int ow) {
const size_t dst_offset_init = strided_offset(mb, dst_n_stride,
od, dst_d_stride, oh, dst_h_stride, ow, dst_w_stride);
@@ -682,8 +683,9 @@ status_t nhwc_pooling_bwd_t<data_type::bf16>::execute_backward(
auto apply_offset = [=](int index, int offset) {
return (index > offset) ? index - offset : 0;
};
+ const int nthr = pd()->nthr_;
- parallel_nd_ext(0, MB, ID, IH, IW,
+ parallel_nd_ext(nthr, MB, ID, IH, IW,
[&](int ithr, int, int mb, int id, int ih, int iw) {
size_t src_offset_init = strided_offset(mb, diff_src_n_stride,
id, diff_src_d_stride, ih, diff_src_h_stride, iw,
diff --git a/src/cpu/nhwc_pooling.hpp b/src/cpu/nhwc_pooling.hpp
index c65196a94..c16e840a2 100644
--- a/src/cpu/nhwc_pooling.hpp
+++ b/src/cpu/nhwc_pooling.hpp
@@ -73,16 +73,19 @@ struct nhwc_pooling_fwd_t : public primitive_t {
init_default_ws();
}
+ nthr_ = dnnl_get_max_threads();
init_scratchpad();
return status::success;
}
+ int nthr_; // To not exceed the limit in execute used for set up.
+
private:
void init_scratchpad() {
using namespace memory_tracking::names;
if (src_md()->data_type == data_type::bf16) {
- const size_t bf16cvt_sz_ = C() * dnnl_get_max_threads();
+ const size_t bf16cvt_sz_ = C() * nthr_;
auto scratchpad = scratchpad_registry().registrar();
scratchpad.template book<float>(
key_pool_src_bf16cvt, bf16cvt_sz_);
@@ -148,16 +151,19 @@ struct nhwc_pooling_bwd_t : public primitive_t {
if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
}
+ nthr_ = dnnl_get_max_threads();
init_scratchpad();
return status::success;
}
+ int nthr_; // To not exceed the limit in execute used for set up.
+
private:
void init_scratchpad() {
using namespace memory_tracking::names;
if (diff_src_md()->data_type == data_type::bf16) {
- size_t bf16cvt_sz_ = C() * dnnl_get_max_threads();
+ size_t bf16cvt_sz_ = C() * nthr_;
auto scratchpad = scratchpad_registry().registrar();
scratchpad.template book<float>(
key_pool_src_bf16cvt, bf16cvt_sz_);
diff --git a/src/cpu/x64/jit_primitive_conf.hpp b/src/cpu/x64/jit_primitive_conf.hpp
index a2a181cfa..5befb81ac 100644
--- a/src/cpu/x64/jit_primitive_conf.hpp
+++ b/src/cpu/x64/jit_primitive_conf.hpp
@@ -672,6 +672,7 @@ struct jit_pool_conf_t {
bool with_postops;
bool with_eltwise;
bool with_binary;
+ int nthr;
};
struct jit_pool_call_s {
diff --git a/src/cpu/x64/jit_uni_pool_kernel.cpp b/src/cpu/x64/jit_uni_pool_kernel.cpp
index 36d129e6d..ebd4f3af1 100644
--- a/src/cpu/x64/jit_uni_pool_kernel.cpp
+++ b/src/cpu/x64/jit_uni_pool_kernel.cpp
@@ -76,8 +76,7 @@ jit_uni_pool_kernel<isa>::jit_uni_pool_kernel(
template <cpu_isa_t isa>
status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp,
- memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd,
- int nthreads) {
+ memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd) {
const auto &pd = *ppd->desc();
const memory_desc_wrapper src_d(
@@ -87,6 +86,7 @@ status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp,
const int ndims = src_d.ndims();
+ jpp.nthr = dnnl_get_max_threads();
jpp.is_training = pd.prop_kind == prop_kind::forward_training;
jpp.is_backward = pd.prop_kind == prop_kind::backward_data;
@@ -248,7 +248,7 @@ status_t jit_uni_pool_kernel<isa>::init_conf(jit_pool_conf_t &jpp,
? (ndims == 5 && jpp.simple_alg ? jpp.od : 1)
: (ndims == 5 ? jpp.od : jpp.oh);
work *= jpp.mb * nb2_c;
- auto eff = (float)work / utils::rnd_up(work, nthreads);
+ auto eff = (float)work / utils::rnd_up(work, jpp.nthr);
if (eff > best_eff) {
best_eff = eff;
diff --git a/src/cpu/x64/jit_uni_pool_kernel.hpp b/src/cpu/x64/jit_uni_pool_kernel.hpp
index d5d5f25a2..57ce6f43d 100644
--- a/src/cpu/x64/jit_uni_pool_kernel.hpp
+++ b/src/cpu/x64/jit_uni_pool_kernel.hpp
@@ -46,8 +46,7 @@ struct jit_uni_pool_kernel : public jit_generator {
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_pool_kernel)
static status_t init_conf(jit_pool_conf_t &jbp,
- memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd,
- int nthreads);
+ memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd);
private:
using Xmm = Xbyak::Xmm;
diff --git a/src/cpu/x64/jit_uni_pooling.cpp b/src/cpu/x64/jit_uni_pooling.cpp
index b2055f2a9..29987f70c 100644
--- a/src/cpu/x64/jit_uni_pooling.cpp
+++ b/src/cpu/x64/jit_uni_pooling.cpp
@@ -612,6 +612,8 @@ void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward(const data_t *src,
(*kernel_)(&arg);
};
+ const int nthr = jpp.nthr;
+
if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) {
const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
parallel_nd(jpp.mb, jpp.oh, nb2_c, [&](int n, int oh, int b2_c) {
@@ -622,7 +624,7 @@ void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward(const data_t *src,
} else {
if (trans_src || trans_dst) {
// ncsp format
- parallel_nd_ext(0, jpp.mb, jpp.nb_c,
+ parallel_nd_ext(nthr, jpp.mb, jpp.nb_c,
[&](int ithr, int nthr, int n, int b_c) {
if (trans_src)
transpose_facade.execute_transpose_input(
@@ -635,7 +637,7 @@ void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward(const data_t *src,
});
} else {
// nChw16c, nChw8c format
- parallel(0, [&](std::size_t ithr, std::size_t nthr) {
+ parallel(nthr, [&](int ithr, int nthr) {
const std::size_t work_amount
= static_cast<std::size_t>(jpp.mb) * jpp.nb_c * jpp.oh;
if (ithr >= work_amount) return;
@@ -739,6 +741,8 @@ void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward_3d(const data_t *src,
(*kernel_)(&arg);
};
+ const int nthr = jpp.nthr;
+
if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) {
const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
parallel_nd(jpp.mb, jpp.od, nb2_c, [&](int n, int od, int b2_c) {
@@ -757,7 +761,7 @@ void jit_uni_pooling_fwd_t<isa, d_type>::execute_forward_3d(const data_t *src,
});
} else {
if (trans_src || trans_dst) {
- parallel_nd_ext(0, jpp.mb, jpp.nb_c,
+ parallel_nd_ext(nthr, jpp.mb, jpp.nb_c,
[&](int ithr, int nthr, int n, int b_c) {
if (trans_src)
transpose_facade.execute_transpose_input(
@@ -948,7 +952,9 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward(
transpose_facade.execute_transpose_output(ithr, n, b_c);
};
- parallel(0, [&](int ithr, int nthr) {
+ const int nthr = jpp.nthr;
+
+ parallel(nthr, [&](int ithr, int nthr) {
const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
const std::size_t work_amount
= static_cast<std::size_t>(jpp.mb) * nb2_c;
@@ -1098,6 +1104,8 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward_3d(
}
};
+ const int nthr = jpp.nthr;
+
if (jpp.simple_alg) {
if (jpp.tag_kind == jit_memory_tag_kind_t::nspc) {
const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
@@ -1109,7 +1117,7 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward_3d(
} else {
assert(jpp.ur_bc == 1);
if (trans_src || trans_dst) {
- parallel_nd_ext(0, jpp.mb, jpp.nb_c,
+ parallel_nd_ext(nthr, jpp.mb, jpp.nb_c,
[&](int ithr, int nthr, int n, int b_c) {
if (trans_src)
transpose_facade.execute_transpose_input(
@@ -1142,7 +1150,7 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward_3d(
if (!trans_src) {
const size_t chunk_size
= (size_t)jpp.id * jpp.ih * jpp.iw * jpp.c_block;
- parallel_nd_ext(0, jpp.mb, jpp.nb_c,
+ parallel_nd_ext(nthr, jpp.mb, jpp.nb_c,
[&](int ithr, int nthr, int n, int b_c) {
const size_t offset
= ((size_t)n * jpp.nb_c + b_c) * chunk_size;
@@ -1155,8 +1163,8 @@ void jit_uni_pooling_bwd_t<isa, d_type>::execute_backward_3d(
const auto nb2_c = utils::div_up(jpp.nb_c, jpp.ur_bc);
if (trans_src || trans_dst) {
- parallel_nd_ext(
- 0, jpp.mb, nb2_c, [&](int ithr, int nthr, int n, int b2_c) {
+ parallel_nd_ext(nthr, jpp.mb, nb2_c,
+ [&](int ithr, int nthr, int n, int b2_c) {
const auto b_c = b2_c * jpp.ur_bc;
if (trans_dst) {
diff --git a/src/cpu/x64/jit_uni_pooling.hpp b/src/cpu/x64/jit_uni_pooling.hpp
index ec4b04a2b..e25d9ce05 100644
--- a/src/cpu/x64/jit_uni_pooling.hpp
+++ b/src/cpu/x64/jit_uni_pooling.hpp
@@ -66,8 +66,9 @@ struct jit_uni_pooling_fwd_t : public primitive_t {
init_default_ws();
auto scratchpad = scratchpad_registry().registrar();
- return jit_uni_pool_kernel<isa>::init_conf(
- jpp_, scratchpad, this, dnnl_get_max_threads());
+ CHECK(jit_uni_pool_kernel<isa>::init_conf(jpp_, scratchpad, this));
+
+ return status::success;
}
jit_pool_conf_t jpp_;
@@ -130,9 +131,11 @@ struct jit_uni_pooling_bwd_t : public primitive_t {
init_default_ws();
if (!compare_ws(hint_fwd_pd_)) return status::unimplemented;
}
+
auto scratchpad = scratchpad_registry().registrar();
- return jit_uni_pool_kernel<isa>::init_conf(
- jpp_, scratchpad, this, dnnl_get_max_threads());
+ CHECK(jit_uni_pool_kernel<isa>::init_conf(jpp_, scratchpad, this));
+
+ return status::success;
}
jit_pool_conf_t jpp_;