From 92f1855a798f047451abf07f7e52c6df9aa62833 Mon Sep 17 00:00:00 2001 From: baihuawei Date: Sat, 5 Sep 2020 14:24:52 +0800 Subject: [PATCH] fix categorical in GraphMode --- .../cpu/mkldnn/lstm_grad_cpu_kernel.cc | 10 +-- .../cpu/mkldnn/lstm_grad_cpu_kernel.h | 2 +- .../gpu/cuda_impl/multinomial_impl.cu | 34 ++++---- .../gpu/cuda_impl/multinomial_impl.cuh | 3 + .../gpu/math/multinomial_gpu_kernel.h | 14 ++-- .../probability/distribution/_utils/utils.py | 22 +---- .../probability/distribution/categorical.py | 80 ++++++------------- 7 files changed, 64 insertions(+), 101 deletions(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.cc index 52642911078..7290dbfef34 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.cc @@ -148,7 +148,7 @@ void LSTMGradCPUKernel::SetArgumentHandleOp(const std::vectoraddr); } -void LSTMGradCPUKernel::Memset_op(const dnnl::memory &mem, string name) { +void LSTMGradCPUKernel::ResetMemory(const dnnl::memory &mem, string name) { if (memset_s(mem.get_data_handle(), mem.get_desc().get_size(), 0, mem.get_desc().get_size())) { MS_LOG(EXCEPTION) << name << " memset error"; } @@ -186,10 +186,10 @@ bool LSTMGradCPUKernel::Launch(const std::vector &inputs, auto user_diff_weights_h_memory = dnnl::memory(dnnl::memory::desc{{weights_h_dims_}, dt::f32, tag::ldgoi}, eng); user_diff_weights_memory.set_data_handle(outputs[3]->addr); user_diff_weights_h_memory.set_data_handle(reinterpret_cast(outputs[3]->addr) + weight_size_); - Memset_op(user_diff_weights_memory, "user weights grad"); - Memset_op(user_diff_weights_h_memory, "user weights iter grad"); - Memset_op(diff_weights_memory, "weights grad"); - Memset_op(diff_weights_h_memory, "weights iter grad"); + ResetMemory(user_diff_weights_memory, "user weights grad"); + ResetMemory(user_diff_weights_h_memory, "user weights iter grad"); + ResetMemory(diff_weights_memory, "weights grad"); + ResetMemory(diff_weights_h_memory, "weights iter grad"); if (has_bias_) { diff_bias_memory.set_data_handle(reinterpret_cast(outputs[3]->addr) + weight_size_ + weight_h_size_); } diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.h index 700bc67bea0..b2368fc5518 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/mkldnn/lstm_grad_cpu_kernel.h @@ -42,7 +42,7 @@ class LSTMGradCPUKernel : public MKLCPUKernel { const dnnl::memory &weights_h_memory, const dnnl::memory &bias_memory, const dnnl::memory &diff_weights_memory, const dnnl::memory &diff_weights_h_memory, const dnnl::memory &diff_bias_memory); - void Memset_op(const dnnl::memory &mem, string name); + void ResetMemory(const dnnl::memory &mem, string name); void CheckParam(const CNodePtr &kernel_node); int weight_size_ = 0; int weight_h_size_ = 0; diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cu index 57730ca0182..c1cbefe31a8 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cu @@ -16,18 +16,6 @@ #include "multinomial_impl.cuh" -template -__global__ void NormInput(T *input, const size_t distributions, const size_t categories) { - size_t size = distributions * categories; - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { - if ((pos + 1) % categories != 0) { - int de_pos = (1 + pos / categories) * categories - 1; - input[pos] /= input[de_pos]; - } - } - return; -} - template __global__ void CheckZeroKernel(const size_t distributions, const size_t categories, const T *input, T *out) { out[0] = 0; @@ -61,6 +49,24 @@ void CheckNonNeg(const size_t size, const T *input, T *output, cudaStream_t cuda CheckNonNegKernel<<>>(size, input, output); } +template +__global__ void NormInputKernel(T *input, const size_t distributions, const size_t categories) { + size_t size = distributions * categories; + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { + if ((pos + 1) % categories != 0) { + int de_pos = (1 + pos / categories) * categories - 1; + input[pos] /= input[de_pos]; + } + } + return; +} + +template +void NormInput(T *input, const size_t distributions, const size_t categories, cudaStream_t cuda_stream) { + int count1 = distributions * categories; + NormInputKernel<<>>(input, distributions, categories); +} + template __device__ int BinarySearchForMultinomial(T *start_addr, int size, T rand) { int start = 0; @@ -104,8 +110,6 @@ void Multinomial(int seed, T *input, int num_sample, curandState *globalState, i RNG_seed = time(NULL); } int count = distributions * num_sample; - int count1 = distributions * categories; - NormInput<<>>(input, distributions, categories); MultinomialKernel<<>>(RNG_seed, input, num_sample, globalState, output, distributions, categories); return; @@ -116,3 +120,5 @@ template void Multinomial(int seed, float *input, int num_sample, curandS template void CheckNonNeg(const size_t size, const float *input, float *output, cudaStream_t cuda_stream); template void CheckZero(const size_t distributions, const size_t categories, const float *input, float *output, cudaStream_t cuda_stream); +template void NormInput(float *input, const size_t distributions, const size_t categories, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cuh index 097f8ef9004..2c4153c1279 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/multinomial_impl.cuh @@ -26,4 +26,7 @@ template void CheckNonNeg(const size_t size, const T *input, T *output, cudaStream_t stream); template void CheckZero(const size_t distributions, const size_t categories, const T *input, T *output, cudaStream_t stream); +template +void NormInput(T *input, const size_t distributions, const size_t categories, cudaStream_t cuda_stream); + #endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MULTINOMIAL_IMPL_CUH_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.h index 436dc50d78e..1647be70421 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/multinomial_gpu_kernel.h @@ -47,22 +47,23 @@ class MultinomialGpuKernel : public GpuKernel { bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream_ptr) override { - void *workspace_addr = GetDeviceAddress(workspace, 0); + void *workspace_addr = GetDeviceAddress(workspace, 1); + T *cum_sum_input = GetDeviceAddress(workspace, 0); curandState *devStates = reinterpret_cast(workspace_addr); int *output_addr = GetDeviceAddress(outputs, 0); T *input_addr = GetDeviceAddress(inputs, 0); int categories = SizeToInt(inputs[0]->size / sizeof(T)) / distributions_; - int num_sample = SizeToInt(outputs[0]->size / sizeof(T)) / distributions_; + int num_sample = SizeToInt(outputs[0]->size / sizeof(int)) / distributions_; // check input - T *cum_sum_input = nullptr; - CHECK_CUDA_RET_WITH_EXCEPT(cudaMalloc(reinterpret_cast(&cum_sum_input), input_size_0_), - "cudaMalloc failed."); CheckPeram(input_addr, cum_sum_input, categories, stream_ptr); if (replacement_) { + NormInput(cum_sum_input, IntToSize(distributions_), IntToSize(categories), + reinterpret_cast(stream_ptr)); + CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamSynchronize(reinterpret_cast(stream_ptr)), + "cudaStreamSynchronize failed."); Multinomial(seed_, cum_sum_input, num_sample, devStates, output_addr, IntToSize(distributions_), IntToSize(categories), reinterpret_cast(stream_ptr)); } - CHECK_CUDA_RET_WITH_EXCEPT(cudaFree(cum_sum_input), "cudaFree failed."); return true; } @@ -145,6 +146,7 @@ class MultinomialGpuKernel : public GpuKernel { input_size_list_.push_back(input_size_0_); input_size_list_.push_back(sizeof(int)); output_size_list_.push_back(output_size_); + workspace_size_list_.push_back(input_size_0_); workspace_size_list_.push_back(workspace_size_); } diff --git a/mindspore/nn/probability/distribution/_utils/utils.py b/mindspore/nn/probability/distribution/_utils/utils.py index 729c42634d7..1139caa9a20 100644 --- a/mindspore/nn/probability/distribution/_utils/utils.py +++ b/mindspore/nn/probability/distribution/_utils/utils.py @@ -271,24 +271,6 @@ def probs_to_logits(probs, is_binary=False): return P.Log()(ps_clamped) -def check_tensor_type(name, inputs, valid_type): - """ - Check if inputs is proper. - - Args: - name: inputs name - inputs: Tensor to be checked. - - Raises: - ValueError: if inputs is not a proper Tensor. - """ - if not isinstance(inputs, Tensor): - raise TypeError(f"{name} should be a Tensor") - input_type = P.DType()(inputs) - if input_type not in valid_type: - raise TypeError(f"{name} dtype is invalid") - - def check_type(data_type, value_type, name): if not data_type in value_type: raise TypeError( @@ -304,6 +286,10 @@ def raise_none_error(name): def raise_probs_logits_error(): raise TypeError("Either 'probs' or 'logits' must be specified, but not both.") +@constexpr +def raise_broadcast_error(shape_a, shape_b): + raise ValueError(f"Shape {shape_a} and {shape_b} is not broadcastable.") + @constexpr def raise_not_impl_error(name): raise ValueError( diff --git a/mindspore/nn/probability/distribution/categorical.py b/mindspore/nn/probability/distribution/categorical.py index 98901058c68..b9c65f1a454 100644 --- a/mindspore/nn/probability/distribution/categorical.py +++ b/mindspore/nn/probability/distribution/categorical.py @@ -17,7 +17,8 @@ from mindspore.ops import operations as P import mindspore.nn as nn from mindspore.common import dtype as mstype from .distribution import Distribution -from ._utils.utils import logits_to_probs, probs_to_logits, check_type, check_tensor_type, cast_to_tensor, raise_probs_logits_error +from ._utils.utils import logits_to_probs, probs_to_logits, check_type, cast_to_tensor, \ + raise_probs_logits_error class Categorical(Distribution): @@ -25,7 +26,7 @@ class Categorical(Distribution): Creates a categorical distribution parameterized by either probs or logits (but not both). Args: - probs (Tensor, list, numpy.ndarray, Parameter, float): event probabilities. + probs (Tensor, list, numpy.ndarray, Parameter): event probabilities. logits (Tensor, list, numpy.ndarray, Parameter, float): event log-odds. seed (int): seed to use in sampling. Default: 0. dtype (mindspore.dtype): type of the distribution. Default: mstype.int32. @@ -77,6 +78,7 @@ class Categorical(Distribution): if (probs is None) == (logits is None): raise_probs_logits_error() self.reduce_sum = P.ReduceSum(keep_dims=True) + self.reduce_sum1 = P.ReduceSum(keep_dims=False) self.log = P.Log() self.exp = P.Exp() self.shape = P.Shape() @@ -88,6 +90,7 @@ class Categorical(Distribution): self.expandim = P.ExpandDims() self.gather = P.GatherNd() self.concat = P.Concat(-1) + self.transpose = P.Transpose() if probs is not None: self._probs = cast_to_tensor(probs, mstype.float32) input_sum = self.reduce_sum(self._probs, -1) @@ -102,8 +105,8 @@ class Categorical(Distribution): self._param = self._logits self._num_events = self.shape(self._param)[-1] self._param2d = self.reshape(self._param, (-1, self._num_events)) - self._batch_shape = self.shape(self._param2d)[0] - + self._batch_shape = self.shape(self._param)[:-1] + self._batch_shape_n = (1,) * len(self._batch_shape) @property def logits(self): @@ -130,72 +133,35 @@ class Categorical(Distribution): Tensor, shape is shape(probs)[:-1] + sample_shape """ self.checktuple(sample_shape, 'shape') - if sample_shape == (): - sample_shape = (1,) num_sample = 1 for i in sample_shape: num_sample *= i probs_2d = self.reshape(self._probs, (-1, self._num_events)) samples = self.mutinomial(probs_2d, num_sample) + samples = self.transpose(samples, (1, 0)) extend_shape = sample_shape if len(self.shape(self._probs)) > 1: extend_shape = sample_shape + self.shape(self._probs)[:-1] return self.cast(self.reshape(samples, extend_shape), self.dtype) - def _broad_cast_shape(self, a, b): - """ - Broadcast Tensor shape. - - Args: - a (Tensor): A Tensor need to Broadcast. - b (Tensor): Another Tensor need to Broadcast. - - Returns: - Tuple, Broadcast shape. - """ - shape_a = self.shape(a) - shape_b = self.shape(b) - size_a = len(shape_a) - size_b = len(shape_b) - if size_a > size_b: - size = size_a - shape_out = list(shape_a) - shape_short = list(shape_b) - diff_size = size_a - size_b - else: - size = size_b - shape_out = list(shape_b) - shape_short = list(shape_a) - diff_size = size_b - size_a - for i in range(diff_size, size): - if shape_out[i] == shape_short[i - diff_size]: - continue - if shape_out[i] == 1 or shape_short[i - diff_size] == 1: - shape_out[i] = shape_out[i] * shape_short[i - diff_size] - else: - raise ValueError(f"Shape {shape_a} and {shape_b} is not broadcastable.") - return tuple(shape_out) - def _log_prob(self, value): r""" Evaluate log probability. Args: - value (Tensor): value to be evaluated. The dtype could be mstype.float32, bool, mstype.int32. + value (Tensor): value to be evaluated. """ - if value is not None: - check_tensor_type("value", value, [mstype.float32, bool, mstype.int32]) - value = self.expandim(self.cast(value, mstype.float32), -1) - broad_shape = self._broad_cast_shape(value, self._logits) - broad = P.BroadcastTo(broad_shape) - logits_pmf = self.reshape(broad(self._logits), (-1, broad_shape[-1])) - value = self.reshape(broad(value)[..., :1], (-1, 1)) - index = nn.Range(0., self.shape(value)[0], 1)() - index = self.reshape(index, (-1, 1)) - value = self.concat((index, value)) - value = self.cast(value, mstype.int32) - return self.reshape(self.gather(logits_pmf, value), broad_shape[:-1]) - return None + value = self._check_value(value, 'value') + value = self.expandim(self.cast(value, mstype.float32), -1) + broad_shape = self.shape(value + self._logits) + broad = P.BroadcastTo(broad_shape) + logits_pmf = self.reshape(broad(self._logits), (-1, broad_shape[-1])) + value = self.reshape(broad(value)[..., :1], (-1, 1)) + index = nn.Range(0., self.shape(value)[0], 1)() + index = self.reshape(index, (-1, 1)) + value = self.concat((index, value)) + value = self.cast(value, mstype.int32) + return self.reshape(self.gather(logits_pmf, value), broad_shape[:-1]) def _entropy(self): r""" @@ -205,7 +171,7 @@ class Categorical(Distribution): H(X) = -\sum(logits * probs) """ p_log_p = self._logits * self._probs - return self.reduce_sum(-p_log_p, -1) + return self.reduce_sum1(-p_log_p, -1) def enumerate_support(self, expand=True): r""" @@ -213,8 +179,8 @@ class Categorical(Distribution): """ num_events = self._num_events values = nn.Range(0., num_events, 1)() - values = self.reshape(values, (num_events, 1)) + values = self.reshape(values, (num_events,) + self._batch_shape_n) if expand: - values = P.BroadcastTo((num_events, self._batch_shape))(values) + values = P.BroadcastTo((num_events,) + self._batch_shape)(values) values = self.cast(values, mstype.int32) return values