!48367 fix cpu dropoutnd

Merge pull request !48367 from 王禹程/drop_ratio
This commit is contained in:
i-robot 2023-02-03 01:37:10 +00:00 committed by Gitee
commit a6b7631174
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 38 additions and 3 deletions

View File

@ -135,9 +135,9 @@ bool DropoutNdCpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
auto per_output = output + start * inner_size;
auto per_mask = mask + start * inner_size;
for (size_t i = start; i < end; ++i) {
bool drop = static_cast<float>(distribution_(generator_)) <= keep_prob_;
if (drop) {
std::fill(per_mask, per_mask + inner_size, drop);
bool keep = distribution_(generator_);
if (keep) {
std::fill(per_mask, per_mask + inner_size, keep);
if constexpr (std::is_same<T, float>::value) {
DropoutFp32(per_input, scale_, SizeToInt(inner_size), per_output);
} else {

View File

@ -16,6 +16,7 @@
import numpy as np
import pytest
import mindspore
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
@ -132,6 +133,40 @@ def test_op1():
assert np.allclose(mask1.asnumpy(), mask2.asnumpy())
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_op2():
"""
Feature: test Dropout2D.
Description: dropout.
Expectation: No exception.
"""
input_np = np.ones((1000, 1000, 20, 5)).astype(np.float32)
input_x = Tensor(input_np, mindspore.float32)
data_size = 1000 * 1000 * 20 * 5
dropout = ops.Dropout2D(keep_prob=0.0)
output_ms, _ = dropout(input_x)
ans = np.sum(np.where(output_ms.asnumpy(), 0, 1))
assert ans == data_size
dropout = ops.Dropout2D(keep_prob=0.2)
output_ms, _ = dropout(input_x)
ans = np.sum(np.where(output_ms.asnumpy(), 0, 1))
assert data_size * 0.75 <= ans <= data_size * 0.85
dropout = ops.Dropout2D(keep_prob=0.8)
output_ms, _ = dropout(input_x)
ans = np.sum(np.where(output_ms.asnumpy(), 0, 1))
assert data_size * 0.15 <= ans <= data_size * 0.25
dropout = ops.Dropout2D(keep_prob=1.0)
output_ms, _ = dropout(input_x)
ans = np.sum(np.where(output_ms.asnumpy(), 0, 1))
assert ans == 0
if __name__ == '__main__':
test_net()
test_net1()