add multinomial cpu kernel

This commit is contained in:
cristoval 2021-08-14 09:35:56 +08:00
parent 962c3f4ae3
commit 0a3e584eec
3 changed files with 223 additions and 0 deletions

View File

@ -0,0 +1,122 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/multinomial_gpu_kernel.h"
#include <algorithm>
#include <random>
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
void MultinomialGpuKernel::InitKernel(const CNodePtr &kernel_node) {
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
// The dimensions of input tensor must be 1 or 2, with data type of float32.
if (input_shape_.size() == 1) {
workspace_size_list_.push_back(input_shape_[0] * sizeof(float));
} else if (input_shape_.size() == 2) {
workspace_size_list_.push_back(input_shape_[1] * sizeof(float));
}
seed_ = static_cast<int>(GetValue<int64_t>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed")));
seed2_ = static_cast<int>(GetValue<int64_t>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2")));
}
bool MultinomialGpuKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &workspace,
const std::vector<kernel::AddressPtr> &outputs) {
if (inputs.size() != 2) {
MS_LOG(EXCEPTION) << "Invalid input numbers, expect input number 2, but actual input number " << inputs.size();
}
if (workspace.size() != 1) {
MS_LOG(EXCEPTION) << "Invalid workspace numbers, expect workspace number 1, actual workspace number "
<< workspace.size();
}
if (outputs.size() != 1) {
MS_LOG(EXCEPTION) << "Invalid output numbers, expect output number 1, actual output number " << outputs.size();
}
MS_EXCEPTION_IF_NULL(inputs[0]);
MS_EXCEPTION_IF_NULL(inputs[1]);
MS_EXCEPTION_IF_NULL(workspace[0]);
MS_EXCEPTION_IF_NULL(outputs[0]);
float *input_tensor = reinterpret_cast<float *>(inputs[0]->addr);
int num_sample = reinterpret_cast<int *>(inputs[1]->addr)[0];
int *output = reinterpret_cast<int *>(outputs[0]->addr);
float *cumulative_value = reinterpret_cast<float *>(workspace[0]->addr);
MS_EXCEPTION_IF_NULL(input_tensor);
MS_EXCEPTION_IF_NULL(output);
MS_EXCEPTION_IF_NULL(cumulative_value);
int num_row = 1;
if (input_shape_.size() == 2) {
num_row = input_shape_[0];
}
int num_col = input_shape_[input_shape_.size() - 1];
for (int i = 0; i < num_row; ++i) {
// Compute the cumulative array.
cumulative_value[i * num_col] = input_tensor[i * num_col];
for (int j = 1; j < num_col; ++j) {
size_t index = i * num_col + j;
cumulative_value[index] = cumulative_value[index - 1] + input_tensor[index];
}
// Normalize the cumulative array.
float sum = cumulative_value[num_col - 1];
if (sum != 0) {
for (int k = 1; k < num_col; ++k) {
cumulative_value[k] /= sum;
}
}
// Initialize random generator.
std::uniform_real_distribution<float> dist(0.0, 1.0);
int RNG_seed = 0;
if (seed2_ > 0) {
RNG_seed = seed2_;
} else if (seed_ > 0) {
RNG_seed = seed_;
} else {
std::random_device rd;
RNG_seed = static_cast<int>(rd());
}
std::default_random_engine rng{RNG_seed};
// Sample data from cumulative array.
for (int n = 0; n < num_sample; ++n) {
auto rand_prob = dist(rng);
int begin = 0;
int end = num_col;
while (end - begin > 0) {
int pivot = begin + (end - begin) / 2;
rand_prob = cumulative_value[i * num_col + pivot];
if (pivot > rand_prob) {
end = pivot;
} else {
begin = pivot + 1;
}
}
output[i * num_col + n] = begin;
}
}
return true;
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,50 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_
#include <memory>
#include <unordered_map>
#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
#include "nnacl/base/tile_base.h"
namespace mindspore {
namespace kernel {
class MultinomialGpuKernel : public CPUKernel {
public:
MultinomialGpuKernel() = default;
~MultinomialGpuKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
std::vector<size_t> input_shape_;
int seed_{0};
int seed2_{0};
};
MS_REG_CPU_KERNEL(
Multinomial,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
MultinomialGpuKernel)
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_

View File

@ -0,0 +1,51 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import composite as C
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
class Net(nn.Cell):
def __init__(self, sample, replacement, seed=0):
super(Net, self).__init__()
self.sample = sample
self.replacement = replacement
self.seed = seed
def construct(self, x):
return C.multinomial(x, self.sample, self.replacement, self.seed)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_multinomial_net():
x0 = Tensor(np.array([0.9, 0.2]).astype(np.float32))
x1 = Tensor(np.array([[0.9, 0.2], [0.9, 0.2]]).astype(np.float32))
net0 = Net(1, True, 20)
net1 = Net(2, True, 20)
net2 = Net(6, True, 20)
out0 = net0(x0)
out1 = net1(x0)
out2 = net2(x1)
assert out0.asnumpy().shape == (1,)
assert out1.asnumpy().shape == (2,)
assert out2.asnumpy().shape == (2, 6)