From 5f8ed2bddfb4f8a395189538cc8632a4b81f02a2 Mon Sep 17 00:00:00 2001 From: BokaiLi Date: Thu, 15 Sep 2022 15:02:20 +0800 Subject: [PATCH] TileGpuKernelMod support dynamic shape --- .../gpu/kernel/arrays/tile_gpu_kernel.cc | 192 ++++++++++++++++-- .../gpu/kernel/arrays/tile_gpu_kernel.h | 100 +++------ 2 files changed, 200 insertions(+), 92 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/tile_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/tile_gpu_kernel.cc index c4020e13000..3ecc18ad143 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/tile_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/tile_gpu_kernel.cc @@ -14,33 +14,185 @@ * limitations under the License. */ +#include +#include #include "plugin/device/gpu/kernel/arrays/tile_gpu_kernel.h" #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/complex.h" namespace mindspore { namespace kernel { +namespace { +constexpr size_t kStaticInputNum = 1; +constexpr size_t kDynInputNum = 2; +constexpr size_t kTileOutputsNum = 1; +constexpr size_t kIndex0 = 0; +constexpr size_t kIndex1 = 1; +} // namespace +bool TileGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) { + MS_EXCEPTION_IF_NULL(base_operator); + auto prim = base_operator->GetPrim(); + MS_EXCEPTION_IF_NULL(prim); + kernel_name_ = base_operator->name(); + size_t input_num = inputs.size(); + if (input_num == kStaticInputNum) { + is_dynamic_case_ = false; + } else if (input_num == kDynInputNum) { + is_dynamic_case_ = true; + } else { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', the number of inputs must be 1 or 2, but got " << input_num; + return false; + } + + auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs); + auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport()); + if (!is_match) { + MS_LOG(ERROR) << "For '" << kernel_name_ << "', it does not support this kernel type: " << kernel_attr; + return false; + } + kernel_func_ = func_list_[index].second; + return true; +} + +int TileGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &inputsOnHost) { + workspace_size_list_.clear(); + if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) { + return ret; + } + auto input_shape = inputs[kIndex0]->GetShapeVector(); + auto output_shape = outputs[kIndex0]->GetShapeVector(); + input_shape_.clear(); + output_shape_.clear(); + std::transform(input_shape.cbegin(), input_shape.cend(), std::back_inserter(input_shape_), LongToSize); + std::transform(output_shape.cbegin(), output_shape.cend(), std::back_inserter(output_shape_), LongToSize); + is_null_input_ = + CHECK_SHAPE_NULL(input_shape_, kernel_name_, "input") || CHECK_SHAPE_NULL(output_shape_, kernel_name_, "output"); + if (is_null_input_) { + return true; + } + if (output_shape_.size() < kTileOutputsNum) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of output cannot be less than 1, but got " + << output_shape_.size(); + } + input_size_ = SizeOf(input_shape_); + if (output_shape_.size() > TILE_MAX_DIMENSION) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of output cannot be greater than " + << TILE_MAX_DIMENSION << ", but got " << output_shape_.size(); + } + shape_size_ = output_shape_.size(); + output_size_ = SizeOf(output_shape_); + if (!is_dynamic_case_) { + const std::string kAttrMultiples = "multiples"; + auto multi_attr = base_operator->GetPrim()->GetAttr(kAttrMultiples); + if (multi_attr == nullptr) { + return KRET_RESIZE_FAILED; + } + multiples = GetValue>(multi_attr); + } else { + GetDynamicAttrIntValue(inputs, kIndex1, inputsOnHost, kernel_name_, &multiples); + } + int64_t filling_value = static_cast(multiples.size()) - static_cast(input_shape_.size()); + (void)input_shape_.insert(input_shape_.begin(), filling_value, kIndex1); + workspace_size_list_.push_back(input_shape_.size() * sizeof(size_t)); + workspace_size_list_.push_back(output_shape_.size() * sizeof(size_t)); + return KRET_OK; +} + +template +bool TileGpuKernelMod::LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr) { + if (is_null_input_) { + return true; + } + T *input = GetDeviceAddress(inputs, kIndex0); + size_t *input_shape_ptr = GetDeviceAddress(workspace, kIndex0); + size_t *output_shape_ptr = GetDeviceAddress(workspace, kIndex1); + T *output = GetDeviceAddress(outputs, kIndex0); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(input_shape_ptr, &input_shape_[kIndex0], input_shape_.size() * sizeof(size_t), + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync input_shape_ failed") + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(output_shape_ptr, &output_shape_[kIndex0], output_shape_.size() * sizeof(size_t), + cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), + "cudaMemcpyAsync output_shape_ failed") + CalTile(output_size_, input_size_, shape_size_, input_shape_ptr, output_shape_ptr, input, output, + reinterpret_cast(stream_ptr)); + return true; +} + template using Complex = mindspore::utils::Complex; -MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), - TileGpuKernelMod, Complex) -MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), - TileGpuKernelMod, Complex) -MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - TileGpuKernelMod, double) -MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - TileGpuKernelMod, float) -MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - TileGpuKernelMod, half) -MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), - TileGpuKernelMod, int16_t) -MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - TileGpuKernelMod, int) -MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - TileGpuKernelMod, int64_t) -MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), - TileGpuKernelMod, int) -MS_REG_GPU_KERNEL_ONE(Tile, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), TileGpuKernelMod, - bool) +std::vector> TileGpuKernelMod::func_list_ = { + {KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), + &TileGpuKernelMod::LaunchKernel>}, + {KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), + &TileGpuKernelMod::LaunchKernel>}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), + &TileGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + &TileGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + &TileGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), + &TileGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), &TileGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + &TileGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), &TileGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), &TileGpuKernelMod::LaunchKernel}, + // For dynamic shape case: + {KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex64), + &TileGpuKernelMod::LaunchKernel>}, + {KernelAttr().AddInputAttr(kNumberTypeComplex128).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeComplex128), + &TileGpuKernelMod::LaunchKernel>}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), + &TileGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + &TileGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), + &TileGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), + &TileGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + &TileGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + &TileGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), + &TileGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), + &TileGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeComplex64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex64), + &TileGpuKernelMod::LaunchKernel>}, + {KernelAttr().AddInputAttr(kNumberTypeComplex128).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeComplex128), + &TileGpuKernelMod::LaunchKernel>}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), + &TileGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + &TileGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + &TileGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), + &TileGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + &TileGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), + &TileGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), + &TileGpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), + &TileGpuKernelMod::LaunchKernel}}; + +std::vector TileGpuKernelMod::GetOpSupport() { + std::vector support_list; + (void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list), + [](const std::pair &pair) { return pair.first; }); + return support_list; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Tile, TileGpuKernelMod); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/tile_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/tile_gpu_kernel.h index ca1961d306c..3079ceab9ef 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/tile_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/tile_gpu_kernel.h @@ -14,90 +14,43 @@ * limitations under the License. */ -#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TILE_GPU_KERNEL_H_ -#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TILE_GPU_KERNEL_H_ +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_TILE_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_TILE_GPU_KERNEL_H_ +#include #include +#include +#include #include "plugin/device/gpu/kernel/gpu_kernel.h" #include "plugin/device/gpu/kernel/gpu_kernel_factory.h" #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/tile_impl.cuh" namespace mindspore { namespace kernel { -template -class TileGpuKernelMod : public DeprecatedNativeGpuKernelMod { +class TileGpuKernelMod : public NativeGpuKernelMod { public: TileGpuKernelMod() { ResetResource(); } ~TileGpuKernelMod() override = default; bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *input = GetDeviceAddress(inputs, 0); - size_t *input_shape_ptr = GetDeviceAddress(workspace, 0); - size_t *output_shape_ptr = GetDeviceAddress(workspace, 1); - T *output = GetDeviceAddress(outputs, 0); + return kernel_func_(this, inputs, workspace, outputs, stream_ptr); + }; - CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, - cudaMemcpyAsync(input_shape_ptr, &input_shape_[0], input_shape_.size() * sizeof(size_t), - cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync input_shape_ failed"); - CHECK_CUDA_RET_WITH_EXCEPT( - kernel_node_, - cudaMemcpyAsync(output_shape_ptr, &output_shape_[0], output_shape_.size() * sizeof(size_t), - cudaMemcpyHostToDevice, reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync output_shape_ failed"); - CalTile(output_size_, input_size_, shape_size_, input_shape_ptr, output_shape_ptr, input, output, - reinterpret_cast(stream_ptr)); - return true; - } + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; - bool Init(const CNodePtr &kernel_node) override { - auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node); - kernel_node_ = kernel_node; - size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs must be 1, but got " << input_num; - } - size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs must be 1, but got " << output_num; - } - input_shape_ = common::AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); - output_shape_ = common::AnfAlgo::GetOutputInferShape(kernel_node, 0); - is_null_input_ = - CHECK_SHAPE_NULL(input_shape_, kernel_name, "input") || CHECK_SHAPE_NULL(output_shape_, kernel_name, "output"); - if (is_null_input_) { - InitSizeLists(); - return true; - } - if (output_shape_.size() < 1) { - MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of output cannot be less than 1, but got " - << output_shape_.size(); - } - input_size_ = SizeOf(input_shape_); - if (output_shape_.size() > TILE_MAX_DIMENSION) { - MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of output cannot be greater than " - << TILE_MAX_DIMENSION << ", but got " << output_shape_.size(); - } - shape_size_ = output_shape_.size(); - output_size_ = SizeOf(output_shape_); + int Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, const std::map &) override; - std::vector multiples = GetAttr>(kernel_node, "multiples"); - int64_t filling_value = static_cast(multiples.size()) - static_cast(input_shape_.size()); - // input_shape_.size() == output_shape_.size() == shape_size_ - (void)input_shape_.insert(input_shape_.begin(), filling_value, 1); - InitSizeLists(); - return true; - } + std::vector GetOpSupport() override; - void ResetResource() noexcept override { + void ResetResource() noexcept { input_size_ = 1; output_size_ = 1; shape_size_ = 1; is_null_input_ = false; + is_dynamic_case_ = false; input_shape_.clear(); output_shape_.clear(); input_size_list_.clear(); @@ -105,22 +58,25 @@ class TileGpuKernelMod : public DeprecatedNativeGpuKernelMod { workspace_size_list_.clear(); } - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_ * sizeof(T)); - workspace_size_list_.push_back(input_shape_.size() * sizeof(size_t)); - workspace_size_list_.push_back(output_shape_.size() * sizeof(size_t)); - output_size_list_.push_back(output_size_ * sizeof(T)); - } - private: + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs, void *stream_ptr); size_t input_size_; size_t output_size_; size_t shape_size_; - bool is_null_input_; ShapeVector input_shape_; ShapeVector output_shape_; + bool is_null_input_; + bool is_dynamic_case_; + std::vector multiples; + std::string kernel_name_; + using TileLaunchFunc = + std::function &, + const std::vector &, const std::vector &, void *)>; + static std::vector> func_list_; + TileLaunchFunc kernel_func_; }; } // namespace kernel } // namespace mindspore -#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_TILE_GPU_KERNEL_H_ +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_ARRAYS_TILE_GPU_KERNEL_H_