!35403 [feat] [assistant] [ops] [I4ZZUC,I4ZZUD] New GPU operator implementation, include Mvlgamma,MvlgammaGrad

Merge pull request !35403 from 路雄博/Mvlgamma
This commit is contained in:
i-robot 2022-07-21 02:39:41 +00:00 committed by Gitee
commit f6e5506f52
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
17 changed files with 811 additions and 31 deletions

View File

@ -0,0 +1,114 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/arrays/mvlgamma_gpu_kernel.h"
namespace mindspore {
namespace kernel {
bool MvlgammaGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr_ = std::dynamic_pointer_cast<ops::Mvlgamma>(base_operator);
kernel_name_ = kernel_ptr_->name();
if (inputs.empty() || outputs.empty()) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
return false;
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_
<< "', the kernel type should be in [float32, float64], but got: " << kernel_attr << ".";
return false;
}
kernel_func_ = func_list_[index].second;
unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
p_ = kernel_ptr_->get_p();
if (p_ < 1) {
MS_LOG(ERROR) << "For " << kernel_name_ << ", the attr 'p' has to be greater than or equal to 1, "
<< "but got " << p_ << ".";
return false;
}
return true;
}
int MvlgammaGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
for (const auto &input : inputs) {
// If any input shape contains -1, means input shape is dynamic, so just
// return do nothing.
auto input_shape = input->GetShapeVector();
if (!IsValidShape(input_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
ResetResource();
std::vector<int64_t> input_shape = std::vector<int64_t>(inputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(),
inputs.at(kIndex0)->GetDeviceShapeAdaptively().end());
input_elements_ = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<int64_t>());
if (input_elements_ == 0) {
is_null_input_ = true;
}
int64_t input_dims = input_shape.size();
if (input_dims < 1) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'x' should be at least 1-D, but got " << input_dims
<< "-D.";
return KRET_RESIZE_FAILED;
}
size_t input_size = input_elements_ * unit_size_;
input_size_list_.push_back(input_size);
output_size_list_.push_back(input_size);
workspace_size_list_.push_back(sizeof(bool));
return KRET_OK;
}
template <typename T>
bool MvlgammaGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
T *input = GetDeviceAddress<T>(inputs, 0);
T *output = GetDeviceAddress<T>(outputs, 0);
bool *valid_d = GetDeviceAddress<bool>(workspace, 0);
bool valid = true;
bool *valid_h = &valid;
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(valid_d, valid_h, sizeof(bool), cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cudaMemcpyAsync valid Host to Device failed.");
CalMvlgamma(valid_d, input_elements_, input, p_, output, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream_));
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(valid_h, valid_d, sizeof(bool), cudaMemcpyDeviceToHost,
reinterpret_cast<cudaStream_t>(cuda_stream_)),
"cudaMemcpyAsync valid Device to Host failed.");
if (!*valid_h) {
MS_LOG(ERROR) << "For " << kernel_name_ << ", all element must be greater than (p-1)/2.";
return false;
}
return true;
}
std::vector<std::pair<KernelAttr, MvlgammaGpuKernelMod::MvlgammaFunc>> MvlgammaGpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&MvlgammaGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&MvlgammaGpuKernelMod::LaunchKernel<double>}};
std::vector<KernelAttr> MvlgammaGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, MvlgammaFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Mvlgamma, MvlgammaGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,89 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_MVLGAMMA_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_MVLGAMMA_GPU_KERNEL_H_
#include <algorithm>
#include <functional>
#include <iostream>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "abstract/utils.h"
#include "mindspore/core/ops/mvlgamma.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/mvlgamma_impl.cuh"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class MvlgammaGpuKernelMod : public NativeGpuKernelMod {
public:
MvlgammaGpuKernelMod() { ResetResource(); }
~MvlgammaGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
if (is_null_input_) {
return true;
}
cuda_stream_ = cuda_stream;
return kernel_func_(this, inputs, workspace, outputs);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
protected:
void ResetResource() noexcept {
is_null_input_ = false;
input_elements_ = 0;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
using MvlgammaFunc =
std::function<bool(MvlgammaGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
private:
int p_{0};
size_t unit_size_{1};
size_t input_elements_;
MvlgammaFunc kernel_func_{};
bool is_null_input_{false};
void *cuda_stream_{nullptr};
static std::vector<std::pair<KernelAttr, MvlgammaFunc>> func_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_mvlgamma_GPU_KERNEL_H_

View File

@ -0,0 +1,98 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/arrays/mvlgamma_grad_gpu_kernel.h"
namespace mindspore {
namespace kernel {
bool MvlgammaGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr_ = std::dynamic_pointer_cast<ops::MvlgammaGrad>(base_operator);
kernel_name_ = kernel_ptr_->name();
if (inputs.empty() || outputs.empty()) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
return false;
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
p_ = kernel_ptr_->get_p();
return true;
}
int MvlgammaGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
for (const auto &input : inputs) {
// If any input shape contains -1, means input shape is dynamic, so just
// return do nothing.
auto input_shape = input->GetShapeVector();
if (!IsValidShape(input_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
ResetResource();
std::vector<int64_t> input_shape = std::vector<int64_t>(inputs.at(kIndex1)->GetDeviceShapeAdaptively().begin(),
inputs.at(kIndex1)->GetDeviceShapeAdaptively().end());
input_elements_ = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<int64_t>());
if (input_elements_ == 0) {
is_null_input_ = true;
}
int64_t input_dims = input_shape.size();
if (input_dims < 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'x' should be at least 1-D, but got "
<< input_dims << "-D.";
return KRET_RESIZE_FAILED;
}
size_t input_size = input_elements_ * unit_size_;
input_size_list_.push_back(input_size);
input_size_list_.push_back(input_size);
output_size_list_.push_back(input_size);
return KRET_OK;
}
template <typename T>
bool MvlgammaGradGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
T *y_grad = GetDeviceAddress<T>(inputs, 0);
T *x = GetDeviceAddress<T>(inputs, 1);
T *output = GetDeviceAddress<T>(outputs, 0);
CalMvlgammaGrad(input_elements_, y_grad, x, p_, output, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream_));
return true;
}
std::vector<std::pair<KernelAttr, MvlgammaGradGpuKernelMod::MvlgammaGradFunc>> MvlgammaGradGpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&MvlgammaGradGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&MvlgammaGradGpuKernelMod::LaunchKernel<double>}};
std::vector<KernelAttr> MvlgammaGradGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, MvlgammaGradFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, MvlgammaGrad, MvlgammaGradGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,86 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_MVLGAMMA_GRAD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_MVLGAMMA_GRAD_GPU_KERNEL_H_
#include <algorithm>
#include <functional>
#include <iostream>
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "abstract/utils.h"
#include "mindspore/core/ops/grad/mvlgamma_grad.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/mvlgamma_grad_impl.cuh"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/factory/ms_factory.h"
namespace mindspore {
namespace kernel {
class MvlgammaGradGpuKernelMod : public NativeGpuKernelMod {
public:
MvlgammaGradGpuKernelMod() { ResetResource(); }
~MvlgammaGradGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
if (is_null_input_) {
return true;
}
cuda_stream_ = cuda_stream;
return kernel_func_(this, inputs, workspace, outputs);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
protected:
void ResetResource() noexcept {
is_null_input_ = false;
input_elements_ = 0;
input_size_list_.clear();
output_size_list_.clear();
}
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
using MvlgammaGradFunc =
std::function<bool(MvlgammaGradGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
private:
int p_{0};
size_t input_elements_;
size_t unit_size_{1};
MvlgammaGradFunc kernel_func_{};
bool is_null_input_{false};
void *cuda_stream_{nullptr};
static std::vector<std::pair<KernelAttr, MvlgammaGradFunc>> func_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_MVLGAMMA_GRAD_GPU_KERNEL_H_

View File

@ -0,0 +1,83 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <limits>
#include <algorithm>
#include "mvlgamma_grad_impl.cuh"
#define PI 3.141592653589793
__constant__ double kLanczosCoefficientsd[8] = {
676.520368121885098567009190444019, -1259.13921672240287047156078755283,
771.3234287776530788486528258894, -176.61502916214059906584551354,
12.507343278686904814458936853, -0.13857109526572011689554707,
9.984369578019570859563e-6, 1.50563273514931155834e-7};
template <typename T>
__device__ __forceinline__ T CalNumDivDenom(T x) {
T num = 0;
T denom = 0.99999999999980993227684700473478;
for (int j = 0; j < 8; ++j) {
num -= kLanczosCoefficientsd[j] / ((x + j + 1) * (x + j + 1));
denom += kLanczosCoefficientsd[j] / (x + j + 1);
}
return num / denom;
}
template <typename T>
__global__ void MvlgammaGrad(const size_t size, const T *y_grad, const T *x, const int p, T *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
T kLanczosGamma = 7;
T log_lanczos_gamma_plus_one_half = std::log(7.5);
T temp = 0;
T cur_input = 0;
T num_div_denom = 0;
for (int i = 0; i < p; i++) {
cur_input = x[pos] - 0.5 * i;
if (cur_input < 0 && cur_input == std::floor(cur_input)) {
temp += std::numeric_limits<T>::quiet_NaN();
break;
}
if (cur_input < 0.5) {
num_div_denom = CalNumDivDenom(-cur_input);
temp += (log_lanczos_gamma_plus_one_half + std::log1pf((-cur_input) / (kLanczosGamma + 0.5))) + num_div_denom -
kLanczosGamma / (kLanczosGamma + 0.5 - cur_input);
temp -= PI / std::tan(PI * (cur_input + std::abs(std::floor(cur_input + 0.5))));
} else {
num_div_denom = CalNumDivDenom(cur_input - 1);
temp += (log_lanczos_gamma_plus_one_half + std::log1pf((cur_input - 1) / (kLanczosGamma + 0.5))) + num_div_denom
- kLanczosGamma / (kLanczosGamma + 0.5 + cur_input - 1);
}
}
output[pos] = temp * y_grad[pos];
}
}
template <typename T>
void CalMvlgammaGrad(const size_t size, const T *y_grad, const T *x, const int p, T *output, const uint32_t &device_id,
cudaStream_t cuda_stream) {
int thread_num = 256 < size ? 256 : size;
cudaDeviceProp prop;
(void)cudaGetDeviceProperties(&prop, device_id);
int max_blocks = prop.multiProcessorCount;
int block_num = std::min(static_cast<int>(((size - 1) / thread_num) + 1), max_blocks);
MvlgammaGrad<<<block_num, thread_num, 0, cuda_stream>>>(size, y_grad, x, p, output);
return;
}
template
CUDA_LIB_EXPORT void CalMvlgammaGrad<float>(const size_t size, const float *y_grad, const float *x, const int p,
float *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void CalMvlgammaGrad<double>(const size_t size, const double *y_grad, const double *x, const int p,
double *output, const uint32_t &device_id, cudaStream_t cuda_stream);

View File

@ -0,0 +1,25 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MVLGAMMA_GRAD_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MVLGAMMA_GRAD_IMPL_CUH_
#include <math.h>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
template <typename T>
void CalMvlgammaGrad(const size_t size, const T *y_grad, const T *x, const int p, T *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MVLGAMMA_GRAD_IMPL_CUH_

View File

@ -0,0 +1,60 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mvlgamma_impl.cuh"
template <typename T>
__global__ void Valid(bool *valid, const size_t size, const T *input, const int p) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
if (input[pos] <= (0.5 * (p - 1))) {
*valid = false;
return;
}
}
return;
}
template <typename T>
__global__ void Mvlgamma(const size_t size, const T *input, const int p, T *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
T temp = 0;
for (int i = 1; i <= p; i++) {
temp += lgamma(input[pos] - static_cast<T>((i - 1) * 0.5));
}
output[pos] = temp + static_cast<T>(p * (p - 1) * 0.25 * log(M_PI));
}
return;
}
template <typename T>
void CalMvlgamma(bool *valid, const size_t size, const T *input, const int p, T *output, const uint32_t &device_id,
cudaStream_t cuda_stream) {
Valid<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>(valid, size, input, p);
bool host_valid = true;
cudaMemcpyFromSymbol(&host_valid, valid, sizeof(bool));
if (!host_valid) {
return;
}
Mvlgamma<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>(size, input, p, output);
return;
}
template
CUDA_LIB_EXPORT void CalMvlgamma<float>(bool *valid, const size_t size, const float *input, const int p,
float *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void CalMvlgamma<double>(bool *valid, const size_t size, const double *input, const int p,
double *output, const uint32_t &device_id, cudaStream_t cuda_stream);

View File

@ -0,0 +1,24 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MVLGAMMA_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MVLGAMMA_IMPL_CUH_
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
template <typename T>
void CalMvlgamma(bool *valid, const size_t size, const T *input, const int p, T *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MVLGAMMA_IMPL_CUH_

View File

@ -13,16 +13,14 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <map>
#include <string>
#include <algorithm>
#include <set>
#include <string>
#include "ops/grad/mvlgamma_grad.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
@ -30,21 +28,41 @@ namespace {
abstract::ShapePtr MvlgammaGradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto y_grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
return std::make_shared<abstract::Shape>(y_grad_shape);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
}
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto first_input_shape_min = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMinShape];
auto first_input_shape_max = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMaxShape];
auto second_input_shape_min = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kMinShape];
auto second_input_shape_max = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kMaxShape];
if (first_input_shape_min.empty() || first_input_shape_min.empty() || second_input_shape_min.empty() ||
second_input_shape_max.empty()) {
return std::make_shared<abstract::Shape>(x_shape);
}
ShapeVector min_shape = {first_input_shape_min[0], second_input_shape_min[0]};
ShapeVector max_shape = {first_input_shape_max[0], second_input_shape_max[0]};
return std::make_shared<abstract::Shape>(x_shape, min_shape, max_shape);
}
TypePtr MvlgammaGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
std::map<std::string, TypePtr> types;
(void)types.emplace("y_grad", input_args[kInputIndex0]->BuildType());
(void)types.emplace("x", input_args[kInputIndex1]->BuildType());
(void)types.emplace("y_grad", input_args[0]->BuildType());
(void)types.emplace("x", input_args[1]->BuildType());
const std::set<TypePtr> valid_types = {kFloat32, kFloat64};
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
}
} // namespace
MIND_API_OPERATOR_IMPL(MvlgammaGrad, BaseOperator);
void MvlgammaGrad::Init(const int64_t p) { set_p(p); }
void MvlgammaGrad::set_p(const int64_t p) { (void)this->AddAttr(kP, api::MakeValue(p)); }
int64_t MvlgammaGrad::get_p() const { return GetValue<int64_t>(GetAttr(kP)); }
AbstractBasePtr MvlgammaGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
@ -54,7 +72,7 @@ AbstractBasePtr MvlgammaGradInfer(const abstract::AnalysisEnginePtr &, const Pri
auto infer_shape = MvlgammaGradInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
MIND_API_OPERATOR_IMPL(MvlgammaGrad, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(MvlgammaGrad, prim::kPrimMvlgammaGrad, MvlgammaGradInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,9 +16,14 @@
#ifndef MINDSPORE_CORE_OPS_MVLGAMMA_GRAD_H_
#define MINDSPORE_CORE_OPS_MVLGAMMA_GRAD_H_
#include <map>
#include <string>
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
@ -29,10 +34,16 @@ class MIND_API MvlgammaGrad : public BaseOperator {
public:
MIND_API_BASE_MEMBER(MvlgammaGrad);
MvlgammaGrad() : BaseOperator(kNameMvlgammaGrad) { InitIOName({"y_grad", "x"}, {"x_grad"}); }
void Init(const int64_t p = 0);
/// \brief Set p.
void set_p(const int64_t p);
int64_t get_p() const;
// ~MvlgammaGrad() = default;
// MS_DECLARE_PARENT(MvlgammaGrad, PrimitiveC);
};
abstract::AbstractBasePtr MvlgammaGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
const std::vector<AbstractBasePtr> &input_args);
using PrimMvlgammaGradPtr = std::shared_ptr<MvlgammaGrad>;
} // namespace ops
} // namespace mindspore

View File

@ -14,34 +14,46 @@
* limitations under the License.
*/
#include <string>
#include <algorithm>
#include <set>
#include <string>
#include "ops/mvlgamma.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "ops/op_utils.h"
#include "mindapi/src/helper.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr MvlgammaInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShapeTrack())[kShape];
return std::make_shared<abstract::Shape>(in_shape);
MS_EXCEPTION_IF_NULL(input_args[0]);
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
auto first_input_shape_min = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMinShape];
auto first_input_shape_max = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMaxShape];
if (first_input_shape_min.empty() || first_input_shape_max.empty()) {
return std::make_shared<abstract::Shape>(in_shape);
}
ShapeVector min_shape = {first_input_shape_min[0]};
ShapeVector max_shape = {first_input_shape_max[0]};
return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape);
}
TypePtr MvlgammaInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
auto input_type = input_args[kInputIndex0]->BuildType();
TypePtr MvlgammaInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto input_type = input_args[0]->BuildType();
const std::set<TypePtr> valid_types = {kFloat32, kFloat64};
return CheckAndConvertUtils::CheckTensorTypeValid("x", input_type, valid_types, prim->name());
return CheckAndConvertUtils::CheckTensorTypeValid("x", input_type, valid_types, primitive->name());
}
} // namespace
MIND_API_OPERATOR_IMPL(Mvlgamma, BaseOperator);
void Mvlgamma::Init(const int64_t p) { set_p(p); }
void Mvlgamma::set_p(const int64_t p) { (void)this->AddAttr(kP, api::MakeValue(p)); }
int64_t Mvlgamma::get_p() const { return GetValue<int64_t>(GetAttr(kP)); }
AbstractBasePtr MvlgammaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
@ -52,6 +64,7 @@ AbstractBasePtr MvlgammaInfer(const abstract::AnalysisEnginePtr &, const Primiti
return abstract::MakeAbstract(infer_shape, infer_type);
}
MIND_API_OPERATOR_IMPL(Mvlgamma, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(Mvlgamma, prim::kPrimMvlgamma, MvlgammaInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -16,9 +16,14 @@
#ifndef MINDSPORE_CORE_OPS_MVLGAMMA_H_
#define MINDSPORE_CORE_OPS_MVLGAMMA_H_
#include <map>
#include <string>
#include <vector>
#include <memory>
#include "ops/primitive_c.h"
#include "ops/op_utils.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
@ -32,10 +37,15 @@ class MIND_API Mvlgamma : public BaseOperator {
MIND_API_BASE_MEMBER(Mvlgamma);
/// \brief Constructor.
Mvlgamma() : BaseOperator(kNameMvlgamma) { InitIOName({"x"}, {"y"}); }
/// \brief Init.
void Init(const int64_t p = 0);
/// \brief Set p.
void set_p(const int64_t p);
int64_t get_p() const;
};
abstract::AbstractBasePtr MvlgammaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
const std::vector<AbstractBasePtr> &input_args);
using PrimMvlgammaPtr = std::shared_ptr<Mvlgamma>;
} // namespace ops
} // namespace mindspore

View File

@ -157,6 +157,7 @@ constexpr auto kOutputPaddings = "output_paddings";
constexpr auto kOutputType = "output_type";
constexpr auto kOutIdx = "out_idx";
constexpr auto kOutQuantized = "out_quantized";
constexpr auto kMvlgammaP = "mvlgamma_p";
constexpr auto kP = "p";
constexpr auto kPad = "pad";
constexpr auto kPadding = "padding";

View File

@ -47,7 +47,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedFill, MaskedSelect, SearchSorted,
TensorScatterUpdate, TensorScatterMax, TensorScatterMin, TensorScatterAdd, TensorScatterSub,
TensorScatterMul, TensorScatterDiv, ExtractVolumePatches, LowerBound,
UpperBound, Cummax)
UpperBound, Cummax, Mvlgamma)
from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, _AllSwap, ReduceScatter,
Broadcast,
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
@ -275,6 +275,7 @@ __all__ = [
'Multinomial',
'Gamma',
'RandomGamma',
'Mvlgamma',
'Poisson',
'UniformInt',
'UniformReal',

View File

@ -3893,7 +3893,7 @@ class Mvlgamma(Primitive):
ValueError: If all elements of `x` are not greater than (p-1)/2.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``CPU`` ``GPU``
Examples:
>>> x = Tensor(np.array([[3, 4, 5], [4, 2, 6]]), mindspore.float32)

View File

@ -0,0 +1,67 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import torch
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops.operations._grad_ops as P
from mindspore import Tensor
from mindspore.common.api import ms_function
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class MvlgammaGradNet(nn.Cell):
def __init__(self, nptype, p):
super(MvlgammaGradNet, self).__init__()
self.mvlgamma_grad = P.MvlgammaGrad(p=p)
self.y_grad_np = np.array([[3, 4, 5], [4, 2, 6]]).astype(nptype)
self.y_grad = Tensor(self.y_grad_np)
self.x_np = np.array([[3, 4, 5], [4, 2, 6]]).astype(nptype)
self.x = Tensor(self.x_np)
@ms_function
def construct(self):
return self.mvlgamma_grad(self.y_grad, self.x)
def mvlgamma_grad_torch(y_grad_np, x_np, p):
x_torch = torch.tensor(x_np, requires_grad=True)
grad_torch = torch.tensor(y_grad_np)
out_torch = torch.mvlgamma(x_torch, p=p)
out_torch.backward(grad_torch)
dx = x_torch.grad
return dx.numpy()
def mvlgamma_grad(nptype, p):
mvlgamma_ = MvlgammaGradNet(nptype, p)
mvlgamma_output = mvlgamma_().asnumpy()
mvlgamma_expect = mvlgamma_grad_torch(mvlgamma_.y_grad_np, mvlgamma_.x_np, p).astype(nptype)
assert np.allclose(mvlgamma_output, mvlgamma_expect, 1e-4, 1e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_mvlgamma_graph_float32():
"""
Feature: ALL To ALL
Description: test cases for MvlgammaGrad
Expectation: the result match to numpy
"""
mvlgamma_grad(np.float32, 2)

View File

@ -0,0 +1,80 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import torch
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops.operations.array_ops as P
from mindspore import Tensor
from mindspore.common.api import ms_function
class MvlgammaNet(nn.Cell):
def __init__(self, nptype, p):
super(MvlgammaNet, self).__init__()
self.mvlgamma = P.Mvlgamma(p=p)
self.a_np = np.array([[3, 4, 5], [4, 2, 6]]).astype(nptype)
self.a = Tensor(self.a_np)
@ms_function
def construct(self):
return self.mvlgamma(self.a)
def mvlgamma_torch(a, d):
return torch.mvlgamma(torch.tensor(a), d).numpy()
def mvlgamma(nptype, p):
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
mvlgamma_ = MvlgammaNet(nptype, p)
mvlgamma_output = mvlgamma_().asnumpy()
mvlgamma_expect = mvlgamma_torch(mvlgamma_.a_np, p).astype(nptype)
assert np.allclose(mvlgamma_output, mvlgamma_expect)
def mvlgamma_pynative(nptype, p):
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
mvlgamma_ = MvlgammaNet(nptype, p)
mvlgamma_output = mvlgamma_().asnumpy()
mvlgamma_expect = mvlgamma_torch(mvlgamma_.a_np, p).astype(nptype)
assert np.allclose(mvlgamma_output, mvlgamma_expect)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_mvlgamma_graph_float32():
"""
Feature: ALL To ALL
Description: test cases for Mvlgamma
Expectation: the result match to numpy
"""
mvlgamma(np.float32, 3)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_mvlgamma_pynative_float64():
"""
Feature: ALL To ALL
Description: test cases for Mvlgamma
Expectation: the result match to numpy
"""
mvlgamma_pynative(np.float64, 3)