!40602 [feat] [assistant] [ops] [I5EWJ9] New GPU operator implementation, include Sinc

Merge pull request !40602 from TBD1/Sinc
This commit is contained in:
i-robot 2022-11-17 12:10:17 +00:00 committed by Gitee
commit 8af9d7c038
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 663 additions and 113 deletions

View File

@ -0,0 +1,241 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <math.h>
#include <complex>
#include "plugin/device/cpu/kernel/nnacl/op_base.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sinc_impl.cuh"
constexpr uint elements_per_thread = 4;
constexpr uint threads_per_block = 256;
constexpr uint elements_per_block = elements_per_thread * threads_per_block;
template <typename T, typename S>
struct VectorizedTrait {
static const uint VecSizeT = 4;
static const uint VecSizeS = 4;
};
template <>
struct VectorizedTrait<half, half> {
static const uint VecSizeT = 2;
static const uint VecSizeS = 2;
};
template <typename T, int VecSizeT>
struct alignas(sizeof(T) * VecSizeT) AlignVecIn {
T datain[VecSizeT];
};
template <typename S, int VecSizeS>
struct alignas(sizeof(S) * VecSizeS) AlignVecOut {
S dataout[VecSizeS];
};
template <typename Func, typename T, typename S>
__device__ __forceinline__ void VectorizedCall(Func func, const T *input, S *output) {
constexpr uint vec_size_t = VectorizedTrait<T, S>::VecSizeT;
constexpr uint vec_size_s = VectorizedTrait<T, S>::VecSizeS;
constexpr uint elements_per_loop = elements_per_thread / vec_size_t;
using VecT = AlignVecIn<T, vec_size_t>;
using VecS = AlignVecOut<S, vec_size_s>;
uint tid = threadIdx.x;
auto vec_input = reinterpret_cast<const VecT *>(input);
auto vec_output = reinterpret_cast<VecS *>(output);
for (uint i = 0; i < elements_per_loop; i++) {
uint index = tid + i * threads_per_block;
VecT cache_in = vec_input[index];
VecS cache_out = vec_output[index];
for (uint j = 0; j < vec_size_t; j++) {
cache_out.dataout[j] = func(cache_in.datain[j]);
}
vec_output[index] = cache_out;
}
}
template <typename Func, typename T, typename S>
__device__ __forceinline__ void NormalCall(Func func, const T *input, S *output, uint remaining) {
uint loop = UP_DIV(remaining, elements_per_thread);
for (uint i = threadIdx.x; i < loop; i += blockDim.x) {
for (uint j = 0; j < elements_per_thread; j++) {
uint index = i * elements_per_thread + j;
if (index >= remaining) {
return;
}
output[index] = func(input[index]);
}
}
}
template <typename Func, typename T, typename S>
__global__ void SincVectorized(Func func, const T *input, S *output, uint num_of_elements) {
uint offset = elements_per_block * blockIdx.x;
uint remaining = num_of_elements - offset;
if (blockIdx.x + 1 == gridDim.x && remaining != elements_per_block) {
NormalCall(func, input + offset, output + offset, remaining);
} else {
VectorizedCall(func, input + offset, output + offset);
}
}
template <typename T, typename S>
struct SincFunctor {
__device__ __forceinline__ S operator()(const T input) const {
const double PI = acos(-1.0);
const double zero = static_cast<double>(0.0);
const double one = static_cast<double>(1.0);
double output = zero;
if (static_cast<double>(input) == zero) {
output = one;
} else {
double temp = PI * static_cast<double>(input);
output = sinf(temp) / temp;
}
return static_cast<S>(output);
}
};
template <>
struct SincFunctor <half, half> {
__device__ __forceinline__ half operator()(const half input) const {
const float PI = acos(-1.0);
const float zero = static_cast<float>(0);
const float one = static_cast<float>(1);
float output = zero;
if (__half2float(input) == zero) {
output = one;
} else {
float temp = PI * static_cast<float>(__half2float(input));
output = sinf(temp) / temp;
}
return __float2half(static_cast<float>(output));
}
};
template <>
struct SincFunctor <Complex<float>, Complex<float>> {
__device__ __forceinline__ Complex<float> operator()(const Complex<float> input) const {
const float PI = acos(-1.0);
const float zero = static_cast<float>(0);
float a = input.real();
float b = input.imag();
Complex<float> result;
if (a == zero && b == zero) {
result.real(1.0);
result.imag(0.0);
} else {
float tmp_a = PI * a;
float tmp_b = PI * b;
float A = sinf(tmp_a) * coshf(tmp_b);
float B = cosf(tmp_a) * sinhf(tmp_b);
float T = tmp_a * tmp_a + tmp_b * tmp_b;
float rs_real = (A * tmp_a + B * tmp_b) / T;
float rs_imag = (B * tmp_a - A * tmp_b) / T;
result.real(rs_real);
result.imag(rs_imag);
}
return result;
}
};
template <>
struct SincFunctor <Complex<double>, Complex<double>> {
__device__ __forceinline__ Complex<double> operator()(const Complex<double> input) const {
const double PI = acos(-1.0);
const double zero = static_cast<double>(0);
double a = input.real();
double b = input.imag();
Complex<double> result;
if (a == zero && b == zero) {
result.real(1.0);
result.imag(0.0);
} else {
double tmp_a = PI * a;
double tmp_b = PI * b;
double A = sinf(tmp_a) * coshf(tmp_b);
double B = cosf(tmp_a) * sinhf(tmp_b);
double T = tmp_a * tmp_a + tmp_b * tmp_b;
double rs_real = (A * tmp_a + B * tmp_b) / T;
double rs_imag = (B * tmp_a - A * tmp_b) / T;
result.real(rs_real);
result.imag(rs_imag);
}
return result;
}
};
template <typename T, typename S>
void CalSinc(const size_t size, const T *input, S *output, const uint32_t &device_id,
cudaStream_t cuda_stream) {
SincFunctor<T, S> functor{};
auto block_x = threads_per_block;
auto grid_x = UP_DIV(static_cast<uint>(size), elements_per_block);
dim3 block{block_x};
dim3 grid{grid_x};
SincVectorized<<<grid, block, 0, cuda_stream>>>(functor, input, output, size);
}
template
CUDA_LIB_EXPORT void CalSinc<uint8_t, float>(const size_t size, const uint8_t *input, float *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void CalSinc<int8_t, float>(const size_t size, const int8_t *input, float *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void CalSinc<uint16_t, float>(const size_t size, const uint16_t *input, float *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void CalSinc<int16_t, float>(const size_t size, const int16_t *input, float *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void CalSinc<uint32_t, float>(const size_t size, const uint32_t *input, float *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void CalSinc<int32_t, float>(const size_t size, const int32_t *input, float *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void CalSinc<uint64_t, float>(const size_t size, const uint64_t *input, float *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void CalSinc<int64_t, float>(const size_t size, const int64_t *input, float *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void CalSinc<bool, float>(const size_t size, const bool *input, float *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void CalSinc<half>(const size_t size, const half *input, half *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void CalSinc<float>(const size_t size, const float *input, float *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void CalSinc<double>(const size_t size, const double *input, double *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void CalSinc<Complex<float>>(const size_t size, const Complex<float> *input,
Complex<float> *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void CalSinc<Complex<double>>(const size_t size, const Complex<double> *input,
Complex<double> *output, const uint32_t &device_id,
cudaStream_t cuda_stream);

View File

@ -0,0 +1,25 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SINC_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SINC_IMPL_CUH_
#include <vector>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
template <typename T, typename S>
CUDA_LIB_EXPORT void CalSinc(const size_t size, const T *input, S *output, const uint32_t &device_id,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SINC_IMPL_CUH_

View File

@ -0,0 +1,127 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/math/sinc_gpu_kernel.h"
#include "mindspore/core/ops/sinc.h"
#include <functional>
#include <utility>
#include <map>
#include <string>
#include <complex>
#include <vector>
#include <algorithm>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
namespace mindspore {
namespace kernel {
bool SincGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr_ = std::dynamic_pointer_cast<ops::Sinc>(base_operator);
kernel_name_ = kernel_ptr_->name();
if (inputs.empty() || outputs.empty()) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
return false;
}
if (!kernel_ptr_) {
MS_LOG(ERROR) << "cast Sinc ops failed!";
return false;
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
return false;
}
kernel_func_ = func_list_[index].second;
unit_input_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
unit_output_size_ = abstract::TypeIdSize(kernel_attr.GetOutputAttr(kIndex0).first);
return true;
}
int SincGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &others) {
for (const auto &input : inputs) {
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
auto input_shape = input->GetShapeVector();
if (!IsValidShape(input_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
ResetResource();
std::vector<int64_t> output_shape = std::vector<int64_t>(outputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(),
outputs.at(kIndex0)->GetDeviceShapeAdaptively().end());
output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int64_t>());
if (output_elements_ == 0) {
is_null_input_ = true;
}
size_t input_size = output_elements_ * unit_input_size_;
size_t output_size = output_elements_ * unit_output_size_;
input_size_list_.emplace_back(input_size);
output_size_list_.emplace_back(output_size);
return KRET_OK;
}
template <typename T, typename S>
bool SincGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
T *input = GetDeviceAddress<T>(inputs, 0);
S *output = GetDeviceAddress<S>(outputs, 0);
CalSinc(output_elements_, input, output, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream_));
return true;
}
template <typename T>
using Complex = mindspore::utils::Complex<T>;
std::vector<std::pair<KernelAttr, SincGpuKernelMod::SincFunc>> SincGpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat32),
&SincGpuKernelMod::LaunchKernel<uint8_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat32),
&SincGpuKernelMod::LaunchKernel<int8_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat32),
&SincGpuKernelMod::LaunchKernel<uint16_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat32),
&SincGpuKernelMod::LaunchKernel<int16_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32),
&SincGpuKernelMod::LaunchKernel<uint32_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
&SincGpuKernelMod::LaunchKernel<int32_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat32),
&SincGpuKernelMod::LaunchKernel<uint64_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
&SincGpuKernelMod::LaunchKernel<int64_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeFloat32),
&SincGpuKernelMod::LaunchKernel<bool, float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&SincGpuKernelMod::LaunchKernel<half, half>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&SincGpuKernelMod::LaunchKernel<float, float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&SincGpuKernelMod::LaunchKernel<double, double>},
{KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
&SincGpuKernelMod::LaunchKernel<Complex<float>, Complex<float>>},
{KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
&SincGpuKernelMod::LaunchKernel<Complex<double>, Complex<double>>}};
std::vector<KernelAttr> SincGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, SincFunc> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Sinc, SincGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,84 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_SINC_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_SINC_GPU_KERNEL_H_
#include <algorithm>
#include <complex>
#include <map>
#include <string>
#include <utility>
#include <vector>
#include "mindspore/core/ops/sinc.h"
#include "abstract/utils.h"
#include "plugin/factory/ms_factory.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sinc_impl.cuh"
namespace mindspore {
namespace kernel {
class SincGpuKernelMod : public NativeGpuKernelMod {
public:
SincGpuKernelMod() { ResetResource(); }
~SincGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
if (is_null_input_) {
return true;
}
cuda_stream_ = cuda_stream;
return kernel_func_(this, inputs, workspace, outputs);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
protected:
void ResetResource() noexcept {
output_elements_ = 0;
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
}
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T, typename S>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
using SincFunc =
std::function<bool(SincGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
private:
size_t output_elements_;
size_t unit_input_size_{1};
size_t unit_output_size_{1};
SincFunc kernel_func_{};
bool is_null_input_{false};
void *cuda_stream_{nullptr};
static std::vector<std::pair<KernelAttr, SincFunc>> func_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_SINC_GPU_KERNEL_H_

View File

@ -110,6 +110,7 @@ constexpr auto kMatrixSolve = "MatrixSolve";
constexpr auto kMatrixPower = "MatrixPower";
constexpr auto kMatrixDeterminant = "MatrixDeterminant";
constexpr auto kLogMatrixDeterminant = "LogMatrixDeterminant";
constexpr auto kSinc = "Sinc";
constexpr auto kCos = "Cos";
constexpr auto kAsinh = "Asinh";
constexpr auto kAsinhGrad = "AsinhGrad";

View File

@ -1,72 +1,72 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/sinc.h"
#include <set>
#include <map>
#include <string>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr SincInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
auto x = input_args[0]->BuildShape();
MS_EXCEPTION_IF_NULL(x);
auto shape_element = x->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_element);
return shape_element;
}
TypePtr SincInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto x_dtype = input_args[0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_dtype, common_valid_types_with_complex_and_bool,
prim->name());
auto tensor_type = x_dtype->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
TypePtr type = tensor_type->element();
MS_EXCEPTION_IF_NULL(type);
auto type_id = type->type_id();
const std::set<TypeId> valid_types = {kNumberTypeUInt8, kNumberTypeInt8, kNumberTypeUInt16,
kNumberTypeInt16, kNumberTypeUInt32, kNumberTypeInt32,
kNumberTypeUInt64, kNumberTypeInt64, kNumberTypeBool};
if (valid_types.count(type_id) > 0) {
return std::make_shared<TensorType>(kFloat32);
} else {
return x_dtype;
}
}
} // namespace
MIND_API_OPERATOR_IMPL(Sinc, BaseOperator);
AbstractBasePtr SincInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto infer_type = SincInferType(primitive, input_args);
auto infer_shape = SincInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(Sinc, prim::kPrimSinc, SincInfer, nullptr, true);
} // namespace ops
} // namespace mindspore
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/sinc.h"
#include <set>
#include <map>
#include <string>
#include <vector>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr SincInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto prim_name = primitive->name();
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
auto x = input_args[0]->BuildShape();
MS_EXCEPTION_IF_NULL(x);
auto shape_element = x->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_element);
return shape_element;
}
TypePtr SincInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto x_dtype = input_args[0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_dtype, common_valid_types_with_complex_and_bool,
prim->name());
auto tensor_type = x_dtype->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_type);
TypePtr type = tensor_type->element();
MS_EXCEPTION_IF_NULL(type);
auto type_id = type->type_id();
const std::set<TypeId> valid_types = {kNumberTypeUInt8, kNumberTypeInt8, kNumberTypeUInt16,
kNumberTypeInt16, kNumberTypeUInt32, kNumberTypeInt32,
kNumberTypeUInt64, kNumberTypeInt64, kNumberTypeBool};
if (valid_types.count(type_id) > 0) {
return std::make_shared<TensorType>(kFloat32);
} else {
return x_dtype;
}
}
} // namespace
MIND_API_OPERATOR_IMPL(Sinc, BaseOperator);
AbstractBasePtr SincInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, primitive->name());
auto infer_type = SincInferType(primitive, input_args);
auto infer_shape = SincInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(Sinc, prim::kPrimSinc, SincInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -1,40 +1,40 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SINC_H_
#define MINDSPORE_CORE_OPS_SINC_H_
#include <vector>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameSinc = "Sinc";
class MIND_API Sinc : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Sinc);
Sinc() : BaseOperator(kNameSinc) { InitIOName({"x"}, {"y"}); }
void Init() const {}
};
abstract::AbstractBasePtr SincInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using kPrimSincPtr = std::shared_ptr<Sinc>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SINH_H_
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SINC_H_
#define MINDSPORE_CORE_OPS_SINC_H_
#include <vector>
#include <memory>
#include "ops/base_operator.h"
#include "mindapi/base/types.h"
namespace mindspore {
namespace ops {
constexpr auto kNameSinc = "Sinc";
class MIND_API Sinc : public BaseOperator {
public:
MIND_API_BASE_MEMBER(Sinc);
Sinc() : BaseOperator(kNameSinc) { InitIOName({"x"}, {"y"}); }
void Init() {}
};
abstract::AbstractBasePtr SincInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
using kPrimSincPtr = std::shared_ptr<Sinc>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SINH_H_

View File

@ -3721,9 +3721,13 @@ class Sinc(Primitive):
TypeError: If `x` is not a Tensor.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> import mindspore
>>> import numpy as np
>>> import mindspore.ops.operations.math_ops as ops
>>> from mindspore import Tensor, dtype
>>> sinc = ops.Sinc()
>>> x = Tensor(np.array([0.62, 0.28, 0.43, 0.62]), mindspore.float32)
>>> output = sinc(x)

View File

@ -0,0 +1,68 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore.nn as nn
import mindspore.context as context
from mindspore import Tensor
import mindspore.ops.operations.math_ops as P
class SincNet(nn.Cell):
def __init__(self):
super(SincNet, self).__init__()
self.sinc = P.Sinc()
def construct(self, x):
return self.sinc(x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sinc_input_float32_output_float32():
"""
Feature: Sinc gpu TEST.
Description: Test case for Sinc
Expectation: The value and shape of output are the expected values.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
input_ms = Tensor(np.array([143, 7, 237, 221]).astype(np.float32))
net = SincNet()
output_ms = net(input_ms)
expect = np.array([-5.0647902e-08, -6.0352242e-08, -4.8641517e-08, -2.2676563e-008])
assert np.allclose(output_ms.asnumpy(), expect.astype(np.float32), 0.001, 0.001)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_sinc_input_float64_output_float64():
"""
Feature: Sinc gpu TEST.
Description: Test case for Sinc
Expectation: The value and shape of output are the expected values.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
input_ms = Tensor(np.array([13, 5.3, 6, 10]).astype(np.float64))
net = SincNet()
output_ms = net(input_ms)
expect = np.array([-4.8008e-17, -4.8588e-02, -3.8982e-17, -3.8982e-17])
assert np.allclose(output_ms.asnumpy(), expect.astype(np.float64), 0.0001, 0.0001)