transpose support input perm is tensor
This commit is contained in:
parent
66ba9e952b
commit
e4e4e2359e
|
@ -26,6 +26,7 @@ namespace mindspore {
|
|||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kTransposeInputsNum = 1;
|
||||
constexpr size_t kDynamicPermInputNum = 2;
|
||||
constexpr size_t kTransposeOutputsNum = 1;
|
||||
constexpr size_t kIndex0 = 0;
|
||||
constexpr size_t kIndex1 = 1;
|
||||
|
@ -41,12 +42,7 @@ using complex128 = std::complex<double>;
|
|||
constexpr size_t kMaxTransposeSerialSize = 50331648;
|
||||
} // namespace
|
||||
|
||||
void TransposeFwdCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
perm_ = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "perm");
|
||||
void TransposeFwdCpuKernelMod::CheckPermValue() {
|
||||
for (auto &p : perm_) {
|
||||
p = (p >= 0) ? p : (SizeToLong(perm_.size()) + p);
|
||||
if (p < 0) {
|
||||
|
@ -54,20 +50,45 @@ void TransposeFwdCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
<< (perm_.size() - 1) << "], but got " << p;
|
||||
}
|
||||
}
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
if (!IsDynamicRank(input_shape_) && perm_.size() != input_shape_.size()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the perm's size must be equal to input_shape's size, but got "
|
||||
<< perm_.size() << " vs " << input_shape_.size();
|
||||
}
|
||||
|
||||
if (perm_.size() > MAX_TRANSPOSE_DIM_SIZE) {
|
||||
MS_LOG(EXCEPTION) << "Transpose support max dimension is " << MAX_TRANSPOSE_DIM_SIZE << "D, but got "
|
||||
<< perm_.size() << "D.";
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TransposeFwdCpuKernelMod::InitPerm(const std::vector<kernel::AddressPtr> &inputs) {
|
||||
auto cnode = cnode_ptr_.lock();
|
||||
auto perm_shape = common::AnfAlgo::GetPrevNodeOutputInferShape(cnode, kIndex1);
|
||||
auto perm_ptr = static_cast<T *>(inputs[kIndex1]->addr);
|
||||
std::vector<T> perm{perm_ptr, perm_ptr + perm_shape[0]};
|
||||
(void)std::transform(perm.begin(), perm.end(), std::back_inserter(perm_),
|
||||
[](const T &value) { return static_cast<int64_t>(value); });
|
||||
CheckPermValue();
|
||||
}
|
||||
|
||||
void TransposeFwdCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
cnode_ptr_ = kernel_node;
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1 && input_num != kDynamicPermInputNum) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 1 or " << kDynamicPermInputNum
|
||||
<< ", but got " << input_num;
|
||||
}
|
||||
if (output_shape_.size() != input_shape_.size()) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the output_shape's size must be equal to input_shape's size, but got "
|
||||
<< output_shape_.size() << " vs " << input_shape_.size();
|
||||
}
|
||||
if (perm_.size() > MAX_TRANSPOSE_DIM_SIZE) {
|
||||
MS_LOG(EXCEPTION) << "Transpose support max dimension is " << MAX_TRANSPOSE_DIM_SIZE << "D, but got "
|
||||
<< perm_.size() << "D.";
|
||||
}
|
||||
dtype_ = AnfAlgo::GetInputDeviceDataType(kernel_node, 0);
|
||||
num_axes_ = input_shape_.size();
|
||||
if (num_axes_ == 0) {
|
||||
MS_LOG(EXCEPTION) << "Transpose's input shape is empty.";
|
||||
|
@ -100,13 +121,25 @@ void TransposeFwdCpuKernelMod::InitKernel(const CNodePtr &kernel_node) {
|
|||
} else {
|
||||
MS_LOG(EXCEPTION) << "Unsupported input data type: " << dtype_;
|
||||
}
|
||||
if (input_num == kDynamicPermInputNum) {
|
||||
perm_type_ = AnfAlgo::GetInputDeviceDataType(kernel_node, kIndex1);
|
||||
return;
|
||||
}
|
||||
perm_ = common::AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "perm");
|
||||
CheckPermValue();
|
||||
}
|
||||
|
||||
bool TransposeFwdCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), kTransposeInputsNum, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kTransposeOutputsNum, kernel_name_);
|
||||
if (inputs.size() == kDynamicPermInputNum) {
|
||||
if (perm_type_ == kNumberTypeInt32) {
|
||||
InitPerm<int32_t>(inputs);
|
||||
} else {
|
||||
InitPerm<int64_t>(inputs);
|
||||
}
|
||||
}
|
||||
launch_func_(this, inputs, outputs);
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -51,7 +51,47 @@ class TransposeFwdCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128)};
|
||||
KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeUInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat64),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeInt32)
|
||||
.AddOutputAttr(kNumberTypeComplex128),
|
||||
KernelAttr().AddInputAttr(kNumberTypeBool).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeBool),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt8).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt8),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeUInt64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat16),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat32),
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeFloat64),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex64)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeComplex64),
|
||||
KernelAttr()
|
||||
.AddInputAttr(kNumberTypeComplex128)
|
||||
.AddInputAttr(kNumberTypeInt64)
|
||||
.AddOutputAttr(kNumberTypeComplex128)};
|
||||
return support_list;
|
||||
}
|
||||
|
||||
|
@ -77,9 +117,15 @@ class TransposeFwdCpuKernelMod : public DeprecatedNativeCpuKernelMod {
|
|||
template <typename T>
|
||||
void TransposeDims(const T *in_data, T *out_data, int64_t task_id, int64_t thread_num) const;
|
||||
|
||||
void CheckPermValue();
|
||||
|
||||
template <typename T>
|
||||
void InitPerm(const std::vector<kernel::AddressPtr> &inputs);
|
||||
|
||||
std::vector<int64_t> input_shape_;
|
||||
std::vector<int64_t> output_shape_;
|
||||
TypeId dtype_{kTypeUnknown};
|
||||
TypeId perm_type_{kNumberTypeInt64};
|
||||
std::vector<int64_t> perm_;
|
||||
size_t num_axes_{0};
|
||||
size_t data_num_{0};
|
||||
|
|
|
@ -22,33 +22,184 @@ namespace kernel {
|
|||
template <typename T>
|
||||
using Complex = mindspore::utils::Complex<T>;
|
||||
|
||||
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeComplex64).AddOutputAttr(kNumberTypeComplex64),
|
||||
TransposeFwdGpuKernelMod, Complex<float>)
|
||||
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeComplex128).AddOutputAttr(kNumberTypeComplex128),
|
||||
TransposeFwdGpuKernelMod, Complex<double>)
|
||||
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
TransposeFwdGpuKernelMod, bool)
|
||||
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
TransposeFwdGpuKernelMod, double)
|
||||
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
TransposeFwdGpuKernelMod, float)
|
||||
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
TransposeFwdGpuKernelMod, half)
|
||||
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
TransposeFwdGpuKernelMod, int64_t)
|
||||
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
TransposeFwdGpuKernelMod, int)
|
||||
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
TransposeFwdGpuKernelMod, int16_t)
|
||||
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
TransposeFwdGpuKernelMod, int8_t)
|
||||
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
TransposeFwdGpuKernelMod, uint64_t)
|
||||
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
TransposeFwdGpuKernelMod, uint32_t)
|
||||
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
TransposeFwdGpuKernelMod, uint16_t)
|
||||
MS_REG_GPU_KERNEL_ONE(Transpose, KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
TransposeFwdGpuKernelMod, uint8_t)
|
||||
constexpr size_t kDynamicPermInputNum = 2;
|
||||
constexpr size_t kDimSize4 = 4;
|
||||
constexpr size_t kAxisZero = 0;
|
||||
constexpr size_t kAxis1st = 1;
|
||||
constexpr size_t kAxis2nd = 2;
|
||||
constexpr size_t kAxis3rd = 3;
|
||||
constexpr size_t kAxisIndexZero = 0;
|
||||
constexpr size_t kAxisIndex1st = 1;
|
||||
constexpr size_t kAxisIndex2nd = 2;
|
||||
constexpr size_t kAxisIndex3rd = 3;
|
||||
|
||||
#define STATIC_REGISTER(INPUTX, OUTPUT, T) \
|
||||
{ KernelAttr().AddInputAttr(INPUTX).AddOutputAttr(OUTPUT), &TransposeGpuKernelMod::LaunchKernel<T> }
|
||||
|
||||
#define DYN_REGISTER(INPUTX, PERM, OUTPUT, T) \
|
||||
{ \
|
||||
KernelAttr().AddInputAttr(INPUTX).AddInputAttr(PERM).AddOutputAttr(OUTPUT), \
|
||||
&TransposeGpuKernelMod::LaunchKernel<T> \
|
||||
}
|
||||
|
||||
const std::vector<std::pair<KernelAttr, TransposeGpuKernelMod::KernelRunFunc>> &TransposeGpuKernelMod::GetFuncList()
|
||||
const {
|
||||
static const std::vector<std::pair<KernelAttr, TransposeGpuKernelMod::KernelRunFunc>> func_list = {
|
||||
STATIC_REGISTER(kNumberTypeComplex64, kNumberTypeComplex64, Complex<float>),
|
||||
STATIC_REGISTER(kNumberTypeComplex128, kNumberTypeComplex128, Complex<double>),
|
||||
STATIC_REGISTER(kNumberTypeBool, kNumberTypeBool, bool),
|
||||
STATIC_REGISTER(kNumberTypeFloat64, kNumberTypeFloat64, double),
|
||||
STATIC_REGISTER(kNumberTypeFloat32, kNumberTypeFloat32, float),
|
||||
STATIC_REGISTER(kNumberTypeFloat16, kNumberTypeFloat16, half),
|
||||
STATIC_REGISTER(kNumberTypeInt64, kNumberTypeInt64, int64_t),
|
||||
STATIC_REGISTER(kNumberTypeInt32, kNumberTypeInt32, int32_t),
|
||||
STATIC_REGISTER(kNumberTypeInt16, kNumberTypeInt16, int16_t),
|
||||
STATIC_REGISTER(kNumberTypeInt8, kNumberTypeInt8, int8_t),
|
||||
STATIC_REGISTER(kNumberTypeUInt8, kNumberTypeUInt8, uint8_t),
|
||||
STATIC_REGISTER(kNumberTypeUInt16, kNumberTypeUInt16, uint16_t),
|
||||
STATIC_REGISTER(kNumberTypeUInt32, kNumberTypeUInt32, uint32_t),
|
||||
STATIC_REGISTER(kNumberTypeUInt64, kNumberTypeUInt64, uint64_t),
|
||||
DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt32, kNumberTypeComplex64, Complex<float>),
|
||||
DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt32, kNumberTypeComplex128, Complex<double>),
|
||||
DYN_REGISTER(kNumberTypeBool, kNumberTypeInt32, kNumberTypeBool, bool),
|
||||
DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt32, kNumberTypeFloat64, double),
|
||||
DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt32, kNumberTypeFloat32, float),
|
||||
DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt32, kNumberTypeFloat16, half),
|
||||
DYN_REGISTER(kNumberTypeInt64, kNumberTypeInt32, kNumberTypeInt64, int64_t),
|
||||
DYN_REGISTER(kNumberTypeInt32, kNumberTypeInt32, kNumberTypeInt32, int32_t),
|
||||
DYN_REGISTER(kNumberTypeInt16, kNumberTypeInt32, kNumberTypeInt16, int16_t),
|
||||
DYN_REGISTER(kNumberTypeInt8, kNumberTypeInt32, kNumberTypeInt8, int8_t),
|
||||
DYN_REGISTER(kNumberTypeUInt8, kNumberTypeInt32, kNumberTypeUInt8, uint8_t),
|
||||
DYN_REGISTER(kNumberTypeUInt16, kNumberTypeInt32, kNumberTypeUInt16, uint16_t),
|
||||
DYN_REGISTER(kNumberTypeUInt32, kNumberTypeInt32, kNumberTypeUInt32, uint32_t),
|
||||
DYN_REGISTER(kNumberTypeUInt64, kNumberTypeInt32, kNumberTypeUInt64, uint64_t),
|
||||
DYN_REGISTER(kNumberTypeComplex64, kNumberTypeInt64, kNumberTypeComplex64, Complex<float>),
|
||||
DYN_REGISTER(kNumberTypeComplex128, kNumberTypeInt64, kNumberTypeComplex128, Complex<double>),
|
||||
DYN_REGISTER(kNumberTypeBool, kNumberTypeInt64, kNumberTypeBool, bool),
|
||||
DYN_REGISTER(kNumberTypeFloat64, kNumberTypeInt64, kNumberTypeFloat64, double),
|
||||
DYN_REGISTER(kNumberTypeFloat32, kNumberTypeInt64, kNumberTypeFloat32, float),
|
||||
DYN_REGISTER(kNumberTypeFloat16, kNumberTypeInt64, kNumberTypeFloat16, half),
|
||||
DYN_REGISTER(kNumberTypeInt64, kNumberTypeInt64, kNumberTypeInt64, int64_t),
|
||||
DYN_REGISTER(kNumberTypeInt32, kNumberTypeInt64, kNumberTypeInt32, int32_t),
|
||||
DYN_REGISTER(kNumberTypeInt16, kNumberTypeInt64, kNumberTypeInt16, int16_t),
|
||||
DYN_REGISTER(kNumberTypeInt8, kNumberTypeInt64, kNumberTypeInt8, int8_t),
|
||||
DYN_REGISTER(kNumberTypeUInt8, kNumberTypeInt64, kNumberTypeUInt8, uint8_t),
|
||||
DYN_REGISTER(kNumberTypeUInt16, kNumberTypeInt64, kNumberTypeUInt16, uint16_t),
|
||||
DYN_REGISTER(kNumberTypeUInt32, kNumberTypeInt64, kNumberTypeUInt32, uint32_t),
|
||||
DYN_REGISTER(kNumberTypeUInt64, kNumberTypeInt64, kNumberTypeUInt64, uint64_t),
|
||||
};
|
||||
return func_list;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool TransposeGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
size_t *input_shape = GetDeviceAddress<size_t>(workspace, 0);
|
||||
size_t *input_axis = GetDeviceAddress<size_t>(workspace, 1);
|
||||
|
||||
if (is_dynamic_perm_ && !get_dynamic_perm_value_) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', fail to get value of the dynamic perm!";
|
||||
}
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
||||
"cudaMemcpyAsync input_shape failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(input_axis, &input_perm_[0], workspace_size_, cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr_)),
|
||||
"cudaMemcpyAsync input_axis failed");
|
||||
size_t size = SizeOf(input_shape_);
|
||||
size_t *h_input_shape = reinterpret_cast<size_t *>(&input_shape_[0]);
|
||||
size_t *h_input_axis = &input_perm_[0];
|
||||
if (shape_size_ == kDimSize4 && h_input_axis[kAxisIndexZero] == kAxisZero &&
|
||||
h_input_axis[kAxisIndex1st] == kAxis3rd && h_input_axis[kAxisIndex2nd] == kAxis1st &&
|
||||
h_input_axis[kAxisIndex3rd] == kAxis2nd) {
|
||||
// nhwc->nchw: 0,3,1,2
|
||||
CalNHWC2NCHWInterface(size, shape_size_, input, h_input_shape, h_input_axis, input_shape, input_axis, output,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr_));
|
||||
} else if (shape_size_ == kDimSize4 && h_input_axis[kAxisIndexZero] == kAxisZero &&
|
||||
h_input_axis[kAxisIndex1st] == kAxis2nd && h_input_axis[kAxisIndex2nd] == kAxis3rd &&
|
||||
h_input_axis[kAxisIndex3rd] == kAxis1st) {
|
||||
// nchw->nhwc: 0,2,3,1
|
||||
CalNCHW2NHWCInterface(size, shape_size_, input, h_input_shape, h_input_axis, input_shape, input_axis, output,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr_));
|
||||
} else {
|
||||
CalTranspose(size, input, input_shape, input_axis, shape_size_, output,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr_));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void TransposeGpuKernelMod::GetPermValue(const std::vector<int64_t> &perm) {
|
||||
for (size_t j = 0; j < perm.size(); j++) {
|
||||
auto p = (perm[j] >= 0) ? perm[j] : (perm.size() + perm[j]);
|
||||
if (p < 0) {
|
||||
MS_LOG(EXCEPTION) << "the perm value must be in [-" << perm.size() << ", " << (perm.size() - 1) << "], but got "
|
||||
<< perm;
|
||||
}
|
||||
input_perm_.push_back(p);
|
||||
}
|
||||
}
|
||||
|
||||
bool TransposeGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
|
||||
return false;
|
||||
}
|
||||
size_t input_num = inputs.size();
|
||||
size_t output_num = outputs.size();
|
||||
kernel_name_ = base_operator->name();
|
||||
if (input_num != 1 && input_num != kDynamicPermInputNum) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 1 or " << kDynamicPermInputNum
|
||||
<< ", but got " << input_num;
|
||||
}
|
||||
if (output_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of outputs must be 1, but got " << output_num;
|
||||
}
|
||||
if (input_num == kDynamicPermInputNum) {
|
||||
is_dynamic_perm_ = true;
|
||||
return true;
|
||||
}
|
||||
|
||||
auto attr = base_operator->GetPrim()->GetAttr(kAttrPerm);
|
||||
if (attr == nullptr) {
|
||||
MS_LOG(ERROR) << "The attr \"perm\" is not found in kernel 'Transpose'.";
|
||||
return false;
|
||||
}
|
||||
auto perm = GetValue<std::vector<int64_t>>(attr);
|
||||
GetPermValue(perm);
|
||||
return true;
|
||||
}
|
||||
|
||||
int TransposeGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
std::vector<int64_t> perm;
|
||||
if (GetDynamicAttrIntValue(inputs, kAxisIndex1st, inputsOnHost, kernel_name_, &perm)) {
|
||||
GetPermValue(perm);
|
||||
get_dynamic_perm_value_ = true;
|
||||
}
|
||||
if (int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
input_shape_ = inputs[kAxisIndexZero]->GetDeviceShapeAdaptively();
|
||||
shape_size_ = input_shape_.size();
|
||||
if (shape_size_ > TRANSPOSE_MAX_DIMENSION) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of output cannot be greater than "
|
||||
<< TRANSPOSE_MAX_DIMENSION << ", but got " << shape_size_;
|
||||
}
|
||||
|
||||
workspace_size_ = shape_size_ * sizeof(size_t);
|
||||
workspace_size_list_.push_back(workspace_size_);
|
||||
workspace_size_list_.push_back(workspace_size_);
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Transpose, TransposeGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,138 +19,52 @@
|
|||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/transpose_impl_opt.cuh"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t kDimSize4 = 4;
|
||||
constexpr size_t kAxisZero = 0;
|
||||
constexpr size_t kAxis1st = 1;
|
||||
constexpr size_t kAxis2nd = 2;
|
||||
constexpr size_t kAxis3rd = 3;
|
||||
constexpr size_t kAxisIndexZero = 0;
|
||||
constexpr size_t kAxisIndex1st = 1;
|
||||
constexpr size_t kAxisIndex2nd = 2;
|
||||
constexpr size_t kAxisIndex3rd = 3;
|
||||
|
||||
template <typename T>
|
||||
class TransposeFwdGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
||||
class TransposeGpuKernelMod : public NativeGpuKernelMod, public MatchKernelHelper<TransposeGpuKernelMod> {
|
||||
public:
|
||||
TransposeFwdGpuKernelMod() { ResetResource(); }
|
||||
~TransposeFwdGpuKernelMod() = default;
|
||||
TransposeGpuKernelMod() = default;
|
||||
~TransposeGpuKernelMod() override = default;
|
||||
|
||||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
size_t *input_shape = GetDeviceAddress<size_t>(workspace, 0);
|
||||
size_t *input_axis = GetDeviceAddress<size_t>(workspace, 1);
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(input_shape, &input_shape_[0], workspace_size_, cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync input_shape failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(input_axis, &input_axis_[0], workspace_size_, cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync input_axis failed");
|
||||
size_t size = input_size_ / sizeof(T);
|
||||
|
||||
size_t *h_input_shape = reinterpret_cast<size_t *>(&input_shape_[0]);
|
||||
size_t *h_input_axis = &input_axis_[0];
|
||||
if (shape_size_ == kDimSize4 && h_input_axis[kAxisIndexZero] == kAxisZero &&
|
||||
h_input_axis[kAxisIndex1st] == kAxis3rd && h_input_axis[kAxisIndex2nd] == kAxis1st &&
|
||||
h_input_axis[kAxisIndex3rd] == kAxis2nd) {
|
||||
// nhwc->nchw: 0,3,1,2
|
||||
CalNHWC2NCHWInterface(size, shape_size_, input, h_input_shape, h_input_axis, input_shape, input_axis, output,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else if (shape_size_ == kDimSize4 && h_input_axis[kAxisIndexZero] == kAxisZero &&
|
||||
h_input_axis[kAxisIndex1st] == kAxis2nd && h_input_axis[kAxisIndex2nd] == kAxis3rd &&
|
||||
h_input_axis[kAxisIndex3rd] == kAxis1st) {
|
||||
// nchw->nhwc: 0,2,3,1
|
||||
CalNCHW2NHWCInterface(size, shape_size_, input, h_input_shape, h_input_axis, input_shape, input_axis, output,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
} else {
|
||||
CalTranspose(size, input, input_shape, input_axis, shape_size_, output,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
auto kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
kernel_node_ = kernel_node;
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of inputs must be 1, but got " << input_num;
|
||||
}
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the number of outputs must be 1, but got " << output_num;
|
||||
}
|
||||
input_shape_ = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 0);
|
||||
is_null_input_ = CHECK_SHAPE_NULL(input_shape_, kernel_name, "input");
|
||||
if (is_null_input_) {
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
shape_size_ = input_shape_.size();
|
||||
if (shape_size_ > TRANSPOSE_MAX_DIMENSION) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the dimension of output cannot be greater than "
|
||||
<< TRANSPOSE_MAX_DIMENSION << ", but got " << shape_size_;
|
||||
}
|
||||
|
||||
input_size_ = sizeof(T) * SizeOf(input_shape_);
|
||||
output_size_ = input_size_;
|
||||
std::vector<int64_t> perm = GetAttr<std::vector<int64_t>>(kernel_node, "perm");
|
||||
for (size_t j = 0; j < perm.size(); j++) {
|
||||
auto p = (perm[j] >= 0) ? perm[j] : (perm.size() + perm[j]);
|
||||
if (p < 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name << "', the perm value must be in [-" << perm.size() << ", "
|
||||
<< (perm.size() - 1) << "], but got " << perm;
|
||||
}
|
||||
input_axis_.push_back(p);
|
||||
}
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
shape_size_ = 0;
|
||||
input_size_ = 0;
|
||||
output_size_ = 0;
|
||||
workspace_size_ = 0;
|
||||
is_null_input_ = false;
|
||||
input_shape_.clear();
|
||||
input_axis_.clear();
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_);
|
||||
output_size_list_.push_back(output_size_);
|
||||
workspace_size_ = shape_size_ * sizeof(size_t);
|
||||
workspace_size_list_.push_back(workspace_size_);
|
||||
workspace_size_list_.push_back(workspace_size_);
|
||||
return;
|
||||
stream_ptr_ = stream_ptr;
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int64_t> input_shape_;
|
||||
std::vector<size_t> input_axis_;
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
|
||||
size_t shape_size_;
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
size_t workspace_size_;
|
||||
void GetPermValue(const std::vector<int64_t> &perm);
|
||||
|
||||
void *stream_ptr_{nullptr};
|
||||
std::vector<int64_t> input_shape_;
|
||||
std::vector<size_t> input_perm_;
|
||||
|
||||
size_t shape_size_{0};
|
||||
size_t workspace_size_{0};
|
||||
bool is_null_input_;
|
||||
bool is_dynamic_perm_{false};
|
||||
bool get_dynamic_perm_value_{false};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -16,94 +16,116 @@
|
|||
|
||||
#include "ops/transpose.h"
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <set>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include "ops/op_utils.h"
|
||||
#include "abstract/ops/op_infer.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
abstract::ShapePtr TransposeInferShape(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("input_x size", SizeToLong(x_shape.size()), kGreaterThan, 0, op_name);
|
||||
ShapeVector p_value;
|
||||
ShapeVector p_value_raw;
|
||||
if (x_shape[0] == 0) {
|
||||
MS_EXCEPTION(ValueError) << "For 'Transpose', first dim of input_x's shape can not be 0, but got 0.";
|
||||
}
|
||||
MIND_API_OPERATOR_IMPL(Transpose, BaseOperator);
|
||||
|
||||
bool CheckAndGetPermValue(const std::vector<AbstractBasePtr> &input_args, ShapeVector *perm_value,
|
||||
const PrimitivePtr &primitive) {
|
||||
MS_EXCEPTION_IF_NULL(perm_value);
|
||||
bool is_dynamic = false;
|
||||
const std::string &op_name = primitive->name();
|
||||
|
||||
if (input_args.size() == 1) {
|
||||
if (!primitive->HasAttr("perm")) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << op_name << "', the value of 'input_perm' is necessary, but missing it.";
|
||||
}
|
||||
ValuePtr perm = primitive->GetAttr("perm");
|
||||
MS_EXCEPTION_IF_NULL(perm);
|
||||
auto perm_val = perm->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(perm_val);
|
||||
auto perm_val_data = perm_val->value();
|
||||
(void)std::transform(std::begin(perm_val_data), std::end(perm_val_data), std::back_inserter(p_value_raw),
|
||||
[](const ValuePtr &e) -> int64_t { return GetValue<int64_t>(e); });
|
||||
} else {
|
||||
auto perm_value = input_args[1]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(perm_value);
|
||||
if (perm_value->isa<tensor::Tensor>()) {
|
||||
p_value_raw = CheckAndConvertUtils::CheckTensorIntValue("perm", perm_value, op_name);
|
||||
*perm_value = CheckAndConvertUtils::CheckTupleInt("perm", perm, op_name);
|
||||
return is_dynamic;
|
||||
}
|
||||
|
||||
auto input_value = input_args[kInputIndex1]->BuildValue();
|
||||
if (input_args[kInputIndex1]->isa<abstract::AbstractTuple>()) {
|
||||
*perm_value = CheckAndConvertUtils::CheckTupleInt("perm", input_value, op_name);
|
||||
} else if (input_args[kInputIndex1]->isa<abstract::AbstractTensor>()) {
|
||||
if (input_value->isa<tensor::Tensor>()) {
|
||||
*perm_value = CheckAndConvertUtils::CheckTensorIntValue("perm", input_value, op_name);
|
||||
} else {
|
||||
p_value_raw = CheckAndConvertUtils::CheckTupleInt("input[perm]", perm_value, op_name);
|
||||
is_dynamic = true;
|
||||
auto perm_shape = CheckAndConvertUtils::GetTensorInputShape("perm", input_args, 1);
|
||||
if (perm_shape->shape().size() != 1) {
|
||||
MS_EXCEPTION(ValueError) << "For 'transpose perm', " << op_name << " must be 1-D, but got"
|
||||
<< perm_shape->shape().size() << "-D.";
|
||||
}
|
||||
}
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "For '" << op_name
|
||||
<< "', the second input type should be tensor or scalar, but got invalid abstract type:"
|
||||
<< input_args[kInputIndex1]->type_name() << ".";
|
||||
}
|
||||
for (auto p : p_value_raw) {
|
||||
p = (p >= 0) ? p : (p_value_raw.size() + p);
|
||||
p_value.emplace_back(p);
|
||||
}
|
||||
if (!IsDynamicRank(x_shape) && x_shape.size() != p_value.size()) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << op_name << "', the dim of 'input_x' and 'input_perm' must be equal, but got "
|
||||
<< x_shape.size() << " and " << p_value.size() << " respectively.";
|
||||
}
|
||||
for (auto i : p_value) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("perm element", i, kLessThan, SizeToLong(p_value.size()), op_name);
|
||||
}
|
||||
std::vector<int64_t> tmp(p_value);
|
||||
for (auto it = tmp.begin(); it != tmp.end();) {
|
||||
auto dim = *it;
|
||||
if (!tmp.empty()) {
|
||||
it = tmp.erase(it);
|
||||
}
|
||||
if (std::find(tmp.begin(), tmp.end(), dim) != tmp.end()) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << op_name << "', the value of perm is wrong.";
|
||||
}
|
||||
}
|
||||
if (IsDynamicRank(x_shape)) {
|
||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{UNKNOWN_RANK});
|
||||
}
|
||||
std::vector<int64_t> in_shape(p_value);
|
||||
(void)std::transform(in_shape.begin(), in_shape.end(), in_shape.begin(), [x_shape](size_t i) { return x_shape[i]; });
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
return is_dynamic;
|
||||
}
|
||||
|
||||
TypePtr TransposeInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
return CheckAndConvertUtils::CheckSubClass("input_x", input_args[0]->BuildType(), {kTensorType}, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
class TransposeInfer : public abstract::OpInferBase {
|
||||
public:
|
||||
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
(void)CheckAndConvertUtils::CheckInteger("input_x size", SizeToLong(x_shape.size()), kGreaterThan, 0, op_name);
|
||||
ShapeVector p_value;
|
||||
ShapeVector p_value_raw;
|
||||
if (x_shape[0] == 0) {
|
||||
MS_EXCEPTION(ValueError) << "For 'Transpose', first dim of input_x's shape can not be 0, but got 0.";
|
||||
}
|
||||
if (IsDynamicRank(x_shape)) {
|
||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{UNKNOWN_RANK});
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(Transpose, BaseOperator);
|
||||
AbstractBasePtr TransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
bool perm_is_dynamic = CheckAndGetPermValue(input_args, &p_value_raw, primitive);
|
||||
if (perm_is_dynamic) {
|
||||
ShapeVector out_shape;
|
||||
(void)out_shape.insert(out_shape.end(), x_shape.size(), -1);
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
for (auto p : p_value_raw) {
|
||||
p = (p >= 0) ? p : (p_value_raw.size() + p);
|
||||
p_value.emplace_back(p);
|
||||
}
|
||||
if (!IsDynamicRank(x_shape) && x_shape.size() != p_value.size()) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << op_name << "', the dim of 'input_x' and 'perm' must be equal, but got "
|
||||
<< x_shape.size() << " and " << p_value.size() << " respectively.";
|
||||
}
|
||||
for (auto i : p_value) {
|
||||
(void)CheckAndConvertUtils::CheckInteger("perm element", i, kLessThan, SizeToLong(p_value.size()), op_name);
|
||||
}
|
||||
std::vector<int64_t> tmp(p_value);
|
||||
for (auto it = tmp.begin(); it != tmp.end();) {
|
||||
auto dim = *it;
|
||||
if (!tmp.empty()) {
|
||||
it = tmp.erase(it);
|
||||
}
|
||||
if (std::find(tmp.begin(), tmp.end(), dim) != tmp.end()) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << op_name << "', the value of perm is wrong.";
|
||||
}
|
||||
}
|
||||
std::vector<int64_t> in_shape(p_value);
|
||||
(void)std::transform(in_shape.begin(), in_shape.end(), in_shape.begin(),
|
||||
[x_shape](size_t i) { return x_shape[i]; });
|
||||
return std::make_shared<abstract::Shape>(in_shape);
|
||||
}
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kGreaterEqual, kInputIndex1, primitive->name());
|
||||
auto type = TransposeInferType(primitive, input_args);
|
||||
auto shape = TransposeInferShape(primitive, input_args);
|
||||
return abstract::MakeAbstract(shape, type);
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(Transpose, prim::kPrimTranspose, TransposeInfer, nullptr, true);
|
||||
TypePtr InferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
return CheckAndConvertUtils::CheckSubClass("input_x", input_args[0]->BuildType(), {kTensorType}, prim->name());
|
||||
}
|
||||
|
||||
std::set<int64_t> GetValueDependArgIndices() const override { return {1}; }
|
||||
};
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(Transpose, prim::kPrimTranspose, TransposeInfer, false);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -35,8 +35,6 @@ class MIND_API Transpose : public BaseOperator {
|
|||
/// \brief Init.
|
||||
void Init() const {}
|
||||
};
|
||||
abstract::AbstractBasePtr TransposeInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
Loading…
Reference in New Issue