From e4e4e2359e6cfef8a9a4dff743bb2a75d9fd0862 Mon Sep 17 00:00:00 2001 From: huoxinyou Date: Thu, 15 Sep 2022 18:06:05 +0800 Subject: [PATCH] transpose support input perm is tensor --- .../device/cpu/kernel/transpose_cpu_kernel.cc | 57 ++++- .../device/cpu/kernel/transpose_cpu_kernel.h | 48 +++- .../gpu/kernel/arrays/transpose_gpu_kernel.cc | 207 +++++++++++++++--- .../gpu/kernel/arrays/transpose_gpu_kernel.h | 148 +++---------- mindspore/core/ops/transpose.cc | 154 +++++++------ mindspore/core/ops/transpose.h | 2 - 6 files changed, 390 insertions(+), 226 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/transpose_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/transpose_cpu_kernel.cc index 923fec7c3f2..6edfbbb3248 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/transpose_cpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/transpose_cpu_kernel.cc @@ -26,6 +26,7 @@ namespace mindspore { namespace kernel { namespace { constexpr size_t kTransposeInputsNum = 1; +constexpr size_t kDynamicPermInputNum = 2; constexpr size_t kTransposeOutputsNum = 1; constexpr size_t kIndex0 = 0; constexpr size_t kIndex1 = 1; @@ -41,12 +42,7 @@ using complex128 = std::complex; constexpr size_t kMaxTransposeSerialSize = 50331648; } // namespace -void TransposeFwdCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { - MS_EXCEPTION_IF_NULL(kernel_node); - kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); - input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); - output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); - perm_ = common::AnfAlgo::GetNodeAttr>(kernel_node, "perm"); +void TransposeFwdCpuKernelMod::CheckPermValue() { for (auto &p : perm_) { p = (p >= 0) ? p : (SizeToLong(perm_.size()) + p); if (p < 0) { @@ -54,20 +50,45 @@ void TransposeFwdCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { << (perm_.size() - 1) << "], but got " << p; } } - dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); if (!IsDynamicRank(input_shape_) && perm_.size() != input_shape_.size()) { MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the perm's size must be equal to input_shape's size, but got " << perm_.size() << " vs " << input_shape_.size(); } + + if (perm_.size() > MAX_TRANSPOSE_DIM_SIZE) { + MS_LOG(EXCEPTION) << "Transpose support max dimension is " << MAX_TRANSPOSE_DIM_SIZE << "D, but got " + << perm_.size() << "D."; + } +} + +template +void TransposeFwdCpuKernelMod::InitPerm(const std::vector &inputs) { + auto cnode = cnode_ptr_.lock(); + auto perm_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, kIndex1); + auto perm_ptr = static_cast(inputs[kIndex1]->addr); + std::vector perm{perm_ptr, perm_ptr + perm_shape[0]}; + (void)std::transform(perm.begin(), perm.end(), std::back_inserter(perm_), + [](const T &value) { return static_cast(value); }); + CheckPermValue(); +} + +void TransposeFwdCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { + MS_EXCEPTION_IF_NULL(kernel_node); + cnode_ptr_ = kernel_node; + kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); + input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0); + output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0); + size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1 && input_num != kDynamicPermInputNum) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 1 or " << kDynamicPermInputNum + << ", but got " << input_num; + } if (output_shape_.size() != input_shape_.size()) { MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the output_shape's size must be equal to input_shape's size, but got " << output_shape_.size() << " vs " << input_shape_.size(); } - if (perm_.size() > MAX_TRANSPOSE_DIM_SIZE) { - MS_LOG(EXCEPTION) << "Transpose support max dimension is " << MAX_TRANSPOSE_DIM_SIZE << "D, but got " - << perm_.size() << "D."; - } + dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0); num_axes_ = input_shape_.size(); if (num_axes_ == 0) { MS_LOG(EXCEPTION) << "Transpose's input shape is empty."; @@ -100,13 +121,25 @@ void TransposeFwdCpuKernelMod::InitKernel(const CNodePtr &kernel_node) { } else { MS_LOG(EXCEPTION) << "Unsupported input data type: " << dtype_; } + if (input_num == kDynamicPermInputNum) { + perm_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex1); + return; + } + perm_ = common::AnfAlgo::GetNodeAttr>(kernel_node, "perm"); + CheckPermValue(); } bool TransposeFwdCpuKernelMod::Launch(const std::vector &inputs, const std::vector &, const std::vector &outputs) { - CHECK_KERNEL_INPUTS_NUM(inputs.size(), kTransposeInputsNum, kernel_name_); CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kTransposeOutputsNum, kernel_name_); + if (inputs.size() == kDynamicPermInputNum) { + if (perm_type_ == kNumberTypeInt32) { + InitPerm(inputs); + } else { + InitPerm(inputs); + } + } launch_func_(this, inputs, outputs); return true; } diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/transpose_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/transpose_cpu_kernel.h index 6d67ef8ffef..97429cf22d6 100644 --- a/mindspore/ccsrc/plugin/device/cpu/kernel/transpose_cpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/transpose_cpu_kernel.h @@ -51,7 +51,47 @@ class TransposeFwdCpuKernelMod : public DeprecatedNativeCpuKernelMod { KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), - KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128)}; + KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), + KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), + KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8), + KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16), + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64), + KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8), + KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16), + KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32), + KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64), + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16), + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64), + KernelAttr() + .AddInputAttr(kNumberTypeComplex64) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeComplex64), + KernelAttr() + .AddInputAttr(kNumberTypeComplex128) + .AddInputAttr(kNumberTypeInt32) + .AddOutputAttr(kNumberTypeComplex128), + KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool), + KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8), + KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16), + KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32), + KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), + KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8), + KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16), + KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32), + KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64), + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16), + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32), + KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64), + KernelAttr() + .AddInputAttr(kNumberTypeComplex64) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeComplex64), + KernelAttr() + .AddInputAttr(kNumberTypeComplex128) + .AddInputAttr(kNumberTypeInt64) + .AddOutputAttr(kNumberTypeComplex128)}; return support_list; } @@ -77,9 +117,15 @@ class TransposeFwdCpuKernelMod : public DeprecatedNativeCpuKernelMod { template void TransposeDims(const T *in_data, T *out_data, int64_t task_id, int64_t thread_num) const; + void CheckPermValue(); + + template + void InitPerm(const std::vector &inputs); + std::vector input_shape_; std::vector output_shape_; TypeId dtype_{kTypeUnknown}; + TypeId perm_type_{kNumberTypeInt64}; std::vector perm_; size_t num_axes_{0}; size_t data_num_{0}; diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/transpose_gpu_kernel.cc b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/transpose_gpu_kernel.cc index ee4176bcbb6..f29631924dd 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/transpose_gpu_kernel.cc +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/transpose_gpu_kernel.cc @@ -22,33 +22,184 @@ namespace kernel { template using Complex = mindspore::utils::Complex; -MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64), - TransposeFwdGpuKernelMod, Complex) -MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128), - TransposeFwdGpuKernelMod, Complex) -MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool), - TransposeFwdGpuKernelMod, bool) -MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64), - TransposeFwdGpuKernelMod, double) -MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - TransposeFwdGpuKernelMod, float) -MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - TransposeFwdGpuKernelMod, half) -MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), - TransposeFwdGpuKernelMod, int64_t) -MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - TransposeFwdGpuKernelMod, int) -MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), - TransposeFwdGpuKernelMod, int16_t) -MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), - TransposeFwdGpuKernelMod, int8_t) -MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64), - TransposeFwdGpuKernelMod, uint64_t) -MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32), - TransposeFwdGpuKernelMod, uint32_t) -MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16), - TransposeFwdGpuKernelMod, uint16_t) -MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), - TransposeFwdGpuKernelMod, uint8_t) +constexpr size_t kDynamicPermInputNum = 2; +constexpr size_t kDimSize4 = 4; +constexpr size_t kAxisZero = 0; +constexpr size_t kAxis1st = 1; +constexpr size_t kAxis2nd = 2; +constexpr size_t kAxis3rd = 3; +constexpr size_t kAxisIndexZero = 0; +constexpr size_t kAxisIndex1st = 1; +constexpr size_t kAxisIndex2nd = 2; +constexpr size_t kAxisIndex3rd = 3; + +#define STATIC_REGISTER(INPUTX, OUTPUT, T) \ + { KernelAttr().AddInputAttr(INPUTX).AddOutputAttr(OUTPUT), &TransposeGpuKernelMod::LaunchKernel } + +#define DYN_REGISTER(INPUTX, PERM, OUTPUT, T) \ + { \ + KernelAttr().AddInputAttr(INPUTX).AddInputAttr(PERM).AddOutputAttr(OUTPUT), \ + &TransposeGpuKernelMod::LaunchKernel \ + } + +const std::vector> &TransposeGpuKernelMod::GetFuncList() + const { + static const std::vector> func_list = { + STATIC_REGISTER(kNumberTypeComplex64, kNumberTypeComplex64, Complex), + STATIC_REGISTER(kNumberTypeComplex128, kNumberTypeComplex128, Complex), + STATIC_REGISTER(kNumberTypeBool, kNumberTypeBool, bool), + STATIC_REGISTER(kNumberTypeFloat64, kNumberTypeFloat64, double), + STATIC_REGISTER(kNumberTypeFloat32, kNumberTypeFloat32, float), + STATIC_REGISTER(kNumberTypeFloat16, kNumberTypeFloat16, half), + STATIC_REGISTER(kNumberTypeInt64, kNumberTypeInt64, int64_t), + STATIC_REGISTER(kNumberTypeInt32, kNumberTypeInt32, int32_t), + STATIC_REGISTER(kNumberTypeInt16, kNumberTypeInt16, int16_t), + STATIC_REGISTER(kNumberTypeInt8, kNumberTypeInt8, int8_t), + STATIC_REGISTER(kNumberTypeUInt8, kNumberTypeUInt8, uint8_t), + STATIC_REGISTER(kNumberTypeUInt16, kNumberTypeUInt16, uint16_t), + STATIC_REGISTER(kNumberTypeUInt32, kNumberTypeUInt32, uint32_t), + STATIC_REGISTER(kNumberTypeUInt64, kNumberTypeUInt64, uint64_t), + DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, kNumberTypeComplex64, Complex), + DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, kNumberTypeComplex128, Complex), + DYN_REGISTER(kNumberTypeBool, kNumberTypeInt32, kNumberTypeBool, bool), + DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, kNumberTypeFloat64, double), + DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, kNumberTypeFloat32, float), + DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, kNumberTypeFloat16, half), + DYN_REGISTER(kNumberTypeInt64, kNumberTypeInt32, kNumberTypeInt64, int64_t), + DYN_REGISTER(kNumberTypeInt32, kNumberTypeInt32, kNumberTypeInt32, int32_t), + DYN_REGISTER(kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt16, int16_t), + DYN_REGISTER(kNumberTypeInt8, kNumberTypeInt32, kNumberTypeInt8, int8_t), + DYN_REGISTER(kNumberTypeUInt8, kNumberTypeInt32, kNumberTypeUInt8, uint8_t), + DYN_REGISTER(kNumberTypeUInt16, kNumberTypeInt32, kNumberTypeUInt16, uint16_t), + DYN_REGISTER(kNumberTypeUInt32, kNumberTypeInt32, kNumberTypeUInt32, uint32_t), + DYN_REGISTER(kNumberTypeUInt64, kNumberTypeInt32, kNumberTypeUInt64, uint64_t), + DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, kNumberTypeComplex64, Complex), + DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, kNumberTypeComplex128, Complex), + DYN_REGISTER(kNumberTypeBool, kNumberTypeInt64, kNumberTypeBool, bool), + DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, kNumberTypeFloat64, double), + DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, kNumberTypeFloat32, float), + DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, kNumberTypeFloat16, half), + DYN_REGISTER(kNumberTypeInt64, kNumberTypeInt64, kNumberTypeInt64, int64_t), + DYN_REGISTER(kNumberTypeInt32, kNumberTypeInt64, kNumberTypeInt32, int32_t), + DYN_REGISTER(kNumberTypeInt16, kNumberTypeInt64, kNumberTypeInt16, int16_t), + DYN_REGISTER(kNumberTypeInt8, kNumberTypeInt64, kNumberTypeInt8, int8_t), + DYN_REGISTER(kNumberTypeUInt8, kNumberTypeInt64, kNumberTypeUInt8, uint8_t), + DYN_REGISTER(kNumberTypeUInt16, kNumberTypeInt64, kNumberTypeUInt16, uint16_t), + DYN_REGISTER(kNumberTypeUInt32, kNumberTypeInt64, kNumberTypeUInt32, uint32_t), + DYN_REGISTER(kNumberTypeUInt64, kNumberTypeInt64, kNumberTypeUInt64, uint64_t), + }; + return func_list; +} + +template +bool TransposeGpuKernelMod::LaunchKernel(const std::vector &inputs, + const std::vector &workspace, + const std::vector &outputs) { + T *input = GetDeviceAddress(inputs, 0); + T *output = GetDeviceAddress(outputs, 0); + size_t *input_shape = GetDeviceAddress(workspace, 0); + size_t *input_axis = GetDeviceAddress(workspace, 1); + + if (is_dynamic_perm_ && !get_dynamic_perm_value_) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', fail to get value of the dynamic perm!"; + } + + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr_)), + "cudaMemcpyAsync input_shape failed"); + CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE( + cudaMemcpyAsync(input_axis, &input_perm_[0], workspace_size_, cudaMemcpyHostToDevice, + reinterpret_cast(stream_ptr_)), + "cudaMemcpyAsync input_axis failed"); + size_t size = SizeOf(input_shape_); + size_t *h_input_shape = reinterpret_cast(&input_shape_[0]); + size_t *h_input_axis = &input_perm_[0]; + if (shape_size_ == kDimSize4 && h_input_axis[kAxisIndexZero] == kAxisZero && + h_input_axis[kAxisIndex1st] == kAxis3rd && h_input_axis[kAxisIndex2nd] == kAxis1st && + h_input_axis[kAxisIndex3rd] == kAxis2nd) { + // nhwc->nchw: 0,3,1,2 + CalNHWC2NCHWInterface(size, shape_size_, input, h_input_shape, h_input_axis, input_shape, input_axis, output, + reinterpret_cast(stream_ptr_)); + } else if (shape_size_ == kDimSize4 && h_input_axis[kAxisIndexZero] == kAxisZero && + h_input_axis[kAxisIndex1st] == kAxis2nd && h_input_axis[kAxisIndex2nd] == kAxis3rd && + h_input_axis[kAxisIndex3rd] == kAxis1st) { + // nchw->nhwc: 0,2,3,1 + CalNCHW2NHWCInterface(size, shape_size_, input, h_input_shape, h_input_axis, input_shape, input_axis, output, + reinterpret_cast(stream_ptr_)); + } else { + CalTranspose(size, input, input_shape, input_axis, shape_size_, output, + reinterpret_cast(stream_ptr_)); + } + return true; +} + +void TransposeGpuKernelMod::GetPermValue(const std::vector &perm) { + for (size_t j = 0; j < perm.size(); j++) { + auto p = (perm[j] >= 0) ? perm[j] : (perm.size() + perm[j]); + if (p < 0) { + MS_LOG(EXCEPTION) << "the perm value must be in [-" << perm.size() << ", " << (perm.size() - 1) << "], but got " + << perm; + } + input_perm_.push_back(p); + } +} + +bool TransposeGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) { + if (!MatchKernelFunc(base_operator, inputs, outputs)) { + return false; + } + size_t input_num = inputs.size(); + size_t output_num = outputs.size(); + kernel_name_ = base_operator->name(); + if (input_num != 1 && input_num != kDynamicPermInputNum) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 1 or " << kDynamicPermInputNum + << ", but got " << input_num; + } + if (output_num != 1) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of outputs must be 1, but got " << output_num; + } + if (input_num == kDynamicPermInputNum) { + is_dynamic_perm_ = true; + return true; + } + + auto attr = base_operator->GetPrim()->GetAttr(kAttrPerm); + if (attr == nullptr) { + MS_LOG(ERROR) << "The attr \"perm\" is not found in kernel 'Transpose'."; + return false; + } + auto perm = GetValue>(attr); + GetPermValue(perm); + return true; +} + +int TransposeGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &inputsOnHost) { + std::vector perm; + if (GetDynamicAttrIntValue(inputs, kAxisIndex1st, inputsOnHost, kernel_name_, &perm)) { + GetPermValue(perm); + get_dynamic_perm_value_ = true; + } + if (int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) { + return ret; + } + + input_shape_ = inputs[kAxisIndexZero]->GetDeviceShapeAdaptively(); + shape_size_ = input_shape_.size(); + if (shape_size_ > TRANSPOSE_MAX_DIMENSION) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of output cannot be greater than " + << TRANSPOSE_MAX_DIMENSION << ", but got " << shape_size_; + } + + workspace_size_ = shape_size_ * sizeof(size_t); + workspace_size_list_.push_back(workspace_size_); + workspace_size_list_.push_back(workspace_size_); + return KRET_OK; +} + +MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Transpose, TransposeGpuKernelMod); } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/transpose_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/transpose_gpu_kernel.h index 54c1fd4e815..b59d61646d1 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/transpose_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/transpose_gpu_kernel.h @@ -19,138 +19,52 @@ #include #include +#include +#include #include "plugin/device/gpu/kernel/gpu_kernel.h" #include "plugin/device/gpu/kernel/gpu_kernel_factory.h" #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cuh" #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl_opt.cuh" namespace mindspore { namespace kernel { -constexpr size_t kDimSize4 = 4; -constexpr size_t kAxisZero = 0; -constexpr size_t kAxis1st = 1; -constexpr size_t kAxis2nd = 2; -constexpr size_t kAxis3rd = 3; -constexpr size_t kAxisIndexZero = 0; -constexpr size_t kAxisIndex1st = 1; -constexpr size_t kAxisIndex2nd = 2; -constexpr size_t kAxisIndex3rd = 3; - -template -class TransposeFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod { +class TransposeGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelper { public: - TransposeFwdGpuKernelMod() { ResetResource(); } - ~TransposeFwdGpuKernelMod() = default; + TransposeGpuKernelMod() = default; + ~TransposeGpuKernelMod() override = default; + + const std::vector> &GetFuncList() const override; + + std::vector GetOpSupport() override { return OpSupport(); } + + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; + + int Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &inputsOnHost) override; bool Launch(const std::vector &inputs, const std::vector &workspace, const std::vector &outputs, void *stream_ptr) override { - if (is_null_input_) { - return true; - } - T *input = GetDeviceAddress(inputs, 0); - T *output = GetDeviceAddress(outputs, 0); - size_t *input_shape = GetDeviceAddress(workspace, 0); - size_t *input_axis = GetDeviceAddress(workspace, 1); - CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, - cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice, - reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync input_shape failed"); - CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_, - cudaMemcpyAsync(input_axis, &input_axis_[0], workspace_size_, cudaMemcpyHostToDevice, - reinterpret_cast(stream_ptr)), - "cudaMemcpyAsync input_axis failed"); - size_t size = input_size_ / sizeof(T); - - size_t *h_input_shape = reinterpret_cast(&input_shape_[0]); - size_t *h_input_axis = &input_axis_[0]; - if (shape_size_ == kDimSize4 && h_input_axis[kAxisIndexZero] == kAxisZero && - h_input_axis[kAxisIndex1st] == kAxis3rd && h_input_axis[kAxisIndex2nd] == kAxis1st && - h_input_axis[kAxisIndex3rd] == kAxis2nd) { - // nhwc->nchw: 0,3,1,2 - CalNHWC2NCHWInterface(size, shape_size_, input, h_input_shape, h_input_axis, input_shape, input_axis, output, - reinterpret_cast(stream_ptr)); - } else if (shape_size_ == kDimSize4 && h_input_axis[kAxisIndexZero] == kAxisZero && - h_input_axis[kAxisIndex1st] == kAxis2nd && h_input_axis[kAxisIndex2nd] == kAxis3rd && - h_input_axis[kAxisIndex3rd] == kAxis1st) { - // nchw->nhwc: 0,2,3,1 - CalNCHW2NHWCInterface(size, shape_size_, input, h_input_shape, h_input_axis, input_shape, input_axis, output, - reinterpret_cast(stream_ptr)); - } else { - CalTranspose(size, input, input_shape, input_axis, shape_size_, output, - reinterpret_cast(stream_ptr)); - } - return true; - } - - bool Init(const CNodePtr &kernel_node) override { - auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node); - kernel_node_ = kernel_node; - size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs must be 1, but got " << input_num; - } - size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs must be 1, but got " << output_num; - } - input_shape_ = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 0); - is_null_input_ = CHECK_SHAPE_NULL(input_shape_, kernel_name, "input"); - if (is_null_input_) { - InitSizeLists(); - return true; - } - shape_size_ = input_shape_.size(); - if (shape_size_ > TRANSPOSE_MAX_DIMENSION) { - MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of output cannot be greater than " - << TRANSPOSE_MAX_DIMENSION << ", but got " << shape_size_; - } - - input_size_ = sizeof(T) * SizeOf(input_shape_); - output_size_ = input_size_; - std::vector perm = GetAttr>(kernel_node, "perm"); - for (size_t j = 0; j < perm.size(); j++) { - auto p = (perm[j] >= 0) ? perm[j] : (perm.size() + perm[j]); - if (p < 0) { - MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the perm value must be in [-" << perm.size() << ", " - << (perm.size() - 1) << "], but got " << perm; - } - input_axis_.push_back(p); - } - InitSizeLists(); - return true; - } - - void ResetResource() noexcept override { - shape_size_ = 0; - input_size_ = 0; - output_size_ = 0; - workspace_size_ = 0; - is_null_input_ = false; - input_shape_.clear(); - input_axis_.clear(); - input_size_list_.clear(); - output_size_list_.clear(); - workspace_size_list_.clear(); - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - workspace_size_ = shape_size_ * sizeof(size_t); - workspace_size_list_.push_back(workspace_size_); - workspace_size_list_.push_back(workspace_size_); - return; + stream_ptr_ = stream_ptr; + return kernel_func_(this, inputs, workspace, outputs); } private: - std::vector input_shape_; - std::vector input_axis_; + template + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); - size_t shape_size_; - size_t input_size_; - size_t output_size_; - size_t workspace_size_; + void GetPermValue(const std::vector &perm); + + void *stream_ptr_{nullptr}; + std::vector input_shape_; + std::vector input_perm_; + + size_t shape_size_{0}; + size_t workspace_size_{0}; bool is_null_input_; + bool is_dynamic_perm_{false}; + bool get_dynamic_perm_value_{false}; }; } // namespace kernel } // namespace mindspore diff --git a/mindspore/core/ops/transpose.cc b/mindspore/core/ops/transpose.cc index 4d88e8f7553..8b5b91c6b28 100644 --- a/mindspore/core/ops/transpose.cc +++ b/mindspore/core/ops/transpose.cc @@ -16,94 +16,116 @@ #include "ops/transpose.h" #include +#include +#include #include #include #include "ops/op_utils.h" +#include "abstract/ops/op_infer.h" #include "utils/check_convert_utils.h" #include "abstract/ops/primitive_infer_map.h" #include "mindapi/src/helper.h" namespace mindspore { namespace ops { -namespace { -abstract::ShapePtr TransposeInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(primitive); - auto op_name = primitive->name(); - auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - (void)CheckAndConvertUtils::CheckInteger("input_x size", SizeToLong(x_shape.size()), kGreaterThan, 0, op_name); - ShapeVector p_value; - ShapeVector p_value_raw; - if (x_shape[0] == 0) { - MS_EXCEPTION(ValueError) << "For 'Transpose', first dim of input_x's shape can not be 0, but got 0."; - } +MIND_API_OPERATOR_IMPL(Transpose, BaseOperator); + +bool CheckAndGetPermValue(const std::vector &input_args, ShapeVector *perm_value, + const PrimitivePtr &primitive) { + MS_EXCEPTION_IF_NULL(perm_value); + bool is_dynamic = false; + const std::string &op_name = primitive->name(); + if (input_args.size() == 1) { if (!primitive->HasAttr("perm")) { MS_EXCEPTION(ValueError) << "For '" << op_name << "', the value of 'input_perm' is necessary, but missing it."; } ValuePtr perm = primitive->GetAttr("perm"); MS_EXCEPTION_IF_NULL(perm); - auto perm_val = perm->cast(); - MS_EXCEPTION_IF_NULL(perm_val); - auto perm_val_data = perm_val->value(); - (void)std::transform(std::begin(perm_val_data), std::end(perm_val_data), std::back_inserter(p_value_raw), - [](const ValuePtr &e) -> int64_t { return GetValue(e); }); - } else { - auto perm_value = input_args[1]->BuildValue(); - MS_EXCEPTION_IF_NULL(perm_value); - if (perm_value->isa()) { - p_value_raw = CheckAndConvertUtils::CheckTensorIntValue("perm", perm_value, op_name); + *perm_value = CheckAndConvertUtils::CheckTupleInt("perm", perm, op_name); + return is_dynamic; + } + + auto input_value = input_args[kInputIndex1]->BuildValue(); + if (input_args[kInputIndex1]->isa()) { + *perm_value = CheckAndConvertUtils::CheckTupleInt("perm", input_value, op_name); + } else if (input_args[kInputIndex1]->isa()) { + if (input_value->isa()) { + *perm_value = CheckAndConvertUtils::CheckTensorIntValue("perm", input_value, op_name); } else { - p_value_raw = CheckAndConvertUtils::CheckTupleInt("input[perm]", perm_value, op_name); + is_dynamic = true; + auto perm_shape = CheckAndConvertUtils::GetTensorInputShape("perm", input_args, 1); + if (perm_shape->shape().size() != 1) { + MS_EXCEPTION(ValueError) << "For 'transpose perm', " << op_name << " must be 1-D, but got" + << perm_shape->shape().size() << "-D."; + } } + } else { + MS_LOG(EXCEPTION) << "For '" << op_name + << "', the second input type should be tensor or scalar, but got invalid abstract type:" + << input_args[kInputIndex1]->type_name() << "."; } - for (auto p : p_value_raw) { - p = (p >= 0) ? p : (p_value_raw.size() + p); - p_value.emplace_back(p); - } - if (!IsDynamicRank(x_shape) && x_shape.size() != p_value.size()) { - MS_EXCEPTION(ValueError) << "For '" << op_name << "', the dim of 'input_x' and 'input_perm' must be equal, but got " - << x_shape.size() << " and " << p_value.size() << " respectively."; - } - for (auto i : p_value) { - (void)CheckAndConvertUtils::CheckInteger("perm element", i, kLessThan, SizeToLong(p_value.size()), op_name); - } - std::vector tmp(p_value); - for (auto it = tmp.begin(); it != tmp.end();) { - auto dim = *it; - if (!tmp.empty()) { - it = tmp.erase(it); - } - if (std::find(tmp.begin(), tmp.end(), dim) != tmp.end()) { - MS_EXCEPTION(ValueError) << "For '" << op_name << "', the value of perm is wrong."; - } - } - if (IsDynamicRank(x_shape)) { - return std::make_shared(std::vector{UNKNOWN_RANK}); - } - std::vector in_shape(p_value); - (void)std::transform(in_shape.begin(), in_shape.end(), in_shape.begin(), [x_shape](size_t i) { return x_shape[i]; }); - return std::make_shared(in_shape); + return is_dynamic; } -TypePtr TransposeInferType(const PrimitivePtr &prim, const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(prim); - return CheckAndConvertUtils::CheckSubClass("input_x", input_args[0]->BuildType(), {kTensorType}, prim->name()); -} -} // namespace +class TransposeInfer : public abstract::OpInferBase { + public: + BaseShapePtr InferShape(const PrimitivePtr &primitive, + const std::vector &input_args) const override { + MS_EXCEPTION_IF_NULL(primitive); + auto op_name = primitive->name(); + auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; + (void)CheckAndConvertUtils::CheckInteger("input_x size", SizeToLong(x_shape.size()), kGreaterThan, 0, op_name); + ShapeVector p_value; + ShapeVector p_value_raw; + if (x_shape[0] == 0) { + MS_EXCEPTION(ValueError) << "For 'Transpose', first dim of input_x's shape can not be 0, but got 0."; + } + if (IsDynamicRank(x_shape)) { + return std::make_shared(std::vector{UNKNOWN_RANK}); + } -MIND_API_OPERATOR_IMPL(Transpose, BaseOperator); -AbstractBasePtr TransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(primitive); - for (const auto &item : input_args) { - MS_EXCEPTION_IF_NULL(item); + bool perm_is_dynamic = CheckAndGetPermValue(input_args, &p_value_raw, primitive); + if (perm_is_dynamic) { + ShapeVector out_shape; + (void)out_shape.insert(out_shape.end(), x_shape.size(), -1); + return std::make_shared(out_shape); + } + + for (auto p : p_value_raw) { + p = (p >= 0) ? p : (p_value_raw.size() + p); + p_value.emplace_back(p); + } + if (!IsDynamicRank(x_shape) && x_shape.size() != p_value.size()) { + MS_EXCEPTION(ValueError) << "For '" << op_name << "', the dim of 'input_x' and 'perm' must be equal, but got " + << x_shape.size() << " and " << p_value.size() << " respectively."; + } + for (auto i : p_value) { + (void)CheckAndConvertUtils::CheckInteger("perm element", i, kLessThan, SizeToLong(p_value.size()), op_name); + } + std::vector tmp(p_value); + for (auto it = tmp.begin(); it != tmp.end();) { + auto dim = *it; + if (!tmp.empty()) { + it = tmp.erase(it); + } + if (std::find(tmp.begin(), tmp.end(), dim) != tmp.end()) { + MS_EXCEPTION(ValueError) << "For '" << op_name << "', the value of perm is wrong."; + } + } + std::vector in_shape(p_value); + (void)std::transform(in_shape.begin(), in_shape.end(), in_shape.begin(), + [x_shape](size_t i) { return x_shape[i]; }); + return std::make_shared(in_shape); } - CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputIndex1, primitive->name()); - auto type = TransposeInferType(primitive, input_args); - auto shape = TransposeInferShape(primitive, input_args); - return abstract::MakeAbstract(shape, type); -} -REGISTER_PRIMITIVE_EVAL_IMPL(Transpose, prim::kPrimTranspose, TransposeInfer, nullptr, true); + TypePtr InferType(const PrimitivePtr &prim, const std::vector &input_args) const override { + MS_EXCEPTION_IF_NULL(prim); + return CheckAndConvertUtils::CheckSubClass("input_x", input_args[0]->BuildType(), {kTensorType}, prim->name()); + } + + std::set GetValueDependArgIndices() const override { return {1}; } +}; +REGISTER_PRIMITIVE_OP_INFER_IMPL(Transpose, prim::kPrimTranspose, TransposeInfer, false); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/transpose.h b/mindspore/core/ops/transpose.h index 7166df93b87..6e4894978bc 100644 --- a/mindspore/core/ops/transpose.h +++ b/mindspore/core/ops/transpose.h @@ -35,8 +35,6 @@ class MIND_API Transpose : public BaseOperator { /// \brief Init. void Init() const {} }; -abstract::AbstractBasePtr TransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); } // namespace ops } // namespace mindspore