fix dnnl binary broadcast
This commit is contained in:
parent
ff2c44c935
commit
102b74682b
|
@ -181,7 +181,7 @@ void ArithmeticCPUKernel::LaunchLess(const std::vector<AddressPtr> &inputs, cons
|
|||
T *input2 = reinterpret_cast<T *>(inputs[1]->addr);
|
||||
bool *output = reinterpret_cast<bool *>(outputs[0]->addr);
|
||||
|
||||
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(T)) : 1;
|
||||
size_t lens = outputs[0]->size > 0 ? static_cast<size_t>(outputs[0]->size / sizeof(bool)) : 1;
|
||||
auto max_thread_num = std::thread::hardware_concurrency();
|
||||
size_t thread_num = lens < 128 * max_thread_num ? std::ceil(lens / 128.0) : max_thread_num;
|
||||
MS_LOG(INFO) << "Lens=" << lens << "; use thread_num=" << thread_num << "; max_thread_num: " << max_thread_num;
|
||||
|
|
|
@ -67,6 +67,52 @@ void MKLCPUKernel::GetPadding(const CNodePtr &kernel_node, const std::string &pa
|
|||
}
|
||||
}
|
||||
|
||||
bool MKLCPUKernel::BinaryBroadCast(std::vector<size_t> *src0_shape, std::vector<size_t> *src1_shape,
|
||||
std::vector<size_t> *dst_shape) {
|
||||
MS_EXCEPTION_IF_NULL(src0_shape);
|
||||
MS_EXCEPTION_IF_NULL(src1_shape);
|
||||
MS_EXCEPTION_IF_NULL(dst_shape);
|
||||
bool need_swap = false;
|
||||
if (dst_shape->size() == 0) {
|
||||
dst_shape->emplace_back(1);
|
||||
src0_shape->emplace_back(1);
|
||||
src1_shape->emplace_back(1);
|
||||
}
|
||||
MS_LOG(DEBUG) << "Binary broadcast in: src0: " << *src0_shape << " src1: " << *src1_shape << " dst: " << *dst_shape;
|
||||
if (src0_shape->size() != dst_shape->size()) {
|
||||
need_swap = true;
|
||||
for (size_t i = src0_shape->size(); i < dst_shape->size(); ++i) {
|
||||
src0_shape->insert(src0_shape->begin(), 1);
|
||||
}
|
||||
} else if (src1_shape->size() != dst_shape->size()) {
|
||||
for (size_t i = src1_shape->size(); i < dst_shape->size(); ++i) {
|
||||
src1_shape->insert(src1_shape->begin(), 1);
|
||||
}
|
||||
}
|
||||
if (src0_shape->size() == src1_shape->size()) {
|
||||
bool visit_src0 = false;
|
||||
bool visit_src1 = false;
|
||||
for (size_t i = 0; i < src0_shape->size(); ++i) {
|
||||
if (src0_shape->at(i) != src1_shape->at(i)) {
|
||||
if (src0_shape->at(i) == 1 && !visit_src1) {
|
||||
need_swap = true;
|
||||
visit_src0 = true;
|
||||
} else if (src1_shape->at(i) == 1 && !visit_src0) {
|
||||
need_swap = false;
|
||||
visit_src1 = true;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid broadcast! " << *src0_shape << " vs " << *src1_shape;
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid broadcast! src0: " << *src0_shape << " src1: " << *src1_shape
|
||||
<< " dst: " << *dst_shape;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Binary broadcast out: src0: " << *src0_shape << " src1: " << *src1_shape << " dst: " << *dst_shape;
|
||||
return need_swap;
|
||||
}
|
||||
|
||||
dnnl::memory::format_tag MKLCPUKernel::GetDefaultFormatTag(const dnnl::memory::dims &dims) const {
|
||||
dnnl::memory::format_tag mem_tag;
|
||||
auto dim_size = dims.size();
|
||||
|
|
|
@ -32,6 +32,8 @@ class MKLCPUKernel : public CPUKernel {
|
|||
~MKLCPUKernel() override = default;
|
||||
|
||||
protected:
|
||||
bool BinaryBroadCast(std::vector<size_t> *src0_shape, std::vector<size_t> *src1_shape,
|
||||
std::vector<size_t> *dst_shape);
|
||||
void GetPadding(const CNodePtr &kernel_node, const std::string &pad_mode, const std::vector<size_t> &src_shape,
|
||||
const std::vector<size_t> &kernel_size, int stride, std::vector<int> *padding_l,
|
||||
std::vector<int> *padding_r);
|
||||
|
|
|
@ -25,49 +25,7 @@ void MulCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
std::vector<size_t> src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
std::vector<size_t> src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
std::vector<size_t> dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
if (dst_shape.size() == 0) {
|
||||
dst_shape.emplace_back(1);
|
||||
src0_shape.emplace_back(1);
|
||||
src1_shape.emplace_back(1);
|
||||
}
|
||||
size_t src0_length = 1;
|
||||
size_t src1_length = 1;
|
||||
for (size_t i = 0; i < src0_shape.size(); ++i) {
|
||||
src0_length = src0_length * src0_shape[i];
|
||||
}
|
||||
for (size_t i = 0; i < src1_shape.size(); ++i) {
|
||||
src1_length = src1_length * src1_shape[i];
|
||||
}
|
||||
if (src1_shape.size() != src0_shape.size()) {
|
||||
if (src0_length == 1 && src0_shape.size() != dst_shape.size()) {
|
||||
need_swap_ = true;
|
||||
for (size_t i = src0_shape.size(); i < src1_shape.size(); ++i) {
|
||||
src0_shape.emplace_back(1);
|
||||
}
|
||||
} else if (src1_length == 1 && src1_shape.size() != dst_shape.size()) {
|
||||
for (size_t i = src1_shape.size(); i < src0_shape.size(); ++i) {
|
||||
src1_shape.emplace_back(1);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid broadcast! " << src0_shape << " vs " << src1_shape;
|
||||
}
|
||||
} else {
|
||||
bool visit_src0 = false;
|
||||
bool visit_src1 = false;
|
||||
for (size_t i = 0; i < src0_shape.size(); ++i) {
|
||||
if (src0_shape[i] != src1_shape[i]) {
|
||||
if (src0_shape[i] == 1 && !visit_src1) {
|
||||
need_swap_ = true;
|
||||
visit_src0 = true;
|
||||
} else if (src1_shape[i] == 1 && !visit_src0) {
|
||||
need_swap_ = false;
|
||||
visit_src1 = true;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid broadcast! " << src0_shape << " vs " << src1_shape;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
need_swap_ = BinaryBroadCast(&src0_shape, &src1_shape, &dst_shape);
|
||||
dnnl::memory::desc src0_desc;
|
||||
dnnl::memory::desc src1_desc;
|
||||
if (need_swap_) {
|
||||
|
|
|
@ -25,49 +25,7 @@ void TensorAddCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
std::vector<size_t> src0_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
std::vector<size_t> src1_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 1);
|
||||
std::vector<size_t> dst_shape = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
if (dst_shape.size() == 0) {
|
||||
dst_shape.emplace_back(1);
|
||||
src0_shape.emplace_back(1);
|
||||
src1_shape.emplace_back(1);
|
||||
}
|
||||
size_t src0_length = 1;
|
||||
size_t src1_length = 1;
|
||||
for (size_t i = 0; i < src0_shape.size(); ++i) {
|
||||
src0_length = src0_length * src0_shape[i];
|
||||
}
|
||||
for (size_t i = 0; i < src1_shape.size(); ++i) {
|
||||
src1_length = src1_length * src1_shape[i];
|
||||
}
|
||||
if (src1_shape.size() != src0_shape.size()) {
|
||||
if (src0_length == 1 && src0_shape.size() != dst_shape.size()) {
|
||||
need_swap_ = true;
|
||||
for (size_t i = src0_shape.size(); i < src1_shape.size(); ++i) {
|
||||
src0_shape.emplace_back(1);
|
||||
}
|
||||
} else if (src1_length == 1 && src1_shape.size() != dst_shape.size()) {
|
||||
for (size_t i = src1_shape.size(); i < src0_shape.size(); ++i) {
|
||||
src1_shape.emplace_back(1);
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid broadcast! " << src0_shape << " vs " << src1_shape;
|
||||
}
|
||||
} else {
|
||||
bool visit_src0 = false;
|
||||
bool visit_src1 = false;
|
||||
for (size_t i = 0; i < src0_shape.size(); ++i) {
|
||||
if (src0_shape[i] != src1_shape[i]) {
|
||||
if (src0_shape[i] == 1 && !visit_src1) {
|
||||
need_swap_ = true;
|
||||
visit_src0 = true;
|
||||
} else if (src1_shape[i] == 1 && !visit_src0) {
|
||||
need_swap_ = false;
|
||||
visit_src1 = true;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid broadcast! " << src0_shape << " vs " << src1_shape;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
need_swap_ = BinaryBroadCast(&src0_shape, &src1_shape, &dst_shape);
|
||||
dnnl::memory::desc src0_desc;
|
||||
dnnl::memory::desc src1_desc;
|
||||
if (need_swap_) {
|
||||
|
|
|
@ -47,6 +47,8 @@ def test_mul():
|
|||
y2 = Tensor(2, mstype.float32)
|
||||
x3 = Tensor(2, mstype.float32)
|
||||
y3 = Tensor(2, mstype.float32)
|
||||
x4 = Tensor(np.random.uniform(-2, 2, (4)).astype(np.float32))
|
||||
y4 = Tensor(np.random.uniform(-2, 2, (4, 4)).astype(np.float32))
|
||||
mul = Net()
|
||||
out = mul(x0, y0).asnumpy()
|
||||
exp = x0.asnumpy() * y0.asnumpy()
|
||||
|
@ -75,3 +77,10 @@ def test_mul():
|
|||
err = np.ones(shape=exp.shape) * 1.0e-5
|
||||
assert np.all(diff < err)
|
||||
assert out.shape == exp.shape
|
||||
|
||||
out = mul(x4, y4).asnumpy()
|
||||
exp = x4.asnumpy() * y4.asnumpy()
|
||||
diff = np.abs(out - exp)
|
||||
err = np.ones(shape=exp.shape) * 1.0e-5
|
||||
assert np.all(diff < err)
|
||||
assert out.shape == exp.shape
|
||||
|
|
|
@ -45,6 +45,8 @@ def test_tensor_add():
|
|||
y2 = Tensor(2, mstype.float32)
|
||||
x3 = Tensor(2, mstype.float32)
|
||||
y3 = Tensor(2, mstype.float32)
|
||||
x4 = Tensor(np.random.uniform(-2, 2, (4)).astype(np.float32))
|
||||
y4 = Tensor(np.random.uniform(-2, 2, (4, 4)).astype(np.float32))
|
||||
add = TensorAdd()
|
||||
out = add(x0, y0).asnumpy()
|
||||
exp = x0.asnumpy() + y0.asnumpy()
|
||||
|
@ -73,3 +75,10 @@ def test_tensor_add():
|
|||
err = np.ones(shape=exp.shape) * 1.0e-5
|
||||
assert np.all(diff < err)
|
||||
assert out.shape == exp.shape
|
||||
|
||||
out = add(x4, y4).asnumpy()
|
||||
exp = x4.asnumpy() + y4.asnumpy()
|
||||
diff = np.abs(out - exp)
|
||||
err = np.ones(shape=exp.shape) * 1.0e-5
|
||||
assert np.all(diff < err)
|
||||
assert out.shape == exp.shape
|
||||
|
|
Loading…
Reference in New Issue