[feat][assistant][I5EWL4] add new gpu operator Hypot

This commit is contained in:
li-qiyao 2022-10-09 21:49:18 +08:00
parent f308449ce8
commit 31a22a633d
7 changed files with 562 additions and 1 deletions

View File

@ -0,0 +1,155 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_HYPOT_HELPER_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_HYPOT_HELPER_H_
#include <memory>
#include <string>
#include <vector>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/hypot_impl.cuh"
namespace mindspore {
namespace cukernel {
constexpr int MAX_DIMS = 7;
template <typename T>
class HypotHelperGpuKernel : public GpuKernelHelperBase {
public:
explicit HypotHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id)
: GpuKernelHelperBase(kernel_name, device_id) {
is_null_input_ = false;
need_broadcast_ = false;
}
virtual ~HypotHelperGpuKernel() = default;
int CalMemSize(const std::vector<std::vector<int64_t>> &input_shapes,
const std::vector<std::vector<int64_t>> &output_shapes) override {
constexpr size_t INPUT_NUM = 2;
constexpr size_t OUTPUT_NUM = 1;
ResetResource();
int inp_flag = CalShapesSizeInBytes<T>(input_shapes, INPUT_NUM, kernel_name_, "input_shapes", &input_size_list_);
if (inp_flag == -1) {
return inp_flag;
}
int out_flag =
CalShapesSizeInBytes<T>(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_);
if (out_flag == -1) {
return out_flag;
}
is_null_input_ = (inp_flag == 1 || out_flag == 1);
auto inputx_shape = input_shapes[0];
auto inputy_shape = input_shapes[1];
auto output_shape = output_shapes[0];
ProcessScalar(&inputx_shape, &inputy_shape, &output_shape);
for (size_t i = 0; i < inputx_shape.size(); i++) {
if (inputx_shape[i] != inputy_shape[i]) {
need_broadcast_ = true;
}
}
lhs_shape_.resize(MAX_DIMS, 1);
rhs_shape_.resize(MAX_DIMS, 1);
output_shape_.resize(MAX_DIMS, 1);
output_num_ = 1;
for (size_t i = 0; i < output_shape.size(); i++) {
if (need_broadcast_) {
output_shape_[i] = output_shape[i];
}
output_num_ *= output_shape[i];
}
int lhs_offset = output_shape.size() - inputx_shape.size();
for (size_t j = 0; j < inputx_shape.size(); j++) {
if (need_broadcast_) {
if ((j + lhs_offset) >= 0 && (j + lhs_offset) < MAX_DIMS) {
lhs_shape_[j + lhs_offset] = inputx_shape[j];
}
}
}
int rhs_offset = output_shape.size() - inputy_shape.size();
for (size_t k = 0; k < inputy_shape.size(); k++) {
if (need_broadcast_) {
if ((k + rhs_offset) >= 0 && (k + rhs_offset) < MAX_DIMS) {
rhs_shape_[k + rhs_offset] = inputy_shape[k];
}
}
}
return CheckKernelParam();
}
int Process(const std::vector<void *> &input_ptrs, const std::vector<void *> &output_ptrs,
const std::vector<void *> &work_ptrs, void *cuda_stream) override {
if (is_null_input_) {
return 0;
}
T *inputx_ptr = nullptr;
T *inputy_ptr = nullptr;
T *output_ptr = nullptr;
int flag = GetDeviceAddress<T>(input_ptrs, 0, kernel_name_, &inputx_ptr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<T>(input_ptrs, 1, kernel_name_, &inputy_ptr);
if (flag != 0) {
return flag;
}
flag = GetDeviceAddress<T>(output_ptrs, 0, kernel_name_, &output_ptr);
if (flag != 0) {
return flag;
}
// call cuda kernel
if (need_broadcast_) {
BroadcastHypot(lhs_shape_, rhs_shape_, output_shape_, inputx_ptr, inputy_ptr, output_ptr, device_id_,
reinterpret_cast<cudaStream_t>(cuda_stream));
} else {
CalHypot(output_num_, inputx_ptr, inputy_ptr, output_ptr, device_id_,
reinterpret_cast<cudaStream_t>(cuda_stream));
}
return 0;
}
void ProcessScalar(std::vector<int64_t> *x1_shape, std::vector<int64_t> *x2_shape, std::vector<int64_t> *y_shape) {
// If there is a scalar in the inputs, its shape will be [], so it will be treated as [1].
if (x1_shape->size() == 0) {
x1_shape->insert(x1_shape->begin(), 1);
}
if (x2_shape->size() == 0) {
x2_shape->insert(x2_shape->begin(), 1);
}
if (y_shape->size() == 0) {
y_shape->insert(y_shape->begin(), 1);
}
}
private:
std::vector<size_t> lhs_shape_;
std::vector<size_t> rhs_shape_;
std::vector<size_t> output_shape_;
bool need_broadcast_;
bool is_null_input_;
size_t output_num_;
};
} // namespace cukernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_HYPOT_HELPER_H_

View File

@ -0,0 +1,144 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/hypot_impl.cuh"
__constant__ size_t start_cal[5];
__constant__ size_t end_cal[5];
__constant__ size_t output_cal[5];
template <typename T> struct HypotFunc {
__device__ __host__ __forceinline__ T operator()(const T &x1, const T &x2) {
return hypotf(x1, x2);
}
};
template <> struct HypotFunc<double> {
__device__ __host__ __forceinline__ double operator()(const double &x1,
const double &x2) {
return hypot(x1, x2);
}
};
template <typename T, typename Func>
__global__ void CalHypotKernel(size_t size, const T *x1, const T *x2, T *y) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size;
pos += blockDim.x * gridDim.x) {
y[pos] = Func()(x1[pos], x2[pos]);
}
}
__device__ __forceinline__ size_t Index(const size_t &index,
const size_t &dim) {
return dim == 1 ? 0 : index;
}
template <typename T, typename Func>
__global__ void BroadcastHypotKernel(
const size_t l0, const size_t l1, const size_t l2, const size_t l3,
const size_t l4, const size_t l5, const size_t l6, const size_t r0,
const size_t r1, const size_t r2, const size_t r3, const size_t r4,
const size_t r5, const size_t r6, const size_t d0, const size_t d1,
const size_t d2, const size_t d3, const size_t d4, const size_t d5,
const size_t d6, const T *x1, const T *x2, T *y) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x;
pos < d0 * d1 * d2 * d3 * d4 * d5 * d6; pos += blockDim.x * gridDim.x) {
size_t i = pos / output_cal[0] % d0;
size_t j = pos / output_cal[1] % d1;
size_t k = pos / output_cal[2] % d2;
size_t l = pos / output_cal[3] % d3;
size_t m = pos / output_cal[4] % d4;
size_t n = pos / d6 % d5;
size_t o = pos % d6;
size_t l_index = Index(i, l0) * start_cal[0];
l_index += Index(j, l1) * start_cal[1];
l_index += Index(k, l2) * start_cal[2];
l_index += Index(l, l3) * start_cal[3];
l_index += Index(m, l4) * start_cal[4];
l_index += Index(n, l5) * l6;
l_index += Index(o, l6);
size_t r_index = Index(i, r0) * end_cal[0];
r_index += Index(j, r1) * end_cal[1];
r_index += Index(k, r2) * end_cal[2];
r_index += Index(l, r3) * end_cal[3];
r_index += Index(m, r4) * end_cal[4];
r_index += Index(n, r5) * r6;
r_index += Index(o, r6);
y[pos] = Func()(x1[l_index], x2[r_index]);
}
}
template <typename T>
void CalHypot(size_t size, const T *x1, const T *x2, T *y,
const uint32_t &device_id, cudaStream_t cuda_stream) {
return CalHypotKernel<T, HypotFunc<T>>
<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0,
cuda_stream>>>(size, x1, x2, y);
}
void CalShapeData(const std::vector<size_t> &start_shape, size_t *output) {
output[4] = start_shape[5] * start_shape[6];
output[3] = output[4] * start_shape[4];
output[2] = output[3] * start_shape[3];
output[1] = output[2] * start_shape[2];
output[0] = output[1] * start_shape[1];
}
template <typename T>
void BroadcastHypot(const std::vector<size_t> &x1_shape,
const std::vector<size_t> &x2_shape,
const std::vector<size_t> &y_shape, const T *x1,
const T *x2, T *y, const uint32_t &device_id,
cudaStream_t cuda_stream) {
size_t size = 1;
for (auto d : y_shape) {
size *= d;
}
size_t start_dim[5];
size_t end_dim[5];
size_t output_dim[5];
CalShapeData(x1_shape, start_dim);
CalShapeData(x2_shape, end_dim);
CalShapeData(y_shape, output_dim);
cudaMemcpyToSymbol(start_cal, start_dim, sizeof(size_t) * 5);
cudaMemcpyToSymbol(end_cal, end_dim, sizeof(size_t) * 5);
cudaMemcpyToSymbol(output_cal, output_dim, sizeof(size_t) * 5);
return BroadcastHypotKernel<T, HypotFunc<T>>
<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0,
cuda_stream>>>(x1_shape[0], x1_shape[1], x1_shape[2], x1_shape[3],
x1_shape[4], x1_shape[5], x1_shape[6], x2_shape[0],
x2_shape[1], x2_shape[2], x2_shape[3], x2_shape[4],
x2_shape[5], x2_shape[6], y_shape[0], y_shape[1],
y_shape[2], y_shape[3], y_shape[4], y_shape[5],
y_shape[6], x1, x2, y);
}
template CUDA_LIB_EXPORT void CalHypot<float>(size_t, const float *, const float *,
float *, const uint32_t &,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void CalHypot<double>(size_t, const double *, const double *,
double *, const uint32_t &,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void
BroadcastHypot<float>(const std::vector<size_t> &, const std::vector<size_t> &,
const std::vector<size_t> &, const float *, const float *,
float *, const uint32_t &, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT void
BroadcastHypot<double>(const std::vector<size_t> &, const std::vector<size_t> &,
const std::vector<size_t> &, const double *, const double *,
double *, const uint32_t &, cudaStream_t cuda_stream);

View File

@ -0,0 +1,31 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HYPOT_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HYPOT_IMPL_CUH_
#include <vector>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
template <typename T>
CUDA_LIB_EXPORT void CalHypot(size_t size, const T *x1, const T *x2, T *y, const uint32_t &device_id,
cudaStream_t cuda_stream);
template <typename T>
CUDA_LIB_EXPORT void BroadcastHypot(const std::vector<size_t> &x1_shape, const std::vector<size_t> &x2_shape,
const std::vector<size_t> &y_shape, const T *x1, const T *x2, T *y,
const uint32_t &device_id, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_HYPOT_IMPL_CUH_

View File

@ -0,0 +1,100 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/math/hypot_gpu_kernel.h"
namespace mindspore {
namespace kernel {
namespace {
template <typename T>
std::unique_ptr<cukernel::GpuKernelHelperBase> CreateHypotKernelPtr(const std::string &kernel_name,
const uint32_t &device_id) {
return std::make_unique<cukernel::HypotHelperGpuKernel<T>>(kernel_name, device_id);
}
using HypotPtrCreatorFunc =
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
const std::vector<std::pair<KernelAttr, HypotPtrCreatorFunc>> kernel_attr = {
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
CreateHypotKernelPtr<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
CreateHypotKernelPtr<double>}};
} // namespace
bool HypotGpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
std::vector<void *> input_ptrs = ConvertPtrs(inputs);
std::vector<void *> work_ptrs = ConvertPtrs(workspace);
std::vector<void *> output_ptrs = ConvertPtrs(outputs);
if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) {
return false;
}
return true;
}
bool HypotGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
auto kernel_ptr = std::dynamic_pointer_cast<ops::Hypot>(base_operator);
kernel_name_ = kernel_ptr->name();
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport());
if (!is_match) {
return false;
}
helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_));
Resize(base_operator, inputs, outputs);
return true;
}
int HypotGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
for (const auto &input : inputs) {
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
auto input_shape = input->GetShapeVector();
if (!IsValidShape(input_shape)) {
return KRET_UNKNOWN_SHAPE;
}
}
std::vector<std::vector<int64_t>> input_shapes;
std::vector<std::vector<int64_t>> output_shapes;
std::vector<int64_t> inpx_shape = inputs.at(kIndex0)->GetShapeVector();
std::vector<int64_t> inpy_shape = inputs.at(kIndex1)->GetShapeVector();
std::vector<int64_t> out_shape = outputs.at(kIndex0)->GetShapeVector();
input_shapes.emplace_back(inpx_shape);
input_shapes.emplace_back(inpy_shape);
output_shapes.emplace_back(out_shape);
if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) {
return KRET_RESIZE_FAILED;
}
input_size_list_ = helper_ptr_->GetInputSizeList();
output_size_list_ = helper_ptr_->GetOutputSizeList();
workspace_size_list_ = helper_ptr_->GetWorkSizeList();
return KRET_OK;
}
std::vector<KernelAttr> HypotGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, HypotPtrCreatorFunc> &item) { return item.first; });
return support_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Hypot, HypotGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,61 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_HYPOT_GPU_KERNEL_H
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_HYPOT_GPU_KERNEL_H
#include <cublas_v2.h>
#include <cuda_runtime_api.h>
#include <cusolverDn.h>
#include <cuda_runtime.h>
#include <vector>
#include <string>
#include <memory>
#include <algorithm>
#include <functional>
#include <map>
#include <utility>
#include <complex>
#include "mindspore/core/ops/hypot.h"
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/hypot_helper.h"
namespace mindspore {
namespace kernel {
class HypotGpuKernelMod : public NativeGpuKernelMod {
public:
HypotGpuKernelMod() {}
~HypotGpuKernelMod() = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
std::vector<KernelAttr> GetOpSupport() override;
private:
std::unique_ptr<cukernel::GpuKernelHelperBase> helper_ptr_{nullptr};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_HYPOT_GPU_KERNEL_H

View File

@ -2635,7 +2635,7 @@ class Hypot(Primitive):
ValueError: If shape of two inputs are not broadcastable.
Supported Platforms:
``Ascend`` ``CPU``
``Ascend`` ``GPU`` ``CPU``
Examples:
>>> x1 = Tensor(np.array([3., 5., 7.]))

View File

@ -0,0 +1,70 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor
import mindspore.context as context
import mindspore.ops.operations.math_ops as P
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
class NetHypot(nn.Cell):
def __init__(self):
super(NetHypot, self).__init__()
self.hypot = P.Hypot()
def construct(self, x1, x2):
return self.hypot(x1, x2)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_hypot_fp32():
"""
Feature: Hypot
Description: test cases for Hypot of float32
Expectation: the results are as expected
"""
x1_np = np.array([3, 4]).astype(np.float32)
x2_np = np.array([4, 3]).astype(np.float32)
input_x1 = Tensor(x1_np)
input_x2 = Tensor(x2_np)
net = NetHypot()
output_ms = net(input_x1, input_x2)
expect_output = np.array([5.0, 5.0]).astype(np.float32)
assert np.allclose(output_ms.asnumpy(), expect_output)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_hypot_fp64():
"""
Feature: Hypot
Description: test cases for Hypot of float64
Expectation: the results are as expected
"""
x1_np = np.array([1.2, 3.4, 2.4, 1.3]).astype(np.float64)
x2_np = np.array([2.3, 1.1, 0.9, 0.3]).astype(np.float64)
input_x1 = Tensor(x1_np)
input_x2 = Tensor(x2_np)
net = NetHypot()
output_ms = net(input_x1, input_x2)
expect_output = np.array([2.59422435, 3.57351368, 2.56320112, 1.33416641]).astype(np.float64)
assert np.allclose(output_ms.asnumpy(), expect_output)