!54434 broadcast 优化

Merge pull request !54434 from zong_shuai/broadcast
This commit is contained in:
i-robot 2023-05-23 09:53:51 +00:00 committed by Gitee
commit 651f8ae568
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
42 changed files with 2194 additions and 2938 deletions

View File

@ -31,7 +31,7 @@
"mindspore/mindspore/ccsrc/plugin/device/gpu/kernel/sequence/sequence_addn_gpu_kernel.cc" "whitespace/indent"
"mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/sequence/sequence_addn_cpu_kernel.cc" "whitespace/indent"
"mindspore/mindspore/ccsrc/pipeline/jit/resource.cc" "readability/fn_size"
"mindspore/mindspore/ccsrc/plugin/device/gpu/kernel/math/binary_ops_gpu_kernel.cc" "whitespace/indent"
# Modelzoo
"mindspore/model_zoo/official/cv/yolov4_tiny/infer/mxbase/src/Yolov4TinyDetection.h" "runtime/references"
"mindspore/model_zoo/official/cv/yolov4_tiny/infer/mxbase/src/PostProcess/Yolov4TinyMindsporePost.h" "runtime/references"

View File

@ -17,7 +17,8 @@
#include <memory>
#include "plugin/device/gpu/kernel/arrays/array_reduce_gpu_kernel.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/unary_op_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/broadcast_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_ops_impl.cuh"
#include "plugin/device/gpu/kernel/math/broadcast_public.h"
#include "ops/reduce.h"
#include "plugin/device/gpu/kernel/arrays/cast_gpu_kernel.h"
#include "plugin/device/gpu/hal/device/gpu_common.h"
@ -410,8 +411,10 @@ void ArrayReduceGpuKernelMod::LaunchComplexKernel(const std::vector<AddressPtr>
cudnnReduceTensor(cudnn_handle_, reduce_tensor_descriptor_, nullptr, 0, workspace_addr, workspace_size_, &alpha,
inputA_descriptor_, input_imag, &beta, outputC_descriptor_, output_imag),
ss.str());
ElewiseComplexArith(output_count, BinaryOpType::kComplex, output_real, output_imag, output_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
std::vector<int64_t> ele_shape = {static_cast<int64_t>(output_count)};
BinaryOpWithBroadcastCudaFunc<BinaryOpType::kComplex, S, S, T>(false, ele_shape, ele_shape, ele_shape, output_real,
output_imag, output_addr, device_id_,
reinterpret_cast<cudaStream_t>(stream_ptr));
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(input_real);
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(input_imag);
device::gpu::GPUMemoryAllocator::GetInstance().FreeTensorMem(output_real);

View File

@ -15,6 +15,7 @@
*/
#include "plugin/device/gpu/kernel/arrays/broadcast_to_gpu_kernel.h"
#include "plugin/device/gpu/kernel/math/broadcast_public.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
namespace mindspore {
@ -29,59 +30,25 @@ bool BroadcastToGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const s
return false;
}
kernel_func_ = func_list_[index].second;
input_type_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).dtype);
return true;
}
void BroadcastToGpuKernelMod::ResetResource() noexcept {
input_size_ = 1;
output_size_ = 1;
for (size_t i = 0; i < SHAPE_SIZE; ++i) {
input_shape_[i] = 1;
output_shape_[i] = 1;
}
is_null_input_ = false;
}
int BroadcastToGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
ResetResource();
auto input_shapes = inputs[kIndex0]->GetShapeVector();
auto output_shapes = outputs[kIndex0]->GetShapeVector();
auto it_x = std::find_if(input_shapes.begin(), input_shapes.end(), [](int64_t sh) { return sh < 0; });
if (it_x != input_shapes.end()) {
return KRET_UNKNOWN_SHAPE;
int ret = KRET_OK;
if ((ret = KernelMod::Resize(base_operator, inputs, outputs)) != 0) {
return ret;
}
if (input_shapes.size() > SHAPE_SIZE || output_shapes.size() > SHAPE_SIZE) {
auto inp_shape = inputs[kIndex0]->GetShapeVector();
auto out_shape = outputs[kIndex0]->GetShapeVector();
if (inp_shape.size() > SHAPE_SIZE || out_shape.size() > SHAPE_SIZE) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input and output cannot be greater than "
<< SHAPE_SIZE << ", but got the dimension of input: " << input_shapes.size()
<< ", the dimension of output: " << output_shapes.size();
<< SHAPE_SIZE << ", but got the dimension of input: " << inp_shape.size()
<< ", the dimension of output: " << out_shape.size();
}
if (output_shapes.size() < input_shapes.size()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the dimension of output cannot be less than the dimension of input "
<< ", but got the dimension of input: " << input_shapes.size()
<< ", the dimension of output: " << output_shapes.size();
}
size_t offset = output_shapes.size() - input_shapes.size();
for (size_t i = 0; i < input_shapes.size(); i++) {
input_shape_[i + offset] = LongToSizeClipNeg(input_shapes[i]);
}
for (size_t j = 0; j < output_shapes.size(); j++) {
output_shape_[j] = LongToSizeClipNeg(output_shapes[j]);
}
input_size_ = std::accumulate(input_shape_.begin(), input_shape_.end(), size_t(1), std::multiplies{});
output_size_ = std::accumulate(output_shape_.begin(), output_shape_.end(), size_t(1), std::multiplies{});
input_size_list_.clear();
output_size_list_.clear();
input_size_list_.push_back(input_size_ * input_type_size_);
output_size_list_.push_back(output_size_ * input_type_size_);
SimplifyBroadcastToShape(inp_shape, out_shape, &simplified_inp_shape_, &simplified_out_shape_);
is_broadcast_ = IsBinaryBroadcast(simplified_inp_shape_, simplified_out_shape_);
return KRET_OK;
}
@ -89,16 +56,17 @@ template <typename T>
bool BroadcastToGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
if (is_null_input_) {
return true;
}
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
if (is_broadcast_) {
BroadcastTo(simplified_inp_shape_, simplified_out_shape_, input_addr, output_addr, device_id_,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
size_t cpy_size = SizeOf(simplified_out_shape_) * sizeof(T);
cudaMemcpyAsync(output_addr, input_addr, cpy_size, cudaMemcpyDeviceToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
BroadcastTo(input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], input_shape_[4], input_shape_[5],
input_shape_[6], input_shape_[7], output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3],
output_shape_[4], output_shape_[5], output_shape_[6], output_shape_[7], input_addr, output_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}

View File

@ -25,7 +25,7 @@
#include <map>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/broadcast_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/broadcast_to_impl.cuh"
namespace mindspore {
namespace kernel {
@ -58,14 +58,10 @@ class BroadcastToGpuKernelMod : public NativeGpuKernelMod {
private:
std::string kernel_name_{};
BroadcastToLaunchFunc kernel_func_;
void ResetResource() noexcept;
static std::vector<std::pair<KernelAttr, BroadcastToLaunchFunc>> func_list_;
size_t input_size_;
size_t output_size_;
size_t input_type_size_; // sizeof(T)
std::vector<size_t> input_shape_ = {1, 1, 1, 1, 1, 1, 1, 1};
std::vector<size_t> output_shape_ = {1, 1, 1, 1, 1, 1, 1, 1};
bool is_null_input_ = false;
bool is_broadcast_;
std::vector<int64_t> simplified_inp_shape_;
std::vector<int64_t> simplified_out_shape_;
};
} // namespace kernel
} // namespace mindspore

View File

@ -16,9 +16,9 @@
#include "plugin/device/gpu/kernel/arrays/meshgrid_gpu_kernel.h"
#include <algorithm>
#include "mindspore/core/ops/meshgrid.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/broadcast_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_ops_impl.cuh"
#include "plugin/device/gpu/kernel/math/broadcast_public.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/elementwise_op_impl.cuh"
#include "plugin/device/gpu/kernel/math/broadcast_gpu_kernel.h"
namespace mindspore {
namespace kernel {
@ -76,8 +76,7 @@ int MeshgridGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std
output_count_ = static_cast<size_t>(output_size_list_.size());
// inferred shape swaps output shape for us if needed
auto shape_signed = outputs[kIndex0]->GetShapeVector();
output_shape_ = Convert2SizeTClipNeg(shape_signed);
output_shape_ = outputs[kIndex0]->GetShapeVector();
is_null_input_ = CHECK_SHAPE_NULL(output_shape_, kernel_name_, "output");
if (is_null_input_) {
workspace_size_list_.push_back(output_size_ * data_size_);
@ -91,16 +90,7 @@ int MeshgridGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std
return KRET_RESIZE_FAILED;
}
for (size_t i = 0; i < output_shape_.size(); i++) {
output_size_ *= output_shape_[i];
}
// need to pad output shape with ones for broadcast kernel
int need_broadcast_size = MAX_DIMS - output_shape_.size();
for (int i = 0; i < need_broadcast_size; i++) {
output_shape_.push_back(1);
}
output_size_ = SizeOf(output_shape_);
workspace_size_list_.push_back(output_size_ * data_size_);
return KRET_OK;
}
@ -114,52 +104,23 @@ bool MeshgridGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, c
T *ones_device = GetDeviceAddress<T>(workspace, 0);
CalOnesLike(static_cast<T *>(nullptr), ones_device, output_size_, reinterpret_cast<cudaStream_t>(cuda_stream_));
std::vector<size_t> broadcasted_ones_shape(MAX_DIMS, 1);
for (size_t i = 0; i < output_shape_.size(); i++) {
broadcasted_ones_shape[i] = output_shape_[i];
}
std::vector<int64_t> simplified_in0_shape;
std::vector<int64_t> simplified_in1_shape;
std::vector<int64_t> simplified_out_shape;
for (size_t i = 0; i < outputs.size(); i++) {
T *input_device = GetDeviceAddress<T>(inputs, i);
T *output_device = GetDeviceAddress<T>(outputs, i);
std::vector<size_t> broadcasted_input_shape(MAX_DIMS, 1);
std::vector<int64_t> broadcasted_input_shape(input_shapes_.size(), 1);
broadcasted_input_shape[i] = input_shapes_[i];
if (swap_indexing_ && i <= 1) {
std::swap(broadcasted_input_shape[0], broadcasted_input_shape[1]);
}
BroadcastArith(broadcasted_input_shape, broadcasted_ones_shape, output_shape_, BinaryOpType::kMul, input_device,
ones_device, output_device, reinterpret_cast<cudaStream_t>(cuda_stream_));
}
return true;
}
template <typename T, typename S, typename G>
bool MeshgridGpuKernelMod::LaunchComplexKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
if (is_null_input_) {
return true;
}
S *ones_device = GetDeviceAddress<S>(workspace, 0);
CalOnesLike(static_cast<S *>(nullptr), ones_device, output_size_, reinterpret_cast<cudaStream_t>(cuda_stream_));
std::vector<size_t> broadcasted_ones_shape(MAX_DIMS, 1);
for (size_t i = 0; i < output_shape_.size(); i++) {
broadcasted_ones_shape[i] = output_shape_[i];
}
for (size_t i = 0; i < outputs.size(); i++) {
T *input_device = GetDeviceAddress<T>(inputs, i);
G *output_device = GetDeviceAddress<G>(outputs, i);
std::vector<size_t> broadcasted_input_shape(MAX_DIMS, 1);
broadcasted_input_shape[i] = input_shapes_[i];
if (swap_indexing_ && i <= 1) {
std::swap(broadcasted_input_shape[0], broadcasted_input_shape[1]);
}
BroadcastComplexArith(broadcasted_input_shape, broadcasted_ones_shape, output_shape_, BinaryOpType::kMul,
input_device, ones_device, output_device, reinterpret_cast<cudaStream_t>(cuda_stream_));
SimplifyBinaryBroadcastShape(broadcasted_input_shape, output_shape_, output_shape_, &simplified_in0_shape,
&simplified_in1_shape, &simplified_out_shape);
bool is_broadcast = IsBinaryBroadcast(simplified_in0_shape, simplified_in1_shape);
BinaryOpWithBroadcastCudaFunc<BinaryOpType::kMul, T, T, T>(
is_broadcast, simplified_in0_shape, simplified_in1_shape, simplified_out_shape, input_device, ones_device,
output_device, device_id_, reinterpret_cast<cudaStream_t>(cuda_stream_));
}
return true;
}
@ -167,13 +128,6 @@ bool MeshgridGpuKernelMod::LaunchComplexKernel(const std::vector<AddressPtr> &in
template <typename T>
using Complex = mindspore::utils::Complex<T>;
std::vector<std::pair<KernelAttr, MeshgridGpuKernelMod::MeshgridFunc>> MeshgridGpuKernelMod::complex_list_ = {
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
&MeshgridGpuKernelMod::LaunchComplexKernel<Complex<float>, Complex<float>, Complex<float>>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
&MeshgridGpuKernelMod::LaunchComplexKernel<Complex<double>, Complex<double>, Complex<double>>},
};
std::vector<std::pair<KernelAttr, MeshgridGpuKernelMod::MeshgridFunc>> MeshgridGpuKernelMod::func_list_ = {
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
&MeshgridGpuKernelMod::LaunchKernel<bool>},
@ -199,12 +153,14 @@ std::vector<std::pair<KernelAttr, MeshgridGpuKernelMod::MeshgridFunc>> MeshgridG
&MeshgridGpuKernelMod::LaunchKernel<int32_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
&MeshgridGpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
&MeshgridGpuKernelMod::LaunchKernel<Complex<float>>},
{KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
&MeshgridGpuKernelMod::LaunchKernel<Complex<double>>},
};
std::vector<KernelAttr> MeshgridGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
(void)std::transform(complex_list_.begin(), complex_list_.end(), std::back_inserter(func_list_),
[](const std::pair<KernelAttr, MeshgridFunc> &item) { return item; });
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, MeshgridFunc> &item) { return item.first; });
return support_list;

View File

@ -62,13 +62,12 @@ class MeshgridGpuKernelMod : public NativeGpuKernelMod {
std::function<bool(MeshgridGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
static std::vector<std::pair<KernelAttr, MeshgridFunc>> func_list_;
static std::vector<std::pair<KernelAttr, MeshgridFunc>> complex_list_;
MeshgridFunc kernel_func_;
void *cuda_stream_{nullptr};
size_t data_size_;
std::vector<size_t> input_shapes_;
std::vector<size_t> output_shape_;
std::vector<int64_t> input_shapes_;
std::vector<int64_t> output_shape_;
size_t input_size_;
size_t input_count_;
size_t output_size_;

View File

@ -16,7 +16,8 @@
#include "plugin/device/gpu/kernel/arrays/sort_key_value_inplace.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/sort_fixed_size.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/broadcast_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/broadcast_to_impl.cuh"
#include "plugin/device/gpu/kernel/math/broadcast_public.h"
#include "plugin/device/gpu/hal/device/gpu_common.h"
constexpr int MAX_DIMS = 8;
@ -120,32 +121,16 @@ bool InitIndexBySlice(const TensorLayoutHelper &t, int64_t axis, K *data, cudaSt
cudaMemcpyAsync(slice_data_device, slice_data_host, slice_size * sizeof(K), cudaMemcpyHostToDevice, cuda_stream),
"Memcpy slice data from host to device failed.");
free(slice_data_host);
int in_size[MAX_DIMS];
int out_size[MAX_DIMS];
for (int i = 0; i < MAX_DIMS; i++) {
in_size[i] = 1;
}
std::vector<int64_t> in_size(MAX_DIMS, 1);
std::vector<int64_t> out_size(MAX_DIMS, 1);
in_size[MAX_DIMS - t.dim_size_ + axis] = t.sizes_[axis];
for (int i = t.dim_size_ - 1; i >= 0; i--) {
out_size[i + MAX_DIMS - t.dim_size_] = t.sizes_[i];
}
for (int i = MAX_DIMS - t.dim_size_ - 1; i >= 0; i--) {
out_size[i] = 1;
}
constexpr size_t kIndex0 = 0;
constexpr size_t kIndex1 = 1;
constexpr size_t kIndex2 = 2;
constexpr size_t kIndex3 = 3;
constexpr size_t kIndex4 = 4;
constexpr size_t kIndex5 = 5;
constexpr size_t kIndex6 = 6;
constexpr size_t kIndex7 = 7;
BroadcastTo<K>(in_size[kIndex0], in_size[kIndex1], in_size[kIndex2], in_size[kIndex3], in_size[kIndex4],
in_size[kIndex5], in_size[kIndex6], in_size[kIndex7], out_size[kIndex0], out_size[kIndex1],
out_size[kIndex2], out_size[kIndex3], out_size[kIndex4], out_size[kIndex5], out_size[kIndex6],
out_size[kIndex7], slice_data_device, data, cuda_stream);
std::vector<int64_t> simplified_inp_shape;
std::vector<int64_t> simplified_out_shape;
SimplifyBroadcastToShape(in_size, out_size, &simplified_inp_shape, &simplified_out_shape);
BroadcastTo<K>(simplified_inp_shape, simplified_out_shape, slice_data_device, data, GET_CTX_DEVICE_ID, cuda_stream);
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(cudaFree(slice_data_device), "Free slice data failed.");
return true;
}

View File

@ -0,0 +1,101 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_ops_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_common.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_pub_impl.cuh"
template <typename In0_t, typename In1_t, typename Out_t>
struct BinaryFunc<BinaryOpType::kAdd, In0_t, In1_t, Out_t> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __forceinline__ Out_t operator()(In0_t val0, In1_t val1) const { return val0 + val1; }
};
REGISTER_BINARY_OP_CUDA_FUNC_INT_TYPE(BinaryOpType::kAdd);
REGISTER_BINARY_OP_CUDA_FUNC_FLOAT_TYPE(BinaryOpType::kAdd);
REGISTER_BINARY_OP_CUDA_FUNC_COMPLEX_TYPE(BinaryOpType::kAdd);
template <typename In0_t, typename In1_t, typename Out_t>
struct BinaryFunc<BinaryOpType::kSub, In0_t, In1_t, Out_t> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __forceinline__ Out_t operator()(In0_t val0, In1_t val1) const { return val0 - val1; }
};
REGISTER_BINARY_OP_CUDA_FUNC_INT_TYPE(BinaryOpType::kSub);
REGISTER_BINARY_OP_CUDA_FUNC_FLOAT_TYPE(BinaryOpType::kSub);
REGISTER_BINARY_OP_CUDA_FUNC_COMPLEX_TYPE(BinaryOpType::kSub);
template <typename In0_t, typename In1_t, typename Out_t>
struct BinaryFunc<BinaryOpType::kPow, In0_t, In1_t, Out_t, typename std::is_floating_point<Out_t>::type> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ Out_t operator()(const In0_t &lhs, const In1_t &rhs) const {
return static_cast<Out_t>(pow(lhs, rhs));
}
};
template <>
struct BinaryFunc<BinaryOpType::kPow, half, half, half> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ half operator()(const half &lhs, const half &rhs) const {
return __float2half(pow(__half2float(lhs), __half2float(rhs)));
}
};
#define POW_INTEGER_IMPL(T) \
template <> \
struct BinaryFunc<BinaryOpType::kPow, T, T, T> { \
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { \
T ret = 1; \
T base = lhs; \
T exp = rhs; \
while (exp) { \
if (exp & 1) { \
ret *= base; \
} \
base *= base; \
exp /= 2; \
} \
return ret; \
} \
};
POW_INTEGER_IMPL(uint8_t)
POW_INTEGER_IMPL(uint16_t)
POW_INTEGER_IMPL(uint32_t)
POW_INTEGER_IMPL(uint64_t)
POW_INTEGER_IMPL(int8_t)
POW_INTEGER_IMPL(int16_t)
POW_INTEGER_IMPL(int32_t)
POW_INTEGER_IMPL(int64_t)
template <typename In0_t, typename In1_t, typename Out_t>
struct BinaryFunc<BinaryOpType::kPow, In0_t, In1_t, Complex<Out_t>> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ Complex<Out_t> operator()(const In0_t &lhs, const In1_t &rhs) const {
Complex<Out_t> result;
#if defined(__CUDACC__)
auto thrust_res = thrust::pow(thrust::complex<Out_t>(lhs), thrust::complex<Out_t>(rhs));
result.real(thrust_res.real());
result.imag(thrust_res.imag());
#else
std::complex<Out_t> lhs_complex(lhs);
std::complex<Out_t> rhs_complex(rhs);
std::complex<Out_t> host_res = std::pow(lhs_complex, rhs_complex);
result.real(host_res.real());
result.imag(host_res.imag());
#endif
return result;
}
};
REGISTER_BINARY_OP_CUDA_FUNC_INT_TYPE(BinaryOpType::kPow);
REGISTER_BINARY_OP_CUDA_FUNC_FLOAT_TYPE(BinaryOpType::kPow);
REGISTER_BINARY_OP_CUDA_FUNC_COMPLEX_TYPE(BinaryOpType::kPow);

View File

@ -0,0 +1,56 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_ops_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_common.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_pub_impl.cuh"
template <typename T>
struct BinaryFunc<BinaryOpType::kBitwiseAnd, T, T, T> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) const { return (lhs & rhs); }
};
REGISTER_BINARY_OP_CUDA_FUNC_BOOL_TYPE(BinaryOpType::kBitwiseAnd);
REGISTER_BINARY_OP_CUDA_FUNC_INT_TYPE(BinaryOpType::kBitwiseAnd);
template <typename T>
struct BinaryFunc<BinaryOpType::kBitwiseOr, T, T, T> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) const { return (lhs | rhs); }
};
REGISTER_BINARY_OP_CUDA_FUNC_BOOL_TYPE(BinaryOpType::kBitwiseOr);
REGISTER_BINARY_OP_CUDA_FUNC_INT_TYPE(BinaryOpType::kBitwiseOr);
template <typename T>
struct BinaryFunc<BinaryOpType::kBitwiseXor, T, T, T> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) const { return (lhs ^ rhs); }
};
REGISTER_BINARY_OP_CUDA_FUNC_BOOL_TYPE(BinaryOpType::kBitwiseXor);
REGISTER_BINARY_OP_CUDA_FUNC_INT_TYPE(BinaryOpType::kBitwiseXor);
template <>
struct BinaryFunc<BinaryOpType::kLogicalAnd, bool, bool, bool> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ bool operator()(const bool &lhs, const bool &rhs) const { return lhs && rhs; }
};
REGISTER_BINARY_OP_CUDA_FUNC_BOOL_TYPE(BinaryOpType::kLogicalAnd);
template <>
struct BinaryFunc<BinaryOpType::kLogicalOr, bool, bool, bool> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ bool operator()(const bool &lhs, const bool &rhs) const { return lhs || rhs; }
};
REGISTER_BINARY_OP_CUDA_FUNC_BOOL_TYPE(BinaryOpType::kLogicalOr);

View File

@ -0,0 +1,87 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_ops_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_common.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_pub_impl.cuh"
template <typename T>
struct BinaryFunc<BinaryOpType::kMod, T, T, T, typename std::is_integral<T>::type> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) const { return lhs % rhs; }
};
template <typename T>
struct BinaryFunc<BinaryOpType::kMod, T, T, T, typename std::is_floating_point<T>::type> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) const { return fmod(lhs, rhs); }
};
template <>
struct BinaryFunc<BinaryOpType::kMod, half, half, half> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ half operator()(const half &lhs, const half &rhs) const {
return __float2half(fmod(__half2float(lhs), __half2float(rhs)));
}
};
REGISTER_BINARY_OP_CUDA_FUNC_FLOAT_TYPE(BinaryOpType::kMod);
REGISTER_BINARY_OP_CUDA_FUNC_INT_TYPE(BinaryOpType::kMod);
template <typename T>
struct BinaryFunc<BinaryOpType::kFloorMod, T, T, T, typename std::is_floating_point<T>::type> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) const {
return lhs - floor(lhs / rhs) * rhs;
}
};
template <>
struct BinaryFunc<BinaryOpType::kFloorMod, half, half, half> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ half operator()(const half &lhs, const half &rhs) const {
float l = __half2float(lhs);
float r = __half2float(rhs);
return __float2half_rn(l - floorf(l / r) * r);
}
};
template <typename T>
struct BinaryFunc<BinaryOpType::kFloorMod, T, T, T, typename std::is_integral<T>::type> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) const {
T res = lhs - floor(static_cast<float>(lhs) / static_cast<float>(rhs)) * rhs;
return res;
}
};
REGISTER_BINARY_OP_CUDA_FUNC_FLOAT_TYPE(BinaryOpType::kFloorMod);
REGISTER_BINARY_OP_CUDA_FUNC_INT_TYPE(BinaryOpType::kFloorMod);
template <typename T>
struct BinaryFunc<BinaryOpType::kTruncateMod, T, T, T> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __forceinline__ T operator()(const T &lhs, const T &rhs) const {
T res = static_cast<T>(lhs - static_cast<int>(lhs / rhs) * rhs);
return res;
}
};
template <>
struct BinaryFunc<BinaryOpType::kTruncateMod, half, half, half> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __forceinline__ half operator()(const half &lhs, const half &rhs) const {
float l = __half2float(lhs);
float r = __half2float(rhs);
float res = l - static_cast<int>(l / r) * r;
return __float2half_rn(res);
}
};
REGISTER_BINARY_OP_CUDA_FUNC_FLOAT_TYPE(BinaryOpType::kTruncateMod);
REGISTER_BINARY_OP_CUDA_FUNC_INT_TYPE(BinaryOpType::kTruncateMod);

View File

@ -0,0 +1,51 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BINARY_COMMON_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BINARY_COMMON_CUH_
#include <limits.h>
#include <cmath>
#include <type_traits>
#include "include/cuda_fp16.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_types.cuh"
template <typename T>
__device__ __host__ __forceinline__ T Eps();
template <typename T>
__device__ __host__ __forceinline__ T Eps() {
return 0;
}
template <>
__device__ __host__ __forceinline__ float Eps() {
return 2e-7;
}
template <>
__device__ __host__ __forceinline__ double Eps() {
return 2e-15;
}
template <>
__device__ __host__ __forceinline__ half Eps() {
return 6.105e-5;
}
template <enum BinaryOpType op, typename In0_t, typename In1_t, typename Out_t, typename Enabled = std::true_type>
struct BinaryFunc {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __forceinline__ Out_t operator()(In0_t val0, In1_t val1) const { return Out_t(0.0); }
};
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BINARY_COMMON_CUH_

View File

@ -0,0 +1,128 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_ops_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_common.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_pub_impl.cuh"
#define REGISTER_BINARY_OP_CUDA_FUNC_COMPARE_TYPE(op) \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, double, double, bool>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, double *input0, double *input1, bool *output, size_t device_id, \
cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, float, float, bool>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, float *input0, float *input1, bool *output, size_t device_id, \
cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, half, half, bool>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, half *input0, half *input1, bool *output, size_t device_id, \
cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, bool, bool, bool>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, bool *input0, bool *input1, bool *output, size_t device_id, \
cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, int8_t, int8_t, bool>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, int8_t *input0, int8_t *input1, bool *output, size_t device_id, \
cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, uint8_t, uint8_t, bool>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, uint8_t *input0, uint8_t *input1, bool *output, size_t device_id, \
cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, int16_t, int16_t, bool>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, int16_t *input0, int16_t *input1, bool *output, size_t device_id, \
cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, uint16_t, uint16_t, bool>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, uint16_t *input0, uint16_t *input1, bool *output, size_t device_id, \
cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, int32_t, int32_t, bool>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, int32_t *input0, int32_t *input1, bool *output, size_t device_id, \
cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, uint32_t, uint32_t, bool>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, uint32_t *input0, uint32_t *input1, bool *output, size_t device_id, \
cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, int64_t, int64_t, bool>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, int64_t *input0, int64_t *input1, bool *output, size_t device_id, \
cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, uint64_t, uint64_t, bool>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, uint64_t *input0, uint64_t *input1, bool *output, size_t device_id, \
cudaStream_t cuda_stream)
template <typename T>
struct BinaryFunc<BinaryOpType::kGreater, T, T, bool> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ bool operator()(const T &lhs, const T &rhs) const { return lhs > rhs; }
};
REGISTER_BINARY_OP_CUDA_FUNC_COMPARE_TYPE(BinaryOpType::kGreater);
template <typename T>
struct BinaryFunc<BinaryOpType::kLess, T, T, bool> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ bool operator()(const T &lhs, const T &rhs) const { return lhs < rhs; }
};
REGISTER_BINARY_OP_CUDA_FUNC_COMPARE_TYPE(BinaryOpType::kLess);
template <typename T>
struct BinaryFunc<BinaryOpType::kEqual, T, T, bool> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __forceinline__ bool operator()(const T &lhs, const T &rhs) const { return lhs == rhs; }
};
REGISTER_BINARY_OP_CUDA_FUNC_COMPARE_TYPE(BinaryOpType::kEqual);
template <typename T>
struct BinaryFunc<BinaryOpType::kGreaterEqual, T, T, bool> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __forceinline__ bool operator()(const T &lhs, const T &rhs) const { return lhs >= rhs; }
};
REGISTER_BINARY_OP_CUDA_FUNC_COMPARE_TYPE(BinaryOpType::kGreaterEqual);
template <typename T>
struct BinaryFunc<BinaryOpType::kLessEqual, T, T, bool> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __forceinline__ bool operator()(const T &lhs, const T &rhs) const { return lhs <= rhs; }
};
REGISTER_BINARY_OP_CUDA_FUNC_COMPARE_TYPE(BinaryOpType::kLessEqual);
template <typename T>
struct BinaryFunc<BinaryOpType::kNotEqual, T, T, bool> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __forceinline__ bool operator()(const T &lhs, const T &rhs) const { return lhs != rhs; }
};
REGISTER_BINARY_OP_CUDA_FUNC_COMPARE_TYPE(BinaryOpType::kNotEqual);
template <typename T>
struct BinaryFunc<BinaryOpType::kMaximum, T, T, T> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) const { return lhs > rhs ? lhs : rhs; }
};
REGISTER_BINARY_OP_CUDA_FUNC_INT_TYPE(BinaryOpType::kMaximum);
REGISTER_BINARY_OP_CUDA_FUNC_FLOAT_TYPE(BinaryOpType::kMaximum);
REGISTER_BINARY_OP_CUDA_FUNC_BOOL_TYPE(BinaryOpType::kMaximum);
template <typename T>
struct BinaryFunc<BinaryOpType::kMinimum, T, T, T> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) const { return lhs < rhs ? lhs : rhs; }
};
REGISTER_BINARY_OP_CUDA_FUNC_INT_TYPE(BinaryOpType::kMinimum);
REGISTER_BINARY_OP_CUDA_FUNC_FLOAT_TYPE(BinaryOpType::kMinimum);
REGISTER_BINARY_OP_CUDA_FUNC_BOOL_TYPE(BinaryOpType::kMinimum);

View File

@ -0,0 +1,159 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_ops_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_common.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_pub_impl.cuh"
template <typename In0_t, typename In1_t, typename Out_t>
struct BinaryFunc<BinaryOpType::kDiv, In0_t, In1_t, Out_t> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __forceinline__ Out_t operator()(In0_t val0, In1_t val1) const { return val0 / val1; }
};
REGISTER_BINARY_OP_CUDA_FUNC_COMPLEX_TYPE(BinaryOpType::kDiv);
REGISTER_BINARY_OP_CUDA_FUNC_FLOAT_TYPE(BinaryOpType::kDiv);
REGISTER_BINARY_OP_CUDA_FUNC_INT_TYPE(BinaryOpType::kDiv);
template <typename In0_t, typename In1_t, typename Out_t>
struct BinaryFunc<BinaryOpType::kRealDiv, In0_t, In1_t, Out_t> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ Out_t operator()(const In0_t &lhs, const In1_t &rhs) const { return (lhs / rhs); }
};
REGISTER_BINARY_OP_CUDA_FUNC_COMPLEX_TYPE(BinaryOpType::kRealDiv);
REGISTER_BINARY_OP_CUDA_FUNC_FLOAT_TYPE(BinaryOpType::kRealDiv);
REGISTER_BINARY_OP_CUDA_FUNC_INT_TYPE(BinaryOpType::kRealDiv);
// XDivy check if lhs is less than epsilon, XDivy support half, float, double
template <typename T>
struct BinaryFunc<BinaryOpType::kXdivy, T, T, T, typename std::is_floating_point<T>::type> {
__device__ __host__ __forceinline__ BinaryFunc() {}
// default T is float
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) const {
return lhs < Eps<T>() && lhs > -Eps<T>() ? 0.0 : (lhs / rhs);
}
};
template <>
struct BinaryFunc<BinaryOpType::kXdivy, half, half, half> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ half operator()(const half &lhs, const half &rhs) const {
if (__half2float(lhs) < (0.00007) && __half2float(lhs) > -0.00007) {
return static_cast<half>(0.0);
}
return __float2half_rn(__half2float(lhs) / __half2float(rhs));
}
};
template <typename In0_t, typename In1_t, typename Out_t>
struct BinaryFunc<BinaryOpType::kXdivy, In0_t, In1_t, Complex<Out_t>> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ Complex<Out_t> operator()(const In0_t &lhs, const In1_t &rhs) const {
Complex<Out_t> res(0.0, 0.0);
Complex<Out_t> complex_lhs(lhs);
Complex<Out_t> complex_rhs(rhs);
if ((complex_lhs.real() >= Eps<float>() && complex_lhs.real() <= -Eps<float>()) ||
(complex_lhs.imag() >= Eps<float>() && complex_lhs.imag() <= -Eps<float>())) {
res = complex_lhs / complex_rhs;
}
return res;
}
};
REGISTER_BINARY_OP_CUDA_FUNC_COMPLEX_TYPE(BinaryOpType::kXdivy);
REGISTER_BINARY_OP_CUDA_FUNC_FLOAT_TYPE(BinaryOpType::kXdivy);
// DivNoNan check if rhs is less than epsilon
template <typename T>
struct BinaryFunc<BinaryOpType::kDivNoNan, T, T, T, typename std::is_floating_point<T>::type> {
__device__ __host__ __forceinline__ BinaryFunc() {}
// default T is float
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) const {
return rhs < Eps<T>() && rhs > -Eps<T>() ? 0.0 : (lhs / rhs);
}
};
template <typename T>
struct BinaryFunc<BinaryOpType::kDivNoNan, T, T, T, typename std::is_integral<T>::type> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ int operator()(const T &lhs, const T &rhs) const {
return rhs == 0 ? 0 : (lhs / rhs);
}
};
template <>
struct BinaryFunc<BinaryOpType::kDivNoNan, half, half, half> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ half operator()(const half &lhs, const half &rhs) const {
if (__half2float(rhs) < (0.00001) && __half2float(rhs) > -0.00001) {
return static_cast<half>(0.0);
}
return __float2half_rn(__half2float(lhs) / __half2float(rhs));
}
};
template <typename In0_t, typename In1_t, typename Out_t>
struct BinaryFunc<BinaryOpType::kDivNoNan, In0_t, In1_t, Complex<Out_t>> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ Complex<Out_t> operator()(const In0_t &lhs, const In1_t &rhs) const {
Complex<Out_t> complex_rhs(rhs);
if ((complex_rhs.real() < Eps<float>() && complex_rhs.real() > -Eps<float>()) ||
(complex_rhs.imag() < Eps<float>() && complex_rhs.imag() > -Eps<float>())) {
Complex<Out_t> res(0.0, 0.0);
return res;
}
return lhs / rhs;
}
};
REGISTER_BINARY_OP_CUDA_FUNC_COMPLEX_TYPE(BinaryOpType::kDivNoNan);
REGISTER_BINARY_OP_CUDA_FUNC_FLOAT_TYPE(BinaryOpType::kDivNoNan);
REGISTER_BINARY_OP_CUDA_FUNC_INT_TYPE(BinaryOpType::kDivNoNan);
template <typename T>
struct BinaryFunc<BinaryOpType::kFloorDiv, T, T, T, typename std::is_floating_point<T>::type> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) const { return floor(lhs / rhs); }
};
template <>
struct BinaryFunc<BinaryOpType::kFloorDiv, half, half, half> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ half operator()(const half &lhs, const half &rhs) const {
return __float2half_rn(floorf(__half2float(lhs) / __half2float(rhs)));
}
};
template <typename T>
struct BinaryFunc<BinaryOpType::kFloorDiv, T, T, T, typename std::is_integral<T>::type> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __forceinline__ T operator()(const T &lhs, const T &rhs) const {
return static_cast<T>(floor(static_cast<float>(lhs) / static_cast<float>(rhs)));
}
};
REGISTER_BINARY_OP_CUDA_FUNC_FLOAT_TYPE(BinaryOpType::kFloorDiv);
REGISTER_BINARY_OP_CUDA_FUNC_INT_TYPE(BinaryOpType::kFloorDiv);
template <typename T>
struct BinaryFunc<BinaryOpType::kTruncateDiv, T, T, T, typename std::is_floating_point<T>::type> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __forceinline__ T operator()(const T &lhs, const T &rhs) const { return static_cast<T>(trunc(lhs / rhs)); }
};
template <>
struct BinaryFunc<BinaryOpType::kTruncateDiv, half, half, half> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __forceinline__ half operator()(const half &lhs, const half &rhs) const {
return __float2half_rn(trunc(__half2float(lhs) / __half2float(rhs)));
}
};
template <typename T>
struct BinaryFunc<BinaryOpType::kTruncateDiv, T, T, T, typename std::is_integral<T>::type> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __forceinline__ T operator()(const T &lhs, const T &rhs) const {
return static_cast<T>(trunc(static_cast<float>(lhs) / static_cast<float>(rhs)));
}
};
REGISTER_BINARY_OP_CUDA_FUNC_FLOAT_TYPE(BinaryOpType::kTruncateDiv);
REGISTER_BINARY_OP_CUDA_FUNC_INT_TYPE(BinaryOpType::kTruncateDiv);

View File

@ -0,0 +1,87 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_ops_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_common.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_pub_impl.cuh"
template <typename T>
struct BinaryFunc<BinaryOpType::kMod, T, T, T, typename std::is_integral<T>::type> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) const { return lhs % rhs; }
};
template <typename T>
struct BinaryFunc<BinaryOpType::kMod, T, T, T, typename std::is_floating_point<T>::type> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) const { return fmod(lhs, rhs); }
};
template <>
struct BinaryFunc<BinaryOpType::kMod, half, half, half> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ half operator()(const half &lhs, const half &rhs) const {
return __float2half(fmod(__half2float(lhs), __half2float(rhs)));
}
};
REGISTER_BINARY_OP_CUDA_FUNC_FLOAT_TYPE(BinaryOpType::kMod);
REGISTER_BINARY_OP_CUDA_FUNC_INT_TYPE(BinaryOpType::kMod);
template <typename T>
struct BinaryFunc<BinaryOpType::kFloorMod, T, T, T, typename std::is_floating_point<T>::type> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) const {
return lhs - floor(lhs / rhs) * rhs;
}
};
template <>
struct BinaryFunc<BinaryOpType::kFloorMod, half, half, half> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ half operator()(const half &lhs, const half &rhs) const {
float l = __half2float(lhs);
float r = __half2float(rhs);
return __float2half_rn(l - floorf(l / r) * r);
}
};
template <typename T>
struct BinaryFunc<BinaryOpType::kFloorMod, T, T, T, typename std::is_integral<T>::type> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ uint32_t operator()(const uint32_t &lhs, const uint32_t &rhs) const {
T res = lhs - floor(static_cast<float>(lhs) / static_cast<float>(rhs)) * rhs;
return res;
}
};
REGISTER_BINARY_OP_CUDA_FUNC_FLOAT_TYPE(BinaryOpType::kFloorMod);
REGISTER_BINARY_OP_CUDA_FUNC_INT_TYPE(BinaryOpType::kFloorMod);
template <typename T>
struct BinaryFunc<BinaryOpType::kTruncateMod, T, T, T> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __forceinline__ T operator()(const T &lhs, const T &rhs) const {
T res = static_cast<T>(lhs - static_cast<int>(lhs / rhs) * rhs);
return res;
}
};
template <>
struct BinaryFunc<BinaryOpType::kTruncateMod, half, half, half> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __forceinline__ half operator()(const half &lhs, const half &rhs) const {
float l = __half2float(lhs);
float r = __half2float(rhs);
float res = l - static_cast<int>(l / r) * r;
return __float2half_rn(res);
}
};
REGISTER_BINARY_OP_CUDA_FUNC_FLOAT_TYPE(BinaryOpType::kTruncateMod);
REGISTER_BINARY_OP_CUDA_FUNC_INT_TYPE(BinaryOpType::kTruncateMod);

View File

@ -0,0 +1,78 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_ops_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_common.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_pub_impl.cuh"
template <typename In0_t, typename In1_t, typename Out_t>
struct BinaryFunc<BinaryOpType::kMul, In0_t, In1_t, Out_t> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __forceinline__ Out_t operator()(In0_t val0, In1_t val1) const { return val0 * val1; }
};
template <>
struct BinaryFunc<BinaryOpType::kMul, bool, bool, bool> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __forceinline__ bool operator()(bool val0, bool val1) const { return val0 && val1; }
};
REGISTER_BINARY_OP_CUDA_FUNC_INT_TYPE(BinaryOpType::kMul);
REGISTER_BINARY_OP_CUDA_FUNC_FLOAT_TYPE(BinaryOpType::kMul);
REGISTER_BINARY_OP_CUDA_FUNC_COMPLEX_TYPE(BinaryOpType::kMul);
REGISTER_BINARY_OP_CUDA_FUNC_BOOL_TYPE(BinaryOpType::kMul);
// MulNoNan
template <typename T>
struct BinaryFunc<BinaryOpType::kMulNoNan, T, T, T, typename std::is_floating_point<T>::type> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) const {
return rhs < Eps<T>() && rhs > -Eps<T>() ? 0.0 : (lhs * rhs);
}
};
template <typename T>
struct BinaryFunc<BinaryOpType::kMulNoNan, T, T, T, typename std::is_integral<T>::type> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) const {
return rhs == 0 ? 0 : (lhs * rhs);
}
};
template <>
struct BinaryFunc<BinaryOpType::kMulNoNan, half, half, half> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ half operator()(const half &lhs, const half &rhs) const {
bool bool1 = __half2float(rhs) < (0.00001) && __half2float(rhs) > -0.00001;
if (bool1) {
return static_cast<half>(0.0);
}
return __float2half_rn(__half2float(lhs) * __half2float(rhs));
}
};
template <typename In0_t, typename In1_t, typename Out_t>
struct BinaryFunc<BinaryOpType::kMulNoNan, In0_t, In1_t, Complex<Out_t>> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ Complex<Out_t> operator()(const In0_t &lhs, const In1_t &rhs) const {
Complex<Out_t> complex_rhs(rhs);
if ((complex_rhs.real() < Eps<float>() && complex_rhs.real() > -Eps<float>()) ||
(complex_rhs.imag() < Eps<float>() && complex_rhs.imag() > -Eps<float>())) {
Complex<Out_t> res(0.0, 0.0);
return res;
}
return lhs * rhs;
}
};
REGISTER_BINARY_OP_CUDA_FUNC_INT_TYPE(BinaryOpType::kMulNoNan);
REGISTER_BINARY_OP_CUDA_FUNC_COMPLEX_TYPE(BinaryOpType::kMulNoNan);
REGISTER_BINARY_OP_CUDA_FUNC_FLOAT_TYPE(BinaryOpType::kMulNoNan);

View File

@ -0,0 +1,32 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BINARY_OPS_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BINARY_OPS_IMPL_CUH_
#include <vector>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_types.cuh"
template <enum BinaryOpType OP, typename In0_t, typename In1_t, typename Out_t>
CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc(const bool is_broadcast,
const std::vector<int64_t> &in0_shape,
const std::vector<int64_t> &in1_shape,
const std::vector<int64_t> &out_shape, In0_t *input0,
In1_t *input1, Out_t *output, size_t device_id,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BINARY_OPS_IMPL_CUH_

View File

@ -0,0 +1,167 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_ops_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_common.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_pub_impl.cuh"
// Xlogy check if lhs is less than epsilon, Xlogy support half, float, double
template <typename T>
struct BinaryFunc<BinaryOpType::kXlogy, T, T, T, typename std::is_floating_point<T>::type> {
__device__ __host__ __forceinline__ BinaryFunc() {}
// default T is float
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) const {
return lhs < Eps<T>() && lhs > -Eps<T>() ? 0.0 : (lhs * log(rhs));
}
};
template <>
struct BinaryFunc<BinaryOpType::kXlogy, half, half, half> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __forceinline__ half operator()(const half &lhs, const half &rhs) const {
half zero = 0.0;
half eps = 6.105e-5;
return (lhs < eps && lhs > -eps) ? zero : __float2half_rn(__half2float(lhs) * log(__half2float(rhs)));
}
};
template <typename IN, typename Out_t>
__device__ __host__ __forceinline__ Out_t CalMid(const IN &inp_val) {
Out_t res(0.5 * log(inp_val * inp_val * 2), 1.0);
return res;
}
template <>
__device__ __host__ __forceinline__ Complex<float> CalMid(const Complex<float> &inp_val) {
Complex<float> res(0.5 * log(inp_val.real() * inp_val.real() + inp_val.imag() * inp_val.imag()),
atan2(inp_val.imag(), inp_val.real()));
return res;
}
template <>
__device__ __host__ __forceinline__ Complex<double> CalMid(const Complex<double> &inp_val) {
Complex<double> res(0.5 * log(inp_val.real() * inp_val.real() + inp_val.imag() * inp_val.imag()),
atan2(inp_val.imag(), inp_val.real()));
return res;
}
template <typename IN>
__device__ __host__ __forceinline__ bool IsZero(const IN &inp_val) {
return inp_val < Eps<IN>() && inp_val > -Eps<IN>();
}
template <>
__device__ __host__ __forceinline__ bool IsZero(const Complex<float> &inp_val) {
return inp_val.real() < Eps<float>() && inp_val.real() > -Eps<float>() && inp_val.imag() < Eps<float>() &&
inp_val.imag() > -Eps<float>();
}
template <>
__device__ __host__ __forceinline__ bool IsZero(const Complex<double> &inp_val) {
return inp_val.real() < Eps<double>() && inp_val.real() > -Eps<double>() && inp_val.imag() < Eps<double>() &&
inp_val.imag() > -Eps<double>();
}
template <typename In0_t, typename In1_t, typename Out_t>
struct BinaryFunc<BinaryOpType::kXlogy, In0_t, In1_t, Complex<Out_t>> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ Complex<Out_t> operator()(const In0_t &lhs, const In1_t &rhs) const {
if (IsZero<In0_t>(lhs)) {
Complex<Out_t> res(0.0, 0.0);
return res;
}
Complex<Out_t> mid = CalMid<In1_t, Complex<Out_t>>(rhs);
return lhs * mid;
}
};
REGISTER_BINARY_OP_CUDA_FUNC_COMPLEX_TYPE(BinaryOpType::kXlogy);
REGISTER_BINARY_OP_CUDA_FUNC_FLOAT_TYPE(BinaryOpType::kXlogy);
template <typename T>
struct BinaryFunc<BinaryOpType::kSquaredDifference, T, T, T, typename std::is_arithmetic<T>::type> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) const {
T diff = lhs - rhs;
return diff * diff;
}
};
template <>
struct BinaryFunc<BinaryOpType::kSquaredDifference, half, half, half> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __forceinline__ half operator()(const half &lhs, const half &rhs) const {
half diff = lhs - rhs;
return diff * diff;
}
};
template <typename In0_t, typename In1_t, typename Out_t>
struct BinaryFunc<BinaryOpType::kSquaredDifference, In0_t, In1_t, Complex<Out_t>> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ Complex<Out_t> operator()(const In0_t &lhs, const In1_t &rhs) const {
Complex<Out_t> diff = lhs - rhs;
Complex<Out_t> conj_diff(diff.real(), -diff.imag());
return conj_diff * diff;
}
};
REGISTER_BINARY_OP_CUDA_FUNC_COMPLEX_TYPE(BinaryOpType::kSquaredDifference);
REGISTER_BINARY_OP_CUDA_FUNC_FLOAT_TYPE(BinaryOpType::kSquaredDifference);
template <typename T>
struct BinaryFunc<BinaryOpType::kAtan2, T, T, T, typename std::is_floating_point<T>::type> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) const { return atan2(lhs, rhs); }
};
template <typename T>
struct BinaryFunc<BinaryOpType::kAtan2, T, T, T, typename std::is_integral<T>::type> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) const {
return static_cast<T>(atan2(static_cast<float>(lhs), static_cast<float>(rhs)));
}
};
template <>
struct BinaryFunc<BinaryOpType::kAtan2, half, half, half> {
__device__ __host__ __forceinline__ half operator()(const half &lhs, const half &rhs) const {
float l = __half2float(lhs);
float r = __half2float(rhs);
float res = atan2f(l, r);
return __float2half_rn(res);
}
};
REGISTER_BINARY_OP_CUDA_FUNC_FLOAT_TYPE(BinaryOpType::kAtan2);
REGISTER_BINARY_OP_CUDA_FUNC_INT_TYPE(BinaryOpType::kAtan2);
template <typename T>
struct BinaryFunc<BinaryOpType::kAbsGrad, T, T, T> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __forceinline__ T operator()(const T &lhs, const T &rhs) const {
T zero = 0;
return lhs < -Eps<T>() ? -rhs : lhs > Eps<T>() ? rhs : zero;
}
};
REGISTER_BINARY_OP_CUDA_FUNC_FLOAT_TYPE(BinaryOpType::kAbsGrad);
REGISTER_BINARY_OP_CUDA_FUNC_INT_TYPE(BinaryOpType::kAbsGrad);
REGISTER_BINARY_OP_CUDA_FUNC_BOOL_TYPE(BinaryOpType::kAbsGrad);
// now only for complex op
#define REGISTER_BINARY_OP_CUDA_FUNC_COMPLEX(op) \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, float, float, Complex<float>>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, float *input0, float *input1, Complex<float> *output, size_t device_id, \
cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, double, double, Complex<double>>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, double *input0, double *input1, Complex<double> *output, size_t device_id, \
cudaStream_t cuda_stream)
template <typename T>
struct BinaryFunc<BinaryOpType::kComplex, T, T, Complex<T>> {
__device__ __host__ __forceinline__ BinaryFunc() {}
__device__ __host__ __forceinline__ Complex<T> operator()(const T &lhs, const T &rhs) const {
return Complex<T>(lhs, rhs);
}
};
REGISTER_BINARY_OP_CUDA_FUNC_COMPLEX(BinaryOpType::kComplex);

View File

@ -0,0 +1,349 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITH WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BINARY_PUB_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BINARY_PUB_CUH_
#include <math.h>
#include <vector>
#include <iostream>
#include "include/cuda_fp16.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_types.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_common.cuh"
struct BinaryBroadcastStrideInfo {
size_t in0_stride[8];
size_t in1_stride[8];
size_t out_stride[8];
};
template <typename T, size_t VecSize>
struct Vec {
T data[VecSize];
};
constexpr size_t kMaxVecBytes = 128 / 8;
constexpr size_t kMaxVecSize = 4;
constexpr size_t MsMin(size_t a, size_t b) { return a < b ? a : b; }
template <typename T>
constexpr size_t VecSize() {
return MsMin(kMaxVecBytes / sizeof(T), kMaxVecSize);
}
template <typename T, typename U, typename... Args>
constexpr size_t VecSize() {
return MsMin(VecSize<T>(), VecSize<U, Args...>());
}
enum class ScalarOption {
NoScalar = 0,
In0Scalar = 1,
In1Scalar = 2,
};
template <BinaryOpType OP, typename In0_t, typename In1_t, typename Out_t, size_t vec_num>
__device__ void ApplyVec(BinaryFunc<OP, In0_t, In1_t, Out_t> func, ScalarOption scalar_option, In0_t *in0_addr,
In1_t *in1_addr, Out_t *out_addr) {
Vec<Out_t, vec_num> out_vec;
if (scalar_option == ScalarOption::NoScalar) {
Vec<In0_t, vec_num> in0_vec = reinterpret_cast<Vec<In0_t, vec_num> *>(in0_addr)[0];
Vec<In1_t, vec_num> in1_vec = reinterpret_cast<Vec<In1_t, vec_num> *>(in1_addr)[0];
#pragma unroll
for (size_t idx = 0; idx < vec_num; ++idx) {
out_vec.data[idx] = func(in0_vec.data[idx], in1_vec.data[idx]);
}
} else if (scalar_option == ScalarOption::In0Scalar) {
In0_t in0_data = in0_addr[0];
Vec<In1_t, vec_num> in1_vec = reinterpret_cast<Vec<In1_t, vec_num> *>(in1_addr)[0];
#pragma unroll
for (size_t idx = 0; idx < vec_num; ++idx) {
out_vec.data[idx] = func(in0_data, in1_vec.data[idx]);
}
} else {
Vec<In0_t, vec_num> in0_vec = reinterpret_cast<Vec<In0_t, vec_num> *>(in0_addr)[0];
In1_t in1_data = in1_addr[0];
#pragma unroll
for (size_t idx = 0; idx < vec_num; ++idx) {
out_vec.data[idx] = func(in0_vec.data[idx], in1_data);
}
}
Vec<Out_t, vec_num> *out_data = reinterpret_cast<Vec<Out_t, vec_num> *>(out_addr);
out_data[0] = out_vec;
}
static __device__ Vec<size_t, 2> CalInposByOutPos(size_t out_pos, size_t dim_size,
const BinaryBroadcastStrideInfo &strides) {
Vec<size_t, 2> in_pos = {0, 0};
size_t tmp_idx = 0;
for (int idx = 0; idx < dim_size; ++idx) {
tmp_idx = out_pos / strides.out_stride[idx];
in_pos.data[0] += tmp_idx * strides.in0_stride[idx];
in_pos.data[1] += tmp_idx * strides.in1_stride[idx];
out_pos -= tmp_idx * strides.out_stride[idx];
}
return in_pos;
}
template <BinaryOpType OP, typename In0_t, typename In1_t, typename Out_t>
__global__ void BinaryWithBroadcastNoVecCuda(BinaryFunc<OP, In0_t, In1_t, Out_t> func, size_t dim_size,
size_t total_threads, BinaryBroadcastStrideInfo strides, In0_t *in0,
In1_t *in1, Out_t *out) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < total_threads; pos += blockDim.x * gridDim.x) {
Vec<size_t, 2> in_pos = CalInposByOutPos(pos, dim_size, strides);
out[pos] = func(in0[in_pos.data[0]], in1[in_pos.data[1]]);
}
}
template <BinaryOpType OP, typename In0_t, typename In1_t, typename Out_t>
__global__ void BinaryWithoutBroadcastNoVecCuda(BinaryFunc<OP, In0_t, In1_t, Out_t> func, ScalarOption scalar_option,
size_t total_threads, In0_t *in0, In1_t *in1, Out_t *out) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < total_threads; pos += blockDim.x * gridDim.x) {
In0_t in0_data = (scalar_option == ScalarOption::In0Scalar) ? in0[0] : in0[pos];
In1_t in1_data = (scalar_option == ScalarOption::In1Scalar) ? in1[0] : in1[pos];
out[pos] = func(in0_data, in1_data);
}
}
template <BinaryOpType OP, typename In0_t, typename In1_t, typename Out_t, size_t vec_num>
__global__ void BinaryBroadcastVecWithoutTailCuda(BinaryFunc<OP, In0_t, In1_t, Out_t> func, ScalarOption scalar_option,
size_t dim_size, size_t total_threads,
BinaryBroadcastStrideInfo strides, In0_t *in0, In1_t *in1,
Out_t *out) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < total_threads; pos += blockDim.x * gridDim.x) {
size_t out_pos = pos * vec_num;
Vec<size_t, 2> in_pos = CalInposByOutPos(out_pos, dim_size, strides);
ApplyVec<OP, In0_t, In1_t, Out_t, vec_num>(func, scalar_option, in0 + in_pos.data[0], in1 + in_pos.data[1],
out + out_pos);
}
}
template <BinaryOpType OP, typename In0_t, typename In1_t, typename Out_t, size_t vec_num>
__global__ void BinaryBroadcastVecWithTailCuda(BinaryFunc<OP, In0_t, In1_t, Out_t> func, ScalarOption scalar_option,
size_t dim_size, size_t total_threads, size_t step, size_t tail_num,
BinaryBroadcastStrideInfo strides, In0_t *in0, In1_t *in1, Out_t *out) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < total_threads; pos += blockDim.x * gridDim.x) {
size_t out_pos = pos * vec_num + pos / step * tail_num;
Vec<size_t, 2> in_pos = CalInposByOutPos(out_pos, dim_size, strides);
if ((pos + 1) % step != 0) {
ApplyVec<OP, In0_t, In1_t, Out_t, vec_num>(func, scalar_option, in0 + in_pos.data[0], in1 + in_pos.data[1],
out + out_pos);
} else {
switch (tail_num) {
case 1:
ApplyVec<OP, In0_t, In1_t, Out_t, vec_num + 1>(func, scalar_option, in0 + in_pos.data[0],
in1 + in_pos.data[1], out + out_pos);
break;
case 2:
ApplyVec<OP, In0_t, In1_t, Out_t, vec_num + 2>(func, scalar_option, in0 + in_pos.data[0],
in1 + in_pos.data[1], out + out_pos);
break;
case 3:
ApplyVec<OP, In0_t, In1_t, Out_t, vec_num + 3>(func, scalar_option, in0 + in_pos.data[0],
in1 + in_pos.data[1], out + out_pos);
break;
}
}
}
}
static BinaryBroadcastStrideInfo BinaryBroadcastCalStride(const size_t dim_size, const std::vector<int64_t> &in0_shape,
const std::vector<int64_t> &in1_shape,
const std::vector<int64_t> &out_shape, const size_t vec_num) {
BinaryBroadcastStrideInfo strides;
strides.in0_stride[dim_size - 1] = 1;
strides.in1_stride[dim_size - 1] = 1;
strides.out_stride[dim_size - 1] = 1;
for (int64_t idx = dim_size - 2; idx >= 0; --idx) {
strides.out_stride[idx] = out_shape[idx + 1] * strides.out_stride[idx + 1];
strides.in0_stride[idx] = in0_shape[idx + 1] * strides.in0_stride[idx + 1];
strides.in1_stride[idx] = in1_shape[idx + 1] * strides.in1_stride[idx + 1];
}
for (size_t idx = 0; idx < dim_size; ++idx) {
strides.in0_stride[idx] = (in0_shape[idx] == 1) ? 0 : strides.in0_stride[idx];
strides.in1_stride[idx] = (in1_shape[idx] == 1) ? 0 : strides.in1_stride[idx];
}
return strides;
}
template <BinaryOpType OP, typename In0_t, typename In1_t, typename Out_t>
cudaError_t BinaryWithBroadcast(BinaryFunc<OP, In0_t, In1_t, Out_t> func, ScalarOption scalar_option,
const size_t out_num, const std::vector<int64_t> &in0_shape,
const std::vector<int64_t> &in1_shape, const std::vector<int64_t> &out_shape,
In0_t *in0, In1_t *in1, Out_t *out, size_t device_id, cudaStream_t cuda_stream) {
cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, device_id);
size_t vec_thread_num = prop.multiProcessorCount * 8 * 32;
const size_t dim_size = out_shape.size();
constexpr size_t vec_num = VecSize<In0_t, In1_t, Out_t>();
size_t total_threads = out_num / out_shape.back();
if (out_num > vec_thread_num && vec_num > 1) {
if (out_shape.back() == 2) {
BinaryBroadcastStrideInfo strides = BinaryBroadcastCalStride(dim_size, in0_shape, in1_shape, out_shape, 2);
size_t thread_num = total_threads > 1024 ? 1024 : total_threads;
BinaryBroadcastVecWithoutTailCuda<OP, In0_t, In1_t, Out_t, 2>
<<<CUDA_BLOCKS_CAL(device_id, total_threads, thread_num), thread_num, 0, cuda_stream>>>(
func, scalar_option, dim_size, total_threads, strides, in0, in1, out);
CHECK_CUDA_LAUNCH_SUCCESS();
} else if (out_shape.back() == 3) {
BinaryBroadcastStrideInfo strides = BinaryBroadcastCalStride(dim_size, in0_shape, in1_shape, out_shape, 3);
size_t total_threads = out_shape[0] * strides.out_stride[0];
size_t thread_num = total_threads > 1024 ? 1024 : total_threads;
BinaryBroadcastVecWithoutTailCuda<OP, In0_t, In1_t, Out_t, 3>
<<<CUDA_BLOCKS_CAL(device_id, total_threads, thread_num), thread_num, 0, cuda_stream>>>(
func, scalar_option, dim_size, total_threads, strides, in0, in1, out);
CHECK_CUDA_LAUNCH_SUCCESS();
} else {
BinaryBroadcastStrideInfo strides = BinaryBroadcastCalStride(dim_size, in0_shape, in1_shape, out_shape, vec_num);
size_t step = out_shape.back() / vec_num;
total_threads *= step;
size_t tail_num = out_shape.back() % vec_num;
size_t thread_num = total_threads > 1024 ? 1024 : total_threads;
if (tail_num == 0) {
BinaryBroadcastVecWithoutTailCuda<OP, In0_t, In1_t, Out_t, vec_num>
<<<CUDA_BLOCKS_CAL(device_id, total_threads, thread_num), thread_num, 0, cuda_stream>>>(
func, scalar_option, dim_size, total_threads, strides, in0, in1, out);
CHECK_CUDA_LAUNCH_SUCCESS();
} else {
BinaryBroadcastVecWithTailCuda<OP, In0_t, In1_t, Out_t, vec_num>
<<<CUDA_BLOCKS_CAL(device_id, total_threads, thread_num), thread_num, 0, cuda_stream>>>(
func, scalar_option, dim_size, total_threads, step, tail_num, strides, in0, in1, out);
CHECK_CUDA_LAUNCH_SUCCESS();
}
}
} else {
BinaryBroadcastStrideInfo strides = BinaryBroadcastCalStride(dim_size, in0_shape, in1_shape, out_shape, 1);
total_threads *= out_shape.back();
size_t thread_num = total_threads > 1024 ? 1024 : total_threads;
BinaryWithBroadcastNoVecCuda<OP, In0_t, In1_t, Out_t>
<<<CUDA_BLOCKS_CAL(device_id, total_threads, thread_num), thread_num, 0, cuda_stream>>>(
func, dim_size, total_threads, strides, in0, in1, out);
CHECK_CUDA_LAUNCH_SUCCESS();
}
}
template <BinaryOpType OP, typename In0_t, typename In1_t, typename Out_t>
cudaError_t BinaryWithoutBroadcast(BinaryFunc<OP, In0_t, In1_t, Out_t> func, ScalarOption scalar_option, size_t nums,
Out_t *out, In0_t *in0, In1_t *in1, size_t device_id, cudaStream_t cuda_stream) {
size_t thread_num = nums > 1024 ? 1024 : nums;
BinaryWithoutBroadcastNoVecCuda<OP, In0_t, In1_t, Out_t>
<<<CUDA_BLOCKS_CAL(device_id, nums, thread_num), thread_num, 0, cuda_stream>>>(func, scalar_option, nums, in0, in1,
out);
CHECK_CUDA_LAUNCH_SUCCESS();
}
template <enum BinaryOpType OP, typename In0_t, typename In1_t, typename Out_t>
cudaError_t BinaryOpWithBroadcastCudaFunc(const bool is_broadcast, const std::vector<int64_t> &in0_shape,
const std::vector<int64_t> &in1_shape, const std::vector<int64_t> &out_shape,
In0_t *in0, In1_t *in1, Out_t *out, size_t device_id,
cudaStream_t cuda_stream) {
BinaryFunc<OP, In0_t, In1_t, Out_t> func;
size_t out_num = 1;
for (auto val : out_shape) {
out_num *= val;
}
ScalarOption scalar_option = ScalarOption::NoScalar;
if (is_broadcast) {
if (in0_shape.back() == 1) {
scalar_option = ScalarOption::In0Scalar;
} else if (in1_shape.back() == 1) {
scalar_option = ScalarOption::In1Scalar;
}
return BinaryWithBroadcast<OP, In0_t, In1_t, Out_t>(func, scalar_option, out_num, in0_shape, in1_shape, out_shape,
in0, in1, out, device_id, cuda_stream);
} else {
if (in0_shape.size() == 1 && in0_shape[0] == 1) {
scalar_option = ScalarOption::In0Scalar;
}
if (in1_shape.size() == 1 && in1_shape[0] == 1) {
scalar_option = ScalarOption::In1Scalar;
}
return BinaryWithoutBroadcast<OP, In0_t, In1_t, Out_t>(func, scalar_option, out_num, out, in0, in1, device_id,
cuda_stream);
}
}
#define REGISTER_BINARY_OP_CUDA_FUNC_BOOL_TYPE(op) \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, bool, bool, bool>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, bool *in0, bool *in1, bool *out, size_t device_id, \
cudaStream_t cuda_stream);
#define REGISTER_BINARY_OP_CUDA_FUNC_INT_TYPE(op) \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, int8_t, int8_t, int8_t>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, int8_t *in0, int8_t *in1, int8_t *out, size_t device_id, \
cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, uint8_t, uint8_t, uint8_t>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, uint8_t *in0, uint8_t *in1, uint8_t *out, size_t device_id, \
cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, int16_t, int16_t, int16_t>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, int16_t *in0, int16_t *in1, int16_t *out, size_t device_id, \
cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, uint16_t, uint16_t, uint16_t>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, uint16_t *in0, uint16_t *in1, uint16_t *out, size_t device_id, \
cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, int32_t, int32_t, int32_t>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, int32_t *in0, int32_t *in1, int32_t *out, size_t device_id, \
cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, uint32_t, uint32_t, uint32_t>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, uint32_t *in0, uint32_t *in1, uint32_t *out, size_t device_id, \
cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, int64_t, int64_t, int64_t>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, int64_t *in0, int64_t *in1, int64_t *out, size_t device_id, \
cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, uint64_t, uint64_t, uint64_t>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, uint64_t *in0, uint64_t *in1, uint64_t *out, size_t device_id, \
cudaStream_t cuda_stream)
#define REGISTER_BINARY_OP_CUDA_FUNC_FLOAT_TYPE(op) \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, double, double, double>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, double *in0, double *in1, double *out, size_t device_id, \
cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, float, float, float>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, float *in0, float *in1, float *out, size_t device_id, \
cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, half, half, half>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, half *in0, half *in1, half *out, size_t device_id, \
cudaStream_t cuda_stream);
#define REGISTER_BINARY_OP_CUDA_FUNC_COMPLEX_TYPE(op) \
template CUDA_LIB_EXPORT cudaError_t \
BinaryOpWithBroadcastCudaFunc<op, Complex<float>, Complex<float>, Complex<float>>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, Complex<float> *in0, Complex<float> *in1, Complex<float> *out, \
size_t device_id, cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT cudaError_t \
BinaryOpWithBroadcastCudaFunc<op, Complex<double>, Complex<double>, Complex<double>>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, Complex<double> *in0, Complex<double> *in1, Complex<double> *out, \
size_t device_id, cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, Complex<float>, float, Complex<float>>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, Complex<float> *in0, float *in1, Complex<float> *out, size_t device_id, \
cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, float, Complex<float>, Complex<float>>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, float *in0, Complex<float> *in1, Complex<float> *out, size_t device_id, \
cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, double, Complex<double>, Complex<double>>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, double *in0, Complex<double> *in1, Complex<double> *out, size_t device_id, \
cudaStream_t cuda_stream); \
template CUDA_LIB_EXPORT cudaError_t BinaryOpWithBroadcastCudaFunc<op, Complex<double>, double, Complex<double>>( \
const bool is_broadcast, const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape, \
const std::vector<int64_t> &out_shape, Complex<double> *in0, double *in1, Complex<double> *out, size_t device_id, \
cudaStream_t cuda_stream)
#endif

View File

@ -14,44 +14,46 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BINARY_FUNC_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BINARY_FUNC_CUH_
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BINARY_TYPES_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BINARY_TYPES_CUH_
#include <limits.h>
enum class BinaryOpType {
// compare
kGreater = 0,
kLess = 1,
kMaximum = 2,
kMinimum = 3,
kPower = 4,
kRealDiv = 5,
kMul = 6,
kSub = 7,
kAdd = 8,
kFloorDiv = 9,
kAbsGrad = 10,
kDiv = 11,
kDivNoNan = 12,
kEqual = 13,
kSquaredDifference = 14,
kMod = 15,
kFloorMod = 16,
kAtan2 = 17,
kGreaterEqual = 18,
kLessEqual = 19,
kNotEqual = 20,
kLogicalAnd = 21,
kLogicalOr = 22,
kEqual = 2,
kGreaterEqual = 3,
kLessEqual = 4,
kNotEqual = 5,
kLogicalAnd = 6,
kLogicalOr = 7,
// math
kMaximum = 8,
kMinimum = 9,
kAdd = 10,
kSub = 11,
kMul = 12,
kDiv = 13,
kPow = 14,
kRealDiv = 15,
kBitwiseAnd = 16,
kBitwiseOr = 17,
kBitwiseXor = 18,
kMod = 19,
kFloorMod = 20,
kSquaredDifference = 21,
kAtan2 = 22,
kTruncateDiv = 23,
kTruncateMod = 24,
kComplex = 25,
kXdivy = 26,
kBitwiseAnd = 27,
kBitwiseOr = 28,
kBitwiseXor = 29,
kMulNoNan = 30,
kXlogy = 31,
kAbsGrad = 25,
kFloorDiv = 26,
kDivNoNan = 27,
kMulNoNan = 28,
kXlogy = 29,
kXdivy = 30,
// complex
kComplex = 31,
kInvalid = INT_MAX,
};
#endif
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BINARY_TYPES_CUH_

View File

@ -1,57 +0,0 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BROADCAST_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BROADCAST_IMPL_CUH_
#include <vector>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_func.cuh"
template <typename T>
CUDA_LIB_EXPORT void ElewiseCmp(const int &nums, enum BinaryOpType op, const T *x0, const T *x1, bool *y,
cudaStream_t stream);
template <typename T>
CUDA_LIB_EXPORT void ElewiseArith(const int &nums, enum BinaryOpType op, const T *x0, const T *x1, T *y,
cudaStream_t stream);
template <typename T1, typename T2, typename T3>
CUDA_LIB_EXPORT void ElewiseComplexArith(const int &nums, enum BinaryOpType op, const T1 *x0, const T2 *x1,
Complex<T3> *y, cudaStream_t stream);
template <typename T>
CUDA_LIB_EXPORT void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BinaryOpType op, const T *x0, const T *x1,
bool *y, cudaStream_t stream);
template <typename T>
CUDA_LIB_EXPORT void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BinaryOpType op, const T *x0, const T *x1,
T *y, cudaStream_t stream);
template <typename T1, typename T2, typename T3>
CUDA_LIB_EXPORT void BroadcastComplexArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BinaryOpType op, const T1 *x0,
const T2 *x1, Complex<T3> *y, cudaStream_t stream);
template <typename T>
CUDA_LIB_EXPORT void BroadcastTo(const size_t &i0, const size_t &i1, const size_t &i2, const size_t &i3,
const size_t &i4, const size_t &i5, const size_t &i6, const size_t &i7,
const size_t &o0, const size_t &o1, const size_t &o2, const size_t &o3,
const size_t &o4, const size_t &o5, const size_t &o6, const size_t &o7,
const T *input_addr, T *output_addr, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BROADCAST_IMPL_CUH_

View File

@ -0,0 +1,111 @@
/**
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITH WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/broadcast_to_impl.cuh"
#include <math.h>
#include <vector>
#include "include/cuda_fp16.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
struct UnaryBroadcastStrideInfo {
size_t input_stride[8];
size_t output_stride[8];
};
// copy
template <typename T>
__global__ void BroadcastToCpyCuda(size_t dim_size, size_t output_num, UnaryBroadcastStrideInfo strides, T *input,
T *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_num; pos += blockDim.x * gridDim.x) {
int64_t cur_out_idx = 0;
size_t cur_pos = pos;
size_t inp_pos = 0;
for (int idx = 0; idx < dim_size; ++idx) {
cur_out_idx = cur_pos / strides.output_stride[idx];
inp_pos += cur_out_idx * strides.input_stride[idx];
cur_pos -= cur_out_idx * strides.output_stride[idx];
}
output[pos] = input[inp_pos];
}
}
template <typename T>
cudaError_t BroadcastTo(const std::vector<int64_t> &inp_shape, const std::vector<int64_t> &out_shape, T *input,
T *output, size_t device_id, cudaStream_t cuda_stream) {
const size_t dim_size = out_shape.size();
UnaryBroadcastStrideInfo strides;
size_t output_num = out_shape.back();
strides.input_stride[dim_size - 1] = 1;
strides.output_stride[dim_size - 1] = 1;
for (int64_t idx = dim_size - 2; idx >= 0; --idx) {
strides.output_stride[idx] = out_shape[idx + 1] * strides.output_stride[idx + 1];
strides.input_stride[idx] = inp_shape[idx + 1] * strides.input_stride[idx + 1];
output_num *= out_shape[idx];
}
for (size_t idx = 0; idx < dim_size; ++idx) {
strides.input_stride[idx] = (inp_shape[idx] == 1) ? 0 : strides.input_stride[idx];
}
size_t thread_num = output_num > 1024 ? 1024 : output_num;
BroadcastToCpyCuda<T><<<CUDA_BLOCKS_CAL(device_id, output_num, thread_num), thread_num, 0, cuda_stream>>>(
dim_size, output_num, strides, input, output);
CHECK_CUDA_LAUNCH_SUCCESS();
}
template CUDA_LIB_EXPORT cudaError_t BroadcastTo<bool>(const std::vector<int64_t> &inp_shape,
const std::vector<int64_t> &out_shape, bool *input, bool *output,
size_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT cudaError_t BroadcastTo<int8_t>(const std::vector<int64_t> &inp_shape,
const std::vector<int64_t> &out_shape, int8_t *input,
int8_t *output, size_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT cudaError_t BroadcastTo<int16_t>(const std::vector<int64_t> &inp_shape,
const std::vector<int64_t> &out_shape, int16_t *input,
int16_t *output, size_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT cudaError_t BroadcastTo<int32_t>(const std::vector<int64_t> &inp_shape,
const std::vector<int64_t> &out_shape, int32_t *input,
int32_t *output, size_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT cudaError_t BroadcastTo<int64_t>(const std::vector<int64_t> &inp_shape,
const std::vector<int64_t> &out_shape, int64_t *input,
int64_t *output, size_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT cudaError_t BroadcastTo<uint8_t>(const std::vector<int64_t> &inp_shape,
const std::vector<int64_t> &out_shape, uint8_t *input,
uint8_t *output, size_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT cudaError_t BroadcastTo<uint16_t>(const std::vector<int64_t> &inp_shape,
const std::vector<int64_t> &out_shape, uint16_t *input,
uint16_t *output, size_t device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT cudaError_t BroadcastTo<uint32_t>(const std::vector<int64_t> &inp_shape,
const std::vector<int64_t> &out_shape, uint32_t *input,
uint32_t *output, size_t device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT cudaError_t BroadcastTo<uint64_t>(const std::vector<int64_t> &inp_shape,
const std::vector<int64_t> &out_shape, uint64_t *input,
uint64_t *output, size_t device_id,
cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT cudaError_t BroadcastTo<half>(const std::vector<int64_t> &inp_shape,
const std::vector<int64_t> &out_shape, half *input, half *output,
size_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT cudaError_t BroadcastTo<float>(const std::vector<int64_t> &inp_shape,
const std::vector<int64_t> &out_shape, float *input,
float *output, size_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT cudaError_t BroadcastTo<double>(const std::vector<int64_t> &inp_shape,
const std::vector<int64_t> &out_shape, double *input,
double *output, size_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT cudaError_t BroadcastTo<Complex<float>>(const std::vector<int64_t> &inp_shape,
const std::vector<int64_t> &out_shape,
Complex<float> *input, Complex<float> *output,
size_t device_id, cudaStream_t cuda_stream);
template CUDA_LIB_EXPORT cudaError_t BroadcastTo<Complex<double>>(const std::vector<int64_t> &inp_shape,
const std::vector<int64_t> &out_shape,
Complex<double> *input, Complex<double> *output,
size_t device_id, cudaStream_t cuda_stream);

View File

@ -0,0 +1,25 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BROADCAST_TO_OPT_IMPL_CUH_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BROADCAST_TO_OPT_IMPL_CUH_
#include <vector>
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
template <typename T>
CUDA_LIB_EXPORT cudaError_t BroadcastTo(const std::vector<int64_t> &inp_shape, const std::vector<int64_t> &out_shape,
T *input, T *output, size_t device_id, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BROADCAST_TO_OPT_IMPL_CUH_

View File

@ -58,15 +58,12 @@ bool AddNFwdGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, co
}
FillDeviceArray(outputs[0]->size / sizeof(T), output_addr, 0.0f, reinterpret_cast<cudaStream_t>(stream_ptr_));
FillDeviceArray(outputs[0]->size / sizeof(T), work_addr, 0.0f, reinterpret_cast<cudaStream_t>(stream_ptr_));
std::vector<int64_t> ele_shape = {static_cast<int64_t>(outputs[0]->size / sizeof(T))};
for (size_t i = 0; i < num_input_; i++) {
T *input_addr = GetDeviceAddress<T>(inputs, i);
if constexpr (std::is_same<T, Complex<float>>::value || std::is_same<T, Complex<double>>::value) {
ElewiseComplexArith(outputs[0]->size / sizeof(T), BinaryOpType::kAdd, input_addr, work_addr, work_addr,
reinterpret_cast<cudaStream_t>(stream_ptr_));
} else {
ElewiseArith(outputs[0]->size / sizeof(T), BinaryOpType::kAdd, input_addr, work_addr, work_addr,
reinterpret_cast<cudaStream_t>(stream_ptr_));
}
BinaryOpWithBroadcastCudaFunc<BinaryOpType::kAdd, T, T, T>(false, ele_shape, ele_shape, ele_shape, input_addr,
work_addr, work_addr, device_id_,
reinterpret_cast<cudaStream_t>(stream_ptr_));
}
if (work_addr != output_addr) {
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(

View File

@ -23,8 +23,8 @@
#include <utility>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/math/broadcast_gpu_kernel.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/slice_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_ops_impl.cuh"
#include "plugin/device/gpu/kernel/kernel_constants.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"

View File

@ -0,0 +1,315 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/math/binary_ops_gpu_kernel.h"
#include <memory>
#include "plugin/device/gpu/kernel/math/broadcast_public.h"
namespace mindspore {
namespace kernel {
bool BroadcastOptGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
if (inputs.empty() || outputs.empty()) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid.";
return false;
}
auto iter = kBroadcastOpMap.find(kernel_name_);
if (iter != kBroadcastOpMap.end()) {
op_type_ = iter->second;
} else {
MS_LOG(ERROR) << "For BroadcastOptGpuKernelMod, it does not support this op: " << kernel_name_;
return false;
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr;
return false;
}
kernel_func_ = supported_type_map_.find(kernel_name_)->second[index].second;
return true;
}
int BroadcastOptGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
auto in0_shape = inputs[kIndex0]->GetShapeVector();
auto in1_shape = inputs[kIndex1]->GetShapeVector();
auto out_shape = outputs[kIndex0]->GetShapeVector();
if (in0_shape.size() == 0) {
in0_shape.emplace_back(1);
}
if (in1_shape.size() == 0) {
in1_shape.emplace_back(1);
}
if (out_shape.size() == 0) {
out_shape.emplace_back(1);
}
is_null_input_ = CHECK_SHAPE_NULL(in0_shape, kernel_name_, "input_0") ||
CHECK_SHAPE_NULL(in1_shape, kernel_name_, "input_1") ||
CHECK_SHAPE_NULL(out_shape, kernel_name_, "output_0");
if (is_null_input_) {
return KRET_OK;
}
SimplifyBinaryBroadcastShape(in0_shape, in1_shape, out_shape, &simplified_in0_shape_, &simplified_in1_shape_,
&simplified_out_shape_);
auto input0_num = SizeOf(simplified_in0_shape_);
auto input1_num = SizeOf(simplified_in1_shape_);
if (input0_num > 1 && input1_num > 1 && IsBinaryBroadcast(simplified_in0_shape_, simplified_in1_shape_)) {
is_broadcast_ = true;
} else {
is_broadcast_ = false;
}
return KRET_OK;
}
template <BinaryOpType op, typename In0_t, typename In1_t, typename Out_t>
bool BroadcastOptGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
auto lhs = GetDeviceAddress<In0_t>(inputs, kIndex0);
auto rhs = GetDeviceAddress<In1_t>(inputs, kIndex1);
auto out = GetDeviceAddress<Out_t>(outputs, kIndex0);
auto status = BinaryOpWithBroadcastCudaFunc<op, In0_t, In1_t, Out_t>(is_broadcast_, simplified_in0_shape_,
simplified_in1_shape_, simplified_out_shape_,
lhs, rhs, out, device_id_, cuda_stream_);
CHECK_CUDA_LAUNCH_STATUS(status, kernel_name_);
return true;
}
std::vector<KernelAttr> BroadcastOptGpuKernelMod::GetOpSupport() {
auto iter = supported_type_map_.find(kernel_name_);
std::vector<KernelAttr> support_list;
if (iter != supported_type_map_.end()) {
(void)std::transform(
iter->second.begin(), iter->second.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, BroadcastOptGpuKernelMod::BroadCastFunc> &item) { return item.first; });
}
return support_list;
}
#define MS_REG_BROADCAST_OP_SAME_TYPE(OP_TYPE, NUM_TYPE, TYPE) \
{ \
KernelAttr().AddInputAttr(NUM_TYPE).AddInputAttr(NUM_TYPE).AddOutputAttr(NUM_TYPE), \
&BroadcastOptGpuKernelMod::LaunchKernel<OP_TYPE, TYPE, TYPE, TYPE> \
}
#define MS_REG_BROADCAST_OP_DIFF_TYPE(OP_TYPE, In0_t_NUM_TYPE, In1_t_NUM_TYPE, OUT_NUM_TYPE, In0_t_TYPE, In1_t_TYPE, \
OUT_TYPE) \
{ \
KernelAttr().AddInputAttr(In0_t_NUM_TYPE).AddInputAttr(In1_t_NUM_TYPE).AddOutputAttr(OUT_NUM_TYPE), \
&BroadcastOptGpuKernelMod::LaunchKernel<OP_TYPE, In0_t_TYPE, In1_t_TYPE, OUT_TYPE> \
}
#define MS_REG_BROADCAST_OP_BOOL_TYPE(OP_TYPE) MS_REG_BROADCAST_OP_SAME_TYPE(OP_TYPE, kNumberTypeBool, bool)
#define MS_REG_BROADCAST_OP_INT_TYPE(OP_TYPE) \
MS_REG_BROADCAST_OP_SAME_TYPE(OP_TYPE, kNumberTypeUInt8, uint8_t), \
MS_REG_BROADCAST_OP_SAME_TYPE(OP_TYPE, kNumberTypeUInt16, uint16_t), \
MS_REG_BROADCAST_OP_SAME_TYPE(OP_TYPE, kNumberTypeUInt32, uint32_t), \
MS_REG_BROADCAST_OP_SAME_TYPE(OP_TYPE, kNumberTypeUInt64, uint64_t), \
MS_REG_BROADCAST_OP_SAME_TYPE(OP_TYPE, kNumberTypeInt8, int8_t), \
MS_REG_BROADCAST_OP_SAME_TYPE(OP_TYPE, kNumberTypeInt16, int16_t), \
MS_REG_BROADCAST_OP_SAME_TYPE(OP_TYPE, kNumberTypeInt32, int32_t), \
MS_REG_BROADCAST_OP_SAME_TYPE(OP_TYPE, kNumberTypeInt64, int64_t)
#define MS_REG_BROADCAST_OP_FLOAT_TYPE(OP_TYPE) \
MS_REG_BROADCAST_OP_SAME_TYPE(OP_TYPE, kNumberTypeFloat16, half), \
MS_REG_BROADCAST_OP_SAME_TYPE(OP_TYPE, kNumberTypeFloat32, float), \
MS_REG_BROADCAST_OP_SAME_TYPE(OP_TYPE, kNumberTypeFloat64, double)
#define MS_REG_BROADCAST_OP_MIX_COMPLEX_TYPE(OP_TYPE, FLOAT_NUM_TYPE, COMPLEX_NUM_TYPE, FLOAT_TYPE, COMPLEX_TYPE) \
MS_REG_BROADCAST_OP_DIFF_TYPE(OP_TYPE, FLOAT_NUM_TYPE, COMPLEX_NUM_TYPE, COMPLEX_NUM_TYPE, FLOAT_TYPE, COMPLEX_TYPE, \
COMPLEX_TYPE), \
MS_REG_BROADCAST_OP_DIFF_TYPE(OP_TYPE, COMPLEX_NUM_TYPE, FLOAT_NUM_TYPE, COMPLEX_NUM_TYPE, COMPLEX_TYPE, \
FLOAT_TYPE, COMPLEX_TYPE)
#define MS_REG_BROADCAST_OP_COMPLEX_TYPE(OP_TYPE) \
MS_REG_BROADCAST_OP_SAME_TYPE(OP_TYPE, kNumberTypeComplex64, Complex<float>), \
MS_REG_BROADCAST_OP_SAME_TYPE(OP_TYPE, kNumberTypeComplex128, Complex<double>), \
MS_REG_BROADCAST_OP_MIX_COMPLEX_TYPE(OP_TYPE, kNumberTypeFloat32, kNumberTypeComplex64, float, Complex<float>), \
MS_REG_BROADCAST_OP_MIX_COMPLEX_TYPE(OP_TYPE, kNumberTypeFloat64, kNumberTypeComplex128, double, Complex<double>)
#define MS_REG_BROADCAST_COMP_OP_TYPE(OP_TYPE, NUM_TYPE, TYPE) \
MS_REG_BROADCAST_OP_DIFF_TYPE(OP_TYPE, NUM_TYPE, NUM_TYPE, kNumberTypeBool, TYPE, TYPE, bool)
#define MS_REG_BROADCAST_COMP_OP_INT_TYPE(OP_TYPE) \
MS_REG_BROADCAST_COMP_OP_TYPE(OP_TYPE, kNumberTypeBool, bool), \
MS_REG_BROADCAST_COMP_OP_TYPE(OP_TYPE, kNumberTypeUInt8, uint8_t), \
MS_REG_BROADCAST_COMP_OP_TYPE(OP_TYPE, kNumberTypeUInt16, uint16_t), \
MS_REG_BROADCAST_COMP_OP_TYPE(OP_TYPE, kNumberTypeUInt32, uint32_t), \
MS_REG_BROADCAST_COMP_OP_TYPE(OP_TYPE, kNumberTypeUInt64, uint64_t), \
MS_REG_BROADCAST_COMP_OP_TYPE(OP_TYPE, kNumberTypeInt8, int8_t), \
MS_REG_BROADCAST_COMP_OP_TYPE(OP_TYPE, kNumberTypeInt16, int16_t), \
MS_REG_BROADCAST_COMP_OP_TYPE(OP_TYPE, kNumberTypeInt32, int32_t), \
MS_REG_BROADCAST_COMP_OP_TYPE(OP_TYPE, kNumberTypeInt64, int64_t)
#define MS_REG_BROADCAST_COMP_OP_FLOAT_TYPE(OP_TYPE) \
MS_REG_BROADCAST_COMP_OP_TYPE(OP_TYPE, kNumberTypeFloat16, half), \
MS_REG_BROADCAST_COMP_OP_TYPE(OP_TYPE, kNumberTypeFloat32, float), \
MS_REG_BROADCAST_COMP_OP_TYPE(OP_TYPE, kNumberTypeFloat64, double)
#define MS_REG_BROADCAST_COMPARE_OP_TYPE(OP_TYPE) \
MS_REG_BROADCAST_COMP_OP_INT_TYPE(OP_TYPE), MS_REG_BROADCAST_COMP_OP_FLOAT_TYPE(OP_TYPE)
// now only for op named Complex
#define MS_REG_BROADCAST_OP_COMPLEX(OP_TYPE) \
MS_REG_BROADCAST_OP_DIFF_TYPE(OP_TYPE, kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeComplex64, float, float, \
Complex<float>), \
MS_REG_BROADCAST_OP_DIFF_TYPE(OP_TYPE, kNumberTypeFloat64, kNumberTypeFloat64, kNumberTypeComplex128, double, \
double, Complex<double>)
std::map<std::string, std::vector<std::pair<KernelAttr, BroadcastOptGpuKernelMod::BroadCastFunc>>>
BroadcastOptGpuKernelMod::supported_type_map_ = {
{"Add",
{MS_REG_BROADCAST_OP_INT_TYPE(BinaryOpType::kAdd), MS_REG_BROADCAST_OP_FLOAT_TYPE(BinaryOpType::kAdd),
MS_REG_BROADCAST_OP_COMPLEX_TYPE(BinaryOpType::kAdd)}},
{"Sub",
{MS_REG_BROADCAST_OP_INT_TYPE(BinaryOpType::kSub), MS_REG_BROADCAST_OP_FLOAT_TYPE(BinaryOpType::kSub),
MS_REG_BROADCAST_OP_COMPLEX_TYPE(BinaryOpType::kSub)}},
{"Mul",
{MS_REG_BROADCAST_OP_INT_TYPE(BinaryOpType::kMul), MS_REG_BROADCAST_OP_FLOAT_TYPE(BinaryOpType::kMul),
MS_REG_BROADCAST_OP_COMPLEX_TYPE(BinaryOpType::kMul)}},
{"Div",
{MS_REG_BROADCAST_OP_INT_TYPE(BinaryOpType::kDiv), MS_REG_BROADCAST_OP_FLOAT_TYPE(BinaryOpType::kDiv),
MS_REG_BROADCAST_OP_COMPLEX_TYPE(BinaryOpType::kDiv)}},
{"Pow",
{MS_REG_BROADCAST_OP_INT_TYPE(BinaryOpType::kPow), MS_REG_BROADCAST_OP_FLOAT_TYPE(BinaryOpType::kPow),
MS_REG_BROADCAST_OP_COMPLEX_TYPE(BinaryOpType::kPow)}},
{"Xdivy",
{MS_REG_BROADCAST_OP_FLOAT_TYPE(BinaryOpType::kXdivy), MS_REG_BROADCAST_OP_COMPLEX_TYPE(BinaryOpType::kXdivy)}},
{"Xlogy",
{MS_REG_BROADCAST_OP_FLOAT_TYPE(BinaryOpType::kXlogy), MS_REG_BROADCAST_OP_COMPLEX_TYPE(BinaryOpType::kXlogy)}},
{"RealDiv",
{MS_REG_BROADCAST_OP_INT_TYPE(BinaryOpType::kRealDiv), MS_REG_BROADCAST_OP_FLOAT_TYPE(BinaryOpType::kRealDiv),
MS_REG_BROADCAST_OP_COMPLEX_TYPE(BinaryOpType::kRealDiv)}},
{"MulNoNan",
{MS_REG_BROADCAST_OP_INT_TYPE(BinaryOpType::kMulNoNan), MS_REG_BROADCAST_OP_FLOAT_TYPE(BinaryOpType::kMulNoNan),
MS_REG_BROADCAST_OP_COMPLEX_TYPE(BinaryOpType::kMulNoNan)}},
{"Atan2",
{MS_REG_BROADCAST_OP_INT_TYPE(BinaryOpType::kAtan2), MS_REG_BROADCAST_OP_FLOAT_TYPE(BinaryOpType::kAtan2)}},
{"AbsGrad",
{MS_REG_BROADCAST_OP_BOOL_TYPE(BinaryOpType::kAbsGrad), MS_REG_BROADCAST_OP_INT_TYPE(BinaryOpType::kAbsGrad),
MS_REG_BROADCAST_OP_FLOAT_TYPE(BinaryOpType::kAbsGrad)}},
{"BitwiseAnd",
{MS_REG_BROADCAST_OP_BOOL_TYPE(BinaryOpType::kBitwiseAnd),
MS_REG_BROADCAST_OP_INT_TYPE(BinaryOpType::kBitwiseAnd)}},
{"BitwiseOr",
{MS_REG_BROADCAST_OP_BOOL_TYPE(BinaryOpType::kBitwiseOr), MS_REG_BROADCAST_OP_INT_TYPE(BinaryOpType::kBitwiseOr)}},
{"BitwiseXor",
{MS_REG_BROADCAST_OP_BOOL_TYPE(BinaryOpType::kBitwiseXor),
MS_REG_BROADCAST_OP_INT_TYPE(BinaryOpType::kBitwiseXor)}},
{"DivNoNan",
{MS_REG_BROADCAST_OP_INT_TYPE(BinaryOpType::kDivNoNan), MS_REG_BROADCAST_OP_FLOAT_TYPE(BinaryOpType::kDivNoNan),
MS_REG_BROADCAST_OP_COMPLEX_TYPE(BinaryOpType::kDivNoNan)}},
{"FloorMod",
{MS_REG_BROADCAST_OP_INT_TYPE(BinaryOpType::kFloorMod), MS_REG_BROADCAST_OP_FLOAT_TYPE(BinaryOpType::kFloorMod)}},
{"FloorDiv",
{MS_REG_BROADCAST_OP_INT_TYPE(BinaryOpType::kFloorDiv), MS_REG_BROADCAST_OP_FLOAT_TYPE(BinaryOpType::kFloorDiv)}},
{"Mod", {MS_REG_BROADCAST_OP_INT_TYPE(BinaryOpType::kMod), MS_REG_BROADCAST_OP_FLOAT_TYPE(BinaryOpType::kMod)}},
{"Minimum",
{MS_REG_BROADCAST_OP_BOOL_TYPE(BinaryOpType::kMinimum), MS_REG_BROADCAST_OP_INT_TYPE(BinaryOpType::kMinimum),
MS_REG_BROADCAST_OP_FLOAT_TYPE(BinaryOpType::kMinimum)}},
{"Maximum",
{MS_REG_BROADCAST_OP_BOOL_TYPE(BinaryOpType::kMaximum), MS_REG_BROADCAST_OP_INT_TYPE(BinaryOpType::kMaximum),
MS_REG_BROADCAST_OP_FLOAT_TYPE(BinaryOpType::kMaximum)}},
{"SquaredDifference",
{MS_REG_BROADCAST_OP_FLOAT_TYPE(BinaryOpType::kSquaredDifference),
MS_REG_BROADCAST_OP_COMPLEX_TYPE(BinaryOpType::kSquaredDifference)}},
{"TruncateDiv",
{MS_REG_BROADCAST_OP_INT_TYPE(BinaryOpType::kTruncateDiv),
MS_REG_BROADCAST_OP_FLOAT_TYPE(BinaryOpType::kTruncateDiv)}},
{"TruncateMod",
{MS_REG_BROADCAST_OP_INT_TYPE(BinaryOpType::kTruncateMod),
MS_REG_BROADCAST_OP_FLOAT_TYPE(BinaryOpType::kTruncateMod)}},
{"Complex", {MS_REG_BROADCAST_OP_COMPLEX(BinaryOpType::kComplex)}},
{"Greater", {MS_REG_BROADCAST_COMPARE_OP_TYPE(BinaryOpType::kGreater)}},
{"Less", {MS_REG_BROADCAST_COMPARE_OP_TYPE(BinaryOpType::kLess)}},
{"Equal", {MS_REG_BROADCAST_COMPARE_OP_TYPE(BinaryOpType::kEqual)}},
{"GreaterEqual", {MS_REG_BROADCAST_COMPARE_OP_TYPE(BinaryOpType::kGreaterEqual)}},
{"LessEqual", {MS_REG_BROADCAST_COMPARE_OP_TYPE(BinaryOpType::kLessEqual)}},
{"NotEqual", {MS_REG_BROADCAST_COMPARE_OP_TYPE(BinaryOpType::kNotEqual)}},
{"LogicalAnd", {MS_REG_BROADCAST_OP_BOOL_TYPE(BinaryOpType::kLogicalAnd)}},
{"LogicalOr", {MS_REG_BROADCAST_OP_BOOL_TYPE(BinaryOpType::kLogicalOr)}},
};
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Add,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("Add"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Div,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("Div"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Mul,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("Mul"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Sub,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("Sub"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Atan2,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("Atan2"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, AbsGrad,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("AbsGrad"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, BitwiseAnd,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("BitwiseAnd"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, BitwiseOr,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("BitwiseOr"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, BitwiseXor,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("BitwiseXor"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, DivNoNan,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("DivNoNan"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, FloorMod,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("FloorMod"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, FloorDiv,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("FloorDiv"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, MulNoNan,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("MulNoNan"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Mod,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("Mod"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Minimum,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("Minimum"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Maximum,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("Maximum"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Pow,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("Pow"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, RealDiv,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("RealDiv"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, TruncateDiv,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("TruncateDiv"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, TruncateMod,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("TruncateMod"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Complex,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("Complex"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Xdivy,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("Xdivy"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Xlogy,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("Xlogy"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Greater,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("Greater"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Less,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("Less"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, Equal,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("Equal"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, GreaterEqual,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("GreaterEqual"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, LessEqual,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("LessEqual"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, NotEqual,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("NotEqual"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, LogicalAnd,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("LogicalAnd"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, LogicalOr,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("LogicalOr"); });
MS_KERNEL_FACTORY_REG_BY_CREATOR(NativeGpuKernelMod, SquaredDifference,
[]() { return std::make_shared<BroadcastOptGpuKernelMod>("SquaredDifference"); });
} // namespace kernel
} // namespace mindspore

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_BROADCAST_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_BROADCAST_GPU_KERNEL_H_
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_BINARY_OPS_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_BINARY_OPS_GPU_KERNEL_H_
#include <cuda_runtime_api.h>
#include <vector>
@ -27,61 +27,56 @@
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/broadcast_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_ops_impl.cuh"
#include "plugin/device/gpu/kernel/kernel_constants.h"
#include "include/backend/anf_runtime_algorithm.h"
#include "include/common/utils/anfalgo.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_func.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_types.cuh"
namespace mindspore {
namespace kernel {
constexpr int MAX_DIMS = 7;
constexpr int STRIDE_NUM = 3;
template <typename T>
using Complex = mindspore::utils::Complex<T>;
static const std::map<std::string, BinaryOpType> kBroadcastCmpTypeMap = {
{"Greater", BinaryOpType::kGreater}, {"Less", BinaryOpType::kLess},
{"Equal", BinaryOpType::kEqual}, {"GreaterEqual", BinaryOpType::kGreaterEqual},
{"LessEqual", BinaryOpType::kLessEqual}, {"NotEqual", BinaryOpType::kNotEqual},
{"LogicalAnd", BinaryOpType::kLogicalAnd}, {"LogicalOr", BinaryOpType::kLogicalOr},
};
static const std::map<std::string, BinaryOpType> kBroadcastArithmetricTypeMap = {
static const std::map<std::string, BinaryOpType> kBroadcastOpMap = {
{"Greater", BinaryOpType::kGreater},
{"Less", BinaryOpType::kLess},
{"Equal", BinaryOpType::kEqual},
{"GreaterEqual", BinaryOpType::kGreaterEqual},
{"LessEqual", BinaryOpType::kLessEqual},
{"NotEqual", BinaryOpType::kNotEqual},
{"LogicalAnd", BinaryOpType::kLogicalAnd},
{"LogicalOr", BinaryOpType::kLogicalOr},
{"Maximum", BinaryOpType::kMaximum},
{"Minimum", BinaryOpType::kMinimum},
{"Pow", BinaryOpType::kPower},
{"RealDiv", BinaryOpType::kRealDiv},
{"Mul", BinaryOpType::kMul},
{"Sub", BinaryOpType::kSub},
{"Add", BinaryOpType::kAdd},
{"FloorDiv", BinaryOpType::kFloorDiv},
{"AbsGrad", BinaryOpType::kAbsGrad},
{"Div", BinaryOpType::kDiv},
{"DivNoNan", BinaryOpType::kDivNoNan},
{"MulNoNan", BinaryOpType::kMulNoNan},
{"Mod", BinaryOpType::kMod},
{"FloorMod", BinaryOpType::kFloorMod},
{"Atan2", BinaryOpType::kAtan2},
{"TruncateDiv", BinaryOpType::kTruncateDiv},
{"TruncateMod", BinaryOpType::kTruncateMod},
{"Pow", BinaryOpType::kPow},
{"RealDiv", BinaryOpType::kRealDiv},
{"BitwiseAnd", BinaryOpType::kBitwiseAnd},
{"BitwiseOr", BinaryOpType::kBitwiseOr},
{"BitwiseXor", BinaryOpType::kBitwiseXor},
{"Xdivy", BinaryOpType::kXdivy},
{"Mod", BinaryOpType::kMod},
{"FloorMod", BinaryOpType::kFloorMod},
{"SquaredDifference", BinaryOpType::kSquaredDifference},
{"Atan2", BinaryOpType::kAtan2},
{"TruncateDiv", BinaryOpType::kTruncateDiv},
{"TruncateMod", BinaryOpType::kTruncateMod},
{"AbsGrad", BinaryOpType::kAbsGrad},
{"FloorDiv", BinaryOpType::kFloorDiv},
{"DivNoNan", BinaryOpType::kDivNoNan},
{"MulNoNan", BinaryOpType::kMulNoNan},
{"Xlogy", BinaryOpType::kXlogy},
};
static const std::map<std::string, BinaryOpType> kBroadcastComplexAndRealTypeMap = {
{"RealDiv", BinaryOpType::kRealDiv}, {"Mul", BinaryOpType::kMul}, {"Sub", BinaryOpType::kSub},
{"Add", BinaryOpType::kAdd}, {"Div", BinaryOpType::kDiv}, {"MulNoNan", BinaryOpType::kMulNoNan},
{"Pow", BinaryOpType::kPower}, {"Xdivy", BinaryOpType::kXdivy}, {"Xlogy", BinaryOpType::kXlogy}};
static const std::map<std::string, BinaryOpType> kBroadcastComplexOnlyTypeMap = {
{"Xdivy", BinaryOpType::kXdivy},
{"Complex", BinaryOpType::kComplex},
};
class BroadcastOpGpuKernelMod : public NativeGpuKernelMod {
class BroadcastOptGpuKernelMod : public NativeGpuKernelMod {
public:
BroadcastOpGpuKernelMod() {}
~BroadcastOpGpuKernelMod() override = default;
explicit BroadcastOptGpuKernelMod(const std::string &kernel_name) { kernel_name_ = kernel_name; }
~BroadcastOptGpuKernelMod() override = default;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
@ -101,37 +96,24 @@ class BroadcastOpGpuKernelMod : public NativeGpuKernelMod {
std::vector<KernelAttr> GetOpSupport() override;
private:
bool GetOpType();
template <typename T>
template <BinaryOpType op, typename In0, typename In1, typename OUT>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
template <typename T, typename S, typename G>
bool LaunchComplexKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
using BroadCastFunc = std::function<bool(BroadcastOpGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
using BroadCastFunc = std::function<bool(BroadcastOptGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &)>;
std::string GetValidKernelTypes();
BinaryOpType op_type_;
bool need_broadcast_;
bool is_compare_op_;
bool support_complex_{true};
bool support_real_{true};
bool is_broadcast_;
bool is_null_input_;
std::vector<size_t> lhs_shape_;
std::vector<size_t> rhs_shape_;
std::vector<size_t> output_shape_;
size_t unit_size_{1};
size_t output_num_{1};
std::vector<int64_t> simplified_in0_shape_;
std::vector<int64_t> simplified_in1_shape_;
std::vector<int64_t> simplified_out_shape_;
cudaStream_t cuda_stream_{nullptr};
BroadCastFunc kernel_func_{};
static std::vector<std::pair<KernelAttr, BroadCastFunc>> real_list_;
static std::vector<std::pair<KernelAttr, BroadCastFunc>> complex_list_;
std::vector<std::pair<KernelAttr, BroadcastOpGpuKernelMod::BroadCastFunc>> func_list_;
BroadCastFunc kernel_func_{nullptr};
static std::map<std::string, std::vector<std::pair<KernelAttr, BroadcastOptGpuKernelMod::BroadCastFunc>>>
supported_type_map_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_BROADCAST_GPU_KERNEL_H_
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_BINARY_OPS_GPU_KERNEL_H_

View File

@ -1,301 +0,0 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/math/broadcast_gpu_kernel.h"
#include <iostream>
namespace mindspore {
namespace kernel {
#define MS_REG_BROADCAST_COMPLEX_GPU_KERNEL1(T0_MS_DTYPE, T1_MS_DTYPE, T0_DTYPE, T1_DTYPE) \
KernelAttr().AddInputAttr(T0_MS_DTYPE).AddInputAttr(T0_MS_DTYPE).AddOutputAttr(T0_MS_DTYPE), \
&BroadcastOpGpuKernelMod::LaunchComplexKernel<T0_DTYPE, T0_DTYPE, T0_DTYPE>
#define MS_REG_BROADCAST_COMPLEX_GPU_KERNEL2(T0_MS_DTYPE, T1_MS_DTYPE, T0_DTYPE, T1_DTYPE) \
KernelAttr().AddInputAttr(T0_MS_DTYPE).AddInputAttr(T1_MS_DTYPE).AddOutputAttr(T0_MS_DTYPE), \
&BroadcastOpGpuKernelMod::LaunchComplexKernel<T0_DTYPE, T1_DTYPE, T0_DTYPE>
#define MS_REG_BROADCAST_COMPLEX_GPU_KERNEL3(T0_MS_DTYPE, T1_MS_DTYPE, T0_DTYPE, T1_DTYPE) \
KernelAttr().AddInputAttr(T1_MS_DTYPE).AddInputAttr(T0_MS_DTYPE).AddOutputAttr(T0_MS_DTYPE), \
&BroadcastOpGpuKernelMod::LaunchComplexKernel<T1_DTYPE, T0_DTYPE, T0_DTYPE>
bool BroadcastOpGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
support_complex_ = false;
if (inputs.empty() || outputs.empty()) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid.";
return false;
}
if (!GetOpType()) {
return false;
}
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
if (!is_match) {
MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr;
return false;
}
kernel_func_ = func_list_[index].second;
return true;
}
bool BroadcastOpGpuKernelMod::GetOpType() {
auto iter = kBroadcastComplexAndRealTypeMap.find(kernel_name_);
if (iter != kBroadcastComplexAndRealTypeMap.end()) {
op_type_ = iter->second;
support_complex_ = true;
}
iter = kBroadcastComplexOnlyTypeMap.find(kernel_name_);
if (iter != kBroadcastComplexOnlyTypeMap.end()) {
op_type_ = iter->second;
support_complex_ = true;
support_real_ = false;
return true;
}
iter = kBroadcastCmpTypeMap.find(kernel_name_);
if (iter != kBroadcastCmpTypeMap.end()) {
op_type_ = iter->second;
is_compare_op_ = true;
return true;
}
iter = kBroadcastArithmetricTypeMap.find(kernel_name_);
if (iter != kBroadcastArithmetricTypeMap.end()) {
op_type_ = iter->second;
is_compare_op_ = false;
return true;
}
MS_LOG(ERROR) << "For 'BroadcastGpuOp', it only support these types: " << GetValidKernelTypes()
<< " currently, but got " << kernel_name_;
return false;
}
std::string BroadcastOpGpuKernelMod::GetValidKernelTypes() {
std::ostringstream valid_types;
valid_types << "Valid Compare Types: ";
std::for_each(kBroadcastCmpTypeMap.cbegin(), kBroadcastCmpTypeMap.cend(),
[&valid_types](const std::map<std::string, BinaryOpType>::value_type &p) {
valid_types << p.first << std::string(", ");
});
valid_types << "; Valid Arithmetric Types: ";
std::for_each(kBroadcastArithmetricTypeMap.cbegin(), kBroadcastArithmetricTypeMap.cend(),
[&valid_types](const std::map<std::string, BinaryOpType>::value_type &p) {
valid_types << p.first << std::string(", ");
});
valid_types << "; Valid Complex Types: ";
std::for_each(kBroadcastComplexOnlyTypeMap.cbegin(), kBroadcastComplexOnlyTypeMap.cend(),
[&valid_types](const std::map<std::string, BinaryOpType>::value_type &p) {
valid_types << p.first << std::string(", ");
});
return valid_types.str();
}
int BroadcastOpGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
lhs_shape_ = LongVecToSizeVec(inputs.at(kIndex0)->GetShapeVector());
rhs_shape_ = LongVecToSizeVec(inputs.at(kIndex1)->GetShapeVector());
output_shape_ = LongVecToSizeVec(outputs.at(kIndex0)->GetShapeVector());
output_num_ = std::accumulate(output_shape_.begin(), output_shape_.end(), size_t(1), std::multiplies<size_t>());
is_null_input_ = CHECK_SHAPE_NULL(lhs_shape_, kernel_name_, "input_0") ||
CHECK_SHAPE_NULL(rhs_shape_, kernel_name_, "input_1") ||
CHECK_SHAPE_NULL(output_shape_, kernel_name_, "output_0");
if (is_null_input_) {
return KRET_OK;
}
need_broadcast_ = common::AnfAlgo::IsTensorBroadcast(lhs_shape_, rhs_shape_);
if (!broadcast_utils::AlignedBroadCastShape(MAX_DIMS, &output_shape_, &lhs_shape_, &rhs_shape_)) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input cannot be greater than " << MAX_DIMS
<< ", and output dimension can't less than input; but got x_shape dimension:" << lhs_shape_.size()
<< " ,y_shape dimension:" << rhs_shape_.size()
<< " ,out_shape dimension:" << output_shape_.size();
}
return KRET_OK;
}
template <typename T>
bool BroadcastOpGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
auto lhs = GetDeviceAddress<T>(inputs, kIndex0);
auto rhs = GetDeviceAddress<T>(inputs, kIndex1);
if (is_compare_op_) {
bool *output = GetDeviceAddress<bool>(outputs, kIndex0);
if (need_broadcast_) {
BroadcastCmp(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output, cuda_stream_);
} else {
ElewiseCmp(output_num_, op_type_, lhs, rhs, output, cuda_stream_);
}
} else {
T *output = GetDeviceAddress<T>(outputs, 0);
if (need_broadcast_) {
BroadcastArith(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output, cuda_stream_);
} else {
ElewiseArith(output_num_, op_type_, lhs, rhs, output, cuda_stream_);
}
}
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
MS_LOG(ERROR) << "Cuda calculate error for " << kernel_name_ << ", error desc:" << cudaGetErrorString(err);
return false;
}
return true;
}
template <typename T, typename S, typename G>
bool BroadcastOpGpuKernelMod::LaunchComplexKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
T *lhs = GetDeviceAddress<T>(inputs, kIndex0);
S *rhs = GetDeviceAddress<S>(inputs, kIndex1);
G *output = GetDeviceAddress<G>(outputs, kIndex0);
if (need_broadcast_) {
BroadcastComplexArith(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output, cuda_stream_);
} else {
ElewiseComplexArith(output_num_, op_type_, lhs, rhs, output, cuda_stream_);
}
return true;
}
std::vector<KernelAttr> BroadcastOpGpuKernelMod::GetOpSupport() {
std::vector<KernelAttr> support_list;
func_list_.clear();
if (support_complex_) {
(void)std::transform(complex_list_.begin(), complex_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, BroadCastFunc> &pair) { return pair.first; });
(void)std::transform(complex_list_.begin(), complex_list_.end(), std::back_inserter(func_list_),
[](const std::pair<KernelAttr, BroadCastFunc> &pair) { return pair; });
}
if (support_real_) {
(void)std::transform(real_list_.begin(), real_list_.end(), std::back_inserter(support_list),
[](const std::pair<KernelAttr, BroadCastFunc> &pair) { return pair.first; });
(void)std::transform(real_list_.begin(), real_list_.end(), std::back_inserter(func_list_),
[](const std::pair<KernelAttr, BroadCastFunc> &pair) { return pair; });
}
return support_list;
}
template <typename T>
using Complex = mindspore::utils::Complex<T>;
std::vector<std::pair<KernelAttr, BroadcastOpGpuKernelMod::BroadCastFunc>> BroadcastOpGpuKernelMod::complex_list_ = {
{MS_REG_BROADCAST_COMPLEX_GPU_KERNEL1(kNumberTypeComplex64, kNumberTypeFloat32, Complex<float>, float)},
{MS_REG_BROADCAST_COMPLEX_GPU_KERNEL2(kNumberTypeComplex64, kNumberTypeFloat32, Complex<float>, float)},
{MS_REG_BROADCAST_COMPLEX_GPU_KERNEL3(kNumberTypeComplex64, kNumberTypeFloat32, Complex<float>, float)},
{MS_REG_BROADCAST_COMPLEX_GPU_KERNEL1(kNumberTypeComplex128, kNumberTypeFloat64, Complex<double>, double)},
{MS_REG_BROADCAST_COMPLEX_GPU_KERNEL2(kNumberTypeComplex128, kNumberTypeFloat64, Complex<double>, double)},
{MS_REG_BROADCAST_COMPLEX_GPU_KERNEL3(kNumberTypeComplex128, kNumberTypeFloat64, Complex<double>, double)},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeComplex64),
&BroadcastOpGpuKernelMod::LaunchComplexKernel<float, float, Complex<float>>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeComplex128),
&BroadcastOpGpuKernelMod::LaunchComplexKernel<double, double, Complex<double>>},
{MS_REG_BROADCAST_COMPLEX_GPU_KERNEL1(kNumberTypeComplex64, kNumberTypeComplex64, Complex<float>, Complex<float>)},
{MS_REG_BROADCAST_COMPLEX_GPU_KERNEL1(kNumberTypeComplex128, kNumberTypeComplex128, Complex<double>,
Complex<double>)},
};
std::vector<std::pair<KernelAttr, BroadcastOpGpuKernelMod::BroadCastFunc>> BroadcastOpGpuKernelMod::real_list_ = {
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
&BroadcastOpGpuKernelMod::LaunchKernel<bool>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
&BroadcastOpGpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeBool),
&BroadcastOpGpuKernelMod::LaunchKernel<uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeBool),
&BroadcastOpGpuKernelMod::LaunchKernel<uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeBool),
&BroadcastOpGpuKernelMod::LaunchKernel<uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
&BroadcastOpGpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
&BroadcastOpGpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
&BroadcastOpGpuKernelMod::LaunchKernel<int>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
&BroadcastOpGpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
&BroadcastOpGpuKernelMod::LaunchKernel<half>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
&BroadcastOpGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeBool),
&BroadcastOpGpuKernelMod::LaunchKernel<double>},
{KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
&BroadcastOpGpuKernelMod::LaunchKernel<bool>},
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
&BroadcastOpGpuKernelMod::LaunchKernel<uint8_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
&BroadcastOpGpuKernelMod::LaunchKernel<uint16_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
&BroadcastOpGpuKernelMod::LaunchKernel<uint32_t>},
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
&BroadcastOpGpuKernelMod::LaunchKernel<uint64_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
&BroadcastOpGpuKernelMod::LaunchKernel<int8_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
&BroadcastOpGpuKernelMod::LaunchKernel<int16_t>},
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
&BroadcastOpGpuKernelMod::LaunchKernel<int>},
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
&BroadcastOpGpuKernelMod::LaunchKernel<int64_t>},
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
&BroadcastOpGpuKernelMod::LaunchKernel<half>},
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
&BroadcastOpGpuKernelMod::LaunchKernel<float>},
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
&BroadcastOpGpuKernelMod::LaunchKernel<double>},
};
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Add, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Atan2, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, AbsGrad, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, BitwiseAnd, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, BitwiseOr, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, BitwiseXor, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Div, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, DivNoNan, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Equal, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, FloorMod, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, FloorDiv, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Greater, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, GreaterEqual, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Less, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, LessEqual, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, LogicalOr, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, LogicalAnd, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Mul, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, MulNoNan, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Mod, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Minimum, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Maximum, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, NotEqual, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Pow, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, RealDiv, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Sub, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, TruncateDiv, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, TruncateMod, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Complex, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Xdivy, BroadcastOpGpuKernelMod);
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Xlogy, BroadcastOpGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,178 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/math/broadcast_public.h"
bool IsBinaryBroadcast(const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape) {
if (in0_shape.size() != in1_shape.size()) {
return true;
}
for (size_t i = 0; i < in0_shape.size(); i++) {
if (in0_shape[i] != in1_shape[i]) {
return true;
}
}
return false;
}
void SimplifyBinaryBroadcastShape(const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape,
const std::vector<int64_t> &out_shape, std::vector<int64_t> *simplified_in0_shape,
std::vector<int64_t> *simplified_in1_shape,
std::vector<int64_t> *simplified_out_shape) {
size_t out_rank = out_shape.size();
size_t l_rank = in0_shape.size();
size_t r_rank = in1_shape.size();
size_t l_offset = out_rank - l_rank;
std::vector<int64_t> aligned_in0_shape(in0_shape);
std::vector<int64_t> aligned_in1_shape(in1_shape);
std::vector<int64_t> aligned_out_shape(out_shape);
if (aligned_in0_shape.size() == 0) {
aligned_in0_shape.emplace_back(1);
}
if (aligned_in1_shape.size() == 0) {
aligned_in1_shape.emplace_back(1);
}
if (aligned_out_shape.size() == 0) {
aligned_out_shape.emplace_back(1);
}
// broadcast shape
if (l_offset > 0) {
std::vector<int64_t> insert_lft(l_offset, 1);
aligned_in0_shape.insert(aligned_in0_shape.begin(), insert_lft.begin(), insert_lft.end());
}
size_t r_offset = out_rank - r_rank;
if (r_offset > 0) {
std::vector<int64_t> insert_rht(r_offset, 1);
aligned_in1_shape.insert(aligned_in1_shape.begin(), insert_rht.begin(), insert_rht.end());
}
// simplify shape
simplified_in0_shape->clear();
simplified_in1_shape->clear();
simplified_out_shape->clear();
auto CalStatus = [](int64_t in0_val, int64_t in1_val) -> int {
if (in0_val == 1 || in1_val == 1) {
if (in0_val == 1) {
if (in1_val == 1) {
return 0;
} else {
return 1;
}
} else {
return 2;
}
} else {
return 3;
}
};
size_t head_idx = 0;
int head_status = CalStatus(aligned_in0_shape[head_idx], aligned_in1_shape[head_idx]);
while (head_status == 0 && head_idx < aligned_out_shape.size() - 1) {
++head_idx;
head_status = CalStatus(aligned_in0_shape[head_idx], aligned_in1_shape[head_idx]);
}
if (head_idx == aligned_out_shape.size() - 1) {
simplified_in0_shape->emplace_back(aligned_in0_shape.back());
simplified_in1_shape->emplace_back(aligned_in1_shape.back());
simplified_out_shape->emplace_back(aligned_out_shape.back());
return;
}
while (head_idx < aligned_out_shape.size()) {
int64_t in0_merged = aligned_in0_shape[head_idx];
int64_t in1_merged = aligned_in1_shape[head_idx];
int64_t out_merged = aligned_out_shape[head_idx];
size_t tail_idx = head_idx + 1;
while (tail_idx < aligned_out_shape.size()) {
int tail_status = CalStatus(aligned_in0_shape[tail_idx], aligned_in1_shape[tail_idx]);
if (tail_status * head_status == 0 || head_status == tail_status) {
in0_merged *= aligned_in0_shape[tail_idx];
in1_merged *= aligned_in1_shape[tail_idx];
out_merged *= aligned_out_shape[tail_idx];
++tail_idx;
} else {
head_status = tail_status;
break;
}
}
head_idx = tail_idx;
simplified_in0_shape->emplace_back(in0_merged);
simplified_in1_shape->emplace_back(in1_merged);
simplified_out_shape->emplace_back(out_merged);
}
}
void SimplifyBroadcastToShape(const std::vector<int64_t> &inp_shape, const std::vector<int64_t> &out_shape,
std::vector<int64_t> *simplified_inp_shape, std::vector<int64_t> *simplified_out_shape) {
std::vector<int64_t> aligned_inp_shape(inp_shape);
std::vector<int64_t> aligned_out_shape(out_shape);
if (aligned_inp_shape.size() == 0) {
aligned_inp_shape.emplace_back(1);
}
if (aligned_out_shape.size() == 0) {
aligned_out_shape.emplace_back(1);
}
size_t offset = aligned_out_shape.size() - aligned_inp_shape.size();
// broadcast shape
if (offset > 0) {
std::vector<int64_t> insert_shape(offset, 1);
aligned_inp_shape.insert(aligned_inp_shape.begin(), insert_shape.begin(), insert_shape.end());
}
// simplify shape
simplified_inp_shape->clear();
simplified_out_shape->clear();
auto CalStatus = [](int64_t inp_val, int64_t out_val) -> int {
if (inp_val == 1) {
if (out_val == 1) {
return 0;
} else {
return 1;
}
} else {
return 2;
}
};
size_t head_idx = 0;
int head_status = CalStatus(aligned_inp_shape[head_idx], aligned_out_shape[head_idx]);
while (head_status == 0 && head_idx < aligned_out_shape.size() - 1) {
++head_idx;
head_status = CalStatus(aligned_inp_shape[head_idx], aligned_out_shape[head_idx]);
}
if (head_idx == aligned_out_shape.size() - 1) {
simplified_inp_shape->emplace_back(aligned_inp_shape.back());
simplified_out_shape->emplace_back(aligned_out_shape.back());
return;
}
while (head_idx < aligned_out_shape.size()) {
int64_t inp_merged = aligned_inp_shape[head_idx];
int64_t out_merged = aligned_out_shape[head_idx];
size_t tail_idx = head_idx + 1;
while (tail_idx < aligned_out_shape.size()) {
int tail_status = CalStatus(aligned_inp_shape[tail_idx], aligned_out_shape[tail_idx]);
if (tail_status * head_status == 0 || head_status == tail_status) {
inp_merged *= aligned_inp_shape[tail_idx];
out_merged *= aligned_out_shape[tail_idx];
++tail_idx;
} else {
head_status = tail_status;
break;
}
}
head_idx = tail_idx;
simplified_inp_shape->emplace_back(inp_merged);
simplified_out_shape->emplace_back(out_merged);
}
}

View File

@ -0,0 +1,29 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_BINARY_BROADCAST_PUB_H_
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_BINARY_BROADCAST_PUB_H_
#include <stdint.h>
#include <vector>
bool IsBinaryBroadcast(const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape);
void SimplifyBinaryBroadcastShape(const std::vector<int64_t> &in0_shape, const std::vector<int64_t> &in1_shape,
const std::vector<int64_t> &out_shape, std::vector<int64_t> *simplified_in0_shape,
std::vector<int64_t> *simplified_in1_shape,
std::vector<int64_t> *simplified_out_shape);
void SimplifyBroadcastToShape(const std::vector<int64_t> &inp_shape, const std::vector<int64_t> &out_shape,
std::vector<int64_t> *simplified_inp_shape, std::vector<int64_t> *simplified_out_shape);
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_BINARY_BROADCAST_PUB_H_

View File

@ -18,7 +18,8 @@
#include <vector>
#include "mindspore/core/ops/lu_solve_.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/matrix_transpose_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/broadcast_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/broadcast_to_impl.cuh"
#include "plugin/device/gpu/kernel/math/broadcast_public.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
namespace mindspore {
@ -149,26 +150,22 @@ bool LuSolveGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, co
MatrixTranspose(b, LongToSize(batch_num_b_ * m_ * k_), SizeToInt(m_), SizeToInt(k_), b_col_major, device_id_,
cuda_stream_);
if (need_broadcast_) {
std::vector<int64_t> simplified_inp_shape;
std::vector<int64_t> simplified_out_shape;
// expand_size :(*,m,m)
auto origin_size = lhs_shape_;
auto expand_size = output_shape_;
expand_size[out_shape_len_ - kIndex1] = m_;
BroadcastTo(origin_size[kIndex0], origin_size[kIndex1], origin_size[kIndex2], origin_size[kIndex3],
origin_size[kIndex4], origin_size[kIndex5], origin_size[kIndex6], origin_size[kIndex7],
expand_size[kIndex0], expand_size[kIndex1], expand_size[kIndex2], expand_size[kIndex3],
expand_size[kIndex4], expand_size[kIndex5], expand_size[kIndex6], expand_size[kIndex7], a_col_major,
a_broadcast, cuda_stream_);
SimplifyBroadcastToShape(origin_size, expand_size, &simplified_inp_shape, &simplified_out_shape);
BroadcastTo(simplified_inp_shape, simplified_out_shape, a_col_major, a_broadcast, device_id_, cuda_stream_);
// expand_size :(*,k,m)
origin_size = rhs_shape_;
expand_size = output_shape_;
std::swap(origin_size[out_shape_len_ - kIndex1], origin_size[out_shape_len_ - kIndex2]);
std::swap(expand_size[out_shape_len_ - kIndex1], expand_size[out_shape_len_ - kIndex2]);
BroadcastTo(origin_size[kIndex0], origin_size[kIndex1], origin_size[kIndex2], origin_size[kIndex3],
origin_size[kIndex4], origin_size[kIndex5], origin_size[kIndex6], origin_size[kIndex7],
expand_size[kIndex0], expand_size[kIndex1], expand_size[kIndex2], expand_size[kIndex3],
expand_size[kIndex4], expand_size[kIndex5], expand_size[kIndex6], expand_size[kIndex7], b_col_major,
b_broadcast, cuda_stream_);
SimplifyBroadcastToShape(origin_size, expand_size, &simplified_inp_shape, &simplified_out_shape);
BroadcastTo(simplified_inp_shape, simplified_out_shape, b_col_major, b_broadcast, device_id_, cuda_stream_);
// origin_size:(*,m,1)
// expand_size :(*,m,1)
@ -176,11 +173,8 @@ bool LuSolveGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, co
origin_size[out_shape_len_ - kIndex1] = 1;
expand_size = output_shape_;
expand_size[out_shape_len_ - kIndex1] = 1;
BroadcastTo(origin_size[kIndex0], origin_size[kIndex1], origin_size[kIndex2], origin_size[kIndex3],
origin_size[kIndex4], origin_size[kIndex5], origin_size[kIndex6], origin_size[kIndex7],
expand_size[kIndex0], expand_size[kIndex1], expand_size[kIndex2], expand_size[kIndex3],
expand_size[kIndex4], expand_size[kIndex5], expand_size[kIndex6], expand_size[kIndex7], piv_array,
piv_broadcast, cuda_stream_);
SimplifyBroadcastToShape(origin_size, expand_size, &simplified_inp_shape, &simplified_out_shape);
BroadcastTo(simplified_inp_shape, simplified_out_shape, piv_array, piv_broadcast, device_id_, cuda_stream_);
} else {
a_broadcast = a_col_major;
b_broadcast = b_col_major;

View File

@ -63,9 +63,9 @@ class LuSolveGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelper<
int64_t k_{0};
bool is_null_input_{false};
std::vector<size_t> lhs_shape_;
std::vector<size_t> rhs_shape_;
std::vector<size_t> output_shape_;
std::vector<int64_t> lhs_shape_;
std::vector<int64_t> rhs_shape_;
std::vector<int64_t> output_shape_;
int64_t a_shape_len_{0};
int64_t b_shape_len_{0};
int64_t out_shape_len_{0};

View File

@ -1,155 +0,0 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "plugin/device/gpu/kernel/math/squared_difference_kernel.h"
#include <map>
#include <utility>
namespace mindspore {
namespace kernel {
using KernelRunFunc = SquaredDifferenceOpGpuKernelMod::KernelRunFunc;
bool SquaredDifferenceOpGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
return false;
}
return true;
}
int SquaredDifferenceOpGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
if (int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
return ret;
}
auto input_shape1 = Convert2SizeTClipNeg(inputs[0]->GetShapeVector());
auto input_shape2 = Convert2SizeTClipNeg(inputs[1]->GetShapeVector());
auto output_shape = Convert2SizeTClipNeg(outputs[0]->GetShapeVector());
need_broadcast_ = false;
if (input_shape1.size() != input_shape2.size()) {
need_broadcast_ = true;
} else {
for (size_t i = 0; i < input_shape1.size(); i++) {
if (input_shape1[i] != input_shape2[i]) {
need_broadcast_ = true;
}
}
}
if (need_broadcast_ && output_shape.size() > MAX_DIMS) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of output cannot be greater than " << MAX_DIMS
<< ", but got " << output_shape.size();
}
lhs_shape_.resize(MAX_DIMS, 1);
rhs_shape_.resize(MAX_DIMS, 1);
output_shape_.resize(MAX_DIMS, 1);
output_num_ = 1;
for (size_t i = 0; i < output_shape.size(); i++) {
if (need_broadcast_) {
output_shape_[i] = output_shape[i];
}
output_num_ *= static_cast<size_t>(output_shape[i]);
}
size_t lhs_offset = output_shape.size() - input_shape1.size();
for (size_t j = 0; j < input_shape1.size(); j++) {
if (need_broadcast_) {
if ((j + lhs_offset) < MAX_DIMS) {
lhs_shape_[j + lhs_offset] = input_shape1[j];
} else {
auto index = j + lhs_offset;
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the index of input cannot be " << index << ", but got "
<< index;
}
}
}
size_t rhs_offset = output_shape.size() - input_shape2.size();
for (size_t k = 0; k < input_shape2.size(); k++) {
if (need_broadcast_) {
if ((k + rhs_offset) < MAX_DIMS) {
rhs_shape_[k + rhs_offset] = input_shape2[k];
} else {
auto index = k + rhs_offset;
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the index of input cannot be " << index << ", but got "
<< index;
}
}
}
return KRET_OK;
}
template <typename T>
bool SquaredDifferenceOpGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
T *lhs = GetDeviceAddress<T>(inputs, 0);
T *rhs = GetDeviceAddress<T>(inputs, 1);
T *output = GetDeviceAddress<T>(outputs, 0);
if (need_broadcast_) {
BroadcastArith(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output,
reinterpret_cast<cudaStream_t>(stream_ptr_));
} else {
ElewiseArith(output_num_, op_type_, lhs, rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr_));
}
return true;
}
template <typename T>
bool SquaredDifferenceOpGpuKernelMod::LaunchComplexKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
T *lhs = GetDeviceAddress<T>(inputs, 0);
T *rhs = GetDeviceAddress<T>(inputs, 1);
T *output = GetDeviceAddress<T>(outputs, 0);
if (need_broadcast_) {
BroadcastComplexArith(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output,
reinterpret_cast<cudaStream_t>(stream_ptr_));
} else {
ElewiseComplexArith(output_num_, op_type_, lhs, rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr_));
}
return true;
}
#define DTYPE_REGISTER_ATTR(INPUT1, INPUT2, OUTPUT, T) \
{ \
KernelAttr().AddInputAttr(INPUT1).AddInputAttr(INPUT2).AddOutputAttr(OUTPUT), \
&SquaredDifferenceOpGpuKernelMod::LaunchKernel<T> \
}
#define COMPLEX_REGISTER_ATTR(INPUT1, INPUT2, OUTPUT, T) \
{ \
KernelAttr().AddInputAttr(INPUT1).AddInputAttr(INPUT2).AddOutputAttr(OUTPUT), \
&SquaredDifferenceOpGpuKernelMod::LaunchComplexKernel<T> \
}
template <typename T>
using Complex = mindspore::utils::Complex<T>;
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &SquaredDifferenceOpGpuKernelMod::GetFuncList() const {
static const std::vector<std::pair<KernelAttr, KernelRunFunc>> func_list = {
DTYPE_REGISTER_ATTR(kNumberTypeFloat32, kNumberTypeFloat32, kNumberTypeFloat32, float),
DTYPE_REGISTER_ATTR(kNumberTypeFloat64, kNumberTypeFloat64, kNumberTypeFloat64, double),
COMPLEX_REGISTER_ATTR(kNumberTypeComplex64, kNumberTypeComplex64, kNumberTypeComplex64, Complex<float>),
COMPLEX_REGISTER_ATTR(kNumberTypeComplex128, kNumberTypeComplex128, kNumberTypeComplex128, Complex<double>),
DTYPE_REGISTER_ATTR(kNumberTypeFloat16, kNumberTypeFloat16, kNumberTypeFloat16, half),
DTYPE_REGISTER_ATTR(kNumberTypeInt64, kNumberTypeInt64, kNumberTypeInt64, int64_t),
DTYPE_REGISTER_ATTR(kNumberTypeInt32, kNumberTypeInt32, kNumberTypeInt32, int)};
return func_list;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, SquaredDifference, SquaredDifferenceOpGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -1,75 +0,0 @@
/**
* Copyright 2020-2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_SQUARED_DIFFERENCE_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_SQUARED_DIFFERENCE_GPU_KERNEL_H_
#include <cuda_runtime_api.h>
#include <vector>
#include <string>
#include <map>
#include <complex>
#include <utility>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/broadcast_impl.cuh"
#include "plugin/device/gpu/kernel/kernel_constants.h"
namespace mindspore {
namespace kernel {
constexpr int MAX_DIMS = 7;
class SquaredDifferenceOpGpuKernelMod : public NativeGpuKernelMod,
public MatchKernelHelper<SquaredDifferenceOpGpuKernelMod> {
public:
SquaredDifferenceOpGpuKernelMod() = default;
~SquaredDifferenceOpGpuKernelMod() override = default;
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
stream_ptr_ = stream_ptr;
return kernel_func_(this, inputs, workspace, outputs);
}
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
private:
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
template <typename T>
bool LaunchComplexKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
BinaryOpType op_type_{BinaryOpType::kSquaredDifference};
bool need_broadcast_;
size_t output_num_;
std::vector<size_t> lhs_shape_;
std::vector<size_t> rhs_shape_;
std::vector<size_t> output_shape_;
void *stream_ptr_{nullptr};
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_SQUARED_DIFFERENCE_GPU_KERNEL_H_

View File

@ -20,9 +20,10 @@
#include <functional>
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cast_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/broadcast_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_ops_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/l2normalize_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/clip_by_norm_impl.cuh"
#include "plugin/device/gpu/kernel/math/broadcast_public.h"
namespace mindspore {
namespace kernel {
@ -141,27 +142,38 @@ bool ClipByNormGpuKernelMod<T, S>::DoLaunch(const std::vector<AddressPtr> &input
l2norm_output_addr),
kernel_name_ + " running cudnnReduceTensor::cudnnReduceTensorNorm2 failed.");
}
auto l2_norm_lhs_shape_size = Convert2SizeTClipNeg(l2_norm_lhs_shape_);
auto l2_norm_rhs_shap_size = Convert2SizeTClipNeg(l2_norm_rhs_shape_);
auto l2_norm_ouths_shape_size = Convert2SizeTClipNeg(l2_norm_ouths_shape_);
std::vector<int64_t> simplified_in0_shape;
std::vector<int64_t> simplified_in1_shape;
std::vector<int64_t> simplified_out_shape;
SimplifyBinaryBroadcastShape(l2_norm_lhs_shape_, l2_norm_rhs_shape_, l2_norm_ouths_shape_, &simplified_in0_shape,
&simplified_in1_shape, &simplified_out_shape);
bool is_broadcast = IsBinaryBroadcast(simplified_in0_shape, simplified_in1_shape);
// Calculation std::max(l2_norm, epsilon) to keep numerical stability.
GetMaxWithEpsAndValue(l2_norm_output_size_ / sizeof(float), epsilon_, l2norm_output_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
// Running `x/l2_norm(x)` and broadcast output shape to `input_x` shape
BroadcastArith(l2_norm_lhs_shape_size, l2_norm_rhs_shap_size, l2_norm_ouths_shape_size, BinaryOpType::kRealDiv,
x_float_addr, l2norm_output_addr, div_output_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
BinaryOpWithBroadcastCudaFunc<BinaryOpType::kRealDiv, float, float, float>(
is_broadcast, simplified_in0_shape, simplified_in1_shape, simplified_out_shape, x_float_addr, l2norm_output_addr,
div_output_addr, device_id_, reinterpret_cast<cudaStream_t>(stream_ptr));
// Running `cast(clip_norm)` to the data type of `input_x`
Cast(clip_norm_size_ / sizeof(S), clip_norm_addr, clip_norm_float_addr, reinterpret_cast<cudaStream_t>(stream_ptr),
GET_CTX_DEVICE_ID);
// Running '(x/l2_norm(x)) * clip_norm' and broadcast output shape to `input_x` shape
if (clip_norm_need_broadcast_) {
BroadcastArith(l2_norm_ouths_shape_size, Convert2SizeTClipNeg(clip_norm_rhs_shape_), l2_norm_ouths_shape_size,
BinaryOpType::kMul, div_output_addr, clip_norm_float_addr, clip_norm_mul_output_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
SimplifyBinaryBroadcastShape(l2_norm_ouths_shape_, clip_norm_rhs_shape_, l2_norm_ouths_shape_,
&simplified_in0_shape, &simplified_in1_shape, &simplified_out_shape);
is_broadcast = IsBinaryBroadcast(simplified_in0_shape, simplified_in1_shape);
BinaryOpWithBroadcastCudaFunc<BinaryOpType::kMul, float, float, float>(
is_broadcast, simplified_in0_shape, simplified_in1_shape, simplified_out_shape, div_output_addr,
clip_norm_float_addr, clip_norm_mul_output_addr, device_id_, reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
ElewiseArith(output_size_ / sizeof(float), BinaryOpType::kMul, div_output_addr, clip_norm_float_addr,
clip_norm_mul_output_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
std::vector<int64_t> ele_shape = {static_cast<int64_t>(output_size_ / sizeof(float))};
BinaryOpWithBroadcastCudaFunc<BinaryOpType::kMul, float, float, float>(
false, ele_shape, ele_shape, ele_shape, div_output_addr, clip_norm_float_addr, clip_norm_mul_output_addr,
device_id_, reinterpret_cast<cudaStream_t>(stream_ptr));
}
// Running compare between `input_x` and `upper output` and cast final output to float type.
CompOp(output_size_ / sizeof(float), x_float_addr, clip_norm_mul_output_addr, output_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
@ -242,7 +254,8 @@ void ClipByNormGpuKernelMod<T, S>::InitAxisAndEpsilon(const ops::ClipByNormPtr &
for (size_t i = 0; i < x_dim_; ++i) {
axis_.emplace_back(i); // Reduce for all dimensions.
}
} else { // Convert negative `axis` to positive `axis` and keep number unique
} else {
// Convert negative `axis` to positive `axis` and keep number unique
int64_t temp_x_dim = SizeToLong(x_dim_);
std::for_each(temp_axis.begin(), temp_axis.end(), [this, &temp_x_dim](const int64_t &value) {
value < 0 ? axis_.emplace_back(LongToSize(value + temp_x_dim)) : axis_.emplace_back(LongToSize(value));

View File

@ -23,8 +23,9 @@
#include <memory>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/broadcast_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_ops_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/l2normalize_impl.cuh"
#include "plugin/device/gpu/kernel/math/broadcast_public.h"
#include "plugin/device/gpu/kernel/kernel_constants.h"
#include "mindspore/core/ops/l2_normalize.h"
@ -88,12 +89,15 @@ class L2NormalizeGpuKernelMod : public NativeGpuKernelMod {
}
GetMaxWithEpsAndValue(workspace_size_list_[0] / sizeof(T), epsilon_, reduce_workspace_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
auto lhs_shape_size = Convert2SizeTClipNeg(lhs_shape_);
auto rhs_shape_size = Convert2SizeTClipNeg(rhs_shape_);
auto output_shape_size = Convert2SizeTClipNeg(output_shape_);
BroadcastArith(lhs_shape_size, rhs_shape_size, output_shape_size, BinaryOpType::kRealDiv, input_addr,
reduce_workspace_addr, output_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
std::vector<int64_t> simplified_in0_shape;
std::vector<int64_t> simplified_in1_shape;
std::vector<int64_t> simplified_out_shape;
SimplifyBinaryBroadcastShape(lhs_shape_, rhs_shape_, output_shape_, &simplified_in0_shape, &simplified_in1_shape,
&simplified_out_shape);
bool is_broadcast = IsBinaryBroadcast(simplified_in0_shape, simplified_in1_shape);
BinaryOpWithBroadcastCudaFunc<BinaryOpType::kRealDiv, T, T, T>(
is_broadcast, simplified_in0_shape, simplified_in1_shape, simplified_out_shape, input_addr, reduce_workspace_addr,
output_addr, GET_CTX_DEVICE_ID, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
@ -122,7 +126,7 @@ class L2NormalizeGpuKernelMod : public NativeGpuKernelMod {
}
ShapeVector outputC_shape = output_shape;
if ((size_t)axis_ >= output_shape.size()) {
if (static_cast<size_t>(axis_) >= output_shape.size()) {
MS_LOG(EXCEPTION) << "For 'L2NormalizeGpuKernelMod', axis_ must be less than the rank of output "
<< "but got axis_: " << axis_ << ", rank of output: " << output_shape.size();
}

View File

@ -71,7 +71,7 @@ int L2NormalizeGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
}
ShapeVector output_reduce_shape = output_shape;
if ((size_t)axis_ >= output_shape.size()) {
if (static_cast<size_t>(axis_) >= output_shape.size()) {
MS_LOG(ERROR) << "For 'L2NormalizeGradGpuKernelMod', axis_ must be less than the rank of output "
<< "but got axis_: " << axis_ << ", rank of output: " << output_shape.size();
return KRET_RESIZE_FAILED;
@ -83,9 +83,9 @@ int L2NormalizeGradGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
output_shape_.resize(MAX_DIMS, 1);
all_match_ = true;
for (size_t i = 0; i < output_shape.size(); i++) {
output_shape_[i] = LongToSizeClipNeg(output_shape[i]);
lhs_shape_[i] = LongToSizeClipNeg(output_shape[i]);
rhs_shape_[i] = LongToSizeClipNeg(output_reduce_shape[i]);
output_shape_[i] = output_shape[i];
lhs_shape_[i] = output_shape[i];
rhs_shape_[i] = output_reduce_shape[i];
if (lhs_shape_[i] != rhs_shape_[i]) {
all_match_ = false;
}
@ -137,8 +137,15 @@ bool L2NormalizeGradGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &in
}
GetMaxWithEpsAndValue(workspace_size_list_[0] / sizeof(T), epsilon_, reduce_workspace_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
BroadcastArith(output_shape_, output_shape_, output_shape_, BinaryOpType::kMul, y_addr, dy_addr, dx_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
std::vector<int64_t> simplified_in0_shape;
std::vector<int64_t> simplified_in1_shape;
std::vector<int64_t> simplified_out_shape;
SimplifyBinaryBroadcastShape(output_shape_, output_shape_, output_shape_, &simplified_in0_shape,
&simplified_in1_shape, &simplified_out_shape);
BinaryOpWithBroadcastCudaFunc<BinaryOpType::kMul, T, T, T>(false, simplified_in0_shape, simplified_in1_shape,
simplified_out_shape, y_addr, dy_addr, dx_addr, device_id_,
reinterpret_cast<cudaStream_t>(stream_ptr));
if (all_match_) {
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(reduce_y_dy_workspace_addr, dx_addr, output_size_list_[0], cudaMemcpyDeviceToDevice,
@ -161,13 +168,24 @@ bool L2NormalizeGradGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &in
kernel_name_ + " cudnnReduceTensor failed.");
}
}
BroadcastArith(rhs_shape_, lhs_shape_, output_shape_, BinaryOpType::kMul, reduce_y_dy_workspace_addr, y_addr, dx_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
BroadcastArith(output_shape_, output_shape_, output_shape_, BinaryOpType::kSub, dy_addr, dx_addr, dx_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
BroadcastArith(output_shape_, rhs_shape_, output_shape_, BinaryOpType::kRealDiv, dx_addr, reduce_workspace_addr,
dx_addr, reinterpret_cast<cudaStream_t>(stream_ptr));
SimplifyBinaryBroadcastShape(rhs_shape_, lhs_shape_, output_shape_, &simplified_in0_shape, &simplified_in1_shape,
&simplified_out_shape);
bool is_broadcast = IsBinaryBroadcast(simplified_in0_shape, simplified_in1_shape);
BinaryOpWithBroadcastCudaFunc<BinaryOpType::kMul, T, T, T>(
is_broadcast, simplified_in0_shape, simplified_in1_shape, simplified_out_shape, reduce_y_dy_workspace_addr, y_addr,
dx_addr, device_id_, reinterpret_cast<cudaStream_t>(stream_ptr));
SimplifyBinaryBroadcastShape(output_shape_, output_shape_, output_shape_, &simplified_in0_shape,
&simplified_in1_shape, &simplified_out_shape);
BinaryOpWithBroadcastCudaFunc<BinaryOpType::kSub, T, T, T>(false, simplified_in0_shape, simplified_in1_shape,
simplified_out_shape, dy_addr, dx_addr, dx_addr,
device_id_, reinterpret_cast<cudaStream_t>(stream_ptr));
SimplifyBinaryBroadcastShape(output_shape_, rhs_shape_, output_shape_, &simplified_in0_shape, &simplified_in1_shape,
&simplified_out_shape);
is_broadcast = IsBinaryBroadcast(simplified_in0_shape, simplified_in1_shape);
BinaryOpWithBroadcastCudaFunc<BinaryOpType::kRealDiv, T, T, T>(
is_broadcast, simplified_in0_shape, simplified_in1_shape, simplified_out_shape, dx_addr, reduce_workspace_addr,
dx_addr, device_id_, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}

View File

@ -23,7 +23,8 @@
#include <utility>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/broadcast_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_ops_impl.cuh"
#include "plugin/device/gpu/kernel/math/broadcast_public.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/l2normalize_impl.cuh"
#include "plugin/device/gpu/kernel/kernel_constants.h"
namespace mindspore {
@ -172,9 +173,9 @@ class L2NormalizeGradGpuKernelMod : public NativeGpuKernelMod {
float epsilon_{0.0};
int axis_origin_{0};
int axis_{0};
std::vector<size_t> lhs_shape_{};
std::vector<size_t> rhs_shape_{};
std::vector<size_t> output_shape_{};
std::vector<int64_t> lhs_shape_{};
std::vector<int64_t> rhs_shape_{};
std::vector<int64_t> output_shape_{};
L2NormalizeGradGpuLaunchFunc kernel_func_;
static std::vector<std::pair<KernelAttr, L2NormalizeGradGpuLaunchFunc>> func_list_;

View File

@ -25,7 +25,7 @@
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/avg_pool3d_helper_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/broadcast_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_ops_impl.cuh"
#include "plugin/device/gpu/kernel/kernel_constants.h"
namespace mindspore {
@ -102,8 +102,10 @@ class PoolingFwdGpuKernelMod : public NativeGpuKernelMod {
CalRealKernelSize(output_shape_exclude_nc_, kernel_size_, edge_kernel_, work_addr, 0,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
ElewiseArith(output_num, BinaryOpType::kMul, output_addr, work_addr, output_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
std::vector<int64_t> shape = {static_cast<int64_t>(output_num)};
BinaryOpWithBroadcastCudaFunc<BinaryOpType::kMul, T, T, T>(false, shape, shape, shape, output_addr, work_addr,
output_addr, device_id_,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
return true;
}

View File

@ -20,7 +20,7 @@
#include "mindspore/core/ops/grad/pool_grad.h"
#include "mindspore/core/ops/grad/avg_pool_3d_grad.h"
#include "mindspore/core/ops/grad/max_pool_3d_grad.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/broadcast_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_ops_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/avg_pool3d_helper_impl.cuh"
namespace mindspore {
@ -159,8 +159,10 @@ bool PoolingGradGpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr>
CalRealKernelSize(input_shape_, kernel_size_, edge_kernel_, work_addr, device_id_,
reinterpret_cast<cudaStream_t>(cuda_stream_));
}
ElewiseArith(output_num, BinaryOpType::kMul, dy_work_addr, work_addr, dy_work_addr,
reinterpret_cast<cudaStream_t>(cuda_stream_));
std::vector<int64_t> shape = {static_cast<int64_t>(output_num)};
BinaryOpWithBroadcastCudaFunc<BinaryOpType::kMul, T, T, T>(false, shape, shape, shape, dy_work_addr, work_addr,
dy_work_addr, device_id_,
reinterpret_cast<cudaStream_t>(cuda_stream_));
if (cudnn_data_type_ == CUDNN_DATA_DOUBLE) {
CHECK_CUDNN_RET_WITH_EXCEPT_NOTRACE(
cudnnPoolingBackward(cudnn_handle_, pooling_descriptor_, &alpha, y_descriptor_, y, dy_descriptor_, dy_work_addr,

View File

@ -64,15 +64,12 @@ bool SequenceAddNGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &input
size_t element_num = outputs[0]->size / sizeof(T);
FillDeviceArray(outputs[0]->size / sizeof(T), output_addr, 0.0f, reinterpret_cast<cudaStream_t>(stream_ptr_));
FillDeviceArray(outputs[0]->size / sizeof(T), work_addr, 0.0f, reinterpret_cast<cudaStream_t>(stream_ptr_));
std::vector<int64_t> ele_shape = {static_cast<int64_t>(element_num)};
for (int64_t i = 0; i < tuple_shape_[0]; i++) {
T *input_addr = element_num * i + input_0;
if constexpr (std::is_same<T, Complex<float>>::value || std::is_same<T, Complex<double>>::value) {
ElewiseComplexArith(outputs[0]->size / sizeof(T), BinaryOpType::kAdd, input_addr, work_addr, work_addr,
reinterpret_cast<cudaStream_t>(stream_ptr_));
} else {
ElewiseArith(outputs[0]->size / sizeof(T), BinaryOpType::kAdd, input_addr, work_addr, work_addr,
reinterpret_cast<cudaStream_t>(stream_ptr_));
}
BinaryOpWithBroadcastCudaFunc<BinaryOpType::kAdd, T, T, T>(false, ele_shape, ele_shape, ele_shape, input_addr,
work_addr, work_addr, device_id_,
reinterpret_cast<cudaStream_t>(stream_ptr_));
}
if (work_addr != output_addr) {

View File

@ -23,7 +23,7 @@
#include <string>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h"
#include "plugin/device/gpu/kernel/math/broadcast_gpu_kernel.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/binary_ops_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/slice_impl.cuh"
#include "plugin/factory/ms_factory.h"