forked from mindspore-Ecosystem/mindspore
add multinomial cpu kernel
This commit is contained in:
parent
962c3f4ae3
commit
0a3e584eec
|
@ -0,0 +1,122 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/cpu/multinomial_gpu_kernel.h"
|
||||
#include <algorithm>
|
||||
#include <random>
|
||||
#include "runtime/device/cpu/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
|
||||
void MultinomialGpuKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
input_shape_ = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
|
||||
// The dimensions of input tensor must be 1 or 2, with data type of float32.
|
||||
if (input_shape_.size() == 1) {
|
||||
workspace_size_list_.push_back(input_shape_[0] * sizeof(float));
|
||||
} else if (input_shape_.size() == 2) {
|
||||
workspace_size_list_.push_back(input_shape_[1] * sizeof(float));
|
||||
}
|
||||
|
||||
seed_ = static_cast<int>(GetValue<int64_t>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed")));
|
||||
seed2_ = static_cast<int>(GetValue<int64_t>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2")));
|
||||
}
|
||||
|
||||
bool MultinomialGpuKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
if (inputs.size() != 2) {
|
||||
MS_LOG(EXCEPTION) << "Invalid input numbers, expect input number 2, but actual input number " << inputs.size();
|
||||
}
|
||||
if (workspace.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Invalid workspace numbers, expect workspace number 1, actual workspace number "
|
||||
<< workspace.size();
|
||||
}
|
||||
if (outputs.size() != 1) {
|
||||
MS_LOG(EXCEPTION) << "Invalid output numbers, expect output number 1, actual output number " << outputs.size();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(inputs[0]);
|
||||
MS_EXCEPTION_IF_NULL(inputs[1]);
|
||||
MS_EXCEPTION_IF_NULL(workspace[0]);
|
||||
MS_EXCEPTION_IF_NULL(outputs[0]);
|
||||
|
||||
float *input_tensor = reinterpret_cast<float *>(inputs[0]->addr);
|
||||
int num_sample = reinterpret_cast<int *>(inputs[1]->addr)[0];
|
||||
int *output = reinterpret_cast<int *>(outputs[0]->addr);
|
||||
float *cumulative_value = reinterpret_cast<float *>(workspace[0]->addr);
|
||||
|
||||
MS_EXCEPTION_IF_NULL(input_tensor);
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
MS_EXCEPTION_IF_NULL(cumulative_value);
|
||||
|
||||
int num_row = 1;
|
||||
if (input_shape_.size() == 2) {
|
||||
num_row = input_shape_[0];
|
||||
}
|
||||
int num_col = input_shape_[input_shape_.size() - 1];
|
||||
|
||||
for (int i = 0; i < num_row; ++i) {
|
||||
// Compute the cumulative array.
|
||||
cumulative_value[i * num_col] = input_tensor[i * num_col];
|
||||
for (int j = 1; j < num_col; ++j) {
|
||||
size_t index = i * num_col + j;
|
||||
cumulative_value[index] = cumulative_value[index - 1] + input_tensor[index];
|
||||
}
|
||||
|
||||
// Normalize the cumulative array.
|
||||
float sum = cumulative_value[num_col - 1];
|
||||
if (sum != 0) {
|
||||
for (int k = 1; k < num_col; ++k) {
|
||||
cumulative_value[k] /= sum;
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize random generator.
|
||||
std::uniform_real_distribution<float> dist(0.0, 1.0);
|
||||
int RNG_seed = 0;
|
||||
if (seed2_ > 0) {
|
||||
RNG_seed = seed2_;
|
||||
} else if (seed_ > 0) {
|
||||
RNG_seed = seed_;
|
||||
} else {
|
||||
std::random_device rd;
|
||||
RNG_seed = static_cast<int>(rd());
|
||||
}
|
||||
std::default_random_engine rng{RNG_seed};
|
||||
|
||||
// Sample data from cumulative array.
|
||||
for (int n = 0; n < num_sample; ++n) {
|
||||
auto rand_prob = dist(rng);
|
||||
int begin = 0;
|
||||
int end = num_col;
|
||||
|
||||
while (end - begin > 0) {
|
||||
int pivot = begin + (end - begin) / 2;
|
||||
rand_prob = cumulative_value[i * num_col + pivot];
|
||||
if (pivot > rand_prob) {
|
||||
end = pivot;
|
||||
} else {
|
||||
begin = pivot + 1;
|
||||
}
|
||||
}
|
||||
output[i * num_col + n] = begin;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_
|
||||
#include <memory>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
|
||||
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
|
||||
#include "nnacl/base/tile_base.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class MultinomialGpuKernel : public CPUKernel {
|
||||
public:
|
||||
MultinomialGpuKernel() = default;
|
||||
~MultinomialGpuKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
private:
|
||||
std::vector<size_t> input_shape_;
|
||||
int seed_{0};
|
||||
int seed2_{0};
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(
|
||||
Multinomial,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
MultinomialGpuKernel)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_TILE_CPU_KERNEL_H_
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, sample, replacement, seed=0):
|
||||
super(Net, self).__init__()
|
||||
self.sample = sample
|
||||
self.replacement = replacement
|
||||
self.seed = seed
|
||||
|
||||
def construct(self, x):
|
||||
return C.multinomial(x, self.sample, self.replacement, self.seed)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_multinomial_net():
|
||||
x0 = Tensor(np.array([0.9, 0.2]).astype(np.float32))
|
||||
x1 = Tensor(np.array([[0.9, 0.2], [0.9, 0.2]]).astype(np.float32))
|
||||
net0 = Net(1, True, 20)
|
||||
net1 = Net(2, True, 20)
|
||||
net2 = Net(6, True, 20)
|
||||
out0 = net0(x0)
|
||||
out1 = net1(x0)
|
||||
out2 = net2(x1)
|
||||
assert out0.asnumpy().shape == (1,)
|
||||
assert out1.asnumpy().shape == (2,)
|
||||
assert out2.asnumpy().shape == (2, 6)
|
Loading…
Reference in New Issue