!42870 [assistant][ops][I56J6C]New GPU operator implementation, include HammingWindow

Merge pull request !42870 from 康渊瑞/HammingWindow
This commit is contained in:
i-robot 2022-09-27 12:28:32 +00:00 committed by Gitee
commit a21ed8e7dc
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
8 changed files with 496 additions and 10 deletions

View File

@ -0,0 +1,125 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <math.h>
#include "hamming_window_impl.cuh"
template <typename S>
__global__ void HammingWindowOne(const size_t size, const double N, const double PI,
const float alpha, const float beta, S *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
output[pos] = static_cast<S>(1);
}
return;
}
template <typename S>
__global__ void HammingWindow(const size_t size, const double N, const double PI,
const float alpha, const float beta, S *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
double out = alpha - beta *cos((2 * pos * PI) / (N - 1));
output[pos] = static_cast<S>(out);
}
return;
}
template <typename T, typename S>
void HammingWindow(const size_t size, T N, const float alpha, const float beta,
const bool periodic, S *output, const uint32_t &device_id,
cudaStream_t cuda_stream) {
const double PI = acos(-1);
if (N == 1) {
HammingWindowOne<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0,
cuda_stream>>>(size, N, PI, alpha, beta, output);
} else {
N = periodic ? static_cast<double>(N + 1) : static_cast<double>(N);
HammingWindow<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0,
cuda_stream>>>(size, N, PI, alpha, beta, output);
}
return;
}
template
CUDA_LIB_EXPORT void HammingWindow<int8_t, half>(const size_t size, int8_t N, const float alpha, const float beta,
const bool periodic, half *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void HammingWindow<int16_t, half>(const size_t size, int16_t N, const float alpha, const float beta,
const bool periodic, half *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void HammingWindow<int32_t, half>(const size_t size, int32_t N, const float alpha, const float beta,
const bool periodic, half *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void HammingWindow<int64_t, half>(const size_t size, int64_t N, const float alpha, const float beta,
const bool periodic, half *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void HammingWindow<uint8_t, half>(const size_t size, uint8_t N, const float alpha, const float beta,
const bool periodic, half *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void HammingWindow<uint16_t, half>(const size_t size, uint16_t N, const float alpha, const float beta,
const bool periodic, half *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void HammingWindow<uint32_t, half>(const size_t size, uint32_t N, const float alpha, const float beta,
const bool periodic, half *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void HammingWindow<uint64_t, half>(const size_t size, uint64_t N, const float alpha, const float beta,
const bool periodic, half *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void HammingWindow<int8_t, float>(const size_t size, int8_t N, const float alpha, const float beta,
const bool periodic, float *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void HammingWindow<int16_t, float>(const size_t size, int16_t N, const float alpha, const float beta,
const bool periodic, float *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void HammingWindow<int32_t, float>(const size_t size, int32_t N, const float alpha, const float beta,
const bool periodic, float *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void HammingWindow<int64_t, float>(const size_t size, int64_t N, const float alpha, const float beta,
const bool periodic, float *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void HammingWindow<uint8_t, float>(const size_t size, uint8_t N, const float alpha, const float beta,
const bool periodic, float *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void HammingWindow<uint16_t, float>(const size_t size, uint16_t N, const float alpha, const float beta,
const bool periodic, float *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void HammingWindow<uint32_t, float>(const size_t size, uint32_t N, const float alpha, const float beta,
const bool periodic, float *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void HammingWindow<uint64_t, float>(const size_t size, uint64_t N, const float alpha, const float beta,
const bool periodic, float *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void HammingWindow<int8_t, double>(const size_t size, int8_t N, const float alpha, const float beta,
const bool periodic, double *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void HammingWindow<int16_t, double>(const size_t size, int16_t N, const float alpha, const float beta,
const bool periodic, double *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void HammingWindow<int32_t, double>(const size_t size, int32_t N, const float alpha, const float beta,
const bool periodic, double *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void HammingWindow<int64_t, double>(const size_t size, int64_t N, const float alpha, const float beta,
const bool periodic, double *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template
CUDA_LIB_EXPORT void HammingWindow<uint8_t, double>(const size_t size, uint8_t N, const float alpha, const float beta,
const bool periodic, double *output, const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void HammingWindow<uint16_t, double>(const size_t size, uint16_t N, const float alpha,
const float beta, const bool periodic, double *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void HammingWindow<uint32_t, double>(const size_t size, uint32_t N, const float alpha,
const float beta, const bool periodic, double *output,
const uint32_t &device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void HammingWindow<uint64_t, double>(const size_t size, uint64_t N, const float alpha,
const float beta, const bool periodic, double *output,
const uint32_t &device_id, cudaStream_t cuda_stream);

View File

@ -0,0 +1,24 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HAMMING_WINDOW_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HAMMING_WINDOW_IMPL_CUH_
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
template <typename T, typename S>
CUDA_LIB_EXPORT void HammingWindow(const size_t size, T N, const float alpha, const float beta,
const bool periodic, S *output, const uint32_t &device_id, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HAMMING_WINDOW_IMPL_CUH_

View File

@ -0,0 +1,153 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/other/hamming_window_gpu_kernel.h"
namespace mindspore {
namespace kernel {
bool HammingWindowGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr_ = std::dynamic_pointer_cast<ops::HammingWindow>(base_operator);
kernel_name_ = kernel_ptr_->name();
if (inputs.empty() || outputs.empty()) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
return false;
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the kernel type should be in [int32, int64], "
<< "but got: " << kernel_attr << ".";
return false;
}
kernel_func_ = func_list_[index].second;
std::vector<int64_t> input_shape = std::vector<int64_t>(inputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(),
inputs.at(kIndex0)->GetDeviceShapeAdaptively().end());
int64_t input_dims = input_shape.size();
if (input_dims != 1) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'x' must be 0-D, but got " << input_dims << "-D.";
return false;
}
unit_input_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
unit_output_size_ = abstract::TypeIdSize(kernel_attr.GetOutputAttr(kIndex0).first);
periodic_ = kernel_ptr_->get_periodic();
alpha_ = kernel_ptr_->get_alpha();
beta_ = kernel_ptr_->get_beta();
if (input_shape[0] < 0) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the window_length should >0, "
<< "but got: " << input_shape[0] << ".";
return false;
}
return true;
}
int HammingWindowGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
for (const auto &input : inputs) {
auto input_shape = input->GetShapeVector();
if (!IsValidShape(input_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
ResetResource();
std::vector<int64_t> output_shape = std::vector<int64_t>(outputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(),
outputs.at(kIndex0)->GetDeviceShapeAdaptively().end());
output_elements_ = std::accumulate(output_shape.begin(), output_shape.end(), 1, std::multiplies<int64_t>());
if (output_elements_ == 0) {
is_null_input_ = true;
}
size_t output_size = output_elements_ * unit_output_size_;
input_size_list_.push_back(unit_input_size_);
output_size_list_.push_back(output_size);
return KRET_OK;
}
template <typename T, typename S>
bool HammingWindowGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
T *input = GetDeviceAddress<T>(inputs, 0);
S *output = GetDeviceAddress<S>(outputs, 0);
T N = 0;
cudaMemcpyAsync(&N, &input[0], sizeof(T), cudaMemcpyDeviceToHost);
HammingWindow(output_elements_, N, alpha_, beta_, periodic_, output, device_id_,
reinterpret_cast<cudaStream_t>(cuda_stream_));
return true;
}
std::vector<std::pair<KernelAttr, HammingWindowGpuKernelMod::Hamming_Func>> HammingWindowGpuKernelMod::func_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat16),
&HammingWindowGpuKernelMod::LaunchKernel<int8_t, half>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat16),
&HammingWindowGpuKernelMod::LaunchKernel<int16_t, half>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
&HammingWindowGpuKernelMod::LaunchKernel<int32_t, half>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
&HammingWindowGpuKernelMod::LaunchKernel<int64_t, half>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat16),
&HammingWindowGpuKernelMod::LaunchKernel<uint8_t, half>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat16),
&HammingWindowGpuKernelMod::LaunchKernel<uint16_t, half>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat16),
&HammingWindowGpuKernelMod::LaunchKernel<uint32_t, half>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat16),
&HammingWindowGpuKernelMod::LaunchKernel<uint64_t, half>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat32),
&HammingWindowGpuKernelMod::LaunchKernel<int8_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat32),
&HammingWindowGpuKernelMod::LaunchKernel<int16_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
&HammingWindowGpuKernelMod::LaunchKernel<int32_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
&HammingWindowGpuKernelMod::LaunchKernel<int64_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat32),
&HammingWindowGpuKernelMod::LaunchKernel<uint8_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat32),
&HammingWindowGpuKernelMod::LaunchKernel<uint16_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat32),
&HammingWindowGpuKernelMod::LaunchKernel<uint32_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat32),
&HammingWindowGpuKernelMod::LaunchKernel<uint64_t, float>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeFloat64),
&HammingWindowGpuKernelMod::LaunchKernel<int8_t, double>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeFloat64),
&HammingWindowGpuKernelMod::LaunchKernel<int16_t, double>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
&HammingWindowGpuKernelMod::LaunchKernel<int32_t, double>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
&HammingWindowGpuKernelMod::LaunchKernel<int64_t, double>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeFloat64),
&HammingWindowGpuKernelMod::LaunchKernel<uint8_t, double>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeFloat64),
&HammingWindowGpuKernelMod::LaunchKernel<uint16_t, double>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeFloat64),
&HammingWindowGpuKernelMod::LaunchKernel<uint32_t, double>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeFloat64),
&HammingWindowGpuKernelMod::LaunchKernel<uint64_t, double>}};
std::vector<KernelAttr> HammingWindowGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, Hamming_Func> &pair) { return pair.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, HammingWindow, HammingWindowGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,88 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_Hamming_WINDOW_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_Hamming_WINDOW_GPU_KERNEL_H_
#include <vector>
#include <string>
#include <memory>
#include <utility>
#include <algorithm>
#include <functional>
#include <map>
#include "mindspore/core/ops/hamming_window.h"
#include "abstract/utils.h"
#include "plugin/factory/ms_factory.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/hamming_window_impl.cuh"
namespace mindspore {
namespace kernel {
class HammingWindowGpuKernelMod : public NativeGpuKernelMod {
public:
HammingWindowGpuKernelMod() { ResetResource(); }
~HammingWindowGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
if (is_null_input_) {
return true;
}
cuda_stream_ = cuda_stream;
return kernel_func_(this, inputs, workspace, outputs);
}
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
protected:
void ResetResource() noexcept {
output_elements_ = 0;
is_null_input_ = false;
input_size_list_.clear();
output_size_list_.clear();
}
std::vector<KernelAttr> GetOpSupport() override;
private:
template <typename T, typename S>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
using Hamming_Func =
std::function<bool(HammingWindowGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
private:
bool periodic_{true};
float alpha_{0.54};
float beta_{0.46};
size_t unit_input_size_{1};
size_t unit_output_size_{1};
size_t output_elements_;
Hamming_Func kernel_func_{};
bool is_null_input_{false};
void *cuda_stream_{nullptr};
static std::vector<std::pair<KernelAttr, Hamming_Func>> func_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_OTHER_Hamming_GPU_KERNEL_H_

View File

@ -26,10 +26,10 @@ namespace ops {
namespace {
const int64_t MAX_WINDOW_LEN = 1024 * 1024;
#define WINDOW_LENGTH_CASE(DTYPE, TYPE, LENGTH_VALUE, LENGTH_TENSOR) \
case (DTYPE): { \
LENGTH_VALUE = static_cast<int64_t>(*static_cast<TYPE *>(LENGTH_TENSOR->data_c())); \
break; \
#define WINDOW_LENGTH_CASE(DTYPE, TYPE, LENGTH_VALUE, LENGTH_TENSOR) \
case (DTYPE): { \
LENGTH_VALUE = static_cast<int64_t>(*reinterpret_cast<TYPE *>(LENGTH_TENSOR->data_c())); \
break; \
}
abstract::ShapePtr HammingWindowInferShape(const PrimitivePtr &primitive,
@ -41,8 +41,7 @@ abstract::ShapePtr HammingWindowInferShape(const PrimitivePtr &primitive,
auto length_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto length_size = length_shape.size();
const int64_t length_dim = 1;
(void)CheckAndConvertUtils::CheckInteger("length dim", SizeToLong(length_size), kEqual, length_dim,
primitive->name());
CheckAndConvertUtils::CheckInteger("length dim", length_size, kEqual, length_dim, primitive->name());
if (input_args[0]->isa<abstract::AbstractTensor>() && !input_args[0]->BuildValue()->isa<AnyValue>() &&
!input_args[0]->BuildValue()->isa<None>()) {
auto length = input_args[0]->cast<abstract::AbstractTensorPtr>();
@ -75,7 +74,7 @@ abstract::ShapePtr HammingWindowInferShape(const PrimitivePtr &primitive,
<< TypeIdLabel(input_type_value);
}
}
(void)CheckAndConvertUtils::CheckInteger("length value", length_value, kGreaterEqual, 0, primitive->name());
CheckAndConvertUtils::CheckInteger("length value", length_value, kGreaterEqual, 0, primitive->name());
out_shape.push_back(length_value);
return std::make_shared<abstract::Shape>(out_shape);
} else {
@ -93,7 +92,7 @@ TypePtr HammingWindowInferType(const PrimitivePtr &primitive, const std::vector<
auto input_type = input_args[0]->BuildType();
MS_EXCEPTION_IF_NULL(input_type);
const std::set<TypePtr> valid_input_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32, kUInt64};
(void)CheckAndConvertUtils::CheckTensorTypeValid("length", input_type, valid_input_types, primitive->name());
CheckAndConvertUtils::CheckTensorTypeValid("length", input_type, valid_input_types, primitive->name());
auto dtype_attr = primitive->GetAttr("dtype");
MS_EXCEPTION_IF_NULL(dtype_attr);
int64_t dtype_value = GetValue<int64_t>(dtype_attr);
@ -116,6 +115,19 @@ TypePtr HammingWindowInferType(const PrimitivePtr &primitive, const std::vector<
}
} // namespace
void HammingWindow::set_periodic(const bool periodic) { (void)this->AddAttr(kPeriodic, api::MakeValue(periodic)); }
bool HammingWindow::get_periodic() const { return GetValue<bool>(GetAttr(kPeriodic)); }
void HammingWindow::set_alpha(const float alpha) { (void)this->AddAttr(kAlpha, api::MakeValue(alpha)); }
float HammingWindow::get_alpha() const { return GetValue<float>(GetAttr(kAlpha)); }
void HammingWindow::set_beta(const float beta) { (void)this->AddAttr(kBeta, api::MakeValue(beta)); }
float HammingWindow::get_beta() const { return GetValue<float>(GetAttr(kBeta)); }
void HammingWindow::Init(const bool periodic, const float alpha, const float beta) {
set_periodic(periodic);
set_alpha(alpha);
set_beta(beta);
}
MIND_API_OPERATOR_IMPL(HammingWindow, BaseOperator);
AbstractBasePtr HammingWindowInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {

View File

@ -37,11 +37,17 @@ class MIND_API HammingWindow : public BaseOperator {
MIND_API_BASE_MEMBER(HammingWindow);
/// \brief Constructor.
HammingWindow() : BaseOperator(kNameHammingWindow) { InitIOName({"length"}, {"y"}); }
void Init(const bool periodic = true, const float alpha = 0.54, const float beta = 0.46);
void set_periodic(const bool periodic);
bool get_periodic() const;
void set_alpha(const float alpha);
float get_alpha() const;
void set_beta(const float beta);
float get_beta() const;
};
abstract::AbstractBasePtr HammingWindowInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_HAMMING_WINDOW_H_

View File

@ -7751,7 +7751,7 @@ class HammingWindow(Primitive):
ValueError: If data of `length` is negative.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``CPU`` ``GPU``
Examples:
>>> # case 1: periodic=True.

View File

@ -0,0 +1,78 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import torch
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.common.api import ms_function
import mindspore.ops.operations.array_ops as P2
class HammingWindowNet(nn.Cell):
def __init__(self, periodic=True, alpha=0.54, beta=0.46, dtype=mstype.Int):
super(HammingWindowNet, self).__init__()
self.hammingwindow = P2.HammingWindow(periodic=periodic, alpha=alpha, beta=beta, dtype=dtype)
@ms_function
def construct(self, input_x):
return self.hammingwindow(input_x)
def hamming_window(periodic, loss):
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
input_x_np = np.array([10]).astype(np.int32)
input_x_ms = Tensor(input_x_np)
hamming_window_net = HammingWindowNet(periodic, 0.54, 0.46, mstype.float32)
hamming_window_output = hamming_window_net(input_x_ms)
hamming_window_expect = torch.hamming_window(10, periodic=periodic)
assert np.allclose(hamming_window_output.asnumpy(), hamming_window_expect.numpy().astype(np.float32), loss, loss)
def hamming_window_pynative(periodic, loss):
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
input_x_np = np.array([10]).astype(np.int32)
input_x_ms = Tensor(input_x_np)
hamming_window_net = HammingWindowNet(periodic, 0.54, 0.46, mstype.float32)
hamming_window_output = hamming_window_net(input_x_ms)
hamming_window_expect = torch.hamming_window(10, periodic=periodic)
assert np.allclose(hamming_window_output.asnumpy(), hamming_window_expect.numpy().astype(np.float32), loss, loss)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_blackman_window_graph_int32_true_float32():
"""
Feature: ALL To ALL
Description: test cases for HammingWindow
Expectation: the result match to torch
"""
hamming_window(periodic=True, loss=1.0e-4)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_blackman_window_pynative_int64_false_float64():
"""
Feature: ALL To ALL
Description: test cases for HammingWindow
Expectation: the result match to torch
"""
hamming_window_pynative(periodic=False, loss=1.0e-4)