forked from mindspore-Ecosystem/mindspore
!47606 fix wrong result of Multinomial op
Merge pull request !47606 from xulei/multinomial
This commit is contained in:
commit
df345e4cd9
|
@ -55,19 +55,16 @@ uint32_t Generate(Tensor *&input_0, Tensor *&input_1, Tensor *&output, CpuKernel
|
|||
// setup seed
|
||||
int64_t final_seed = 0;
|
||||
auto attr_seed = ctx.GetAttr("seed");
|
||||
if (attr_seed != nullptr) {
|
||||
auto attr_seed2 = ctx.GetAttr("seed2");
|
||||
if (attr_seed2 != nullptr) {
|
||||
final_seed = attr_seed2->GetInt();
|
||||
} else if (attr_seed != nullptr) {
|
||||
final_seed = attr_seed->GetInt();
|
||||
} else {
|
||||
std::random_device r;
|
||||
final_seed = r();
|
||||
}
|
||||
if (final_seed == 0) {
|
||||
auto attr_seed2 = ctx.GetAttr("seed2");
|
||||
if (attr_seed2 != nullptr) {
|
||||
final_seed = attr_seed2->GetInt();
|
||||
}
|
||||
}
|
||||
|
||||
// setup random engine
|
||||
std::random_device r;
|
||||
final_seed = final_seed ? final_seed : r();
|
||||
RNG_Engine rng;
|
||||
rng.seed(final_seed);
|
||||
|
||||
|
@ -76,8 +73,6 @@ uint32_t Generate(Tensor *&input_0, Tensor *&input_1, Tensor *&output, CpuKernel
|
|||
auto total_num = output->NumElements();
|
||||
|
||||
if (total_num < kParallelDataNums) {
|
||||
double max_logit = std::numeric_limits<double>::lowest();
|
||||
|
||||
auto cur_out = output_data;
|
||||
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
|
@ -86,20 +81,10 @@ uint32_t Generate(Tensor *&input_0, Tensor *&input_1, Tensor *&output, CpuKernel
|
|||
double running_total = 0;
|
||||
auto row_start = input_0_data + i * num_classes;
|
||||
|
||||
for (int64_t j = 0; j < num_classes; ++j) {
|
||||
if (isfinite(*(row_start + j))) {
|
||||
max_logit = std::max(max_logit, static_cast<double>(*(row_start + j)));
|
||||
}
|
||||
}
|
||||
|
||||
// calculate cdf (Cumulative Distribution Function)
|
||||
for (int64_t j = 0; j < num_classes; ++j) {
|
||||
cumulative_distribution_function[j] = std::exp(static_cast<double>(*(row_start + j)) - max_logit);
|
||||
}
|
||||
|
||||
for (int64_t j = 0; j < num_classes; ++j) {
|
||||
if (isfinite(*(row_start + j))) {
|
||||
running_total += cumulative_distribution_function[j];
|
||||
running_total += row_start[j];
|
||||
}
|
||||
cumulative_distribution_function[j] = running_total;
|
||||
}
|
||||
|
@ -110,8 +95,9 @@ uint32_t Generate(Tensor *&input_0, Tensor *&input_1, Tensor *&output, CpuKernel
|
|||
auto found_iter = std::upper_bound(cumulative_distribution_function,
|
||||
cumulative_distribution_function + num_classes, rand * running_total);
|
||||
|
||||
*cur_out = static_cast<T_out>(std::distance(cumulative_distribution_function, found_iter));
|
||||
cur_out = cur_out + 1;
|
||||
*cur_out =
|
||||
static_cast<T_out>(std::min(num_classes - 1, std::distance(cumulative_distribution_function, found_iter)));
|
||||
++cur_out;
|
||||
}
|
||||
if (cumulative_distribution_function != nullptr) {
|
||||
delete[] cumulative_distribution_function;
|
||||
|
@ -121,27 +107,16 @@ uint32_t Generate(Tensor *&input_0, Tensor *&input_1, Tensor *&output, CpuKernel
|
|||
double *rand_list = new double[total_num];
|
||||
auto shard = [&](size_t start_outer, size_t end_outer) {
|
||||
double *cumulative_distribution_function = new double[num_classes];
|
||||
double max_logit = std::numeric_limits<double>::lowest();
|
||||
RNG_Engine rng_outer = rng;
|
||||
rng_outer.discard(start_outer * num_samples);
|
||||
double running_total = 0;
|
||||
|
||||
auto row_start = input_0_data + start_outer * num_classes;
|
||||
// Takes an along-class maximum (for numerical stability).
|
||||
for (int64_t j = 0; j < num_classes; ++j) {
|
||||
if (isfinite(*(row_start + j))) {
|
||||
max_logit = std::max(max_logit, static_cast<double>(*(row_start + j)));
|
||||
}
|
||||
}
|
||||
|
||||
// calculate cdf (Cumulative Distribution Function)
|
||||
for (int64_t j = 0; j < num_classes; ++j) {
|
||||
cumulative_distribution_function[j] = std::exp(static_cast<double>(*(row_start + j)) - max_logit);
|
||||
}
|
||||
|
||||
for (int64_t j = 0; j < num_classes; ++j) {
|
||||
if (isfinite(*(row_start + j))) {
|
||||
running_total += cumulative_distribution_function[j];
|
||||
running_total += row_start[j];
|
||||
}
|
||||
cumulative_distribution_function[j] = running_total;
|
||||
}
|
||||
|
@ -161,7 +136,8 @@ uint32_t Generate(Tensor *&input_0, Tensor *&input_1, Tensor *&output, CpuKernel
|
|||
auto found_iter = std::upper_bound(cumulative_distribution_function,
|
||||
cumulative_distribution_function + num_classes, rand * running_total);
|
||||
|
||||
*(cur_output + j) = static_cast<T_out>(std::distance(cumulative_distribution_function, found_iter));
|
||||
*(cur_output + j) =
|
||||
static_cast<T_out>(std::min(num_classes - 1, std::distance(cumulative_distribution_function, found_iter)));
|
||||
}
|
||||
};
|
||||
CpuKernelUtils::ParallelFor(ctx, ceil((double)num_samples / kNumPerThread), 1, shard_inner);
|
||||
|
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright 2023 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend')
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, sample):
|
||||
super(Net, self).__init__()
|
||||
self.sample = sample
|
||||
self.multinomial = P.Multinomial()
|
||||
|
||||
def construct(self, x):
|
||||
return self.multinomial(x, self.sample)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_multinomial_net():
|
||||
"""
|
||||
Feature: test Multinomial op.
|
||||
Description: test Multinomial op.
|
||||
Expectation: success.
|
||||
"""
|
||||
x0 = Tensor(np.array([[2, 0], [0, 9]]).astype(np.float32))
|
||||
x1 = Tensor(np.array([[0, 0.1, 0], [0, 0, 1000]]).astype(np.float32))
|
||||
net0 = Net(1)
|
||||
net1 = Net(6)
|
||||
out0 = net0(x0)
|
||||
out1 = net1(x1)
|
||||
expect_result0 = np.array([[0], [1]]).astype(np.int32)
|
||||
expect_result1 = np.array([[1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2, 2]]).astype(np.int32)
|
||||
np.array_equal(expect_result0, out0.asnumpy())
|
||||
np.array_equal(expect_result1, out1.asnumpy())
|
Loading…
Reference in New Issue