!5454 fix Categorical log_prob

Merge pull request !5454 from baihuawei/categorical
This commit is contained in:
mindspore-ci-bot 2020-08-29 15:36:09 +08:00 committed by Gitee
commit b4caf21f63
6 changed files with 191 additions and 13 deletions

View File

@ -0,0 +1,26 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/arrays/range_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(Range, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
RangeGPUKernel, float)
MS_REG_GPU_KERNEL_ONE(Range, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
RangeGPUKernel, int)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,89 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANGE_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANGE_GPU_KERNEL_H_
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/range_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class RangeGPUKernel : public GpuKernel {
public:
RangeGPUKernel() : input_size_(0), output_size_(0), start_(0.), limit_(1.), delta_(1.) {}
~RangeGPUKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input = GetDeviceAddress<T>(inputs, 0);
T *output = GetDeviceAddress<T>(outputs, 0);
int size = SizeToInt(input_size_ / sizeof(T));
CalRange(size, start_, limit_, delta_, input, output, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but Range needs 1 input.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but Range needs 1 output.";
return false;
}
auto input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
auto shape_size = input_shape.size();
input_size_ = 1;
for (size_t i = 0; i < shape_size; i++) {
input_size_ *= input_shape[i];
}
input_size_ *= sizeof(T);
output_size_ = input_size_;
start_ = GetAttr<float>(kernel_node, "start");
limit_ = GetAttr<float>(kernel_node, "limit");
delta_ = GetAttr<float>(kernel_node, "delta");
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
output_size_list_.push_back(output_size_);
return;
}
private:
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
size_t input_size_;
size_t output_size_;
float start_;
float limit_;
float delta_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_RANGE_GPU_KERNEL_H_

View File

@ -0,0 +1,39 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_runtime.h>
#include "range_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
__global__ void Range(const int size, const float start, const float limit, const float delta, const T *input,
T *output) {
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
output[pos] = input[pos] * delta + start;
}
}
template <typename T>
void CalRange(const int size, const float start, const float limit, const float delta, const T *input, T *output,
cudaStream_t cuda_stream) {
Range<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, start, limit, delta, input, output);
return;
}
template void CalRange<float>(const int size, const float start, const float limit, const float delta,
const float *input, float *output, cudaStream_t cuda_stream);
template void CalRange<int>(const int size, const float start, const float limit, const float delta, const int *input,
int *output, cudaStream_t cuda_stream);

View File

@ -0,0 +1,23 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANGE_IMPL_CUH_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANGE_IMPL_CUH_
template <typename T>
void CalRange(const int size, const float start, const float limit, const float delta, const T *input, T *output,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANGE_IMPL_CUH

View File

@ -13,8 +13,8 @@
# limitations under the License.
# ============================================================================
"""Categorical Distribution"""
import numpy as np
from mindspore.ops import operations as P
import mindspore.nn as nn
from mindspore.common import dtype as mstype
from .distribution import Distribution
from ._utils.utils import logits_to_probs, probs_to_logits, check_type, check_tensor_type, cast_to_tensor, raise_probs_logits_error
@ -119,17 +119,19 @@ class Categorical(Distribution):
"""
return self._probs
def _sample(self, sample_shape=(1,)):
def _sample(self, sample_shape=()):
"""
Sampling.
Args:
sample_shape (tuple): shape of the sample. Default: (1,).
sample_shape (tuple): shape of the sample. Default: ().
Returns:
Tensor, shape is shape(probs)[:-1] + sample_shape
"""
self.checktuple(sample_shape, 'shape')
if sample_shape == ():
sample_shape = (1,)
num_sample = 1
for i in sample_shape:
num_sample *= i
@ -184,16 +186,15 @@ class Categorical(Distribution):
if value is not None:
check_tensor_type("value", value, [mstype.float32, bool, mstype.int32])
value = self.expandim(self.cast(value, mstype.float32), -1)
index = cast_to_tensor(np.arange(self.shape(value)[0]).astype(np.float32))
index = self.expandim(index, -1)
logits = self._logits if self._logits.dim() == 1 else self.expandim(self._logits, 0)
broad_shape = self._broad_cast_shape(value, logits)
broad_shape = self._broad_cast_shape(value, self._logits)
broad = P.BroadcastTo(broad_shape)
value = broad(value)[..., :1]
index = broad(index)[..., :1]
logits_pmf = self.reshape(broad(self._logits), (-1, broad_shape[-1]))
value = self.reshape(broad(value)[..., :1], (-1, 1))
index = nn.Range(0., self.shape(value)[0], 1)()
index = self.reshape(index, (-1, 1))
value = self.concat((index, value))
value = self.cast(value, mstype.int32)
return self.gather(logits, value)
return self.reshape(self.gather(logits_pmf, value), broad_shape[:-1])
return None
def _entropy(self):
@ -211,7 +212,7 @@ class Categorical(Distribution):
Enumerate categories.
"""
num_events = self._num_events
values = cast_to_tensor(np.arange(num_events).astype(np.int32), mstype.float32)
values = nn.Range(0., num_events, 1)()
values = self.reshape(values, (num_events, 1))
if expand:
values = P.BroadcastTo((num_events, self._batch_shape))(values)

View File

@ -450,8 +450,8 @@ class Multinomial(PrimitiveWithInfer):
Examples:
>>> input = Tensor([0., 9., 4., 0.], mstype.float32)
>>> multinomial = P.Multinomial(seed=10)
>>> output = multinomial(input, 2, True)
>>> multinomial = P.Multinomial(replacement=True, seed=10)
>>> output = multinomial(input, 2)
"""
@prim_attr_register