From c00fd21b17f61db454d8709b597d3b848e2b0ece Mon Sep 17 00:00:00 2001 From: ZRN Date: Thu, 21 Jul 2022 16:02:39 +0800 Subject: [PATCH] Add GPU Operator CompareAndBitpack --- .../cuda_ops/compare_and_bitpack_impl.cu | 87 ++++++++++++ .../cuda_ops/compare_and_bitpack_impl.cuh | 27 ++++ .../math/compare_and_bitpack_gpu_kernel.cc | 130 ++++++++++++++++++ .../math/compare_and_bitpack_gpu_kernel.h | 78 +++++++++++ mindspore/core/ops/compareAndBitpack.cc | 9 +- .../mindspore/ops/operations/__init__.py | 2 +- .../mindspore/ops/operations/math_ops.py | 2 +- .../st/ops/gpu/test_compare_and_bitpack_op.py | 54 ++++++++ 8 files changed, 383 insertions(+), 6 deletions(-) create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/compare_and_bitpack_impl.cu create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/compare_and_bitpack_impl.cuh create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/math/compare_and_bitpack_gpu_kernel.cc create mode 100644 mindspore/ccsrc/plugin/device/gpu/kernel/math/compare_and_bitpack_gpu_kernel.h create mode 100644 tests/st/ops/gpu/test_compare_and_bitpack_op.py diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/compare_and_bitpack_impl.cu b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/compare_and_bitpack_impl.cu new file mode 100644 index 00000000000..5ade1de35b7 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/compare_and_bitpack_impl.cu @@ -0,0 +1,87 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/compare_and_bitpack_impl.cuh" +#include + + +template +__global__ void CompareAndBitpack(const T *x, const T *threshold, uint8_t *output, const size_t output_num) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_num; pos += blockDim.x * gridDim.x) { + uint8_t res; + res = (x[8 * pos] > *threshold) << 7; + res = res | ((x[8 * pos + 1] > *threshold) << 6); + res = res | ((x[8 * pos + 2] > *threshold) << 5); + res = res | ((x[8 * pos + 3] > *threshold) << 4); + res = res | ((x[8 * pos + 4] > *threshold) << 3); + res = res | ((x[8 * pos + 5] > *threshold) << 2); + res = res | ((x[8 * pos + 6] > *threshold) << 1); + res = res | (x[8 * pos + 7] > *threshold); + output[pos] = res; + } + return; +} + +template <> +__global__ void CompareAndBitpack(const bool *x, const bool *threshold, + uint8_t *output, const size_t output_num) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_num; pos += blockDim.x * gridDim.x) { + uint8_t res; + res = x[8 * pos] << 7; + res = res | (x[8 * pos + 1] << 6); + res = res | (x[8 * pos + 2] << 5); + res = res | (x[8 * pos + 3] << 4); + res = res | (x[8 * pos + 4] << 3); + res = res | (x[8 * pos + 5] << 2); + res = res | (x[8 * pos + 6] << 1); + res = res | x[8 * pos + 7]; + output[pos] = res; + } + return; +} + +template +void CalCompareAndBitpack(const T *x, const T *threshold, uint8_t *output, const size_t output_num, + const uint32_t &device_id, cudaStream_t cuda_stream) { + CompareAndBitpack<<>>( + x, threshold, output, output_num); + return; +} + +template CUDA_LIB_EXPORT void CalCompareAndBitpack( + const half *x, const half *threshold, uint8_t *output, const size_t output_num, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalCompareAndBitpack( + const float *x, const float *threshold, uint8_t *output, const size_t output_num, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalCompareAndBitpack( + const double *x, const double *threshold, uint8_t *output, const size_t output_num, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalCompareAndBitpack( + const int8_t *x, const int8_t *threshold, uint8_t *output, const size_t output_num, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalCompareAndBitpack( + const int16_t *x, const int16_t *threshold, uint8_t *output, const size_t output_num, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalCompareAndBitpack( + const int32_t *x, const int32_t *threshold, uint8_t *output, const size_t output_num, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalCompareAndBitpack( + const int64_t *x, const int64_t *threshold, uint8_t *output, const size_t output_num, + const uint32_t &device_id, cudaStream_t cuda_stream); +template CUDA_LIB_EXPORT void CalCompareAndBitpack( + const bool *x, const bool *threshold, uint8_t *output, const size_t output_num, + const uint32_t &device_id, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/compare_and_bitpack_impl.cuh b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/compare_and_bitpack_impl.cuh new file mode 100644 index 00000000000..b0984047ed9 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/cuda_impl/cuda_ops/compare_and_bitpack_impl.cuh @@ -0,0 +1,27 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_COMPARE_AND_BITPACK_IMPL_CUH_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_COMPARE_AND_BITPACK_IMPL_CUH_ +#include +#include +#include +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h" + +template +CUDA_LIB_EXPORT void CalCompareAndBitpack(const T *x, const T *threshold, uint8_t *output, const size_t output_num, + const uint32_t &device_id, cudaStream_t cuda_stream); +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_COMPARE_AND_BITPACK_IMPL_CUH_ diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/compare_and_bitpack_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/math/compare_and_bitpack_gpu_kernel.cc new file mode 100644 index 00000000000..1a01edc9670 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/compare_and_bitpack_gpu_kernel.cc @@ -0,0 +1,130 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/gpu/kernel/math/compare_and_bitpack_gpu_kernel.h" +#include +#include +#include +#include +#include +#include "include/curand.h" +#include "mindspore/core/ops/compareAndBitpack.h" +#include "abstract/utils.h" +#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/compare_and_bitpack_impl.cuh" +#include "include/common/utils/anfalgo.h" + +namespace mindspore { +namespace kernel { +constexpr size_t kBitpack = 8; +bool CompareAndBitpackGpuKernelMod::Init(const BaseOperatorPtr &base_operator, + const std::vector &inputs, + const std::vector &outputs) { + kernel_name_ = base_operator->name(); + if (inputs.empty() || outputs.empty()) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it got empty inputs or outputs, which is invalid."; + return false; + } + kernel_ptr_ = std::make_shared(base_operator->GetPrim()); + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the kernel type should be in " + << "[int8, int16, int32, int64, float16, float32, float64, bool], but got: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + x_unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first); + threshold_unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex1).first); + cudnn_handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCudnnHandle(); + return true; +} + +int CompareAndBitpackGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, + const std::vector &inputs, + const std::vector &outputs, + const std::map &) { + ResetResource(); + for (const auto &input : inputs) { + auto input_shape = input->GetShapeVector(); + if (!IsValidShape(input_shape)) { + return KRET_UNKNOWN_SHAPE; + } + } + auto x_long_shape = inputs.at(kIndex0)->GetShapeVector(); + std::vector x_shape; + (void)std::transform(x_long_shape.begin(), x_long_shape.end(), std::back_inserter(x_shape), LongToSize); + for (size_t i = 0; i < x_shape.size(); i++) { + x_count_ *= x_shape[i]; + } + y_count_ = x_count_ / kBitpack; + size_t x_size = x_count_ * x_unit_size_; + input_size_list_.emplace_back(x_size); + size_t threshold_size = threshold_unit_size_; + input_size_list_.emplace_back(threshold_size); + size_t output_size = y_count_ * sizeof(uint8_t); + output_size_list_.emplace_back(output_size); + size_t workspace_size = 0; + workspace_size_list_.emplace_back(workspace_size); + return KRET_OK; +} + +void CompareAndBitpackGpuKernelMod::ResetResource() noexcept { + is_null_input_ = false; + x_count_ = 1; + input_size_list_.clear(); + output_size_list_.clear(); + workspace_size_list_.clear(); +} + +template +bool CompareAndBitpackGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + T *x = GetDeviceAddress(inputs, kIndex0); + T *threshold = GetDeviceAddress(inputs, kIndex1); + uint8_t *y = GetDeviceAddress(outputs, kIndex0); + CalCompareAndBitpack(x, threshold, y, y_count_, device_id_, reinterpret_cast(cuda_stream_)); + return true; +} + +std::vector> + CompareAndBitpackGpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeUInt8), + &CompareAndBitpackGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeUInt8), + &CompareAndBitpackGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), + &CompareAndBitpackGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), + &CompareAndBitpackGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeUInt8), + &CompareAndBitpackGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeUInt8), + &CompareAndBitpackGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeUInt8), + &CompareAndBitpackGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeUInt8), + &CompareAndBitpackGpuKernelMod::LaunchKernel}}; + +std::vector CompareAndBitpackGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, CompareAndBitpack, CompareAndBitpackGpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/math/compare_and_bitpack_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/math/compare_and_bitpack_gpu_kernel.h new file mode 100644 index 00000000000..13c16543239 --- /dev/null +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/math/compare_and_bitpack_gpu_kernel.h @@ -0,0 +1,78 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_COMPARE_AND_BITPACK_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_COMPARE_AND_BITPACK_GPU_KERNEL_H_ +#include +#include +#include +#include +#include "plugin/device/gpu/kernel/gpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class CompareAndBitpackGpuKernelMod : public NativeGpuKernelMod { + public: + CompareAndBitpackGpuKernelMod() { ResetResource(); } + ~CompareAndBitpackGpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *cuda_stream) override { + if (is_null_input_) { + return true; + } + cuda_stream_ = cuda_stream; + return kernel_func_(this, inputs, workspace, outputs); + } + + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; + + int Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, const std::map &) override; + + std::vector GetOpSupport() override; + + private: + void ResetResource() noexcept; + + void CheckCompareAndBitpackShape(); + + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + using CompareAndBitpackFunc = + std::function &, + const std::vector &, const std::vector &)>; + + private: + size_t x_unit_size_{1}; + size_t threshold_unit_size_{1}; + bool is_null_input_{false}; + size_t x_count_{}; + size_t y_count_{}; + void *cuda_stream_{nullptr}; + BaseOperatorPtr kernel_ptr_{nullptr}; + cudnnHandle_t cudnn_handle_{}; + curandGenerator_t curand_generator_{nullptr}; + CompareAndBitpackFunc kernel_func_{}; + static std::vector> func_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_MATH_COMPARE_AND_BITPACK_GPU_KERNEL_H_ diff --git a/mindspore/core/ops/compareAndBitpack.cc b/mindspore/core/ops/compareAndBitpack.cc index bf5f34c45a6..6178d4d9390 100644 --- a/mindspore/core/ops/compareAndBitpack.cc +++ b/mindspore/core/ops/compareAndBitpack.cc @@ -45,10 +45,11 @@ abstract::ShapePtr CompareAndBitpackInferShape(const PrimitivePtr &primitive, (void)CheckAndConvertUtils::CheckInteger("x's rank'", x_rank, kNotEqual, kShapeSize_, primitive->name()); // check the innermost dimension of `x`'s shape is disvisible by 8. - (void)CheckAndConvertUtils::Check("x innermost dimension % 8", x_shape[x_rank - 1] % divisible_num, kEqual, 0, - primitive->name()); - - ShapeVector out_shape; + if (x_shape[x_rank - 1] != -1) { + (void)CheckAndConvertUtils::Check("x innermost dimension % 8", x_shape[x_rank - 1] % divisible_num, kEqual, 0, + primitive->name()); + } + std::vector out_shape; for (int dim = 0; dim < x_rank - 1; dim = dim + 1) { (void)out_shape.emplace_back(x_shape[dim]); } diff --git a/mindspore/python/mindspore/ops/operations/__init__.py b/mindspore/python/mindspore/ops/operations/__init__.py index 97d8039a2b9..cf757b17142 100644 --- a/mindspore/python/mindspore/ops/operations/__init__.py +++ b/mindspore/python/mindspore/ops/operations/__init__.py @@ -524,7 +524,7 @@ __all__ = [ "Custom", "LuSolve", "CholeskyInverse", - "Cummax", + "Cummax" ] __sponge__ = [ diff --git a/mindspore/python/mindspore/ops/operations/math_ops.py b/mindspore/python/mindspore/ops/operations/math_ops.py index e1e6a350f5d..73e991f39f0 100644 --- a/mindspore/python/mindspore/ops/operations/math_ops.py +++ b/mindspore/python/mindspore/ops/operations/math_ops.py @@ -6873,7 +6873,7 @@ class CompareAndBitpack(Primitive): ValueError: If the innermost dimension of `x`'s shape is not disvisible by 8. Supported Platforms: - ``Ascend`` ``CPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8]), mindspore.float32) diff --git a/tests/st/ops/gpu/test_compare_and_bitpack_op.py b/tests/st/ops/gpu/test_compare_and_bitpack_op.py new file mode 100644 index 00000000000..dcdff6d19e4 --- /dev/null +++ b/tests/st/ops/gpu/test_compare_and_bitpack_op.py @@ -0,0 +1,54 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest +import mindspore.nn as nn +import mindspore.context as context +from mindspore import Tensor +from mindspore.ops.operations.math_ops import CompareAndBitpack +import mindspore.common.dtype as mstype + + +class NetCompareAndBitpack(nn.Cell): + def __init__(self): + super(NetCompareAndBitpack, self).__init__() + self.compare_and_bitpack = CompareAndBitpack() + + def construct(self, x, threshold): + return self.compare_and_bitpack(x, threshold) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_compare_and_bitpack_graph(): + """ + Feature: Compare and bitpack + Description: test case for CompareAndBitpack of float16 + Expectation: The result are as expected + """ + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + x = Tensor(np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=np.float16)) + threshold = Tensor(6, dtype=mstype.float16) + net = NetCompareAndBitpack() + output = net(x, threshold) + out_type = output.asnumpy().dtype + out_expect = np.array([3], dtype=np.uint8) + diff0 = output.asnumpy() - out_expect + error0 = np.zeros(shape=out_expect.shape) + assert np.all(diff0 == error0) + assert output.shape == out_expect.shape + assert out_type == 'uint8'