forked from mindspore-Ecosystem/mindspore
!47535 fix cpu dropout seed
Merge pull request !47535 from 王禹程/fix_drop_seed
This commit is contained in:
commit
2543b35970
|
@ -3,7 +3,7 @@ mindspore.ops.dropout
|
|||
|
||||
.. py:function:: mindspore.ops.dropout(x, p=0.5, seed0=0, seed1=0)
|
||||
|
||||
在训练期间,以服从伯努利分布的概率 `p` 随机将输入Tensor的某些值归零,起到减少神经元相关性的作用,避免过拟合。此概率与 `ops.dropout` 和 `nn.dropout` 中的含义相反。
|
||||
在训练期间,以服从伯努利分布的概率 `p` 随机将输入Tensor的某些值归零,起到减少神经元相关性的作用,避免过拟合。此概率与 `ops.Dropout` 和 `nn.Dropout` 中的含义相反。
|
||||
|
||||
参数:
|
||||
- **x** (Tensor) - dropout的输入,任意维度的Tensor,其数据类型为float16或float32。
|
||||
|
|
|
@ -51,6 +51,14 @@ bool DropoutCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::
|
|||
if (keep_prob_ <= 0.0 || keep_prob_ > 1.0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << ", the 'keep_prob' must be in (0.0, 1.0], but got " << keep_prob_;
|
||||
}
|
||||
auto seed = GetValue<int64_t>(base_operator->GetAttr("Seed0"));
|
||||
if (seed == 0) {
|
||||
seed = GetValue<int64_t>(base_operator->GetAttr("Seed1"));
|
||||
if (seed == 0) {
|
||||
seed = time(nullptr);
|
||||
}
|
||||
}
|
||||
seed_ = static_cast<uint64_t>(seed);
|
||||
return MatchKernelFunc(base_operator, inputs, outputs);
|
||||
}
|
||||
|
||||
|
@ -78,8 +86,7 @@ bool DropoutCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &in
|
|||
auto *output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
auto mask_addr = reinterpret_cast<T *>(outputs[1]->addr);
|
||||
T scale = static_cast<T>(1.f / keep_prob_);
|
||||
std::random_device rd;
|
||||
std::default_random_engine generator(rd());
|
||||
std::default_random_engine generator(seed_ + seed_offset_);
|
||||
std::uniform_real_distribution<float> uniform(0.f, 1.f);
|
||||
auto task = [input_addr, output_addr, mask_addr, scale, &uniform, &generator, this](size_t start, size_t end) {
|
||||
for (size_t i = start; i < end; i++) {
|
||||
|
@ -88,6 +95,7 @@ bool DropoutCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &in
|
|||
}
|
||||
};
|
||||
ParallelLaunchAutoSearch(task, tensor_size_, this, ¶llel_search_info_);
|
||||
seed_offset_ += 1;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
|
|
@ -53,6 +53,8 @@ class DropoutCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper<
|
|||
|
||||
ShapeVector input_shape_;
|
||||
float keep_prob_{0.0};
|
||||
uint64_t seed_{0};
|
||||
uint64_t seed_offset_{0};
|
||||
size_t tensor_size_{1};
|
||||
};
|
||||
} // namespace kernel
|
||||
|
|
|
@ -1064,7 +1064,7 @@ def dropout(x, p=0.5, seed0=0, seed1=0):
|
|||
During training, randomly zeroes some of the elements of the input tensor
|
||||
with probability `p` from a Bernoulli distribution. It plays the role of
|
||||
reducing neuron correlation and avoid overfitting. The meaning of probability
|
||||
here is opposite to that in `ops.dropout` and `nn.dropout`.
|
||||
here is opposite to that in `ops.Dropout` and `nn.Dropout`.
|
||||
|
||||
Args:
|
||||
x (Tensor): The input of Dropout, a Tensor of any shape with data type of float16 or float32.
|
||||
|
|
|
@ -20,6 +20,7 @@ import mindspore.context as context
|
|||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore import ops
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
|
@ -87,7 +88,52 @@ def test_net2():
|
|||
print(mask)
|
||||
|
||||
|
||||
class Net3(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net3, self).__init__()
|
||||
self.dropout = P.Dropout(keep_prob=0.5)
|
||||
|
||||
def construct(self, x):
|
||||
return self.dropout(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_net3():
|
||||
"""
|
||||
Feature: test dropout mask diff by diff step.
|
||||
Description: dropout.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = np.arange(0, 12).reshape(3, 4).astype(np.float16)
|
||||
dropout = Net3()
|
||||
output1, mask1 = dropout(Tensor(x))
|
||||
output2, mask2 = dropout(Tensor(x))
|
||||
assert np.allclose(mask1.asnumpy(), mask2.asnumpy()) is False
|
||||
assert np.allclose(output1.asnumpy(), output2.asnumpy()) is False
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_op1():
|
||||
"""
|
||||
Feature: test dropout mask equal by equal seed.
|
||||
Description: dropout.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
x = Tensor(np.arange(0, 12).reshape(3, 4).astype(np.float16))
|
||||
output1, mask1 = ops.dropout(x, p=0.5, seed0=1, seed1=100)
|
||||
output2, mask2 = ops.dropout(x, p=0.5, seed0=1, seed1=100)
|
||||
|
||||
assert mask1.shape == mask2.shape
|
||||
assert np.allclose(output1.asnumpy(), output2.asnumpy())
|
||||
assert np.allclose(mask1.asnumpy(), mask2.asnumpy())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_net()
|
||||
test_net1()
|
||||
test_net2()
|
||||
test_op1()
|
||||
|
|
Loading…
Reference in New Issue