[feat][assistant][I5EWL4] add new gpu operator Hypot
This commit is contained in:
parent
f308449ce8
commit
31a22a633d
|
@ -0,0 +1,155 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_HYPOT_HELPER_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_HYPOT_HELPER_H_
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/hypot_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace cukernel {
|
||||
constexpr int MAX_DIMS = 7;
|
||||
template <typename T>
|
||||
class HypotHelperGpuKernel : public GpuKernelHelperBase {
|
||||
public:
|
||||
explicit HypotHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id)
|
||||
: GpuKernelHelperBase(kernel_name, device_id) {
|
||||
is_null_input_ = false;
|
||||
need_broadcast_ = false;
|
||||
}
|
||||
|
||||
virtual ~HypotHelperGpuKernel() = default;
|
||||
int CalMemSize(const std::vector<std::vector<int64_t>> &input_shapes,
|
||||
const std::vector<std::vector<int64_t>> &output_shapes) override {
|
||||
constexpr size_t INPUT_NUM = 2;
|
||||
constexpr size_t OUTPUT_NUM = 1;
|
||||
ResetResource();
|
||||
int inp_flag = CalShapesSizeInBytes<T>(input_shapes, INPUT_NUM, kernel_name_, "input_shapes", &input_size_list_);
|
||||
if (inp_flag == -1) {
|
||||
return inp_flag;
|
||||
}
|
||||
int out_flag =
|
||||
CalShapesSizeInBytes<T>(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_);
|
||||
if (out_flag == -1) {
|
||||
return out_flag;
|
||||
}
|
||||
is_null_input_ = (inp_flag == 1 || out_flag == 1);
|
||||
|
||||
auto inputx_shape = input_shapes[0];
|
||||
auto inputy_shape = input_shapes[1];
|
||||
auto output_shape = output_shapes[0];
|
||||
|
||||
ProcessScalar(&inputx_shape, &inputy_shape, &output_shape);
|
||||
|
||||
for (size_t i = 0; i < inputx_shape.size(); i++) {
|
||||
if (inputx_shape[i] != inputy_shape[i]) {
|
||||
need_broadcast_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
lhs_shape_.resize(MAX_DIMS, 1);
|
||||
rhs_shape_.resize(MAX_DIMS, 1);
|
||||
output_shape_.resize(MAX_DIMS, 1);
|
||||
output_num_ = 1;
|
||||
for (size_t i = 0; i < output_shape.size(); i++) {
|
||||
if (need_broadcast_) {
|
||||
output_shape_[i] = output_shape[i];
|
||||
}
|
||||
output_num_ *= output_shape[i];
|
||||
}
|
||||
int lhs_offset = output_shape.size() - inputx_shape.size();
|
||||
for (size_t j = 0; j < inputx_shape.size(); j++) {
|
||||
if (need_broadcast_) {
|
||||
if ((j + lhs_offset) >= 0 && (j + lhs_offset) < MAX_DIMS) {
|
||||
lhs_shape_[j + lhs_offset] = inputx_shape[j];
|
||||
}
|
||||
}
|
||||
}
|
||||
int rhs_offset = output_shape.size() - inputy_shape.size();
|
||||
for (size_t k = 0; k < inputy_shape.size(); k++) {
|
||||
if (need_broadcast_) {
|
||||
if ((k + rhs_offset) >= 0 && (k + rhs_offset) < MAX_DIMS) {
|
||||
rhs_shape_[k + rhs_offset] = inputy_shape[k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return CheckKernelParam();
|
||||
}
|
||||
|
||||
int Process(const std::vector<void *> &input_ptrs, const std::vector<void *> &output_ptrs,
|
||||
const std::vector<void *> &work_ptrs, void *cuda_stream) override {
|
||||
if (is_null_input_) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
T *inputx_ptr = nullptr;
|
||||
T *inputy_ptr = nullptr;
|
||||
T *output_ptr = nullptr;
|
||||
int flag = GetDeviceAddress<T>(input_ptrs, 0, kernel_name_, &inputx_ptr);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
|
||||
flag = GetDeviceAddress<T>(input_ptrs, 1, kernel_name_, &inputy_ptr);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
|
||||
flag = GetDeviceAddress<T>(output_ptrs, 0, kernel_name_, &output_ptr);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
|
||||
// call cuda kernel
|
||||
if (need_broadcast_) {
|
||||
BroadcastHypot(lhs_shape_, rhs_shape_, output_shape_, inputx_ptr, inputy_ptr, output_ptr, device_id_,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream));
|
||||
} else {
|
||||
CalHypot(output_num_, inputx_ptr, inputy_ptr, output_ptr, device_id_,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream));
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void ProcessScalar(std::vector<int64_t> *x1_shape, std::vector<int64_t> *x2_shape, std::vector<int64_t> *y_shape) {
|
||||
// If there is a scalar in the inputs, its shape will be [], so it will be treated as [1].
|
||||
if (x1_shape->size() == 0) {
|
||||
x1_shape->insert(x1_shape->begin(), 1);
|
||||
}
|
||||
if (x2_shape->size() == 0) {
|
||||
x2_shape->insert(x2_shape->begin(), 1);
|
||||
}
|
||||
if (y_shape->size() == 0) {
|
||||
y_shape->insert(y_shape->begin(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<size_t> lhs_shape_;
|
||||
std::vector<size_t> rhs_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
bool need_broadcast_;
|
||||
bool is_null_input_;
|
||||
size_t output_num_;
|
||||
};
|
||||
} // namespace cukernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_HYPOT_HELPER_H_
|
|
@ -0,0 +1,144 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/hypot_impl.cuh"
|
||||
|
||||
__constant__ size_t start_cal[5];
|
||||
__constant__ size_t end_cal[5];
|
||||
__constant__ size_t output_cal[5];
|
||||
|
||||
template <typename T> struct HypotFunc {
|
||||
__device__ __host__ __forceinline__ T operator()(const T &x1, const T &x2) {
|
||||
return hypotf(x1, x2);
|
||||
}
|
||||
};
|
||||
|
||||
template <> struct HypotFunc<double> {
|
||||
__device__ __host__ __forceinline__ double operator()(const double &x1,
|
||||
const double &x2) {
|
||||
return hypot(x1, x2);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename Func>
|
||||
__global__ void CalHypotKernel(size_t size, const T *x1, const T *x2, T *y) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size;
|
||||
pos += blockDim.x * gridDim.x) {
|
||||
y[pos] = Func()(x1[pos], x2[pos]);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ size_t Index(const size_t &index,
|
||||
const size_t &dim) {
|
||||
return dim == 1 ? 0 : index;
|
||||
}
|
||||
|
||||
template <typename T, typename Func>
|
||||
__global__ void BroadcastHypotKernel(
|
||||
const size_t l0, const size_t l1, const size_t l2, const size_t l3,
|
||||
const size_t l4, const size_t l5, const size_t l6, const size_t r0,
|
||||
const size_t r1, const size_t r2, const size_t r3, const size_t r4,
|
||||
const size_t r5, const size_t r6, const size_t d0, const size_t d1,
|
||||
const size_t d2, const size_t d3, const size_t d4, const size_t d5,
|
||||
const size_t d6, const T *x1, const T *x2, T *y) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
pos < d0 * d1 * d2 * d3 * d4 * d5 * d6; pos += blockDim.x * gridDim.x) {
|
||||
size_t i = pos / output_cal[0] % d0;
|
||||
size_t j = pos / output_cal[1] % d1;
|
||||
size_t k = pos / output_cal[2] % d2;
|
||||
size_t l = pos / output_cal[3] % d3;
|
||||
size_t m = pos / output_cal[4] % d4;
|
||||
size_t n = pos / d6 % d5;
|
||||
size_t o = pos % d6;
|
||||
|
||||
size_t l_index = Index(i, l0) * start_cal[0];
|
||||
l_index += Index(j, l1) * start_cal[1];
|
||||
l_index += Index(k, l2) * start_cal[2];
|
||||
l_index += Index(l, l3) * start_cal[3];
|
||||
l_index += Index(m, l4) * start_cal[4];
|
||||
l_index += Index(n, l5) * l6;
|
||||
l_index += Index(o, l6);
|
||||
size_t r_index = Index(i, r0) * end_cal[0];
|
||||
r_index += Index(j, r1) * end_cal[1];
|
||||
r_index += Index(k, r2) * end_cal[2];
|
||||
r_index += Index(l, r3) * end_cal[3];
|
||||
r_index += Index(m, r4) * end_cal[4];
|
||||
r_index += Index(n, r5) * r6;
|
||||
r_index += Index(o, r6);
|
||||
y[pos] = Func()(x1[l_index], x2[r_index]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalHypot(size_t size, const T *x1, const T *x2, T *y,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
return CalHypotKernel<T, HypotFunc<T>>
|
||||
<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0,
|
||||
cuda_stream>>>(size, x1, x2, y);
|
||||
}
|
||||
|
||||
void CalShapeData(const std::vector<size_t> &start_shape, size_t *output) {
|
||||
output[4] = start_shape[5] * start_shape[6];
|
||||
output[3] = output[4] * start_shape[4];
|
||||
output[2] = output[3] * start_shape[3];
|
||||
output[1] = output[2] * start_shape[2];
|
||||
output[0] = output[1] * start_shape[1];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BroadcastHypot(const std::vector<size_t> &x1_shape,
|
||||
const std::vector<size_t> &x2_shape,
|
||||
const std::vector<size_t> &y_shape, const T *x1,
|
||||
const T *x2, T *y, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream) {
|
||||
size_t size = 1;
|
||||
for (auto d : y_shape) {
|
||||
size *= d;
|
||||
}
|
||||
size_t start_dim[5];
|
||||
size_t end_dim[5];
|
||||
size_t output_dim[5];
|
||||
CalShapeData(x1_shape, start_dim);
|
||||
CalShapeData(x2_shape, end_dim);
|
||||
CalShapeData(y_shape, output_dim);
|
||||
cudaMemcpyToSymbol(start_cal, start_dim, sizeof(size_t) * 5);
|
||||
cudaMemcpyToSymbol(end_cal, end_dim, sizeof(size_t) * 5);
|
||||
cudaMemcpyToSymbol(output_cal, output_dim, sizeof(size_t) * 5);
|
||||
return BroadcastHypotKernel<T, HypotFunc<T>>
|
||||
<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0,
|
||||
cuda_stream>>>(x1_shape[0], x1_shape[1], x1_shape[2], x1_shape[3],
|
||||
x1_shape[4], x1_shape[5], x1_shape[6], x2_shape[0],
|
||||
x2_shape[1], x2_shape[2], x2_shape[3], x2_shape[4],
|
||||
x2_shape[5], x2_shape[6], y_shape[0], y_shape[1],
|
||||
y_shape[2], y_shape[3], y_shape[4], y_shape[5],
|
||||
y_shape[6], x1, x2, y);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void CalHypot<float>(size_t, const float *, const float *,
|
||||
float *, const uint32_t &,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalHypot<double>(size_t, const double *, const double *,
|
||||
double *, const uint32_t &,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void
|
||||
BroadcastHypot<float>(const std::vector<size_t> &, const std::vector<size_t> &,
|
||||
const std::vector<size_t> &, const float *, const float *,
|
||||
float *, const uint32_t &, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void
|
||||
BroadcastHypot<double>(const std::vector<size_t> &, const std::vector<size_t> &,
|
||||
const std::vector<size_t> &, const double *, const double *,
|
||||
double *, const uint32_t &, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HYPOT_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HYPOT_IMPL_CUH_
|
||||
|
||||
#include <vector>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void CalHypot(size_t size, const T *x1, const T *x2, T *y, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void BroadcastHypot(const std::vector<size_t> &x1_shape, const std::vector<size_t> &x2_shape,
|
||||
const std::vector<size_t> &y_shape, const T *x1, const T *x2, T *y,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HYPOT_IMPL_CUH_
|
|
@ -0,0 +1,100 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/math/hypot_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
template <typename T>
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> CreateHypotKernelPtr(const std::string &kernel_name,
|
||||
const uint32_t &device_id) {
|
||||
return std::make_unique<cukernel::HypotHelperGpuKernel<T>>(kernel_name, device_id);
|
||||
}
|
||||
using HypotPtrCreatorFunc =
|
||||
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
|
||||
|
||||
const std::vector<std::pair<KernelAttr, HypotPtrCreatorFunc>> kernel_attr = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateHypotKernelPtr<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateHypotKernelPtr<double>}};
|
||||
} // namespace
|
||||
|
||||
bool HypotGpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
std::vector<void *> input_ptrs = ConvertPtrs(inputs);
|
||||
std::vector<void *> work_ptrs = ConvertPtrs(workspace);
|
||||
std::vector<void *> output_ptrs = ConvertPtrs(outputs);
|
||||
if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HypotGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::Hypot>(base_operator);
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
return false;
|
||||
}
|
||||
helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_));
|
||||
|
||||
Resize(base_operator, inputs, outputs);
|
||||
return true;
|
||||
}
|
||||
|
||||
int HypotGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
for (const auto &input : inputs) {
|
||||
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
|
||||
auto input_shape = input->GetShapeVector();
|
||||
if (!IsValidShape(input_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<std::vector<int64_t>> input_shapes;
|
||||
std::vector<std::vector<int64_t>> output_shapes;
|
||||
std::vector<int64_t> inpx_shape = inputs.at(kIndex0)->GetShapeVector();
|
||||
std::vector<int64_t> inpy_shape = inputs.at(kIndex1)->GetShapeVector();
|
||||
std::vector<int64_t> out_shape = outputs.at(kIndex0)->GetShapeVector();
|
||||
input_shapes.emplace_back(inpx_shape);
|
||||
input_shapes.emplace_back(inpy_shape);
|
||||
output_shapes.emplace_back(out_shape);
|
||||
if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
input_size_list_ = helper_ptr_->GetInputSizeList();
|
||||
output_size_list_ = helper_ptr_->GetOutputSizeList();
|
||||
workspace_size_list_ = helper_ptr_->GetWorkSizeList();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> HypotGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, HypotPtrCreatorFunc> &item) { return item.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Hypot, HypotGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,61 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_HYPOT_GPU_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_HYPOT_GPU_KERNEL_H
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <cusolverDn.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include <complex>
|
||||
#include "mindspore/core/ops/hypot.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/hypot_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class HypotGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
HypotGpuKernelMod() {}
|
||||
~HypotGpuKernelMod() = default;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||
|
||||
int Resize(
|
||||
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> helper_ptr_{nullptr};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_HYPOT_GPU_KERNEL_H
|
|
@ -2635,7 +2635,7 @@ class Hypot(Primitive):
|
|||
ValueError: If shape of two inputs are not broadcastable.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> x1 = Tensor(np.array([3., 5., 7.]))
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.context as context
|
||||
import mindspore.ops.operations.math_ops as P
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
|
||||
|
||||
class NetHypot(nn.Cell):
|
||||
def __init__(self):
|
||||
super(NetHypot, self).__init__()
|
||||
self.hypot = P.Hypot()
|
||||
|
||||
def construct(self, x1, x2):
|
||||
return self.hypot(x1, x2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_hypot_fp32():
|
||||
"""
|
||||
Feature: Hypot
|
||||
Description: test cases for Hypot of float32
|
||||
Expectation: the results are as expected
|
||||
"""
|
||||
x1_np = np.array([3, 4]).astype(np.float32)
|
||||
x2_np = np.array([4, 3]).astype(np.float32)
|
||||
input_x1 = Tensor(x1_np)
|
||||
input_x2 = Tensor(x2_np)
|
||||
net = NetHypot()
|
||||
output_ms = net(input_x1, input_x2)
|
||||
expect_output = np.array([5.0, 5.0]).astype(np.float32)
|
||||
assert np.allclose(output_ms.asnumpy(), expect_output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_hypot_fp64():
|
||||
"""
|
||||
Feature: Hypot
|
||||
Description: test cases for Hypot of float64
|
||||
Expectation: the results are as expected
|
||||
"""
|
||||
x1_np = np.array([1.2, 3.4, 2.4, 1.3]).astype(np.float64)
|
||||
x2_np = np.array([2.3, 1.1, 0.9, 0.3]).astype(np.float64)
|
||||
input_x1 = Tensor(x1_np)
|
||||
input_x2 = Tensor(x2_np)
|
||||
net = NetHypot()
|
||||
output_ms = net(input_x1, input_x2)
|
||||
expect_output = np.array([2.59422435, 3.57351368, 2.56320112, 1.33416641]).astype(np.float64)
|
||||
assert np.allclose(output_ms.asnumpy(), expect_output)
|
Loading…
Reference in New Issue