[feat] [assistant] [ops] [I5EWIZ] New GPU operator impletation, include DataFormatVecPermute
This commit is contained in:
parent
b06f5623de
commit
0cce5cfc70
|
@ -0,0 +1,66 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/data_format_vec_permute_impl.cuh"
|
||||
#include "include/cuda_runtime.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void DataFormatVecPermuteKernel1D(const size_t size, const T *input, T *output, int32_t *index) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
output[pos] = input[index[pos]];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void DataFormatVecPermuteKernel2D(const size_t size, const T *input, T *output, int32_t *index) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
|
||||
int32_t dim = static_cast<int32_t>(2);
|
||||
int32_t i = static_cast<int32_t>(pos) / dim;
|
||||
output[dim * i] = input[dim * index[i]];
|
||||
output[dim * i + 1] = input[dim * index[i]+1];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalDataFormatVecPermute1D(const size_t size, const T *input, T *output, int32_t *index, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream) {
|
||||
DataFormatVecPermuteKernel1D<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>(size, input,
|
||||
output,
|
||||
index);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalDataFormatVecPermute2D(const size_t size, const T *input, T *output, int32_t *index, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream) {
|
||||
DataFormatVecPermuteKernel2D<<<CUDA_BLOCKS(device_id, size), CUDA_THREADS(device_id), 0, cuda_stream>>>(size, input,
|
||||
output,
|
||||
index);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void CalDataFormatVecPermute1D<int>(const size_t size, const int *input, int *output,
|
||||
int32_t *index, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalDataFormatVecPermute1D<int64_t>(const size_t size, const int64_t *input,
|
||||
int64_t *output, int32_t *index,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalDataFormatVecPermute2D<int>(const size_t size, const int *input, int *output,
|
||||
int32_t *index, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalDataFormatVecPermute2D<int64_t>(const size_t size, const int64_t *input,
|
||||
int64_t *output, int32_t *index,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_DATEFORMATEVECPERMUTE_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_DATEFORMATEVECPERMUTE_IMPL_CUH_
|
||||
#include <vector>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void CalDataFormatVecPermute1D(const size_t size, const T *input, T *output, int32_t *index,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void CalDataFormatVecPermute2D(const size_t size, const T *input, T *output, int32_t *index,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_DATEFORMATEVECPERMUTE_IMPL_CUH_
|
|
@ -0,0 +1,145 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/nn/data_format_vec_permute_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
const std::vector<int32_t> kDataSameFormat = {0, 1, 2, 3};
|
||||
const std::vector<int32_t> kDataNHWC2NCHW = {0, 3, 1, 2};
|
||||
const std::vector<int32_t> kDataNCHW2NHWC = {0, 2, 3, 1};
|
||||
constexpr const size_t k1DElementNum = 4;
|
||||
|
||||
bool DataFormatVecPermuteGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr_ = std::dynamic_pointer_cast<ops::DataFormatVecPermute>(base_operator);
|
||||
kernel_name_ = kernel_ptr_->name();
|
||||
if (inputs.empty() || outputs.empty()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
|
||||
return false;
|
||||
}
|
||||
constexpr int INPUT_NUM = 1;
|
||||
constexpr int OUTPUT_NUM = 1;
|
||||
if (inputs.size() != INPUT_NUM || outputs.size() != OUTPUT_NUM) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', input and output must be " << INPUT_NUM << " and " << OUTPUT_NUM
|
||||
<< ", but got " << inputs.size() << " and " << outputs.size();
|
||||
}
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', the kernel type should be in [int32, int64], but got: " << kernel_attr << ".";
|
||||
return false;
|
||||
}
|
||||
src_format = kernel_ptr_->get_src_format();
|
||||
dst_format = kernel_ptr_->get_dst_format();
|
||||
if (src_format == dst_format) {
|
||||
data_map_ = kDataSameFormat;
|
||||
} else if (src_format == "NHWC" && dst_format == "NCHW") {
|
||||
data_map_ = kDataNHWC2NCHW;
|
||||
} else if (src_format == "NCHW" && dst_format == "NHWC") {
|
||||
data_map_ = kDataNCHW2NHWC;
|
||||
} else {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', src_format and dst_format must be 'NCHW' or 'NHWC' "
|
||||
<< ", but got src_format " << src_format << " dst_format " << dst_format;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
unit_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
|
||||
return true;
|
||||
}
|
||||
|
||||
int DataFormatVecPermuteGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
for (const auto &input : inputs) {
|
||||
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
|
||||
auto input_shape = input->GetShapeVector();
|
||||
if (!IsValidShape(input_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
ResetResource();
|
||||
std::vector<int64_t> input_shape_ = std::vector<int64_t>(inputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(),
|
||||
inputs.at(kIndex0)->GetDeviceShapeAdaptively().end());
|
||||
std::vector<int64_t> output_shape_ = std::vector<int64_t>(outputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(),
|
||||
outputs.at(kIndex0)->GetDeviceShapeAdaptively().end());
|
||||
std::vector<int64_t> shape1 = {4};
|
||||
std::vector<int64_t> shape2 = {4, 2};
|
||||
if (input_shape_ != shape1 && input_shape_ != shape2) {
|
||||
MS_EXCEPTION(ValueError) << "For " << kernel_name_ << ", input shape must be (4, ) or (4, 2), but got "
|
||||
<< input_shape_ << ".";
|
||||
}
|
||||
auto in_shape_size = input_shape_.size();
|
||||
auto output_shape_size = output_shape_.size();
|
||||
if (in_shape_size != output_shape_size) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', input shape size should be the same as output shape size, but got"
|
||||
<< " input shape size " << in_shape_size << " output shape size" << output_shape_size;
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
output_elements_ = std::accumulate(output_shape_.begin(), output_shape_.end(), 1, std::multiplies<int64_t>());
|
||||
if (output_elements_ == 0) {
|
||||
is_null_input_ = true;
|
||||
}
|
||||
InitSizeLists();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool DataFormatVecPermuteGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
auto *index = GetDeviceAddress<int32_t>(workspace, kIndex0);
|
||||
|
||||
// code block for sync dim_map
|
||||
{
|
||||
CHECK_CUDA_RET_WITH_EXCEPT_NOTRACE(
|
||||
cudaMemcpyAsync(index, data_map_.data(), kDataFormatNum * sizeof(int32_t), cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream_)),
|
||||
"cudaMemcpy failed in DataFormatVecPermuteGpuKernelMod::LaunchKernel.");
|
||||
}
|
||||
|
||||
if (output_elements_ == k1DElementNum) {
|
||||
CalDataFormatVecPermute1D(output_elements_, input, output, index, device_id_,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
} else {
|
||||
CalDataFormatVecPermute2D(output_elements_, input, output, index, device_id_,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, DataFormatVecPermuteGpuKernelMod::DataFormatVecPermuteFunc>>
|
||||
DataFormatVecPermuteGpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
&DataFormatVecPermuteGpuKernelMod::LaunchKernel<int>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&DataFormatVecPermuteGpuKernelMod::LaunchKernel<int64_t>}};
|
||||
|
||||
std::vector<KernelAttr> DataFormatVecPermuteGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, DataFormatVecPermuteFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, DataFormatVecPermute, DataFormatVecPermuteGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,99 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_GPU_KERNEL_H_
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include "mindspore/core/ops/data_format_vec_permute.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/data_format_vec_permute_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr const size_t kDataFormatNum = 4;
|
||||
|
||||
class DataFormatVecPermuteGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
DataFormatVecPermuteGpuKernelMod() { ResetResource(); }
|
||||
~DataFormatVecPermuteGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
cuda_stream_ = cuda_stream;
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
protected:
|
||||
void ResetResource() noexcept {
|
||||
output_elements_ = 0;
|
||||
is_null_input_ = false;
|
||||
input_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
}
|
||||
|
||||
void InitSizeLists() {
|
||||
size_t x_size = output_elements_ * unit_size_;
|
||||
size_t work_size = kDataFormatNum * sizeof(int32_t);
|
||||
input_size_list_.emplace_back(x_size);
|
||||
workspace_size_list_.emplace_back(work_size);
|
||||
output_size_list_.emplace_back(x_size);
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
using DataFormatVecPermuteFunc =
|
||||
std::function<bool(DataFormatVecPermuteGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
|
||||
private:
|
||||
size_t unit_size_{1};
|
||||
size_t output_elements_{};
|
||||
DataFormatVecPermuteFunc kernel_func_{};
|
||||
BaseOperatorPtr kernel_ptr_{nullptr};
|
||||
std::string src_format{"NHWC"};
|
||||
std::string dst_format{"NCHW"};
|
||||
std::vector<int32_t> data_map_;
|
||||
bool is_null_input_{false};
|
||||
void *cuda_stream_{nullptr};
|
||||
static std::vector<std::pair<KernelAttr, DataFormatVecPermuteFunc>> func_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_GPU_KERNEL_H_
|
|
@ -33,11 +33,15 @@ abstract::ShapePtr DataFormatVecPermuteInferShape(const PrimitivePtr &primitive,
|
|||
auto prim_name = primitive->name();
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
|
||||
auto x_shape_ptr = input_args[kInputIndex0]->BuildShape()->cast<abstract::ShapePtr>();
|
||||
std::vector<int64_t> shape1 = {4};
|
||||
std::vector<int64_t> shape2 = {4, 2};
|
||||
if (x_shape != shape1 && x_shape != shape2) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", input shape must be (4, ) or (4, 2), but got " << x_shape
|
||||
<< ".";
|
||||
if (input_args[kInputIndex0]->isa<abstract::AbstractTensor>() &&
|
||||
!input_args[kInputIndex0]->BuildValue()->isa<AnyValue>() &&
|
||||
!input_args[kInputIndex0]->BuildValue()->isa<None>()) {
|
||||
std::vector<int64_t> shape1 = {4};
|
||||
std::vector<int64_t> shape2 = {4, 2};
|
||||
if (x_shape != shape1 && x_shape != shape2) {
|
||||
MS_EXCEPTION(ValueError) << "For " << prim_name << ", input shape must be (4, ) or (4, 2), but got " << x_shape
|
||||
<< ".";
|
||||
}
|
||||
}
|
||||
return x_shape_ptr;
|
||||
}
|
||||
|
@ -50,6 +54,31 @@ TypePtr DataFormatVecPermuteInferType(const PrimitivePtr &prim, const std::vecto
|
|||
}
|
||||
} // namespace
|
||||
|
||||
void DataFormatVecPermute::Init(const std::string &src_format, const std::string &dst_format) {
|
||||
this->set_src_format(src_format);
|
||||
this->set_dst_format(dst_format);
|
||||
}
|
||||
|
||||
void DataFormatVecPermute::set_src_format(const std::string &src_format) {
|
||||
CheckAndConvertUtils::CheckString(kSrcFormat, src_format, {"NHWC", "NCHW"}, this->name());
|
||||
(void)this->AddAttr(kSrcFormat, api::MakeValue(src_format));
|
||||
}
|
||||
|
||||
void DataFormatVecPermute::set_dst_format(const std::string &dst_format) {
|
||||
CheckAndConvertUtils::CheckString(kSrcFormat, dst_format, {"NHWC", "NCHW"}, this->name());
|
||||
(void)this->AddAttr(kDstFormat, api::MakeValue(dst_format));
|
||||
}
|
||||
|
||||
std::string DataFormatVecPermute::get_src_format() const {
|
||||
auto value_ptr = this->GetAttr(kSrcFormat);
|
||||
return GetValue<std::string>(value_ptr);
|
||||
}
|
||||
|
||||
std::string DataFormatVecPermute::get_dst_format() const {
|
||||
auto value_ptr = this->GetAttr(kDstFormat);
|
||||
return GetValue<std::string>(value_ptr);
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(DataFormatVecPermute, BaseOperator);
|
||||
AbstractBasePtr DataFormatVecPermuteInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
|
|
|
@ -35,6 +35,16 @@ class MIND_API DataFormatVecPermute : public BaseOperator {
|
|||
MIND_API_BASE_MEMBER(DataFormatVecPermute);
|
||||
/// \brief Constructor.
|
||||
DataFormatVecPermute() : BaseOperator(kNameDataFormatVecPermute) { InitIOName({"x"}, {"y"}); }
|
||||
/// \brief Init.
|
||||
void Init(const std::string &src_format = "NHWC", const std::string &dst_format = "NCHW");
|
||||
/// \brief Set src_format.
|
||||
void set_src_format(const std::string &src_format);
|
||||
/// \brief Set dst_format.
|
||||
void set_dst_format(const std::string &dst_format);
|
||||
/// \brief Get src_format.
|
||||
std::string get_src_format() const;
|
||||
/// \brief Get dst_format.
|
||||
std::string get_dst_format() const;
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr DataFormatVecPermuteInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -1405,13 +1405,13 @@ class DataFormatVecPermute(Primitive):
|
|||
ValueError: If input_x shape is not (4, ) or (4, 2).
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> class Net(nn.Cell):
|
||||
... def __init__(self, src_format="NHWC", dst_format="NCHW"):
|
||||
... super().__init__()
|
||||
... self.op = P.DataFormatVecPermute(src_format, dst_format)
|
||||
... self.op = P.nn_ops.DataFormatVecPermute(src_format, dst_format)
|
||||
... def construct(self, x):
|
||||
... return self.op(x)
|
||||
...
|
||||
|
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class DataFormatVecPermuteNet(nn.Cell):
|
||||
|
||||
def __init__(self, src_format, dst_format):
|
||||
super().__init__()
|
||||
self.op = P.nn_ops.DataFormatVecPermute(src_format, dst_format)
|
||||
|
||||
def construct(self, x):
|
||||
return self.op(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_data_format_vec_permute_1d_input_int32():
|
||||
"""
|
||||
Feature: DataFormatVecPermute gpu TEST.
|
||||
Description: 1d test case for DataFormatVecPermute, "NHWC" to "NCHW"
|
||||
Expectation: The value and shape of output are the expected values.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
x_ms = Tensor(np.array([1, 2, 3, 4]).astype(np.int32))
|
||||
net = DataFormatVecPermuteNet(src_format="NHWC", dst_format="NCHW")
|
||||
z_ms = net(x_ms)
|
||||
expect = np.array([1, 4, 2, 3]).astype(np.int32)
|
||||
|
||||
assert (z_ms.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_data_format_vec_permute_2d_input_int64():
|
||||
"""
|
||||
Feature: DataFormatVecPermute gpu TEST.
|
||||
Description: 2d test case for DataFormatVecPermute, "NCHW" to "NHWC"
|
||||
Expectation: The value and shape of output are the expected values.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
|
||||
x_ms = Tensor(np.array([[1, 1], [2, 2], [3, 3], [4, 4]]).astype(np.int64))
|
||||
net = DataFormatVecPermuteNet(src_format="NCHW", dst_format="NHWC")
|
||||
z_ms = net(x_ms)
|
||||
expect = np.array([[1, 1], [3, 3], [4, 4], [2, 2]]).astype(np.int64)
|
||||
|
||||
assert (z_ms.asnumpy() == expect).all()
|
Loading…
Reference in New Issue