forked from mindspore-Ecosystem/mindspore
!35403 [feat] [assistant] [ops] [I4ZZUC,I4ZZUD] New GPU operator implementation, include Mvlgamma,MvlgammaGrad
Merge pull request !35403 from 路雄博/Mvlgamma
This commit is contained in:
commit
f6e5506f52
|
@ -0,0 +1,114 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/arrays/mvlgamma_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
bool MvlgammaGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr_ = std::dynamic_pointer_cast<ops::Mvlgamma>(base_operator);
|
||||
kernel_name_ = kernel_ptr_->name();
|
||||
if (inputs.empty() || outputs.empty()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
|
||||
return false;
|
||||
}
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', the kernel type should be in [float32, float64], but got: " << kernel_attr << ".";
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
|
||||
p_ = kernel_ptr_->get_p();
|
||||
if (p_ < 1) {
|
||||
MS_LOG(ERROR) << "For " << kernel_name_ << ", the attr 'p' has to be greater than or equal to 1, "
|
||||
<< "but got " << p_ << ".";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int MvlgammaGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
for (const auto &input : inputs) {
|
||||
// If any input shape contains -1, means input shape is dynamic, so just
|
||||
// return do nothing.
|
||||
auto input_shape = input->GetShapeVector();
|
||||
if (!IsValidShape(input_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
ResetResource();
|
||||
std::vector<int64_t> input_shape = std::vector<int64_t>(inputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(),
|
||||
inputs.at(kIndex0)->GetDeviceShapeAdaptively().end());
|
||||
input_elements_ = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<int64_t>());
|
||||
if (input_elements_ == 0) {
|
||||
is_null_input_ = true;
|
||||
}
|
||||
int64_t input_dims = input_shape.size();
|
||||
if (input_dims < 1) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'x' should be at least 1-D, but got " << input_dims
|
||||
<< "-D.";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
size_t input_size = input_elements_ * unit_size_;
|
||||
input_size_list_.push_back(input_size);
|
||||
output_size_list_.push_back(input_size);
|
||||
workspace_size_list_.push_back(sizeof(bool));
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool MvlgammaGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
bool *valid_d = GetDeviceAddress<bool>(workspace, 0);
|
||||
bool valid = true;
|
||||
bool *valid_h = &valid;
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(valid_d, valid_h, sizeof(bool), cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
||||
"cudaMemcpyAsync valid Host to Device failed.");
|
||||
CalMvlgamma(valid_d, input_elements_, input, p_, output, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaMemcpyAsync(valid_h, valid_d, sizeof(bool), cudaMemcpyDeviceToHost,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
||||
"cudaMemcpyAsync valid Device to Host failed.");
|
||||
if (!*valid_h) {
|
||||
MS_LOG(ERROR) << "For " << kernel_name_ << ", all element must be greater than (p-1)/2.";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, MvlgammaGpuKernelMod::MvlgammaFunc>> MvlgammaGpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&MvlgammaGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&MvlgammaGpuKernelMod::LaunchKernel<double>}};
|
||||
|
||||
std::vector<KernelAttr> MvlgammaGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, MvlgammaFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Mvlgamma, MvlgammaGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,89 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_MVLGAMMA_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_MVLGAMMA_GPU_KERNEL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "abstract/utils.h"
|
||||
#include "mindspore/core/ops/mvlgamma.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/mvlgamma_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class MvlgammaGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
MvlgammaGpuKernelMod() { ResetResource(); }
|
||||
~MvlgammaGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
cuda_stream_ = cuda_stream;
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
protected:
|
||||
void ResetResource() noexcept {
|
||||
is_null_input_ = false;
|
||||
input_elements_ = 0;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
using MvlgammaFunc =
|
||||
std::function<bool(MvlgammaGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
|
||||
private:
|
||||
int p_{0};
|
||||
size_t unit_size_{1};
|
||||
size_t input_elements_;
|
||||
MvlgammaFunc kernel_func_{};
|
||||
bool is_null_input_{false};
|
||||
void *cuda_stream_{nullptr};
|
||||
static std::vector<std::pair<KernelAttr, MvlgammaFunc>> func_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_mvlgamma_GPU_KERNEL_H_
|
|
@ -0,0 +1,98 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/arrays/mvlgamma_grad_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
bool MvlgammaGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr_ = std::dynamic_pointer_cast<ops::MvlgammaGrad>(base_operator);
|
||||
kernel_name_ = kernel_ptr_->name();
|
||||
if (inputs.empty() || outputs.empty()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
|
||||
return false;
|
||||
}
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
|
||||
p_ = kernel_ptr_->get_p();
|
||||
return true;
|
||||
}
|
||||
|
||||
int MvlgammaGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
for (const auto &input : inputs) {
|
||||
// If any input shape contains -1, means input shape is dynamic, so just
|
||||
// return do nothing.
|
||||
auto input_shape = input->GetShapeVector();
|
||||
if (!IsValidShape(input_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
ResetResource();
|
||||
std::vector<int64_t> input_shape = std::vector<int64_t>(inputs.at(kIndex1)->GetDeviceShapeAdaptively().begin(),
|
||||
inputs.at(kIndex1)->GetDeviceShapeAdaptively().end());
|
||||
input_elements_ = std::accumulate(input_shape.begin(), input_shape.end(), 1, std::multiplies<int64_t>());
|
||||
if (input_elements_ == 0) {
|
||||
is_null_input_ = true;
|
||||
}
|
||||
int64_t input_dims = input_shape.size();
|
||||
if (input_dims < 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of 'x' should be at least 1-D, but got "
|
||||
<< input_dims << "-D.";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
size_t input_size = input_elements_ * unit_size_;
|
||||
|
||||
input_size_list_.push_back(input_size);
|
||||
input_size_list_.push_back(input_size);
|
||||
output_size_list_.push_back(input_size);
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool MvlgammaGradGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
T *y_grad = GetDeviceAddress<T>(inputs, 0);
|
||||
T *x = GetDeviceAddress<T>(inputs, 1);
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
CalMvlgammaGrad(input_elements_, y_grad, x, p_, output, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, MvlgammaGradGpuKernelMod::MvlgammaGradFunc>> MvlgammaGradGpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&MvlgammaGradGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&MvlgammaGradGpuKernelMod::LaunchKernel<double>}};
|
||||
|
||||
std::vector<KernelAttr> MvlgammaGradGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, MvlgammaGradFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, MvlgammaGrad, MvlgammaGradGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,86 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_MVLGAMMA_GRAD_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_MVLGAMMA_GRAD_GPU_KERNEL_H_
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
#include "abstract/utils.h"
|
||||
#include "mindspore/core/ops/grad/mvlgamma_grad.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/mvlgamma_grad_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class MvlgammaGradGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
MvlgammaGradGpuKernelMod() { ResetResource(); }
|
||||
~MvlgammaGradGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
cuda_stream_ = cuda_stream;
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
protected:
|
||||
void ResetResource() noexcept {
|
||||
is_null_input_ = false;
|
||||
input_elements_ = 0;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
using MvlgammaGradFunc =
|
||||
std::function<bool(MvlgammaGradGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
|
||||
private:
|
||||
int p_{0};
|
||||
size_t input_elements_;
|
||||
size_t unit_size_{1};
|
||||
MvlgammaGradFunc kernel_func_{};
|
||||
bool is_null_input_{false};
|
||||
void *cuda_stream_{nullptr};
|
||||
static std::vector<std::pair<KernelAttr, MvlgammaGradFunc>> func_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_MVLGAMMA_GRAD_GPU_KERNEL_H_
|
|
@ -0,0 +1,83 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <limits>
|
||||
#include <algorithm>
|
||||
#include "mvlgamma_grad_impl.cuh"
|
||||
#define PI 3.141592653589793
|
||||
|
||||
__constant__ double kLanczosCoefficientsd[8] = {
|
||||
676.520368121885098567009190444019, -1259.13921672240287047156078755283,
|
||||
771.3234287776530788486528258894, -176.61502916214059906584551354,
|
||||
12.507343278686904814458936853, -0.13857109526572011689554707,
|
||||
9.984369578019570859563e-6, 1.50563273514931155834e-7};
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T CalNumDivDenom(T x) {
|
||||
T num = 0;
|
||||
T denom = 0.99999999999980993227684700473478;
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
num -= kLanczosCoefficientsd[j] / ((x + j + 1) * (x + j + 1));
|
||||
denom += kLanczosCoefficientsd[j] / (x + j + 1);
|
||||
}
|
||||
return num / denom;
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void MvlgammaGrad(const size_t size, const T *y_grad, const T *x, const int p, T *output) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
T kLanczosGamma = 7;
|
||||
T log_lanczos_gamma_plus_one_half = std::log(7.5);
|
||||
T temp = 0;
|
||||
T cur_input = 0;
|
||||
T num_div_denom = 0;
|
||||
for (int i = 0; i < p; i++) {
|
||||
cur_input = x[pos] - 0.5 * i;
|
||||
if (cur_input < 0 && cur_input == std::floor(cur_input)) {
|
||||
temp += std::numeric_limits<T>::quiet_NaN();
|
||||
break;
|
||||
}
|
||||
if (cur_input < 0.5) {
|
||||
num_div_denom = CalNumDivDenom(-cur_input);
|
||||
temp += (log_lanczos_gamma_plus_one_half + std::log1pf((-cur_input) / (kLanczosGamma + 0.5))) + num_div_denom -
|
||||
kLanczosGamma / (kLanczosGamma + 0.5 - cur_input);
|
||||
temp -= PI / std::tan(PI * (cur_input + std::abs(std::floor(cur_input + 0.5))));
|
||||
} else {
|
||||
num_div_denom = CalNumDivDenom(cur_input - 1);
|
||||
temp += (log_lanczos_gamma_plus_one_half + std::log1pf((cur_input - 1) / (kLanczosGamma + 0.5))) + num_div_denom
|
||||
- kLanczosGamma / (kLanczosGamma + 0.5 + cur_input - 1);
|
||||
}
|
||||
}
|
||||
output[pos] = temp * y_grad[pos];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalMvlgammaGrad(const size_t size, const T *y_grad, const T *x, const int p, T *output, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream) {
|
||||
int thread_num = 256 < size ? 256 : size;
|
||||
cudaDeviceProp prop;
|
||||
(void)cudaGetDeviceProperties(&prop, device_id);
|
||||
int max_blocks = prop.multiProcessorCount;
|
||||
int block_num = std::min(static_cast<int>(((size - 1) / thread_num) + 1), max_blocks);
|
||||
MvlgammaGrad<<<block_num, thread_num, 0, cuda_stream>>>(size, y_grad, x, p, output);
|
||||
return;
|
||||
}
|
||||
|
||||
template
|
||||
CUDA_LIB_EXPORT void CalMvlgammaGrad<float>(const size_t size, const float *y_grad, const float *x, const int p,
|
||||
float *output, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template
|
||||
CUDA_LIB_EXPORT void CalMvlgammaGrad<double>(const size_t size, const double *y_grad, const double *x, const int p,
|
||||
double *output, const uint32_t &device_id, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,25 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MVLGAMMA_GRAD_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MVLGAMMA_GRAD_IMPL_CUH_
|
||||
#include <math.h>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
|
||||
template <typename T>
|
||||
void CalMvlgammaGrad(const size_t size, const T *y_grad, const T *x, const int p, T *output, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MVLGAMMA_GRAD_IMPL_CUH_
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "mvlgamma_impl.cuh"
|
||||
|
||||
template <typename T>
|
||||
__global__ void Valid(bool *valid, const size_t size, const T *input, const int p) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
if (input[pos] <= (0.5 * (p - 1))) {
|
||||
*valid = false;
|
||||
return;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void Mvlgamma(const size_t size, const T *input, const int p, T *output) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
T temp = 0;
|
||||
for (int i = 1; i <= p; i++) {
|
||||
temp += lgamma(input[pos] - static_cast<T>((i - 1) * 0.5));
|
||||
}
|
||||
output[pos] = temp + static_cast<T>(p * (p - 1) * 0.25 * log(M_PI));
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalMvlgamma(bool *valid, const size_t size, const T *input, const int p, T *output, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream) {
|
||||
Valid<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>(valid, size, input, p);
|
||||
bool host_valid = true;
|
||||
cudaMemcpyFromSymbol(&host_valid, valid, sizeof(bool));
|
||||
if (!host_valid) {
|
||||
return;
|
||||
}
|
||||
Mvlgamma<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>(size, input, p, output);
|
||||
return;
|
||||
}
|
||||
|
||||
template
|
||||
CUDA_LIB_EXPORT void CalMvlgamma<float>(bool *valid, const size_t size, const float *input, const int p,
|
||||
float *output, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template
|
||||
CUDA_LIB_EXPORT void CalMvlgamma<double>(bool *valid, const size_t size, const double *input, const int p,
|
||||
double *output, const uint32_t &device_id, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,24 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MVLGAMMA_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MVLGAMMA_IMPL_CUH_
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
|
||||
template <typename T>
|
||||
void CalMvlgamma(bool *valid, const size_t size, const T *input, const int p, T *output, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_MVLGAMMA_IMPL_CUH_
|
|
@ -13,16 +13,14 @@
|
|||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
|
||||
#include <string>
|
||||
#include "ops/grad/mvlgamma_grad.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
|
@ -30,21 +28,41 @@ namespace {
|
|||
abstract::ShapePtr MvlgammaGradInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto y_grad_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
return std::make_shared<abstract::Shape>(y_grad_shape);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
|
||||
auto first_input_shape_min = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMinShape];
|
||||
auto first_input_shape_max = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMaxShape];
|
||||
auto second_input_shape_min = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kMinShape];
|
||||
auto second_input_shape_max = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kMaxShape];
|
||||
if (first_input_shape_min.empty() || first_input_shape_min.empty() || second_input_shape_min.empty() ||
|
||||
second_input_shape_max.empty()) {
|
||||
return std::make_shared<abstract::Shape>(x_shape);
|
||||
}
|
||||
ShapeVector min_shape = {first_input_shape_min[0], second_input_shape_min[0]};
|
||||
ShapeVector max_shape = {first_input_shape_max[0], second_input_shape_max[0]};
|
||||
return std::make_shared<abstract::Shape>(x_shape, min_shape, max_shape);
|
||||
}
|
||||
|
||||
TypePtr MvlgammaGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
std::map<std::string, TypePtr> types;
|
||||
(void)types.emplace("y_grad", input_args[kInputIndex0]->BuildType());
|
||||
(void)types.emplace("x", input_args[kInputIndex1]->BuildType());
|
||||
(void)types.emplace("y_grad", input_args[0]->BuildType());
|
||||
(void)types.emplace("x", input_args[1]->BuildType());
|
||||
const std::set<TypePtr> valid_types = {kFloat32, kFloat64};
|
||||
return CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(MvlgammaGrad, BaseOperator);
|
||||
void MvlgammaGrad::Init(const int64_t p) { set_p(p); }
|
||||
|
||||
void MvlgammaGrad::set_p(const int64_t p) { (void)this->AddAttr(kP, api::MakeValue(p)); }
|
||||
|
||||
int64_t MvlgammaGrad::get_p() const { return GetValue<int64_t>(GetAttr(kP)); }
|
||||
|
||||
AbstractBasePtr MvlgammaGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
@ -54,7 +72,7 @@ AbstractBasePtr MvlgammaGradInfer(const abstract::AnalysisEnginePtr &, const Pri
|
|||
auto infer_shape = MvlgammaGradInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(MvlgammaGrad, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(MvlgammaGrad, prim::kPrimMvlgammaGrad, MvlgammaGradInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -16,9 +16,14 @@
|
|||
|
||||
#ifndef MINDSPORE_CORE_OPS_MVLGAMMA_GRAD_H_
|
||||
#define MINDSPORE_CORE_OPS_MVLGAMMA_GRAD_H_
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
|
@ -29,10 +34,16 @@ class MIND_API MvlgammaGrad : public BaseOperator {
|
|||
public:
|
||||
MIND_API_BASE_MEMBER(MvlgammaGrad);
|
||||
MvlgammaGrad() : BaseOperator(kNameMvlgammaGrad) { InitIOName({"y_grad", "x"}, {"x_grad"}); }
|
||||
void Init(const int64_t p = 0);
|
||||
/// \brief Set p.
|
||||
void set_p(const int64_t p);
|
||||
int64_t get_p() const;
|
||||
// ~MvlgammaGrad() = default;
|
||||
// MS_DECLARE_PARENT(MvlgammaGrad, PrimitiveC);
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr MvlgammaGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimMvlgammaGradPtr = std::shared_ptr<MvlgammaGrad>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -14,34 +14,46 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
|
||||
#include <string>
|
||||
#include "ops/mvlgamma.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr MvlgammaInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
MS_EXCEPTION_IF_NULL(input_args[kInputIndex0]);
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->GetShapeTrack())[kShape];
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
MS_EXCEPTION_IF_NULL(input_args[0]);
|
||||
auto in_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
auto first_input_shape_min = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMinShape];
|
||||
auto first_input_shape_max = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kMaxShape];
|
||||
if (first_input_shape_min.empty() || first_input_shape_max.empty()) {
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
ShapeVector min_shape = {first_input_shape_min[0]};
|
||||
ShapeVector max_shape = {first_input_shape_max[0]};
|
||||
return std::make_shared<abstract::Shape>(in_shape, min_shape, max_shape);
|
||||
}
|
||||
|
||||
TypePtr MvlgammaInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
auto input_type = input_args[kInputIndex0]->BuildType();
|
||||
TypePtr MvlgammaInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto input_type = input_args[0]->BuildType();
|
||||
const std::set<TypePtr> valid_types = {kFloat32, kFloat64};
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("x", input_type, valid_types, prim->name());
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("x", input_type, valid_types, primitive->name());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
MIND_API_OPERATOR_IMPL(Mvlgamma, BaseOperator);
|
||||
void Mvlgamma::Init(const int64_t p) { set_p(p); }
|
||||
|
||||
void Mvlgamma::set_p(const int64_t p) { (void)this->AddAttr(kP, api::MakeValue(p)); }
|
||||
|
||||
int64_t Mvlgamma::get_p() const { return GetValue<int64_t>(GetAttr(kP)); }
|
||||
|
||||
AbstractBasePtr MvlgammaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
|
@ -52,6 +64,7 @@ AbstractBasePtr MvlgammaInfer(const abstract::AnalysisEnginePtr &, const Primiti
|
|||
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(Mvlgamma, BaseOperator);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Mvlgamma, prim::kPrimMvlgamma, MvlgammaInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,9 +16,14 @@
|
|||
|
||||
#ifndef MINDSPORE_CORE_OPS_MVLGAMMA_H_
|
||||
#define MINDSPORE_CORE_OPS_MVLGAMMA_H_
|
||||
#include <map>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
|
||||
#include "ops/primitive_c.h"
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "ops/base_operator.h"
|
||||
#include "mindapi/base/types.h"
|
||||
|
||||
|
@ -32,10 +37,15 @@ class MIND_API Mvlgamma : public BaseOperator {
|
|||
MIND_API_BASE_MEMBER(Mvlgamma);
|
||||
/// \brief Constructor.
|
||||
Mvlgamma() : BaseOperator(kNameMvlgamma) { InitIOName({"x"}, {"y"}); }
|
||||
/// \brief Init.
|
||||
void Init(const int64_t p = 0);
|
||||
/// \brief Set p.
|
||||
void set_p(const int64_t p);
|
||||
int64_t get_p() const;
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr MvlgammaInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimMvlgammaPtr = std::shared_ptr<Mvlgamma>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -157,6 +157,7 @@ constexpr auto kOutputPaddings = "output_paddings";
|
|||
constexpr auto kOutputType = "output_type";
|
||||
constexpr auto kOutIdx = "out_idx";
|
||||
constexpr auto kOutQuantized = "out_quantized";
|
||||
constexpr auto kMvlgammaP = "mvlgamma_p";
|
||||
constexpr auto kP = "p";
|
||||
constexpr auto kPad = "pad";
|
||||
constexpr auto kPadding = "padding";
|
||||
|
|
|
@ -47,7 +47,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta
|
|||
EmbeddingLookup, Unique, GatherD, Identity, Range, MaskedFill, MaskedSelect, SearchSorted,
|
||||
TensorScatterUpdate, TensorScatterMax, TensorScatterMin, TensorScatterAdd, TensorScatterSub,
|
||||
TensorScatterMul, TensorScatterDiv, ExtractVolumePatches, LowerBound,
|
||||
UpperBound, Cummax)
|
||||
UpperBound, Cummax, Mvlgamma)
|
||||
from .comm_ops import (AllGather, AllReduce, NeighborExchange, NeighborExchangeV2, AlltoAll, _AllSwap, ReduceScatter,
|
||||
Broadcast,
|
||||
_MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset,
|
||||
|
@ -275,6 +275,7 @@ __all__ = [
|
|||
'Multinomial',
|
||||
'Gamma',
|
||||
'RandomGamma',
|
||||
'Mvlgamma',
|
||||
'Poisson',
|
||||
'UniformInt',
|
||||
'UniformReal',
|
||||
|
|
|
@ -3893,7 +3893,7 @@ class Mvlgamma(Primitive):
|
|||
ValueError: If all elements of `x` are not greater than (p-1)/2.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``CPU`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.array([[3, 4, 5], [4, 2, 6]]), mindspore.float32)
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import torch
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations._grad_ops as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
|
||||
class MvlgammaGradNet(nn.Cell):
|
||||
def __init__(self, nptype, p):
|
||||
super(MvlgammaGradNet, self).__init__()
|
||||
self.mvlgamma_grad = P.MvlgammaGrad(p=p)
|
||||
self.y_grad_np = np.array([[3, 4, 5], [4, 2, 6]]).astype(nptype)
|
||||
self.y_grad = Tensor(self.y_grad_np)
|
||||
self.x_np = np.array([[3, 4, 5], [4, 2, 6]]).astype(nptype)
|
||||
self.x = Tensor(self.x_np)
|
||||
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return self.mvlgamma_grad(self.y_grad, self.x)
|
||||
|
||||
|
||||
def mvlgamma_grad_torch(y_grad_np, x_np, p):
|
||||
x_torch = torch.tensor(x_np, requires_grad=True)
|
||||
grad_torch = torch.tensor(y_grad_np)
|
||||
out_torch = torch.mvlgamma(x_torch, p=p)
|
||||
out_torch.backward(grad_torch)
|
||||
dx = x_torch.grad
|
||||
return dx.numpy()
|
||||
|
||||
|
||||
def mvlgamma_grad(nptype, p):
|
||||
mvlgamma_ = MvlgammaGradNet(nptype, p)
|
||||
mvlgamma_output = mvlgamma_().asnumpy()
|
||||
mvlgamma_expect = mvlgamma_grad_torch(mvlgamma_.y_grad_np, mvlgamma_.x_np, p).astype(nptype)
|
||||
assert np.allclose(mvlgamma_output, mvlgamma_expect, 1e-4, 1e-4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_mvlgamma_graph_float32():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for MvlgammaGrad
|
||||
Expectation: the result match to numpy
|
||||
"""
|
||||
mvlgamma_grad(np.float32, 2)
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations.array_ops as P
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
|
||||
|
||||
class MvlgammaNet(nn.Cell):
|
||||
def __init__(self, nptype, p):
|
||||
super(MvlgammaNet, self).__init__()
|
||||
self.mvlgamma = P.Mvlgamma(p=p)
|
||||
self.a_np = np.array([[3, 4, 5], [4, 2, 6]]).astype(nptype)
|
||||
self.a = Tensor(self.a_np)
|
||||
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
return self.mvlgamma(self.a)
|
||||
|
||||
|
||||
def mvlgamma_torch(a, d):
|
||||
return torch.mvlgamma(torch.tensor(a), d).numpy()
|
||||
|
||||
|
||||
def mvlgamma(nptype, p):
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
mvlgamma_ = MvlgammaNet(nptype, p)
|
||||
mvlgamma_output = mvlgamma_().asnumpy()
|
||||
mvlgamma_expect = mvlgamma_torch(mvlgamma_.a_np, p).astype(nptype)
|
||||
assert np.allclose(mvlgamma_output, mvlgamma_expect)
|
||||
|
||||
|
||||
def mvlgamma_pynative(nptype, p):
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
mvlgamma_ = MvlgammaNet(nptype, p)
|
||||
mvlgamma_output = mvlgamma_().asnumpy()
|
||||
mvlgamma_expect = mvlgamma_torch(mvlgamma_.a_np, p).astype(nptype)
|
||||
assert np.allclose(mvlgamma_output, mvlgamma_expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_mvlgamma_graph_float32():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for Mvlgamma
|
||||
Expectation: the result match to numpy
|
||||
"""
|
||||
mvlgamma(np.float32, 3)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_mvlgamma_pynative_float64():
|
||||
"""
|
||||
Feature: ALL To ALL
|
||||
Description: test cases for Mvlgamma
|
||||
Expectation: the result match to numpy
|
||||
"""
|
||||
mvlgamma_pynative(np.float64, 3)
|
Loading…
Reference in New Issue