!41630 transpose support input_perm is tensor

Merge pull request !41630 from huoxinyou/0907transpose
This commit is contained in:
i-robot 2022-09-17 02:44:39 +00:00 committed by Gitee
commit 0ff063e2a7
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
6 changed files with 390 additions and 226 deletions

View File

@ -26,6 +26,7 @@ namespace mindspore {
namespace kernel {
namespace {
constexpr size_t kTransposeInputsNum = 1;
constexpr size_t kDynamicPermInputNum = 2;
constexpr size_t kTransposeOutputsNum = 1;
constexpr size_t kIndex0 = 0;
constexpr size_t kIndex1 = 1;
@ -41,12 +42,7 @@ using complex128 = std::complex<double>;
constexpr size_t kMaxTransposeSerialSize = 50331648;
} // namespace
void TransposeFwdCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
perm_ = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "perm");
void TransposeFwdCpuKernelMod::CheckPermValue() {
for (auto &p : perm_) {
p = (p >= 0) ? p : (SizeToLong(perm_.size()) + p);
if (p < 0) {
@ -54,20 +50,45 @@ void TransposeFwdCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
<< (perm_.size() - 1) << "], but got " << p;
}
}
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
if (!IsDynamicRank(input_shape_) && perm_.size() != input_shape_.size()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the perm's size must be equal to input_shape's size, but got "
<< perm_.size() << " vs " << input_shape_.size();
}
if (perm_.size() > MAX_TRANSPOSE_DIM_SIZE) {
MS_LOG(EXCEPTION) << "Transpose support max dimension is " << MAX_TRANSPOSE_DIM_SIZE << "D, but got "
<< perm_.size() << "D.";
}
}
template <typename T>
void TransposeFwdCpuKernelMod::InitPerm(const std::vector<kernel::AddressPtr> &inputs) {
auto cnode = cnode_ptr_.lock();
auto perm_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, kIndex1);
auto perm_ptr = static_cast<T *>(inputs[kIndex1]->addr);
std::vector<T> perm{perm_ptr, perm_ptr + perm_shape[0]};
(void)std::transform(perm.begin(), perm.end(), std::back_inserter(perm_),
[](const T &value) { return static_cast<int64_t>(value); });
CheckPermValue();
}
void TransposeFwdCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
cnode_ptr_ = kernel_node;
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1 && input_num != kDynamicPermInputNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 1 or " << kDynamicPermInputNum
<< ", but got " << input_num;
}
if (output_shape_.size() != input_shape_.size()) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the output_shape's size must be equal to input_shape's size, but got "
<< output_shape_.size() << " vs " << input_shape_.size();
}
if (perm_.size() > MAX_TRANSPOSE_DIM_SIZE) {
MS_LOG(EXCEPTION) << "Transpose support max dimension is " << MAX_TRANSPOSE_DIM_SIZE << "D, but got "
<< perm_.size() << "D.";
}
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
num_axes_ = input_shape_.size();
if (num_axes_ == 0) {
MS_LOG(EXCEPTION) << "Transpose's input shape is empty.";
@ -100,13 +121,25 @@ void TransposeFwdCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
} else {
MS_LOG(EXCEPTION) << "Unsupported input data type: " << dtype_;
}
if (input_num == kDynamicPermInputNum) {
perm_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex1);
return;
}
perm_ = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "perm");
CheckPermValue();
}
bool TransposeFwdCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> &,
const std::vector<kernel::AddressPtr> &outputs) {
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kTransposeInputsNum, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kTransposeOutputsNum, kernel_name_);
if (inputs.size() == kDynamicPermInputNum) {
if (perm_type_ == kNumberTypeInt32) {
InitPerm<int32_t>(inputs);
} else {
InitPerm<int64_t>(inputs);
}
}
launch_func_(this, inputs, outputs);
return true;
}

View File

@ -51,7 +51,47 @@ class TransposeFwdCpuKernelMod : public DeprecatedNativeCpuKernelMod {
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128)};
KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeComplex64),
KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt32)
.AddOutputAttr(kNumberTypeComplex128),
KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
KernelAttr()
.AddInputAttr(kNumberTypeComplex64)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex64),
KernelAttr()
.AddInputAttr(kNumberTypeComplex128)
.AddInputAttr(kNumberTypeInt64)
.AddOutputAttr(kNumberTypeComplex128)};
return support_list;
}
@ -77,9 +117,15 @@ class TransposeFwdCpuKernelMod : public DeprecatedNativeCpuKernelMod {
template <typename T>
void TransposeDims(const T *in_data, T *out_data, int64_t task_id, int64_t thread_num) const;
void CheckPermValue();
template <typename T>
void InitPerm(const std::vector<kernel::AddressPtr> &inputs);
std::vector<int64_t> input_shape_;
std::vector<int64_t> output_shape_;
TypeId dtype_{kTypeUnknown};
TypeId perm_type_{kNumberTypeInt64};
std::vector<int64_t> perm_;
size_t num_axes_{0};
size_t data_num_{0};

View File

@ -22,33 +22,184 @@ namespace kernel {
template <typename T>
using Complex = mindspore::utils::Complex<T>;
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
TransposeFwdGpuKernelMod, Complex<float>)
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
TransposeFwdGpuKernelMod, Complex<double>)
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
TransposeFwdGpuKernelMod, bool)
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
TransposeFwdGpuKernelMod, double)
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
TransposeFwdGpuKernelMod, float)
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
TransposeFwdGpuKernelMod, half)
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
TransposeFwdGpuKernelMod, int64_t)
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
TransposeFwdGpuKernelMod, int)
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
TransposeFwdGpuKernelMod, int16_t)
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
TransposeFwdGpuKernelMod, int8_t)
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
TransposeFwdGpuKernelMod, uint64_t)
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
TransposeFwdGpuKernelMod, uint32_t)
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
TransposeFwdGpuKernelMod, uint16_t)
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
TransposeFwdGpuKernelMod, uint8_t)
constexpr size_t kDynamicPermInputNum = 2;
constexpr size_t kDimSize4 = 4;
constexpr size_t kAxisZero = 0;
constexpr size_t kAxis1st = 1;
constexpr size_t kAxis2nd = 2;
constexpr size_t kAxis3rd = 3;
constexpr size_t kAxisIndexZero = 0;
constexpr size_t kAxisIndex1st = 1;
constexpr size_t kAxisIndex2nd = 2;
constexpr size_t kAxisIndex3rd = 3;
#define STATIC_REGISTER(INPUTX, OUTPUT, T) \
{ KernelAttr().AddInputAttr(INPUTX).AddOutputAttr(OUTPUT), &TransposeGpuKernelMod::LaunchKernel<T> }
#define DYN_REGISTER(INPUTX, PERM, OUTPUT, T) \
{ \
KernelAttr().AddInputAttr(INPUTX).AddInputAttr(PERM).AddOutputAttr(OUTPUT), \
&TransposeGpuKernelMod::LaunchKernel<T> \
}
const std::vector<std::pair<KernelAttr, TransposeGpuKernelMod::KernelRunFunc>> &TransposeGpuKernelMod::GetFuncList()
const {
static const std::vector<std::pair<KernelAttr, TransposeGpuKernelMod::KernelRunFunc>> func_list = {
STATIC_REGISTER(kNumberTypeComplex64, kNumberTypeComplex64, Complex<float>),
STATIC_REGISTER(kNumberTypeComplex128, kNumberTypeComplex128, Complex<double>),
STATIC_REGISTER(kNumberTypeBool, kNumberTypeBool, bool),
STATIC_REGISTER(kNumberTypeFloat64, kNumberTypeFloat64, double),
STATIC_REGISTER(kNumberTypeFloat32, kNumberTypeFloat32, float),
STATIC_REGISTER(kNumberTypeFloat16, kNumberTypeFloat16, half),
STATIC_REGISTER(kNumberTypeInt64, kNumberTypeInt64, int64_t),
STATIC_REGISTER(kNumberTypeInt32, kNumberTypeInt32, int32_t),
STATIC_REGISTER(kNumberTypeInt16, kNumberTypeInt16, int16_t),
STATIC_REGISTER(kNumberTypeInt8, kNumberTypeInt8, int8_t),
STATIC_REGISTER(kNumberTypeUInt8, kNumberTypeUInt8, uint8_t),
STATIC_REGISTER(kNumberTypeUInt16, kNumberTypeUInt16, uint16_t),
STATIC_REGISTER(kNumberTypeUInt32, kNumberTypeUInt32, uint32_t),
STATIC_REGISTER(kNumberTypeUInt64, kNumberTypeUInt64, uint64_t),
DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, kNumberTypeComplex64, Complex<float>),
DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, kNumberTypeComplex128, Complex<double>),
DYN_REGISTER(kNumberTypeBool, kNumberTypeInt32, kNumberTypeBool, bool),
DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, kNumberTypeFloat64, double),
DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, kNumberTypeFloat32, float),
DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, kNumberTypeFloat16, half),
DYN_REGISTER(kNumberTypeInt64, kNumberTypeInt32, kNumberTypeInt64, int64_t),
DYN_REGISTER(kNumberTypeInt32, kNumberTypeInt32, kNumberTypeInt32, int32_t),
DYN_REGISTER(kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt16, int16_t),
DYN_REGISTER(kNumberTypeInt8, kNumberTypeInt32, kNumberTypeInt8, int8_t),
DYN_REGISTER(kNumberTypeUInt8, kNumberTypeInt32, kNumberTypeUInt8, uint8_t),
DYN_REGISTER(kNumberTypeUInt16, kNumberTypeInt32, kNumberTypeUInt16, uint16_t),
DYN_REGISTER(kNumberTypeUInt32, kNumberTypeInt32, kNumberTypeUInt32, uint32_t),
DYN_REGISTER(kNumberTypeUInt64, kNumberTypeInt32, kNumberTypeUInt64, uint64_t),
DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, kNumberTypeComplex64, Complex<float>),
DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, kNumberTypeComplex128, Complex<double>),
DYN_REGISTER(kNumberTypeBool, kNumberTypeInt64, kNumberTypeBool, bool),
DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, kNumberTypeFloat64, double),
DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, kNumberTypeFloat32, float),
DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, kNumberTypeFloat16, half),
DYN_REGISTER(kNumberTypeInt64, kNumberTypeInt64, kNumberTypeInt64, int64_t),
DYN_REGISTER(kNumberTypeInt32, kNumberTypeInt64, kNumberTypeInt32, int32_t),
DYN_REGISTER(kNumberTypeInt16, kNumberTypeInt64, kNumberTypeInt16, int16_t),
DYN_REGISTER(kNumberTypeInt8, kNumberTypeInt64, kNumberTypeInt8, int8_t),
DYN_REGISTER(kNumberTypeUInt8, kNumberTypeInt64, kNumberTypeUInt8, uint8_t),
DYN_REGISTER(kNumberTypeUInt16, kNumberTypeInt64, kNumberTypeUInt16, uint16_t),
DYN_REGISTER(kNumberTypeUInt32, kNumberTypeInt64, kNumberTypeUInt32, uint32_t),
DYN_REGISTER(kNumberTypeUInt64, kNumberTypeInt64, kNumberTypeUInt64, uint64_t),
};
return func_list;
}
template <typename T>
bool TransposeGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) {
T *input = GetDeviceAddress<T>(inputs, 0);
T *output = GetDeviceAddress<T>(outputs, 0);
size_t *input_shape = GetDeviceAddress<size_t>(workspace, 0);
size_t *input_axis = GetDeviceAddress<size_t>(workspace, 1);
if (is_dynamic_perm_ && !get_dynamic_perm_value_) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', fail to get value of the dynamic perm!";
}
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr_)),
"cudaMemcpyAsync input_shape failed");
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
cudaMemcpyAsync(input_axis, &input_perm_[0], workspace_size_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr_)),
"cudaMemcpyAsync input_axis failed");
size_t size = SizeOf(input_shape_);
size_t *h_input_shape = reinterpret_cast<size_t *>(&input_shape_[0]);
size_t *h_input_axis = &input_perm_[0];
if (shape_size_ == kDimSize4 && h_input_axis[kAxisIndexZero] == kAxisZero &&
h_input_axis[kAxisIndex1st] == kAxis3rd && h_input_axis[kAxisIndex2nd] == kAxis1st &&
h_input_axis[kAxisIndex3rd] == kAxis2nd) {
// nhwc->nchw: 0,3,1,2
CalNHWC2NCHWInterface(size, shape_size_, input, h_input_shape, h_input_axis, input_shape, input_axis, output,
reinterpret_cast<cudaStream_t>(stream_ptr_));
} else if (shape_size_ == kDimSize4 && h_input_axis[kAxisIndexZero] == kAxisZero &&
h_input_axis[kAxisIndex1st] == kAxis2nd && h_input_axis[kAxisIndex2nd] == kAxis3rd &&
h_input_axis[kAxisIndex3rd] == kAxis1st) {
// nchw->nhwc: 0,2,3,1
CalNCHW2NHWCInterface(size, shape_size_, input, h_input_shape, h_input_axis, input_shape, input_axis, output,
reinterpret_cast<cudaStream_t>(stream_ptr_));
} else {
CalTranspose(size, input, input_shape, input_axis, shape_size_, output,
reinterpret_cast<cudaStream_t>(stream_ptr_));
}
return true;
}
void TransposeGpuKernelMod::GetPermValue(const std::vector<int64_t> &perm) {
for (size_t j = 0; j < perm.size(); j++) {
auto p = (perm[j] >= 0) ? perm[j] : (perm.size() + perm[j]);
if (p < 0) {
MS_LOG(EXCEPTION) << "the perm value must be in [-" << perm.size() << ", " << (perm.size() - 1) << "], but got "
<< perm;
}
input_perm_.push_back(p);
}
}
bool TransposeGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) {
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
return false;
}
size_t input_num = inputs.size();
size_t output_num = outputs.size();
kernel_name_ = base_operator->name();
if (input_num != 1 && input_num != kDynamicPermInputNum) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 1 or " << kDynamicPermInputNum
<< ", but got " << input_num;
}
if (output_num != 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of outputs must be 1, but got " << output_num;
}
if (input_num == kDynamicPermInputNum) {
is_dynamic_perm_ = true;
return true;
}
auto attr = base_operator->GetPrim()->GetAttr(kAttrPerm);
if (attr == nullptr) {
MS_LOG(ERROR) << "The attr \"perm\" is not found in kernel 'Transpose'.";
return false;
}
auto perm = GetValue<std::vector<int64_t>>(attr);
GetPermValue(perm);
return true;
}
int TransposeGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
std::vector<int64_t> perm;
if (GetDynamicAttrIntValue(inputs, kAxisIndex1st, inputsOnHost, kernel_name_, &perm)) {
GetPermValue(perm);
get_dynamic_perm_value_ = true;
}
if (int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
return ret;
}
input_shape_ = inputs[kAxisIndexZero]->GetDeviceShapeAdaptively();
shape_size_ = input_shape_.size();
if (shape_size_ > TRANSPOSE_MAX_DIMENSION) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of output cannot be greater than "
<< TRANSPOSE_MAX_DIMENSION << ", but got " << shape_size_;
}
workspace_size_ = shape_size_ * sizeof(size_t);
workspace_size_list_.push_back(workspace_size_);
workspace_size_list_.push_back(workspace_size_);
return KRET_OK;
}
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Transpose, TransposeGpuKernelMod);
} // namespace kernel
} // namespace mindspore

View File

@ -19,138 +19,52 @@
#include <vector>
#include <algorithm>
#include <utility>
#include <map>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cuh"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl_opt.cuh"
namespace mindspore {
namespace kernel {
constexpr size_t kDimSize4 = 4;
constexpr size_t kAxisZero = 0;
constexpr size_t kAxis1st = 1;
constexpr size_t kAxis2nd = 2;
constexpr size_t kAxis3rd = 3;
constexpr size_t kAxisIndexZero = 0;
constexpr size_t kAxisIndex1st = 1;
constexpr size_t kAxisIndex2nd = 2;
constexpr size_t kAxisIndex3rd = 3;
template <typename T>
class TransposeFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
class TransposeGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelper<TransposeGpuKernelMod> {
public:
TransposeFwdGpuKernelMod() { ResetResource(); }
~TransposeFwdGpuKernelMod() = default;
TransposeGpuKernelMod() = default;
~TransposeGpuKernelMod() override = default;
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
if (is_null_input_) {
return true;
}
T *input = GetDeviceAddress<T>(inputs, 0);
T *output = GetDeviceAddress<T>(outputs, 0);
size_t *input_shape = GetDeviceAddress<size_t>(workspace, 0);
size_t *input_axis = GetDeviceAddress<size_t>(workspace, 1);
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync input_shape failed");
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
cudaMemcpyAsync(input_axis, &input_axis_[0], workspace_size_, cudaMemcpyHostToDevice,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemcpyAsync input_axis failed");
size_t size = input_size_ / sizeof(T);
size_t *h_input_shape = reinterpret_cast<size_t *>(&input_shape_[0]);
size_t *h_input_axis = &input_axis_[0];
if (shape_size_ == kDimSize4 && h_input_axis[kAxisIndexZero] == kAxisZero &&
h_input_axis[kAxisIndex1st] == kAxis3rd && h_input_axis[kAxisIndex2nd] == kAxis1st &&
h_input_axis[kAxisIndex3rd] == kAxis2nd) {
// nhwc->nchw: 0,3,1,2
CalNHWC2NCHWInterface(size, shape_size_, input, h_input_shape, h_input_axis, input_shape, input_axis, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else if (shape_size_ == kDimSize4 && h_input_axis[kAxisIndexZero] == kAxisZero &&
h_input_axis[kAxisIndex1st] == kAxis2nd && h_input_axis[kAxisIndex2nd] == kAxis3rd &&
h_input_axis[kAxisIndex3rd] == kAxis1st) {
// nchw->nhwc: 0,2,3,1
CalNCHW2NHWCInterface(size, shape_size_, input, h_input_shape, h_input_axis, input_shape, input_axis, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CalTranspose(size, input, input_shape, input_axis, shape_size_, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
kernel_node_ = kernel_node;
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs must be 1, but got " << input_num;
}
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs must be 1, but got " << output_num;
}
input_shape_ = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 0);
is_null_input_ = CHECK_SHAPE_NULL(input_shape_, kernel_name, "input");
if (is_null_input_) {
InitSizeLists();
return true;
}
shape_size_ = input_shape_.size();
if (shape_size_ > TRANSPOSE_MAX_DIMENSION) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of output cannot be greater than "
<< TRANSPOSE_MAX_DIMENSION << ", but got " << shape_size_;
}
input_size_ = sizeof(T) * SizeOf(input_shape_);
output_size_ = input_size_;
std::vector<int64_t> perm = GetAttr<std::vector<int64_t>>(kernel_node, "perm");
for (size_t j = 0; j < perm.size(); j++) {
auto p = (perm[j] >= 0) ? perm[j] : (perm.size() + perm[j]);
if (p < 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the perm value must be in [-" << perm.size() << ", "
<< (perm.size() - 1) << "], but got " << perm;
}
input_axis_.push_back(p);
}
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
shape_size_ = 0;
input_size_ = 0;
output_size_ = 0;
workspace_size_ = 0;
is_null_input_ = false;
input_shape_.clear();
input_axis_.clear();
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
output_size_list_.push_back(output_size_);
workspace_size_ = shape_size_ * sizeof(size_t);
workspace_size_list_.push_back(workspace_size_);
workspace_size_list_.push_back(workspace_size_);
return;
stream_ptr_ = stream_ptr;
return kernel_func_(this, inputs, workspace, outputs);
}
private:
std::vector<int64_t> input_shape_;
std::vector<size_t> input_axis_;
template <typename T>
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs);
size_t shape_size_;
size_t input_size_;
size_t output_size_;
size_t workspace_size_;
void GetPermValue(const std::vector<int64_t> &perm);
void *stream_ptr_{nullptr};
std::vector<int64_t> input_shape_;
std::vector<size_t> input_perm_;
size_t shape_size_{0};
size_t workspace_size_{0};
bool is_null_input_;
bool is_dynamic_perm_{false};
bool get_dynamic_perm_value_{false};
};
} // namespace kernel
} // namespace mindspore

View File

@ -16,94 +16,116 @@
#include "ops/transpose.h"
#include <vector>
#include <string>
#include <set>
#include <memory>
#include <algorithm>
#include "ops/op_utils.h"
#include "abstract/ops/op_infer.h"
#include "utils/check_convert_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
namespace {
abstract::ShapePtr TransposeInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("input_x size", SizeToLong(x_shape.size()), kGreaterThan, 0, op_name);
ShapeVector p_value;
ShapeVector p_value_raw;
if (x_shape[0] == 0) {
MS_EXCEPTION(ValueError) << "For 'Transpose', first dim of input_x's shape can not be 0, but got 0.";
}
MIND_API_OPERATOR_IMPL(Transpose, BaseOperator);
bool CheckAndGetPermValue(const std::vector<AbstractBasePtr> &input_args, ShapeVector *perm_value,
const PrimitivePtr &primitive) {
MS_EXCEPTION_IF_NULL(perm_value);
bool is_dynamic = false;
const std::string &op_name = primitive->name();
if (input_args.size() == 1) {
if (!primitive->HasAttr("perm")) {
MS_EXCEPTION(ValueError) << "For '" << op_name << "', the value of 'input_perm' is necessary, but missing it.";
}
ValuePtr perm = primitive->GetAttr("perm");
MS_EXCEPTION_IF_NULL(perm);
auto perm_val = perm->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(perm_val);
auto perm_val_data = perm_val->value();
(void)std::transform(std::begin(perm_val_data), std::end(perm_val_data), std::back_inserter(p_value_raw),
[](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
} else {
auto perm_value = input_args[1]->BuildValue();
MS_EXCEPTION_IF_NULL(perm_value);
if (perm_value->isa<tensor::Tensor>()) {
p_value_raw = CheckAndConvertUtils::CheckTensorIntValue("perm", perm_value, op_name);
*perm_value = CheckAndConvertUtils::CheckTupleInt("perm", perm, op_name);
return is_dynamic;
}
auto input_value = input_args[kInputIndex1]->BuildValue();
if (input_args[kInputIndex1]->isa<abstract::AbstractTuple>()) {
*perm_value = CheckAndConvertUtils::CheckTupleInt("perm", input_value, op_name);
} else if (input_args[kInputIndex1]->isa<abstract::AbstractTensor>()) {
if (input_value->isa<tensor::Tensor>()) {
*perm_value = CheckAndConvertUtils::CheckTensorIntValue("perm", input_value, op_name);
} else {
p_value_raw = CheckAndConvertUtils::CheckTupleInt("input[perm]", perm_value, op_name);
is_dynamic = true;
auto perm_shape = CheckAndConvertUtils::GetTensorInputShape("perm", input_args, 1);
if (perm_shape->shape().size() != 1) {
MS_EXCEPTION(ValueError) << "For 'transpose perm', " << op_name << " must be 1-D, but got"
<< perm_shape->shape().size() << "-D.";
}
}
} else {
MS_LOG(EXCEPTION) << "For '" << op_name
<< "', the second input type should be tensor or scalar, but got invalid abstract type:"
<< input_args[kInputIndex1]->type_name() << ".";
}
for (auto p : p_value_raw) {
p = (p >= 0) ? p : (p_value_raw.size() + p);
p_value.emplace_back(p);
}
if (!IsDynamicRank(x_shape) && x_shape.size() != p_value.size()) {
MS_EXCEPTION(ValueError) << "For '" << op_name << "', the dim of 'input_x' and 'input_perm' must be equal, but got "
<< x_shape.size() << " and " << p_value.size() << " respectively.";
}
for (auto i : p_value) {
(void)CheckAndConvertUtils::CheckInteger("perm element", i, kLessThan, SizeToLong(p_value.size()), op_name);
}
std::vector<int64_t> tmp(p_value);
for (auto it = tmp.begin(); it != tmp.end();) {
auto dim = *it;
if (!tmp.empty()) {
it = tmp.erase(it);
}
if (std::find(tmp.begin(), tmp.end(), dim) != tmp.end()) {
MS_EXCEPTION(ValueError) << "For '" << op_name << "', the value of perm is wrong.";
}
}
if (IsDynamicRank(x_shape)) {
return std::make_shared<abstract::Shape>(std::vector<int64_t>{UNKNOWN_RANK});
}
std::vector<int64_t> in_shape(p_value);
(void)std::transform(in_shape.begin(), in_shape.end(), in_shape.begin(), [x_shape](size_t i) { return x_shape[i]; });
return std::make_shared<abstract::Shape>(in_shape);
return is_dynamic;
}
TypePtr TransposeInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(prim);
return CheckAndConvertUtils::CheckSubClass("input_x", input_args[0]->BuildType(), {kTensorType}, prim->name());
}
} // namespace
class TransposeInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
(void)CheckAndConvertUtils::CheckInteger("input_x size", SizeToLong(x_shape.size()), kGreaterThan, 0, op_name);
ShapeVector p_value;
ShapeVector p_value_raw;
if (x_shape[0] == 0) {
MS_EXCEPTION(ValueError) << "For 'Transpose', first dim of input_x's shape can not be 0, but got 0.";
}
if (IsDynamicRank(x_shape)) {
return std::make_shared<abstract::Shape>(std::vector<int64_t>{UNKNOWN_RANK});
}
MIND_API_OPERATOR_IMPL(Transpose, BaseOperator);
AbstractBasePtr TransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
for (const auto &item : input_args) {
MS_EXCEPTION_IF_NULL(item);
bool perm_is_dynamic = CheckAndGetPermValue(input_args, &p_value_raw, primitive);
if (perm_is_dynamic) {
ShapeVector out_shape;
(void)out_shape.insert(out_shape.end(), x_shape.size(), -1);
return std::make_shared<abstract::Shape>(out_shape);
}
for (auto p : p_value_raw) {
p = (p >= 0) ? p : (p_value_raw.size() + p);
p_value.emplace_back(p);
}
if (!IsDynamicRank(x_shape) && x_shape.size() != p_value.size()) {
MS_EXCEPTION(ValueError) << "For '" << op_name << "', the dim of 'input_x' and 'perm' must be equal, but got "
<< x_shape.size() << " and " << p_value.size() << " respectively.";
}
for (auto i : p_value) {
(void)CheckAndConvertUtils::CheckInteger("perm element", i, kLessThan, SizeToLong(p_value.size()), op_name);
}
std::vector<int64_t> tmp(p_value);
for (auto it = tmp.begin(); it != tmp.end();) {
auto dim = *it;
if (!tmp.empty()) {
it = tmp.erase(it);
}
if (std::find(tmp.begin(), tmp.end(), dim) != tmp.end()) {
MS_EXCEPTION(ValueError) << "For '" << op_name << "', the value of perm is wrong.";
}
}
std::vector<int64_t> in_shape(p_value);
(void)std::transform(in_shape.begin(), in_shape.end(), in_shape.begin(),
[x_shape](size_t i) { return x_shape[i]; });
return std::make_shared<abstract::Shape>(in_shape);
}
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputIndex1, primitive->name());
auto type = TransposeInferType(primitive, input_args);
auto shape = TransposeInferShape(primitive, input_args);
return abstract::MakeAbstract(shape, type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(Transpose, prim::kPrimTranspose, TransposeInfer, nullptr, true);
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(prim);
return CheckAndConvertUtils::CheckSubClass("input_x", input_args[0]->BuildType(), {kTensorType}, prim->name());
}
std::set<int64_t> GetValueDependArgIndices() const override { return {1}; }
};
REGISTER_PRIMITIVE_OP_INFER_IMPL(Transpose, prim::kPrimTranspose, TransposeInfer, false);
} // namespace ops
} // namespace mindspore

View File

@ -35,8 +35,6 @@ class MIND_API Transpose : public BaseOperator {
/// \brief Init.
void Init() const {}
};
abstract::AbstractBasePtr TransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore