!17402 Meshgrid gpu kernel

From: @peilin-wang
Reviewed-by: @robingrosman,@nsyca
Signed-off-by: @robingrosman
This commit is contained in:
mindspore-ci-bot 2021-06-04 07:56:21 +08:00 committed by Gitee
commit bd19b654aa
6 changed files with 327 additions and 4 deletions

View File

@ -0,0 +1,65 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/arrays/meshgrid_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
Meshgrid, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
MeshgridGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
Meshgrid, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
MeshgridGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
Meshgrid, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
MeshgridGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(Meshgrid,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
MeshgridGpuKernel, uint8_t)
MS_REG_GPU_KERNEL_ONE(
Meshgrid, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
MeshgridGpuKernel, uint16_t)
MS_REG_GPU_KERNEL_ONE(
Meshgrid, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
MeshgridGpuKernel, uint32_t)
MS_REG_GPU_KERNEL_ONE(
Meshgrid, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
MeshgridGpuKernel, uint64_t)
MS_REG_GPU_KERNEL_ONE(Meshgrid,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
MeshgridGpuKernel, int8_t)
MS_REG_GPU_KERNEL_ONE(Meshgrid,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
MeshgridGpuKernel, int16_t)
MS_REG_GPU_KERNEL_ONE(Meshgrid,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
MeshgridGpuKernel, int32_t)
MS_REG_GPU_KERNEL_ONE(Meshgrid,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
MeshgridGpuKernel, int64_t)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,155 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_MESHGRID_GPU_KERNEL_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_MESHGRID_GPU_KERNEL_H
#include <cstdint>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cuh"
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h"
namespace mindspore {
namespace kernel {
template <typename T>
class MeshgridGpuKernel : public GpuKernel {
public:
MeshgridGpuKernel() { ResetResource(); }
~MeshgridGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *ones_device = GetDeviceAddress<T>(workspace, 0);
CalOnesLike(output_size_, static_cast<T *>(nullptr), ones_device, reinterpret_cast<cudaStream_t>(stream_ptr));
std::vector<size_t> broadcasted_ones_shape(MAX_DIMS, 1);
for (size_t i = 0; i < output_shape_.size(); i++) {
broadcasted_ones_shape[i] = output_shape_[i];
}
for (size_t i = 0; i < outputs.size(); i++) {
T *input_device = GetDeviceAddress<T>(inputs, i);
T *output_device = GetDeviceAddress<T>(outputs, i);
std::vector<size_t> broadcasted_input_shape(MAX_DIMS, 1);
broadcasted_input_shape[i] = input_shapes_[i];
if (swap_indexing_ && i < 2) {
std::swap(broadcasted_input_shape[0], broadcasted_input_shape[1]);
}
BroadcastArith(broadcasted_input_shape, broadcasted_ones_shape, output_shape_, BROADCAST_TYPE_MUL, input_device,
ones_device, output_device, reinterpret_cast<cudaStream_t>(stream_ptr));
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
std::string indexing = GetAttr<std::string>(kernel_node, "indexing");
if (indexing == "xy") {
swap_indexing_ = true;
} else if (indexing == "ij") {
swap_indexing_ = false;
} else {
MS_LOG(ERROR) << "invalid string for argument \"indexing\", must be \"xy\" or \"ij\" but got " << indexing;
}
input_size_ = 1;
input_count_ = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t i = 0; i < input_count_; i++) {
size_t input_shape = AnfAlgo::GetInputDeviceShape(kernel_node, i)[0];
input_shapes_.push_back(input_shape);
input_size_ *= input_shape;
}
output_size_ = 1;
output_count_ = AnfAlgo::GetOutputTensorNum(kernel_node);
// inferred shape swaps output shape for us if needed
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
if (output_count_ != input_count_) {
MS_LOG(ERROR) << "output count is " << output_count_ << ", but MeshgridGpuKernel needs " << input_count_
<< " output(s).";
return false;
}
for (size_t i = 0; i < output_shape_.size(); i++) {
output_size_ *= output_shape_[i];
}
// need to pad output shape with ones for broadcast kernel
for (size_t i = 0; i < output_shape_.size() - MAX_DIMS; i++) {
output_shape_.push_back(1);
}
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
input_shapes_.clear();
output_shape_.clear();
input_size_ = 0;
input_count_ = 0;
output_size_ = 0;
output_count_ = 0;
swap_indexing_ = true;
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
for (const size_t &input_shape : input_shapes_) {
input_size_list_.push_back(input_shape * sizeof(T));
}
for (size_t i = 0; i < output_count_; i++) {
output_size_list_.push_back(output_size_ * sizeof(T));
}
workspace_size_list_.push_back(output_size_ * sizeof(T));
}
private:
std::vector<size_t> input_shapes_;
std::vector<size_t> output_shape_;
size_t input_size_;
size_t input_count_;
size_t output_size_;
size_t output_count_;
bool swap_indexing_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_MESHGRID_GPU_KERNEL_H

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,11 +16,27 @@
#include "backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(OnesLike, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
OnesLikeGpuKernel, double)
MS_REG_GPU_KERNEL_ONE(OnesLike, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
OnesLikeGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(OnesLike, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
OnesLikeGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(OnesLike, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
OnesLikeGpuKernel, int8_t)
MS_REG_GPU_KERNEL_ONE(OnesLike, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
OnesLikeGpuKernel, int16_t)
MS_REG_GPU_KERNEL_ONE(OnesLike, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
OnesLikeGpuKernel, int)
OnesLikeGpuKernel, int32_t)
MS_REG_GPU_KERNEL_ONE(OnesLike, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
OnesLikeGpuKernel, int64_t)
MS_REG_GPU_KERNEL_ONE(OnesLike, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
OnesLikeGpuKernel, uint8_t)
MS_REG_GPU_KERNEL_ONE(OnesLike, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
OnesLikeGpuKernel, uint16_t)
MS_REG_GPU_KERNEL_ONE(OnesLike, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
OnesLikeGpuKernel, uint32_t)
MS_REG_GPU_KERNEL_ONE(OnesLike, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
OnesLikeGpuKernel, uint64_t)
} // namespace kernel
} // namespace mindspore

View File

@ -323,6 +323,28 @@ struct FloorModFunc<half2> {
}
};
// the FloorModFunc specializations for uint32_t and uint64_t are there
// because of a 'more than one instance of overloaded function "std::abs"'
// error. I realize the specializations are exactly the same, but I found
// no good alternative.
template <>
struct FloorModFunc<uint32_t> {
__device__ __host__ __forceinline__ int32_t operator()(const int32_t &lhs, const int32_t &rhs) {
int32_t res = lhs - floorf(lhs / rhs) * rhs;
res = (res > 1e-9) && ((res < 0.0) != (rhs < 0.0)) ? res + rhs : res;
return res;
}
};
template <>
struct FloorModFunc<uint64_t> {
__device__ __host__ __forceinline__ int64_t operator()(const int64_t &lhs, const int64_t &rhs) {
int64_t res = lhs - floorf(lhs / rhs) * rhs;
res = (res > 1e-9) && ((res < 0.0) != (rhs < 0.0)) ? res + rhs : res;
return res;
}
};
template <typename T>
struct AbsGradFunc {
__device__ __forceinline__ T operator()(const T &lhs, const T &rhs) {
@ -429,6 +451,12 @@ template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const int64_t
cudaStream_t stream);
template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const int16_t *x0, const int16_t *x1, bool *y,
cudaStream_t stream);
template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const uint16_t *x0, const uint16_t *x1, bool *y,
cudaStream_t stream);
template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const uint32_t *x0, const uint32_t *x1, bool *y,
cudaStream_t stream);
template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const uint64_t *x0, const uint64_t *x1, bool *y,
cudaStream_t stream);
template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const bool *x0, const bool *x1, bool *y,
cudaStream_t stream);
// Element-wise ArithMetic
@ -510,6 +538,12 @@ template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int64
cudaStream_t stream);
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int16_t *x0, const int16_t *x1, int16_t *y,
cudaStream_t stream);
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const uint16_t *x0, const uint16_t *x1,
uint16_t *y, cudaStream_t stream);
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const uint32_t *x0, const uint32_t *x1,
uint32_t *y, cudaStream_t stream);
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const uint64_t *x0, const uint64_t *x1,
uint64_t *y, cudaStream_t stream);
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const bool *x0, const bool *x1, bool *y,
cudaStream_t stream);
// Broadcast comparison
@ -628,6 +662,15 @@ template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector
template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const int16_t *x0,
const int16_t *x1, bool *y, cudaStream_t stream);
template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const uint16_t *x0,
const uint16_t *x1, bool *y, cudaStream_t stream);
template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const uint32_t *x0,
const uint32_t *x1, bool *y, cudaStream_t stream);
template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const uint64_t *x0,
const uint64_t *x1, bool *y, cudaStream_t stream);
template void BroadcastCmp(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const bool *x0,
const bool *x1, bool *y, cudaStream_t stream);
@ -780,6 +823,15 @@ template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vect
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const int16_t *x0,
const int16_t *x1, int16_t *y, cudaStream_t stream);
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const uint16_t *x0,
const uint16_t *x1, uint16_t *y, cudaStream_t stream);
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const uint32_t *x0,
const uint32_t *x1, uint32_t *y, cudaStream_t stream);
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const uint64_t *x0,
const uint64_t *x1, uint64_t *y, cudaStream_t stream);
template void BroadcastArith(const std::vector<size_t> &x0_dims, const std::vector<size_t> &x1_dims,
const std::vector<size_t> &y_dims, enum BroadcastOpType op, const bool *x0,
const bool *x1, bool *y, cudaStream_t stream);

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -32,6 +32,17 @@ void CalOnesLike(const size_t size, const T* input, T* output, cudaStream_t cuda
return;
}
template void CalOnesLike<double>(const size_t size, const double* input, double* output, cudaStream_t cuda_stream);
template void CalOnesLike<float>(const size_t size, const float* input, float* output, cudaStream_t cuda_stream);
template void CalOnesLike<half>(const size_t size, const half* input, half* output, cudaStream_t cuda_stream);
template void CalOnesLike<int>(const size_t size, const int* input, int* output, cudaStream_t cuda_stream);
template void CalOnesLike<int8_t>(const size_t size, const int8_t* input, int8_t* output, cudaStream_t cuda_stream);
template void CalOnesLike<int16_t>(const size_t size, const int16_t* input, int16_t* output, cudaStream_t cuda_stream);
template void CalOnesLike<int32_t>(const size_t size, const int32_t* input, int32_t* output, cudaStream_t cuda_stream);
template void CalOnesLike<int64_t>(const size_t size, const int64_t* input, int64_t* output, cudaStream_t cuda_stream);
template void CalOnesLike<uint8_t>(const size_t size, const uint8_t* input, uint8_t* output, cudaStream_t cuda_stream);
template void CalOnesLike<uint16_t>(const size_t size, const uint16_t* input, uint16_t* output,
cudaStream_t cuda_stream);
template void CalOnesLike<uint32_t>(const size_t size, const uint32_t* input, uint32_t* output,
cudaStream_t cuda_stream);
template void CalOnesLike<uint64_t>(const size_t size, const uint64_t* input, uint64_t* output,
cudaStream_t cuda_stream);

View File

@ -348,6 +348,9 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE(
LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, int8_t)
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
BroadcastOpGpuKernel, int8_t)
// uint8
MS_REG_GPU_KERNEL_ONE(
@ -366,6 +369,9 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE(
NotEqual, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, uint8_t)
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
BroadcastOpGpuKernel, uint8_t)
// int16
MS_REG_GPU_KERNEL_ONE(
@ -381,6 +387,24 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE(
LessEqual, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, int16_t)
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
BroadcastOpGpuKernel, int16_t)
// uint16
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
BroadcastOpGpuKernel, uint16_t)
// uint32
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
BroadcastOpGpuKernel, uint32_t)
// uint64
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
BroadcastOpGpuKernel, uint64_t)
// bool
MS_REG_GPU_KERNEL_ONE(