!38371 add gpu ops logit and logitgrad

Merge pull request !38371 from 高攀/logit_logitgrad
This commit is contained in:
i-robot 2022-07-27 04:21:11 +00:00 committed by Gitee
commit 841b9df521
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
32 changed files with 1377 additions and 0 deletions

View File

@ -169,6 +169,7 @@ functional算子是经过初始化后的Primitive可以直接作为函数使
mindspore.ops.logical_and
mindspore.ops.logical_not
mindspore.ops.logical_or
mindspore.ops.logit
mindspore.ops.log_matrix_determinant
mindspore.ops.matrix_determinant
mindspore.ops.mul

View File

@ -969,6 +969,32 @@ mindspore.Tensor
- **TypeError** - `x` 不是Tensor。
- **TypeError** - `x` 的数据类型非float16或float32。
.. py:method:: logit(eps=None)
逐元素计算张量的logit值当 eps 不是 None 时, `x` 中的元素被截断到范围[eps, 1-eps]内。
当 eps 为 None 时,输入 `x` 不进行数值截断。
`x` 指的当前 Tensor。
.. math::
y_{i} = \ln(\frac{z_{i}}{1 - z_{i}}) \\
z_{i} = \begin{cases}
x_{i} & \text{if eps is None} \\
\text{eps} & \text{if } x_{i} < \text{eps} \\
x_{i} & \text{if } \text{eps} \leq x_{i} \leq 1 - \text{eps} \\
1 - \text{eps} & \text{if } x_{i} > 1 - \text{eps}
\end{cases}
参数:
- **eps** (float) - epsilon值。输入的数值界限被定义[eps, 1-eps]。 默认值None。
返回:
Tensor具有与 `x` 相同的shape。
异常:
- **TypeError** - `eps` 不是float类型。
- **TypeError** - `x` 不是Tensor类型。
- **TypeError** - `x` 的数据类型不是float16、float32或float64。
.. py:method:: log_matrix_determinant()
计算一个或多个平方矩阵行列式绝对值的对数的符号和绝对值的对数。

View File

@ -0,0 +1,17 @@
mindspore.ops.Logit
===================
.. py:class:: mindspore.ops.Logit(eps=-1.0)
逐元素计算张量的logit值。 `x` 中的元素被截断到范围[eps, 1-eps]内。
.. math::
y_{i} = \ln(\frac{z_{i}}{1 - z_{i}}) \\
z_{i} = \begin{cases}
x_{i} & \text{if eps is None} \\
\text{eps} & \text{if } x_{i} < \text{eps} \\
x_{i} & \text{if } \text{eps} \leq x_{i} \leq 1 - \text{eps} \\
1 - \text{eps} & \text{if } x_{i} > 1 - \text{eps}
\end{cases}
更多参考详见 :func:`mindspore.ops.logit`

View File

@ -0,0 +1,28 @@
mindspore.ops.logit
===================
.. py:function:: mindspore.ops.logit(x, eps=None)
逐元素计算张量的logit值当 eps 不是 None 时, `x` 中的元素被截断到范围[eps, 1-eps]内。
当 eps 为 None 时,输入 `x` 不进行数值截断。
.. math::
y_{i} = \ln(\frac{z_{i}}{1 - z_{i}}) \\
z_{i} = \begin{cases}
x_{i} & \text{if eps is None} \\
\text{eps} & \text{if } x_{i} < \text{eps} \\
x_{i} & \text{if } \text{eps} \leq x_{i} \leq 1 - \text{eps} \\
1 - \text{eps} & \text{if } x_{i} > 1 - \text{eps}
\end{cases}
参数:
- **x** (Tensor) - 张量输入。
- **eps** (float) - epsilon值。输入的数值界限被定义[eps, 1-eps]。 默认值None。
返回:
Tensor具有与 `x` 相同的shape。
异常:
- **TypeError** - `eps` 不是float类型。
- **TypeError** - `x` 不是Tensor类型。
- **TypeError** - `x` 的数据类型不是float16、float32或float64。

View File

@ -172,6 +172,7 @@ Element-by-Element Operations
mindspore.ops.logical_and
mindspore.ops.logical_not
mindspore.ops.logical_or
mindspore.ops.logit
mindspore.ops.log_matrix_determinant
mindspore.ops.matrix_determinant
mindspore.ops.mul

View File

@ -212,6 +212,7 @@ BuiltInTypeMap &GetMethodMap() {
{"inplace_update", std::string("inplace_update")}, // P.InplaceUpdate
{"lerp", std::string("lerp")}, // lerp()
{"log1p", std::string("log1p")}, // P.Log1p()
{"logit", std::string("logit")}, // Logit()
{"log_matrix_determinant", std::string("log_matrix_determinant")}, // log_matrix_determinant()
{"matrix_determinant", std::string("matrix_determinant")}, // log_matrix_determinant()
{"max", std::string("max")}, // P.reduce_max()

View File

@ -0,0 +1,107 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_LOGIT_GRAD_HELPER_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_LOGIT_GRAD_HELPER_H_
#include <memory>
#include <string>
#include <vector>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/logit_grad_impl.cuh"
namespace mindspore {
namespace cukernel {
class LogitGradAttr : public GpuKernelAttrBase {
public:
LogitGradAttr() = default;
~LogitGradAttr() override = default;
float eps;
};
template <typename T, typename S>
class LogitGradHelperGpuKernel : public GpuKernelHelperBase {
public:
explicit LogitGradHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id)
: GpuKernelHelperBase(kernel_name, device_id) {
eps_ = -1.0;
is_null_input_ = false;
}
virtual ~LogitGradHelperGpuKernel() = default;
int CalMemSize(const std::vector<std::vector<int64_t>> &input_shapes,
const std::vector<std::vector<int64_t>> &output_shapes) override {
constexpr size_t INPUT_NUM = 1;
constexpr size_t OUTPUT_NUM = 1;
ResetResource();
int inp_flag = CalShapesSizeInBytes<T>(input_shapes, INPUT_NUM, kernel_name_, "input_shapes", &input_size_list_);
if (inp_flag == -1) {
return inp_flag;
}
int out_flag =
CalShapesSizeInBytes<T>(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_);
if (out_flag != 0) {
return out_flag;
}
is_null_input_ = (inp_flag == 1 || out_flag == 1);
eps_ = attr_ptr_->eps;
return 0;
}
int Process(const std::vector<void *> &input_ptrs, const std::vector<void *> &output_ptrs,
const std::vector<void *> &work_ptrs, void *cuda_stream) override {
if (is_null_input_) {
return 0;
}
T *input_grad_ptr = nullptr;
T *input_x_ptr = nullptr;
T *output_dx_ptr = nullptr;
int flag = GetDeviceAddress<T>(input_ptrs, 0, kernel_name_, &input_grad_ptr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<T>(input_ptrs, 1, kernel_name_, &input_x_ptr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<T>(output_ptrs, 0, kernel_name_, &output_dx_ptr);
if (flag != 0) {
return flag;
}
// call cuda kernel
CalLogitGrad(input_grad_ptr, input_x_ptr, eps_, output_dx_ptr, input_size_list_[0] / sizeof(T), device_id_,
reinterpret_cast<cudaStream_t>(cuda_stream));
return 0;
}
void SetKernelParam(const GpuKernelAttrBasePtr &kernel_attr) override {
attr_ptr_ = std::dynamic_pointer_cast<LogitGradAttr>(kernel_attr);
}
private:
std::shared_ptr<LogitGradAttr> attr_ptr_;
float eps_;
bool is_null_input_;
};
} // namespace cukernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_LOGIT_GRAD_HELPER_H_

View File

@ -0,0 +1,101 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_LOGIT_HELPER_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_LOGIT_HELPER_H_
#include <memory>
#include <string>
#include <vector>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/logit_impl.cuh"
namespace mindspore {
namespace cukernel {
class LogitAttr : public GpuKernelAttrBase {
public:
LogitAttr() = default;
~LogitAttr() override = default;
float eps;
};
template <typename T, typename S>
class LogitHelperGpuKernel : public GpuKernelHelperBase {
public:
explicit LogitHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id)
: GpuKernelHelperBase(kernel_name, device_id) {
eps_ = -1.0;
is_null_input_ = false;
}
virtual ~LogitHelperGpuKernel() = default;
int CalMemSize(const std::vector<std::vector<int64_t>> &input_shapes,
const std::vector<std::vector<int64_t>> &output_shapes) override {
constexpr size_t INPUT_NUM = 1;
constexpr size_t OUTPUT_NUM = 1;
ResetResource();
int inp_flag = CalShapesSizeInBytes<T>(input_shapes, INPUT_NUM, kernel_name_, "input_shapes", &input_size_list_);
if (inp_flag == -1) {
return inp_flag;
}
int out_flag =
CalShapesSizeInBytes<T>(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_);
if (out_flag != 0) {
return out_flag;
}
is_null_input_ = (inp_flag == 1 || out_flag == 1);
eps_ = attr_ptr_->eps;
return 0;
}
int Process(const std::vector<void *> &input_ptrs, const std::vector<void *> &output_ptrs,
const std::vector<void *> &work_ptrs, void *cuda_stream) override {
if (is_null_input_) {
return 0;
}
T *input_ptr = nullptr;
T *output_ptr = nullptr;
int flag = GetDeviceAddress<T>(input_ptrs, 0, kernel_name_, &input_ptr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<T>(output_ptrs, 0, kernel_name_, &output_ptr);
if (flag != 0) {
return flag;
}
// call cuda kernel
CalLogit(input_ptr, static_cast<T>(static_cast<T>(1.0) - static_cast<T>(eps_)), eps_, output_ptr,
input_size_list_[0] / sizeof(T), device_id_, reinterpret_cast<cudaStream_t>(cuda_stream));
return 0;
}
void SetKernelParam(const GpuKernelAttrBasePtr &kernel_attr) override {
attr_ptr_ = std::dynamic_pointer_cast<LogitAttr>(kernel_attr);
}
private:
std::shared_ptr<LogitAttr> attr_ptr_;
float eps_;
T up_bound_;
bool is_null_input_;
};
} // namespace cukernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_LOGIT_HELPER_H_

View File

@ -0,0 +1,88 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "logit_grad_impl.cuh"
#include <limits>
#include "include/cuda_fp16.h"
template <typename T>
__global__ void LogitGradLessZero(const T *grad, const T *input, const float eps, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = (input[i] < T(0) || input[i] > T(1)) ? std::numeric_limits<T>::quiet_NaN()
: (grad[i] / input[i] / (T(1) - input[i]));
}
return;
}
template <>
__global__ void LogitGradLessZero(const half *grad, const half *input, const float eps, half *output,
const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = (input[i] < half(0) || input[i] > half(1))
? half(std::numeric_limits<float>::quiet_NaN())
: half(static_cast<float>(grad[i]) / static_cast<float>(input[i])
/ (static_cast<float>(1) - static_cast<float>(input[i])));
}
return;
}
template <typename T>
__global__ void LogitGradGreaterZero(const T *grad, const T *input, const float eps, T *output, const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = (input[i] < static_cast<T>(eps) || input[i] > T(1) - static_cast<T>(eps))
? T(0)
: (grad[i] / input[i] / (T(1) - input[i]));
}
return;
}
template <>
__global__ void LogitGradGreaterZero(const half *grad, const half *input, const float eps, half *output,
const size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
output[i] = (input[i] < static_cast<half>(eps) || input[i] > half(1) - static_cast<half>(eps))
? half(0)
: half(static_cast<float>(grad[i]) / static_cast<float>(input[i])
/ (static_cast<float>(1) - static_cast<float>(input[i])));
}
return;
}
template <typename T>
void CalLogitGrad(const T *grad, const T *input, const float eps, T *output, const size_t count,
const uint32_t &device_id, cudaStream_t cuda_stream) {
if (eps < 0) {
LogitGradLessZero<<<CUDA_BLOCKS(device_id, count), CUDA_THREADS(device_id), 0, cuda_stream>>>(grad, input, eps,
output, count);
} else {
LogitGradGreaterZero<<<CUDA_BLOCKS(device_id, count), CUDA_THREADS(device_id), 0, cuda_stream>>>(grad, input, eps,
output, count);
}
return;
}
template CUDA_LIB_EXPORT void CalLogitGrad<half>(const half *grad, const half *input, const float eps, half *output,
const size_t count, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalLogitGrad<float>(const float *grad, const float *input, const float eps, float *output,
const size_t count, const uint32_t &device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalLogitGrad<double>(const double *grad, const double *input, const float eps,
double *output, const size_t count, const uint32_t &device_id,
cudaStream_t cuda_stream);

View File

@ -0,0 +1,26 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LOGIT_GRAD_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LOGIT_GRAD_IMPL_CUH_
#include "include/cuda_fp16.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
template <typename T>
CUDA_LIB_EXPORT void CalLogitGrad(const T *grad, const T *input, const float eps, T *output, const size_t count,
const uint32_t &device_id, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LOGIT_GRAD_IMPL_CUH_

View File

@ -0,0 +1,89 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "logit_impl.cuh"
#include "include/cuda_fp16.h"
template <typename T>
__global__ void LogitGreaterZero(const T *input, const T up_bound, const T eps, T *output, const size_t count) {
T one = T(1);
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
T z;
T x = input[i];
z = x < eps ? eps : (x > up_bound ? up_bound : x);
output[i] = log(z / (one - z));
}
return;
}
template <>
__global__ void LogitGreaterZero(const half *input, const half up_bound, const half eps, half *output,
const size_t count) {
half one = half(1);
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
half z;
half x = input[i];
z = x < eps ? eps : (x > up_bound ? up_bound : x);
output[i] = hlog(z / (one - z));
}
return;
}
template <typename T>
__global__ void LogitLessZero(const T *input, const T up_bound, const T eps, T *output, const size_t count) {
T one = T(1);
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
T x = input[i];
output[i] = log(x / (one - x));
}
return;
}
template <>
__global__ void LogitLessZero(const half *input, const half up_bound, const half eps, half *output,
const size_t count) {
half one = half(1);
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
half x = input[i];
output[i] = hlog(x / (one - x));
}
return;
}
template <typename T>
void CalLogit(const T *input, const T up_bound, const float eps, T *output, const size_t count,
const uint32_t &device_id, cudaStream_t cuda_stream) {
T eps_value;
eps_value = T(eps);
if (eps < 0) {
LogitLessZero<<<CUDA_BLOCKS(device_id, count), CUDA_THREADS(device_id), 0, cuda_stream>>>(input, up_bound,
eps_value, output,
count);
} else {
LogitGreaterZero<<<CUDA_BLOCKS(device_id, count), CUDA_THREADS(device_id), 0, cuda_stream>>>(input, up_bound,
eps_value, output,
count);
}
return;
}
template CUDA_LIB_EXPORT void CalLogit<half>(const half *input, const half up_bound, const float eps, half *output,
const size_t count, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalLogit<float>(const float *input, const float up_bound, const float eps, float *output,
const size_t count, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalLogit<double>(const double *input, const double up_bound, const float eps,
double *output, const size_t count, const uint32_t &device_id,
cudaStream_t cuda_stream);

View File

@ -0,0 +1,26 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LOGIT_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LOGIT_IMPL_CUH_
#include "include/cuda_fp16.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
template <typename T>
CUDA_LIB_EXPORT void CalLogit(const T *input, const T up_bound, const float eps, T *output, const size_t count,
const uint32_t &device_id, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_LOGIT_IMPL_CUH_

View File

@ -0,0 +1,101 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/math/logit_gpu_kernel.h"
#include <map>
#include <memory>
#include <string>
#include <utility>
namespace mindspore {
namespace kernel {
namespace {
template <typename T, typename S>
std::unique_ptr<cukernel::GpuKernelHelperBase> CreateLogitKernelPtr(const std::string &kernel_name,
const uint32_t &device_id) {
return std::make_unique<cukernel::LogitHelperGpuKernel<T, S>>(kernel_name, device_id);
}
using LogitPtrCreatorFunc =
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
const std::vector<std::pair<KernelAttr, LogitPtrCreatorFunc>> kernel_attr = {
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
CreateLogitKernelPtr<double, float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), CreateLogitKernelPtr<float, float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), CreateLogitKernelPtr<half, float>}};
} // namespace
bool LogitGpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
std::vector<void *> input_ptrs = ConvertPtrs(inputs);
std::vector<void *> work_ptrs = ConvertPtrs(workspace);
std::vector<void *> output_ptrs = ConvertPtrs(outputs);
if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) {
return false;
}
return true;
}
bool LogitGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::make_shared<ops::Logit>(base_operator->GetPrim());
kernel_name_ = kernel_ptr->name();
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport());
if (!is_match) {
return false;
}
attr_ptr_->eps = kernel_ptr->get_eps();
helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_));
helper_ptr_->SetKernelParam(attr_ptr_);
return true;
}
int LogitGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
for (const auto &input : inputs) {
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
auto input_shape = input->GetShapeVector();
if (!IsValidShape(input_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
std::vector<std::vector<int64_t>> input_shapes;
std::vector<std::vector<int64_t>> output_shapes;
std::vector<int64_t> input_shape = inputs[0]->GetShapeVector();
std::vector<int64_t> output_shape = outputs[0]->GetShapeVector();
input_shapes.emplace_back(input_shape);
output_shapes.emplace_back(output_shape);
if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) {
return KRET_RESIZE_FAILED;
}
input_size_list_ = helper_ptr_->GetInputSizeList();
output_size_list_ = helper_ptr_->GetOutputSizeList();
return KRET_OK;
}
std::vector<KernelAttr> LogitGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, LogitPtrCreatorFunc> &item) { return item.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Logit, LogitGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,57 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_LOGIT_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_LOGIT_GPU_KERNEL_H_
#include <algorithm>
#include <functional>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "mindspore/core/ops/logit.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/logit_helper.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class LogitGpuKernelMod : public NativeGpuKernelMod {
public:
LogitGpuKernelMod() { attr_ptr_ = std::make_shared<cukernel::LogitAttr>(); }
~LogitGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
std::vector<KernelAttr> GetOpSupport() override;
private:
std::unique_ptr<cukernel::GpuKernelHelperBase> helper_ptr_{nullptr};
std::shared_ptr<cukernel::LogitAttr> attr_ptr_{nullptr};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_LOGIT_GPU_KERNEL_H_

View File

@ -0,0 +1,103 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/math/logit_grad_gpu_kernel.h"
#include <map>
#include <memory>
#include <string>
#include <utility>
namespace mindspore {
namespace kernel {
namespace {
template <typename T, typename S>
std::unique_ptr<cukernel::GpuKernelHelperBase> CreateLogitGradKernelPtr(const std::string &kernel_name,
const uint32_t &device_id) {
return std::make_unique<cukernel::LogitGradHelperGpuKernel<T, S>>(kernel_name, device_id);
}
using LogitGradPtrCreatorFunc =
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
const std::vector<std::pair<KernelAttr, LogitGradPtrCreatorFunc>> kernel_attr = {
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
CreateLogitGradKernelPtr<double, float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
CreateLogitGradKernelPtr<float, float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
CreateLogitGradKernelPtr<half, float>}};
} // namespace
bool LogitGradGpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
std::vector<void *> input_ptrs = ConvertPtrs(inputs);
std::vector<void *> work_ptrs = ConvertPtrs(workspace);
std::vector<void *> output_ptrs = ConvertPtrs(outputs);
if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) {
return false;
}
return true;
}
bool LogitGradGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::make_shared<ops::LogitGrad>(base_operator->GetPrim());
kernel_name_ = kernel_ptr->name();
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport());
if (!is_match) {
return false;
}
attr_ptr_->eps = kernel_ptr->get_eps();
helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_));
helper_ptr_->SetKernelParam(attr_ptr_);
return true;
}
int LogitGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
for (const auto &input : inputs) {
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
auto input_shape = input->GetShapeVector();
if (!IsValidShape(input_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
std::vector<std::vector<int64_t>> input_shapes;
std::vector<std::vector<int64_t>> output_shapes;
std::vector<int64_t> input_shape = inputs[0]->GetShapeVector();
std::vector<int64_t> output_shape = outputs[0]->GetShapeVector();
input_shapes.emplace_back(input_shape);
output_shapes.emplace_back(output_shape);
if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) {
return KRET_RESIZE_FAILED;
}
input_size_list_ = helper_ptr_->GetInputSizeList();
output_size_list_ = helper_ptr_->GetOutputSizeList();
return KRET_OK;
}
std::vector<KernelAttr> LogitGradGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, LogitGradPtrCreatorFunc> &item) { return item.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, LogitGrad, LogitGradGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,57 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_LOGIT_GRAD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_LOGIT_GRAD_GPU_KERNEL_H_
#include <algorithm>
#include <functional>
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "mindspore/core/ops/grad/logit_grad.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/logit_grad_helper.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class LogitGradGpuKernelMod : public NativeGpuKernelMod {
public:
LogitGradGpuKernelMod() { attr_ptr_ = std::make_shared<cukernel::LogitGradAttr>(); }
~LogitGradGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
std::vector<KernelAttr> GetOpSupport() override;
private:
std::unique_ptr<cukernel::GpuKernelHelperBase> helper_ptr_{nullptr};
std::shared_ptr<cukernel::LogitGradAttr> attr_ptr_{nullptr};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_LOGIT_GRAD_GPU_KERNEL_H_

View File

@ -76,6 +76,8 @@ constexpr auto kReciprocal = "Reciprocal";
constexpr auto kInv = "Inv";
constexpr auto kReduceStd = "ReduceStd";
constexpr auto kLog = "Log";
constexpr auto kLogit = "Logit";
constexpr auto kLogitGrad = "LogitGrad";
constexpr auto kLogicalXor = "LogicalXor";
constexpr auto kSelect = "Select";
constexpr auto kAdd = "Add";
@ -1023,6 +1025,8 @@ GVAR_DEF(PrimitivePtr, kPrimRound, std::make_shared<Primitive>("Round"));
GVAR_DEF(PrimitivePtr, kPrimExp, std::make_shared<Primitive>(kExp));
GVAR_DEF(PrimitivePtr, kPrimExpm1, std::make_shared<Primitive>("Expm1"));
GVAR_DEF(PrimitivePtr, kPrimLog, std::make_shared<Primitive>(kLog));
GVAR_DEF(PrimitivePtr, kPrimLogit, std::make_shared<Primitive>(kLogit));
GVAR_DEF(PrimitivePtr, kPrimLogitGrad, std::make_shared<Primitive>(kLogitGrad));
GVAR_DEF(PrimitivePtr, kPrimRsqrt, std::make_shared<Primitive>("Rsqrt"));
GVAR_DEF(PrimitivePtr, kPrimRsqrtGrad, std::make_shared<Primitive>("RsqrtGrad"));
GVAR_DEF(PrimitivePtr, kPrimLinSpace, std::make_shared<Primitive>("LinSpace"));

View File

@ -0,0 +1,75 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/grad/logit_grad.h"
#include <algorithm>
#include <map>
#include <set>
#include <string>
#include <vector>
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr LogitGradInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
auto x = input_args[kInputIndex0]->BuildShape();
MS_EXCEPTION_IF_NULL(x);
auto shape_element = x->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_element);
return shape_element;
}
TypePtr LogitGradInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
auto x_type = input_args[kInputIndex0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name);
return input_args[kInputIndex0]->BuildType();
}
} // namespace
void LogitGrad::Init(const float eps) { set_eps(eps); }
void LogitGrad::set_eps(const float eps) { (void)this->AddAttr(kEps, api::MakeValue(eps)); }
float LogitGrad::get_eps() const {
auto value_ptr = GetAttr(kEps);
return GetValue<float>(value_ptr);
}
AbstractBasePtr LogitGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const size_t input_num = 2;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
auto types = LogitGradInferType(primitive, input_args);
auto shapes = LogitGradInferShape(primitive, input_args);
return abstract::MakeAbstract(shapes, types);
}
MIND_API_OPERATOR_IMPL(LogitGrad, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(LogitGrad, prim::kPrimLogitGrad, LogitGradInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,43 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_LOGIT_GRAD_H_
#define MINDSPORE_CORE_OPS_LOGIT_GRAD_H_
#include <vector>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameLogitGrad = "LogitGrad";
class MIND_API LogitGrad : public BaseOperator {
public:
MIND_API_BASE_MEMBER(LogitGrad);
LogitGrad() : BaseOperator(kNameLogitGrad) { InitIOName({"grad", "input"}, {"dx"}); }
void Init(const float eps = -1.0);
void set_eps(const float eps);
float get_eps() const;
};
abstract::AbstractBasePtr LogitGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_LOGIT_GRAD_H_

View File

@ -0,0 +1,75 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/logit.h"
#include <algorithm>
#include <map>
#include <set>
#include <string>
#include <vector>
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr LogitInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
auto x = input_args[kInputIndex0]->BuildShape();
MS_EXCEPTION_IF_NULL(x);
auto shape_element = x->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_element);
return shape_element;
}
TypePtr LogitInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
auto x_type = input_args[kInputIndex0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name);
return input_args[kInputIndex0]->BuildType();
}
} // namespace
void Logit::Init(const float eps) { set_eps(eps); }
void Logit::set_eps(const float eps) { (void)this->AddAttr(kEps, api::MakeValue(eps)); }
float Logit::get_eps() const {
auto value_ptr = GetAttr(kEps);
return GetValue<float>(value_ptr);
}
AbstractBasePtr LogitInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const size_t input_num = 1;
(void)CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
auto types = LogitInferType(primitive, input_args);
auto shapes = LogitInferShape(primitive, input_args);
return abstract::MakeAbstract(shapes, types);
}
MIND_API_OPERATOR_IMPL(Logit, BaseOperator);
REGISTER_PRIMITIVE_EVAL_IMPL(Logit, prim::kPrimLogit, LogitInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,50 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_LOGIT_H_
#define MINDSPORE_CORE_OPS_LOGIT_H_
#include <vector>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameLogit = "Logit";
/// \brief Returns logit of a tensor element-wise.
/// Refer to Python API @ref mindspore.ops.Logit for more details.
class MIND_API Logit : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Logit);
/// \brief Constructor.
Logit() : BaseOperator(kNameLogit) { InitIOName({"x"}, {"y"}); }
/// \brief Init. Refer to the parameters of Python API @ref mindspore.ops.Logit for the inputs.
void Init(const float eps = -1.0);
/// \brief Set epsilon.
void set_eps(const float eps);
/// \brief Get epsilon.
///
/// \return epsilon.
float get_eps() const;
};
abstract::AbstractBasePtr LogitInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_LOGIT_H_

View File

@ -2221,6 +2221,29 @@ def log1p(x):
return F.log1p(x)
def logit(x, eps=None):
r"""
Calculate the logit of a tensor element-wise. When eps is not None, element in 'x' is clamped to [eps, 1-eps].
When eps is None, input 'x' is not clamped.
`x` refer to self tensor.
.. math::
y_{i} = \ln(\frac{z_{i}}{1 - z_{i}}) \\
z_{i} = \begin{cases}
x_{i} &amp; \text{if eps is None} \\
\text{eps} &amp; \text{if } x_{i} &lt; \text{eps} \\
x_{i} &amp; \text{if } \text{eps} \leq x_{i} \leq 1 - \text{eps} \\
1 - \text{eps} &amp; \text{if } x_{i} &gt; 1 - \text{eps}
\end{cases}
"""
if eps is None:
eps = -1.0
check_value_type('eps', eps, (float,), 'Tensor.logit')
return F.logit(x, eps)
def log_matrix_determinant(x):
"""Computes the sign and the log of the absolute value of the determinant of one or more square matrices."""
return F.log_matrix_determinant(x)

View File

@ -1341,6 +1341,49 @@ class Tensor(Tensor_):
self._init_check()
return tensor_operator_registry.get('log1p')(self)
def logit(self, eps=None):
r"""
Calculate the logit of a tensor element-wise. When eps is not None, element in 'x' is clamped to [eps, 1-eps].
When eps is None, input 'x' is not clamped.
`x` refer to self tensor.
.. math::
y_{i} = \ln(\frac{z_{i}}{1 - z_{i}}) \\
z_{i} = \begin{cases}
x_{i} &amp; \text{if eps is None} \\
\text{eps} &amp; \text{if } x_{i} &lt; \text{eps} \\
x_{i} &amp; \text{if } \text{eps} \leq x_{i} \leq 1 - \text{eps} \\
1 - \text{eps} &amp; \text{if } x_{i} &gt; 1 - \text{eps}
\end{cases}
Args:
eps (float, optional): The epsilon. The input clamp bound is defined as [eps, 1-eps]. Default: None.
Returns:
Tensor, with the same shape as the `x`.
Raises:
TypeError: If `eps` is not a float.
TypeError: If `x` is not a Tensor.
TypeError: If dtype of `x` is not float16, float32 or float64.
Supported Platforms:
``GPU``
Examples:
>>> x = Tensor(np.array([0.1, 0.2, 0.3]).astype(np.float32))
>>> output = x.logit(eps=1e-5)
>>> print(output)
[-2.1972246 -1.3862944 -0.8472978]
"""
self._init_check()
if eps is None:
eps = -1.0
validator.check_value_type('eps', eps, (float,), 'Tensor.logit')
return tensor_operator_registry.get('logit')(self, eps)
def log_matrix_determinant(self):
r"""
Computes the sign and the log of the absolute value of the determinant of one or more square matrices.

View File

@ -40,6 +40,7 @@ from ..operations.math_ops import MatrixSolve
from ..operations.math_ops import Betainc
from ..operations.math_ops import CholeskySolve
from ..operations.math_ops import AddV2
from ..operations.math_ops import Logit
transpose = P.Transpose()
@ -62,6 +63,18 @@ def get_bprop_acos(self):
return bprop
@bprop_getters.register(Logit)
def get_bprop_logit(self):
"""Grad definition for `Logit` operation."""
logitgrad = G.LogitGrad(self.eps)
def bprop(x, out, dout):
dx = logitgrad(dout, x)
return (dx,)
return bprop
@bprop_getters.register(P.Cdist)
def get_bprop_cdist(self):
"""Generate bprop for Cdist"""

View File

@ -785,6 +785,7 @@ get_unop_vmap_rule = vmap_rules_getters.register(P.Exp)(get_unop_vmap_rule)
get_unop_vmap_rule = vmap_rules_getters.register(P.Expm1)(get_unop_vmap_rule)
get_unop_vmap_rule = vmap_rules_getters.register(P.Floor)(get_unop_vmap_rule)
get_unop_vmap_rule = vmap_rules_getters.register(P.Log)(get_unop_vmap_rule)
get_unop_vmap_rule = vmap_rules_getters.register(math_ops.Logit)(get_unop_vmap_rule)
get_unop_vmap_rule = vmap_rules_getters.register(P.Log1p)(get_unop_vmap_rule)
get_unop_vmap_rule = vmap_rules_getters.register(P.LogicalNot)(get_unop_vmap_rule)
get_unop_vmap_rule = vmap_rules_getters.register(P.Mish)(get_unop_vmap_rule)
@ -817,3 +818,4 @@ get_unop_vmap_rule = vmap_rules_getters.register(BesselK1)(get_unop_vmap_rule)
get_unop_vmap_rule = vmap_rules_getters.register(BesselK1e)(get_unop_vmap_rule)
# UnaryGrad vmap
get_unary_grad_vmap_rule = vmap_rules_getters.register(G.InvGrad)(get_unary_grad_vmap_rule)
get_unary_grad_vmap_rule = vmap_rules_getters.register(G.LogitGrad)(get_unary_grad_vmap_rule)

View File

@ -176,6 +176,7 @@ from .math_func import (
maximum,
logaddexp,
logaddexp2,
logit,
std,
ldexp,
mv,

View File

@ -27,6 +27,7 @@ from mindspore.ops import composite as C
from mindspore.ops.operations._inner_ops import Cummin
from mindspore.ops.operations.math_ops import STFT
from mindspore.ops.operations.math_ops import ReduceStd
from mindspore.ops.operations.math_ops import Logit
from mindspore.nn import layer
from mindspore._checkparam import check_is_number
from ..operations.math_ops import (
@ -2345,6 +2346,47 @@ def ldexp(x, other):
return out
def logit(x, eps=None):
r"""
Calculate the logit of a tensor element-wise. When eps is not None, element in 'x' is clamped to [eps, 1-eps].
When eps is None, input 'x' is not clamped.
.. math::
y_{i} = \ln(\frac{z_{i}}{1 - z_{i}}) \\
z_{i} = \begin{cases}
x_{i} &amp; \text{if eps is None} \\
\text{eps} &amp; \text{if } x_{i} &lt; \text{eps} \\
x_{i} &amp; \text{if } \text{eps} \leq x_{i} \leq 1 - \text{eps} \\
1 - \text{eps} &amp; \text{if } x_{i} &gt; 1 - \text{eps}
\end{cases}
Args:
x (Tensor): The input tensor.
eps (float, optional): The epsilon. The input clamp bound is defined as [eps, 1-eps]. Default: None.
Returns:
Tensor, with the same shape as the `x`.
Raises:
TypeError: If `eps` is not a float.
TypeError: If `x` is not a Tensor.
TypeError: If dtype of `x` is not float16, float32 or float64.
Supported Platforms:
``GPU``
Examples:
>>> x = Tensor(np.array([0.1, 0.2, 0.3]).astype(np.float32))
>>> output = ops.logit(x, eps=1e-5)
>>> print(output)
[-2.1972246 -1.3862944 -0.8472978]
"""
if eps is None:
eps = -1.0
logit_ = _get_cache_prim(Logit)(eps)
return logit_(x)
#####################################
# Comparison Operation Functions.
#####################################
@ -5103,6 +5145,7 @@ __all__ = [
'logical_not',
'logical_or',
'logical_and',
'logit',
'logsumexp',
'ldexp',
'sin',

View File

@ -803,6 +803,7 @@ tensor_operator_registry.register('ceil', P.Ceil)
tensor_operator_registry.register('fill', P.Fill)
tensor_operator_registry.register('tile', P.Tile)
tensor_operator_registry.register('logical_not', P.LogicalNot)
tensor_operator_registry.register('logit', logit)
tensor_operator_registry.register('sum', P.ReduceSum)
tensor_operator_registry.register('split', P.Split)
tensor_operator_registry.register('select', P.Select)

View File

@ -51,6 +51,21 @@ class ACosGrad(Primitive):
self.init_prim_io_names(inputs=['y', 'dy'], outputs=['z'])
class LogitGrad(Primitive):
"""
Computes LogitGrad of input element-wise.
Returns:
Tensor, has the same type as input.
"""
@prim_attr_register
def __init__(self, eps=-1.0):
"""Initialize Exp"""
self.init_prim_io_names(inputs=['grad', 'input'], outputs=['dx'])
validator.check_value_type("eps", eps, [float], self.name)
self.add_prim_attr('eps', eps)
class AcoshGrad(Primitive):
"""Performs grad of Acosh operation."""

View File

@ -2322,6 +2322,40 @@ class Exp(Primitive):
self.add_prim_attr("shift", 0.0)
class Logit(Primitive):
r"""
Calculate the logit of a tensor element-wise. Element in `x` is clamped to [eps, 1-eps].
.. math::
y_{i} = \ln(\frac{z_{i}}{1 - z_{i}}) \\
z_{i} = \begin{cases}
x_{i} &amp; \text{if eps is None} \\
\text{eps} &amp; \text{if } x_{i} &lt; \text{eps} \\
x_{i} &amp; \text{if } \text{eps} \leq x_{i} \leq 1 - \text{eps} \\
1 - \text{eps} &amp; \text{if } x_{i} &gt; 1 - \text{eps}
\end{cases}
Refer to :func:`mindspore.ops.logit` for more detail.
Supported Platforms:
``GPU``
Examples:
>>> x = Tensor(np.array([0.1, 0.2, 0.3]).astype(np.float32))
>>> op = ops.Logit(eps=1e-5)
>>> output = op(x)
>>> print(output)
[-2.1972246 -1.3862944 -0.8472978]
"""
@prim_attr_register
def __init__(self, eps=-1.0):
"""Initialize Exp"""
self.add_prim_attr("eps", eps)
validator.check_value_type("eps", eps, [float], self.name)
self.init_prim_io_names(inputs=['x'], outputs=['y'])
class ReduceStd(Primitive):
"""
Returns the standard-deviation and mean of each row of the input tensor in the dimension `axis`.

View File

@ -0,0 +1,71 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops.operations import _grad_ops as G
class Net(nn.Cell):
def __init__(self, eps=-1.0):
super(Net, self).__init__()
self.grad = G.LogitGrad(eps)
def construct(self, dy, x):
return self.grad(dy, x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_logit_grad_graph_float32():
"""
Feature: LogitGrad gpu TEST.
Description: 1d test case for LogitGrad with GRAPH_MODE
Expectation: The value and shape of output are the expected values.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x = Tensor(np.array([0.1, 0.2, 0.3]).astype(np.float32))
dy = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float32))
expect = np.array([11.11111164, 6.24999952, 4.76190472]).astype(np.float32)
net = Net()
output = net(dy, x)
diff = output.asnumpy() - expect
error = np.ones(shape=expect.shape) * 1e-4
assert np.all(diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_logit_grad_pynative_float32():
"""
Feature: LogitGrad gpu TEST.
Description: 1d test case for LogitGrad with PYNATIVE_MODE
Expectation: The value and shape of output are the expected values.
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
x = Tensor(np.array([0.1, 0.2, 0.3]).astype(np.float32))
dy = Tensor(np.array([1.0, 1.0, 1.0]).astype(np.float32))
expect = np.array([11.11111164, 6.24999952, 4.76190472]).astype(np.float32)
logitgrad = G.LogitGrad()
output = logitgrad(dy, x)
diff = output.asnumpy() - expect
error = np.ones(shape=expect.shape) * 1e-4
assert np.all(diff < error)

View File

@ -0,0 +1,55 @@
import numpy as np
import pytest
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
from mindspore import ops
class Net(nn.Cell):
def __init__(self):
super(Net, self).__init__()
self.logit = ops.logit
def construct(self, x, eps):
return self.logit(x, eps)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_logit_graph():
"""
Feature: Logit gpu TEST.
Description: 1d test case for Logit with GRAPH_MODE
Expectation: The value and shape of output are the expected values.
"""
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
x = Tensor(np.array([0.1, 0.2, 0.3]).astype(np.float32))
eps = 1e-5
net = Net()
output = net(x, eps)
expect = np.array([-2.19722462, -1.38629436, -0.84729779]).astype(np.float32)
diff = output.asnumpy() - expect
error = np.ones(shape=expect.shape) * 1e-4
assert np.all(diff < error)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_logit_pynative():
"""
Feature: Logit gpu TEST.
Description: 1d test case for Logit with PYNATIVE_MODE
Expectation: The value and shape of output are the expected values.
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
x = Tensor(np.array([0.1, 0.2, 0.3]).astype(np.float32))
eps = 1e-5
logit = ops.logit
output = logit(x, eps)
expect = np.array([-2.19722462, -1.38629436, -0.84729779]).astype(np.float32)
diff = output.asnumpy() - expect
error = np.ones(shape=expect.shape) * 1e-4
assert np.all(diff < error)