forked from mindspore-Ecosystem/mindspore
add multinomial cpu kernel
This commit is contained in:
parent
962c3f4ae3
commit
0a3e584eec
|
@ -0,0 +1,122 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include "backend/kernel_compiler/cpu/multinomial_gpu_kernel.h"
|
||||||
|
#include <algorithm>
|
||||||
|
#include <random>
|
||||||
|
#include "runtime/device/cpu/cpu_device_address.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
|
||||||
|
void MultinomialGpuKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
|
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||||
|
|
||||||
|
// The dimensions of input tensor must be 1 or 2, with data type of float32.
|
||||||
|
if (input_shape_.size() == 1) {
|
||||||
|
workspace_size_list_.push_back(input_shape_[0] * sizeof(float));
|
||||||
|
} else if (input_shape_.size() == 2) {
|
||||||
|
workspace_size_list_.push_back(input_shape_[1] * sizeof(float));
|
||||||
|
}
|
||||||
|
|
||||||
|
seed_ = static_cast<int>(GetValue<int64_t>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed")));
|
||||||
|
seed2_ = static_cast<int>(GetValue<int64_t>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2")));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool MultinomialGpuKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||||
|
const std::vector<kernel::AddressPtr> &workspace,
|
||||||
|
const std::vector<kernel::AddressPtr> &outputs) {
|
||||||
|
if (inputs.size() != 2) {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid input numbers, expect input number 2, but actual input number " << inputs.size();
|
||||||
|
}
|
||||||
|
if (workspace.size() != 1) {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid workspace numbers, expect workspace number 1, actual workspace number "
|
||||||
|
<< workspace.size();
|
||||||
|
}
|
||||||
|
if (outputs.size() != 1) {
|
||||||
|
MS_LOG(EXCEPTION) << "Invalid output numbers, expect output number 1, actual output number " << outputs.size();
|
||||||
|
}
|
||||||
|
MS_EXCEPTION_IF_NULL(inputs[0]);
|
||||||
|
MS_EXCEPTION_IF_NULL(inputs[1]);
|
||||||
|
MS_EXCEPTION_IF_NULL(workspace[0]);
|
||||||
|
MS_EXCEPTION_IF_NULL(outputs[0]);
|
||||||
|
|
||||||
|
float *input_tensor = reinterpret_cast<float *>(inputs[0]->addr);
|
||||||
|
int num_sample = reinterpret_cast<int *>(inputs[1]->addr)[0];
|
||||||
|
int *output = reinterpret_cast<int *>(outputs[0]->addr);
|
||||||
|
float *cumulative_value = reinterpret_cast<float *>(workspace[0]->addr);
|
||||||
|
|
||||||
|
MS_EXCEPTION_IF_NULL(input_tensor);
|
||||||
|
MS_EXCEPTION_IF_NULL(output);
|
||||||
|
MS_EXCEPTION_IF_NULL(cumulative_value);
|
||||||
|
|
||||||
|
int num_row = 1;
|
||||||
|
if (input_shape_.size() == 2) {
|
||||||
|
num_row = input_shape_[0];
|
||||||
|
}
|
||||||
|
int num_col = input_shape_[input_shape_.size() - 1];
|
||||||
|
|
||||||
|
for (int i = 0; i < num_row; ++i) {
|
||||||
|
// Compute the cumulative array.
|
||||||
|
cumulative_value[i * num_col] = input_tensor[i * num_col];
|
||||||
|
for (int j = 1; j < num_col; ++j) {
|
||||||
|
size_t index = i * num_col + j;
|
||||||
|
cumulative_value[index] = cumulative_value[index - 1] + input_tensor[index];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize the cumulative array.
|
||||||
|
float sum = cumulative_value[num_col - 1];
|
||||||
|
if (sum != 0) {
|
||||||
|
for (int k = 1; k < num_col; ++k) {
|
||||||
|
cumulative_value[k] /= sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize random generator.
|
||||||
|
std::uniform_real_distribution<float> dist(0.0, 1.0);
|
||||||
|
int RNG_seed = 0;
|
||||||
|
if (seed2_ > 0) {
|
||||||
|
RNG_seed = seed2_;
|
||||||
|
} else if (seed_ > 0) {
|
||||||
|
RNG_seed = seed_;
|
||||||
|
} else {
|
||||||
|
std::random_device rd;
|
||||||
|
RNG_seed = static_cast<int>(rd());
|
||||||
|
}
|
||||||
|
std::default_random_engine rng{RNG_seed};
|
||||||
|
|
||||||
|
// Sample data from cumulative array.
|
||||||
|
for (int n = 0; n < num_sample; ++n) {
|
||||||
|
auto rand_prob = dist(rng);
|
||||||
|
int begin = 0;
|
||||||
|
int end = num_col;
|
||||||
|
|
||||||
|
while (end - begin > 0) {
|
||||||
|
int pivot = begin + (end - begin) / 2;
|
||||||
|
rand_prob = cumulative_value[i * num_col + pivot];
|
||||||
|
if (pivot > rand_prob) {
|
||||||
|
end = pivot;
|
||||||
|
} else {
|
||||||
|
begin = pivot + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output[i * num_col + n] = begin;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
|
@ -0,0 +1,50 @@
|
||||||
|
/**
|
||||||
|
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_
|
||||||
|
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_
|
||||||
|
#include <memory>
|
||||||
|
#include <unordered_map>
|
||||||
|
#include <vector>
|
||||||
|
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||||
|
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||||
|
#include "nnacl/base/tile_base.h"
|
||||||
|
|
||||||
|
namespace mindspore {
|
||||||
|
namespace kernel {
|
||||||
|
class MultinomialGpuKernel : public CPUKernel {
|
||||||
|
public:
|
||||||
|
MultinomialGpuKernel() = default;
|
||||||
|
~MultinomialGpuKernel() override = default;
|
||||||
|
|
||||||
|
void InitKernel(const CNodePtr &kernel_node) override;
|
||||||
|
|
||||||
|
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||||
|
const std::vector<AddressPtr> &outputs) override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::vector<size_t> input_shape_;
|
||||||
|
int seed_{0};
|
||||||
|
int seed2_{0};
|
||||||
|
};
|
||||||
|
|
||||||
|
MS_REG_CPU_KERNEL(
|
||||||
|
Multinomial,
|
||||||
|
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||||
|
MultinomialGpuKernel)
|
||||||
|
} // namespace kernel
|
||||||
|
} // namespace mindspore
|
||||||
|
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_
|
|
@ -0,0 +1,51 @@
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import mindspore.context as context
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore import Tensor
|
||||||
|
from mindspore.ops import composite as C
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||||
|
|
||||||
|
|
||||||
|
class Net(nn.Cell):
|
||||||
|
def __init__(self, sample, replacement, seed=0):
|
||||||
|
super(Net, self).__init__()
|
||||||
|
self.sample = sample
|
||||||
|
self.replacement = replacement
|
||||||
|
self.seed = seed
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
return C.multinomial(x, self.sample, self.replacement, self.seed)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.level0
|
||||||
|
@pytest.mark.platform_x86_cpu
|
||||||
|
@pytest.mark.env_onecard
|
||||||
|
def test_multinomial_net():
|
||||||
|
x0 = Tensor(np.array([0.9, 0.2]).astype(np.float32))
|
||||||
|
x1 = Tensor(np.array([[0.9, 0.2], [0.9, 0.2]]).astype(np.float32))
|
||||||
|
net0 = Net(1, True, 20)
|
||||||
|
net1 = Net(2, True, 20)
|
||||||
|
net2 = Net(6, True, 20)
|
||||||
|
out0 = net0(x0)
|
||||||
|
out1 = net1(x0)
|
||||||
|
out2 = net2(x1)
|
||||||
|
assert out0.asnumpy().shape == (1,)
|
||||||
|
assert out1.asnumpy().shape == (2,)
|
||||||
|
assert out2.asnumpy().shape == (2, 6)
|
Loading…
Reference in New Issue