forked from mindspore-Ecosystem/mindspore
!41895 [assistant][ops][I5EWTB] New GPU operator implementation, include Extractvolumepatches
Merge pull request !41895 from 黎冠新/Extractvolumepatches_1
This commit is contained in:
commit
223cd9c2ab
|
@ -0,0 +1,179 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/arrays/extract_volume_patches_gpu_kernel.h"
|
||||
#include <functional>
|
||||
#include <utility>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include "abstract/utils.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/extract_volume_patches_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t kDimSize5 = 5;
|
||||
constexpr size_t kFormatNCDHWIndexC = 1;
|
||||
constexpr size_t kFormatNCDHWIndexD = 2;
|
||||
constexpr size_t kFormatNCDHWIndexH = 3;
|
||||
constexpr size_t kFormatNCDHWIndexW = 4;
|
||||
|
||||
bool ExtractVolumePatchesGpuKernelMod::Init(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::ExtractVolumePatches>(base_operator);
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
if (!kernel_ptr) {
|
||||
MS_LOG(EXCEPTION) << "cast ExtractVolumePatches ops failed!";
|
||||
return false;
|
||||
}
|
||||
kernel_size_ = kernel_ptr->get_kernel_size();
|
||||
strides_ = kernel_ptr->get_strides();
|
||||
padding_ = kernel_ptr->get_padding();
|
||||
size_t kernel_size_dims = kernel_size_.size();
|
||||
size_t strides_dims = strides_.size();
|
||||
if (!MatchKernelFunc(base_operator, inputs, outputs)) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
if (kernel_size_dims != kDimSize5) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'kernel_size' must be equal to 5, but got "
|
||||
<< kernel_size_dims << ".";
|
||||
return false;
|
||||
}
|
||||
if (strides_dims != kDimSize5) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'strides' must be equal to 5, but got "
|
||||
<< strides_dims << ".";
|
||||
return false;
|
||||
}
|
||||
if (padding_ != "VALID" && padding_ != "valid" && padding_ != "SAME" && padding_ != "same") {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', padding_ must be VALID, valid, SAME or same, but got " << padding_
|
||||
<< ".";
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int ExtractVolumePatchesGpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost);
|
||||
if (ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
input_shape_ = inputs[0]->GetShapeVector();
|
||||
output_shape_ = outputs[0]->GetShapeVector();
|
||||
size_t input_shape_dims = input_shape_.size();
|
||||
size_t output_shape_dims = output_shape_.size();
|
||||
// check parameter
|
||||
if (input_shape_dims != kDimSize5) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'input_shape' must be equal to 5, but got "
|
||||
<< input_shape_dims << ".";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
if (output_shape_dims != kDimSize5) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the dimension of 'output_shape' must be equal to 5, but got "
|
||||
<< output_shape_dims << ".";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
ksize_d_ = kernel_size_[kFormatNCDHWIndexD];
|
||||
ksize_h_ = kernel_size_[kFormatNCDHWIndexH];
|
||||
ksize_w_ = kernel_size_[kFormatNCDHWIndexW];
|
||||
stride_d_ = strides_[kFormatNCDHWIndexD];
|
||||
stride_h_ = strides_[kFormatNCDHWIndexH];
|
||||
stride_w_ = strides_[kFormatNCDHWIndexW];
|
||||
input_channel_ = input_shape_[kFormatNCDHWIndexC];
|
||||
input_depth_ = input_shape_[kFormatNCDHWIndexD];
|
||||
input_height_ = input_shape_[kFormatNCDHWIndexH];
|
||||
input_width_ = input_shape_[kFormatNCDHWIndexW];
|
||||
output_depth_ = output_shape_[kFormatNCDHWIndexD];
|
||||
output_height_ = output_shape_[kFormatNCDHWIndexH];
|
||||
output_width_ = output_shape_[kFormatNCDHWIndexW];
|
||||
|
||||
if (padding_.compare("VALID") == 0 || padding_.compare("valid") == 0) {
|
||||
pad_head_ = 0;
|
||||
pad_top_ = 0;
|
||||
pad_left_ = 0;
|
||||
}
|
||||
if (padding_.compare("SAME") == 0 || padding_.compare("same") == 0) {
|
||||
constexpr int64_t zero_value = 0;
|
||||
constexpr int64_t kMidDividend = 2;
|
||||
pad_head_ = std::max(zero_value, ((output_depth_ - 1) * stride_d_ + ksize_d_ - input_depth_) / kMidDividend);
|
||||
pad_top_ = std::max(zero_value, ((output_height_ - 1) * stride_h_ + ksize_h_ - input_height_) / kMidDividend);
|
||||
pad_left_ = std::max(zero_value, ((output_width_ - 1) * stride_w_ + ksize_w_ - input_width_) / kMidDividend);
|
||||
}
|
||||
output_size_ = 1;
|
||||
for (size_t i = 0; i < output_shape_.size(); i++) {
|
||||
output_size_ *= output_shape_[i];
|
||||
}
|
||||
d_stride_ = ksize_h_ * ksize_w_;
|
||||
h_stride_ = ksize_h_;
|
||||
w_stride_ = ksize_w_;
|
||||
patch_stride_ = output_depth_ * output_height_ * output_width_;
|
||||
other_stride_ = patch_stride_ * ksize_d_ * ksize_h_ * ksize_w_ * input_channel_;
|
||||
chan_input_stride_ = input_depth_ * input_height_ * input_width_;
|
||||
dep_input_stride_ = input_height_ * input_width_;
|
||||
row_input_stride_ = input_width_;
|
||||
patch_input_stride_ = input_channel_ * input_depth_ * input_height_ * input_width_;
|
||||
MS_EXCEPTION_IF_ZERO("other stride", other_stride_);
|
||||
need_batch_ = (output_size_ - 1) / other_stride_;
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ExtractVolumePatchesGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
T *input_ptr = GetDeviceAddress<T>(inputs, kIndex0);
|
||||
T *output_ptr = GetDeviceAddress<T>(outputs, kIndex0);
|
||||
CalExtractVolumePatches(output_size_, stride_d_, stride_h_, stride_w_, output_depth_, output_height_, output_width_,
|
||||
need_batch_, d_stride_, h_stride_, w_stride_, patch_stride_, other_stride_, input_channel_,
|
||||
input_depth_, input_height_, input_width_, pad_head_, pad_top_, pad_left_, chan_input_stride_,
|
||||
dep_input_stride_, row_input_stride_, patch_input_stride_, input_ptr, output_ptr,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
return true;
|
||||
}
|
||||
|
||||
using FuncList = std::vector<std::pair<KernelAttr, ExtractVolumePatchesGpuKernelMod::KernelRunFunc>>;
|
||||
const FuncList &ExtractVolumePatchesGpuKernelMod::GetFuncList() const {
|
||||
static const FuncList func_list_ = {{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&ExtractVolumePatchesGpuKernelMod::LaunchKernel<double>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&ExtractVolumePatchesGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
&ExtractVolumePatchesGpuKernelMod::LaunchKernel<half>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&ExtractVolumePatchesGpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
&ExtractVolumePatchesGpuKernelMod::LaunchKernel<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
&ExtractVolumePatchesGpuKernelMod::LaunchKernel<int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
&ExtractVolumePatchesGpuKernelMod::LaunchKernel<int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
&ExtractVolumePatchesGpuKernelMod::LaunchKernel<uint64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
&ExtractVolumePatchesGpuKernelMod::LaunchKernel<uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
&ExtractVolumePatchesGpuKernelMod::LaunchKernel<uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
&ExtractVolumePatchesGpuKernelMod::LaunchKernel<uint8_t>}};
|
||||
return func_list_;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, ExtractVolumePatches, ExtractVolumePatchesGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,105 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_EXTRACT_VOLUME_PATCHES_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_EXTRACT_VOLUME_PATCHES_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
|
||||
#include "mindspore/core/ops/extract_volume_patches.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/extract_volume_patches_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class ExtractVolumePatchesGpuKernelMod : public NativeGpuKernelMod,
|
||||
public MatchKernelHelper<ExtractVolumePatchesGpuKernelMod> {
|
||||
public:
|
||||
ExtractVolumePatchesGpuKernelMod() = default;
|
||||
~ExtractVolumePatchesGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *cuda_stream) override {
|
||||
cuda_stream_ = cuda_stream;
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(
|
||||
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
|
||||
const std::vector<std::pair<KernelAttr, KernelRunFunc>> &GetFuncList() const override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override { return OpSupport(); }
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
|
||||
void *cuda_stream_{nullptr};
|
||||
std::vector<int64_t> kernel_size_;
|
||||
std::vector<int64_t> strides_;
|
||||
std::string padding_;
|
||||
int64_t ksize_d_{1};
|
||||
int64_t ksize_h_{1};
|
||||
int64_t ksize_w_{1};
|
||||
int64_t stride_d_{1};
|
||||
int64_t stride_h_{1};
|
||||
int64_t stride_w_{1};
|
||||
size_t output_size_{1};
|
||||
int64_t input_channel_{1};
|
||||
int64_t input_depth_{1};
|
||||
int64_t input_height_{1};
|
||||
int64_t input_width_{1};
|
||||
int64_t output_depth_{1};
|
||||
int64_t output_height_{1};
|
||||
int64_t output_width_{1};
|
||||
|
||||
int64_t d_stride_{1};
|
||||
int64_t h_stride_{1};
|
||||
int64_t w_stride_{1};
|
||||
int64_t patch_stride_{1};
|
||||
int64_t other_stride_{1};
|
||||
int64_t chan_input_stride_{1};
|
||||
int64_t dep_input_stride_{1};
|
||||
int64_t row_input_stride_{1};
|
||||
|
||||
int64_t patch_input_stride_{1};
|
||||
bool need_batch_{1};
|
||||
int64_t pad_head_{0};
|
||||
int64_t pad_top_{0};
|
||||
int64_t pad_left_{0};
|
||||
bool is_null_input_{false};
|
||||
std::vector<int64_t> input_shape_;
|
||||
std::vector<int64_t> output_shape_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_EXTRACT_VOLUME_PATCHES_GPU_KERNEL_H_
|
|
@ -0,0 +1,171 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/extract_volume_patches_impl.cuh"
|
||||
#include "include/cuda_fp16.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void ExtractVolumePatches(size_t output_size, int64_t stride_dep, int64_t stride_row, int64_t stride_col,
|
||||
int64_t output_depth, int64_t output_height, int64_t output_width, bool need_batch,
|
||||
int64_t d_stride, int64_t h_stride, int64_t w_stride, int64_t patch_stride,
|
||||
int64_t other_stride, int64_t input_channel, int64_t input_dep_size,
|
||||
int64_t input_row_size, int64_t input_col_size, int64_t pad_head, int64_t pad_top,
|
||||
int64_t pad_left, int64_t chan_input_stride, int64_t dep_input_stride,
|
||||
int64_t row_input_stride, int64_t patch_input_stride, const T *input, T *output) {
|
||||
size_t pos;
|
||||
for (size_t w_pos = blockIdx.x * blockDim.x + threadIdx.x; w_pos < output_size / (w_stride * input_channel);
|
||||
w_pos += blockDim.x * gridDim.x) {
|
||||
pos = static_cast<size_t>(w_pos / patch_stride) * w_stride * input_channel * patch_stride + (w_pos % patch_stride);
|
||||
const int64_t batch_index = need_batch ? (static_cast<int64_t>(pos) / other_stride) : 0;
|
||||
const int64_t inner_index =
|
||||
need_batch ? (static_cast<int64_t>(pos) - batch_index * other_stride) : static_cast<int64_t>(pos);
|
||||
// inner index
|
||||
const int64_t patch_index = inner_index % patch_stride;
|
||||
const int64_t patch_offset = inner_index / patch_stride / input_channel;
|
||||
// channel
|
||||
const int64_t channel = inner_index / patch_stride % input_channel;
|
||||
// depth
|
||||
const int64_t dep_index = patch_index / (output_height * output_width);
|
||||
const int64_t dep_offset = patch_offset / d_stride;
|
||||
const int64_t input_dep = dep_index * stride_dep + dep_offset - pad_head;
|
||||
if (input_dep < 0 || input_dep >= input_dep_size) {
|
||||
continue;
|
||||
}
|
||||
// height
|
||||
const int64_t row_index = patch_index / output_width % output_height;
|
||||
const int64_t row_offset = patch_offset / w_stride % h_stride;
|
||||
const int64_t input_row = row_index * stride_row + row_offset - pad_top;
|
||||
if (input_row < 0 || input_row >= input_row_size) {
|
||||
continue;
|
||||
}
|
||||
// width
|
||||
const int64_t col_index = patch_index % output_width;
|
||||
const int64_t col_offset = patch_offset % w_stride;
|
||||
const int64_t input_col = col_index * stride_col + col_offset - pad_left;
|
||||
// input index
|
||||
const int64_t input_index = input_col + input_row * row_input_stride + input_dep * dep_input_stride +
|
||||
channel * chan_input_stride + batch_index * patch_input_stride;
|
||||
#pragma unroll
|
||||
for (int64_t i = 0; i < w_stride; i++) {
|
||||
if (input_col + i < 0) {
|
||||
continue;
|
||||
}
|
||||
if (input_col + i >= input_col_size) {
|
||||
break;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int64_t j = 0; j < input_channel; j++) {
|
||||
output[pos + (i * input_channel + j) * patch_stride] = input[input_index + i + j * chan_input_stride];
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalExtractVolumePatches(size_t output_size, int64_t stride_dep, int64_t stride_row, int64_t stride_col,
|
||||
int64_t output_depth, int64_t output_height, int64_t output_width, bool need_batch,
|
||||
int64_t d_stride, int64_t h_stride, int64_t w_stride, int64_t patch_stride,
|
||||
int64_t other_stride, int64_t input_channel, int64_t input_dep_size,
|
||||
int64_t input_row_size, int64_t input_col_size, int64_t pad_head, int64_t pad_top,
|
||||
int64_t pad_left, int64_t chan_input_stride, int64_t dep_input_stride,
|
||||
int64_t row_input_stride, int64_t patch_input_stride, const T *input, T *output,
|
||||
cudaStream_t stream) {
|
||||
cudaMemsetAsync(output, 0, sizeof(T) * output_size);
|
||||
ExtractVolumePatches<<<GET_BLOCKS(output_size / (w_stride * input_channel)), GET_THREADS, 0, stream>>>(
|
||||
output_size, stride_dep, stride_row, stride_col, output_depth, output_height, output_width, need_batch, d_stride,
|
||||
h_stride, w_stride, patch_stride, other_stride, input_channel, input_dep_size, input_row_size, input_col_size,
|
||||
pad_head, pad_top, pad_left, chan_input_stride, dep_input_stride, row_input_stride, patch_input_stride, input,
|
||||
output);
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void CalExtractVolumePatches<double>(
|
||||
size_t output_size, int64_t stride_dep, int64_t stride_row, int64_t stride_col, int64_t output_depth,
|
||||
int64_t output_height, int64_t output_width, bool need_batch, int64_t d_stride, int64_t h_stride, int64_t w_stride,
|
||||
int64_t patch_stride, int64_t other_stride, int64_t input_channel, int64_t input_dep_size, int64_t input_row_size,
|
||||
int64_t input_col_size, int64_t pad_head, int64_t pad_top, int64_t pad_left, int64_t chan_input_stride,
|
||||
int64_t dep_input_stride, int64_t row_input_stride, int64_t patch_input_stride, const double *input, double *output,
|
||||
cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void CalExtractVolumePatches<float>(
|
||||
size_t output_size, int64_t stride_dep, int64_t stride_row, int64_t stride_col, int64_t output_depth,
|
||||
int64_t output_height, int64_t output_width, bool need_batch, int64_t d_stride, int64_t h_stride, int64_t w_stride,
|
||||
int64_t patch_stride, int64_t other_stride, int64_t input_channel, int64_t input_dep_size, int64_t input_row_size,
|
||||
int64_t input_col_size, int64_t pad_head, int64_t pad_top, int64_t pad_left, int64_t chan_input_stride,
|
||||
int64_t dep_input_stride, int64_t row_input_stride, int64_t patch_input_stride, const float *input, float *output,
|
||||
cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void CalExtractVolumePatches<half>(
|
||||
size_t output_size, int64_t stride_dep, int64_t stride_row, int64_t stride_col, int64_t output_depth,
|
||||
int64_t output_height, int64_t output_width, bool need_batch, int64_t d_stride, int64_t h_stride, int64_t w_stride,
|
||||
int64_t patch_stride, int64_t other_stride, int64_t input_channel, int64_t input_dep_size, int64_t input_row_size,
|
||||
int64_t input_col_size, int64_t pad_head, int64_t pad_top, int64_t pad_left, int64_t chan_input_stride,
|
||||
int64_t dep_input_stride, int64_t row_input_stride, int64_t patch_input_stride, const half *input, half *output,
|
||||
cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void CalExtractVolumePatches<int64_t>(
|
||||
size_t output_size, int64_t stride_dep, int64_t stride_row, int64_t stride_col, int64_t output_depth,
|
||||
int64_t output_height, int64_t output_width, bool need_batch, int64_t d_stride, int64_t h_stride, int64_t w_stride,
|
||||
int64_t patch_stride, int64_t other_stride, int64_t input_channel, int64_t input_dep_size, int64_t input_row_size,
|
||||
int64_t input_col_size, int64_t pad_head, int64_t pad_top, int64_t pad_left, int64_t chan_input_stride,
|
||||
int64_t dep_input_stride, int64_t row_input_stride, int64_t patch_input_stride, const int64_t *input, int64_t *output,
|
||||
cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void CalExtractVolumePatches<int32_t>(
|
||||
size_t output_size, int64_t stride_dep, int64_t stride_row, int64_t stride_col, int64_t output_depth,
|
||||
int64_t output_height, int64_t output_width, bool need_batch, int64_t d_stride, int64_t h_stride, int64_t w_stride,
|
||||
int64_t patch_stride, int64_t other_stride, int64_t input_channel, int64_t input_dep_size, int64_t input_row_size,
|
||||
int64_t input_col_size, int64_t pad_head, int64_t pad_top, int64_t pad_left, int64_t chan_input_stride,
|
||||
int64_t dep_input_stride, int64_t row_input_stride, int64_t patch_input_stride, const int32_t *input, int32_t *output,
|
||||
cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void CalExtractVolumePatches<int16_t>(
|
||||
size_t output_size, int64_t stride_dep, int64_t stride_row, int64_t stride_col, int64_t output_depth,
|
||||
int64_t output_height, int64_t output_width, bool need_batch, int64_t d_stride, int64_t h_stride, int64_t w_stride,
|
||||
int64_t patch_stride, int64_t other_stride, int64_t input_channel, int64_t input_dep_size, int64_t input_row_size,
|
||||
int64_t input_col_size, int64_t pad_head, int64_t pad_top, int64_t pad_left, int64_t chan_input_stride,
|
||||
int64_t dep_input_stride, int64_t row_input_stride, int64_t patch_input_stride, const int16_t *input, int16_t *output,
|
||||
cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void CalExtractVolumePatches<int8_t>(
|
||||
size_t output_size, int64_t stride_dep, int64_t stride_row, int64_t stride_col, int64_t output_depth,
|
||||
int64_t output_height, int64_t output_width, bool need_batch, int64_t d_stride, int64_t h_stride, int64_t w_stride,
|
||||
int64_t patch_stride, int64_t other_stride, int64_t input_channel, int64_t input_dep_size, int64_t input_row_size,
|
||||
int64_t input_col_size, int64_t pad_head, int64_t pad_top, int64_t pad_left, int64_t chan_input_stride,
|
||||
int64_t dep_input_stride, int64_t row_input_stride, int64_t patch_input_stride, const int8_t *input, int8_t *output,
|
||||
cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void CalExtractVolumePatches<uint64_t>(
|
||||
size_t output_size, int64_t stride_dep, int64_t stride_row, int64_t stride_col, int64_t output_depth,
|
||||
int64_t output_height, int64_t output_width, bool need_batch, int64_t d_stride, int64_t h_stride, int64_t w_stride,
|
||||
int64_t patch_stride, int64_t other_stride, int64_t input_channel, int64_t input_dep_size, int64_t input_row_size,
|
||||
int64_t input_col_size, int64_t pad_head, int64_t pad_top, int64_t pad_left, int64_t chan_input_stride,
|
||||
int64_t dep_input_stride, int64_t row_input_stride, int64_t patch_input_stride, const uint64_t *input,
|
||||
uint64_t *output, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void CalExtractVolumePatches<uint32_t>(
|
||||
size_t output_size, int64_t stride_dep, int64_t stride_row, int64_t stride_col, int64_t output_depth,
|
||||
int64_t output_height, int64_t output_width, bool need_batch, int64_t d_stride, int64_t h_stride, int64_t w_stride,
|
||||
int64_t patch_stride, int64_t other_stride, int64_t input_channel, int64_t input_dep_size, int64_t input_row_size,
|
||||
int64_t input_col_size, int64_t pad_head, int64_t pad_top, int64_t pad_left, int64_t chan_input_stride,
|
||||
int64_t dep_input_stride, int64_t row_input_stride, int64_t patch_input_stride, const uint32_t *input,
|
||||
uint32_t *output, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void CalExtractVolumePatches<uint16_t>(
|
||||
size_t output_size, int64_t stride_dep, int64_t stride_row, int64_t stride_col, int64_t output_depth,
|
||||
int64_t output_height, int64_t output_width, bool need_batch, int64_t d_stride, int64_t h_stride, int64_t w_stride,
|
||||
int64_t patch_stride, int64_t other_stride, int64_t input_channel, int64_t input_dep_size, int64_t input_row_size,
|
||||
int64_t input_col_size, int64_t pad_head, int64_t pad_top, int64_t pad_left, int64_t chan_input_stride,
|
||||
int64_t dep_input_stride, int64_t row_input_stride, int64_t patch_input_stride, const uint16_t *input,
|
||||
uint16_t *output, cudaStream_t stream);
|
||||
template CUDA_LIB_EXPORT void CalExtractVolumePatches<uint8_t>(
|
||||
size_t output_size, int64_t stride_dep, int64_t stride_row, int64_t stride_col, int64_t output_depth,
|
||||
int64_t output_height, int64_t output_width, bool need_batch, int64_t d_stride, int64_t h_stride, int64_t w_stride,
|
||||
int64_t patch_stride, int64_t other_stride, int64_t input_channel, int64_t input_dep_size, int64_t input_row_size,
|
||||
int64_t input_col_size, int64_t pad_head, int64_t pad_top, int64_t pad_left, int64_t chan_input_stride,
|
||||
int64_t dep_input_stride, int64_t row_input_stride, int64_t patch_input_stride, const uint8_t *input, uint8_t *output,
|
||||
cudaStream_t stream);
|
|
@ -0,0 +1,32 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_EXTRACT_VOLUME_PATCHES_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_EXTRACT_VOLUME_PATCHES_IMPL_CUH_
|
||||
#include <vector>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void CalExtractVolumePatches(size_t output_size, int64_t stride_dep, int64_t stride_row,
|
||||
int64_t stride_col, int64_t output_depth, int64_t output_height,
|
||||
int64_t output_width, bool need_batch, int64_t d_stride, int64_t h_stride,
|
||||
int64_t w_stride, int64_t patch_stride, int64_t other_stride,
|
||||
int64_t input_channel, int64_t input_dep_size, int64_t input_row_size,
|
||||
int64_t input_col_size, int64_t pad_head, int64_t pad_top,
|
||||
int64_t pad_left, int64_t chan_input_stride, int64_t dep_input_stride,
|
||||
int64_t row_input_stride, int64_t patch_input_stride, const T *input,
|
||||
T *output, cudaStream_t stream);
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_EXTRACT_VOLUME_PATCHES_IMPL_CUH_
|
|
@ -30,7 +30,6 @@ constexpr size_t kIdx4 = 4;
|
|||
abstract::ShapePtr ExtractVolumePatchesInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
const int MAX_SHAPE = 2048;
|
||||
const int d = 2;
|
||||
const int w = 4;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, 1, primitive->name());
|
||||
|
@ -45,8 +44,6 @@ abstract::ShapePtr ExtractVolumePatchesInferShape(const PrimitivePtr &primitive,
|
|||
constexpr int64_t shape_size = 5;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input shape", SizeToLong(x_shape.size()), kEqual, shape_size,
|
||||
primitive->name());
|
||||
auto x_v = x_shape[kIdx2] * x_shape[kIdx3] * x_shape[kIdx4];
|
||||
(void)CheckAndConvertUtils::CheckInteger("x_d * x_h * x_w", x_v, kLessEqual, MAX_SHAPE, primitive->name());
|
||||
std::vector<int64_t> kernel_size = GetValue<std::vector<int64_t>>(primitive->GetAttr(kKernelSize));
|
||||
std::vector<int64_t> strides = GetValue<std::vector<int64_t>>(primitive->GetAttr(kStrides));
|
||||
(void)CheckAndConvertUtils::CheckInteger("kernel_size_length", SizeToLong(kernel_size.size()), kEqual, shape_size,
|
||||
|
|
|
@ -6705,7 +6705,7 @@ class ExtractVolumePatches(Primitive):
|
|||
ValueError: If x_d * x_h * x_w is greater than 2048.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> kernel_size = (1, 1, 2, 2, 2)
|
||||
|
|
|
@ -0,0 +1,109 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
import mindspore.ops.operations.array_ops as ops
|
||||
|
||||
|
||||
class NetExtractVolumePatches(nn.Cell):
|
||||
def __init__(self, kernel_size, strides, padding="valid"):
|
||||
super(NetExtractVolumePatches, self).__init__()
|
||||
self.extractvolumepatches = ops.ExtractVolumePatches(
|
||||
kernel_size, strides, padding)
|
||||
|
||||
def construct(self, x):
|
||||
return self.extractvolumepatches(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_extractvolumepatches_graph():
|
||||
"""
|
||||
Feature: extractvolumepatches
|
||||
Description: Test of input
|
||||
Expectation: The results are as expected
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
types = [np.float16, np.float32, np.float64, np.int8, np.int16,
|
||||
np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64]
|
||||
for type_i in types:
|
||||
x = Tensor(np.ones([1, 1, 3, 3, 3]).astype(type_i))
|
||||
extractvolumepatches = NetExtractVolumePatches(
|
||||
[1, 1, 2, 2, 2], [1, 1, 1, 1, 1], "VALID")
|
||||
output = extractvolumepatches(x).transpose(0, 2, 3, 4, 1)
|
||||
expect_output = np.array([[[[[1., 1., 1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1., 1., 1., 1.]],
|
||||
[[1., 1., 1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1., 1., 1., 1.]]],
|
||||
[[[1., 1., 1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1., 1., 1., 1.]],
|
||||
[[1., 1., 1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1., 1., 1., 1.]]]]]).astype(type_i)
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
||||
assert output.shape == expect_output.shape
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_extractvolumepatches_pynative():
|
||||
"""
|
||||
Feature: extractvolumepatches
|
||||
Description: Test of input
|
||||
Expectation: The results are as expected
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
types = [np.float16, np.float32, np.float64, np.int8, np.int16,
|
||||
np.int32, np.int64, np.uint8, np.uint16, np.uint32, np.uint64]
|
||||
for type_i in types:
|
||||
x = Tensor(np.ones([1, 1, 3, 3, 3]).astype(type_i))
|
||||
extractvolumepatches = NetExtractVolumePatches(
|
||||
[1, 1, 2, 2, 2], [1, 1, 1, 1, 1], "SAME")
|
||||
output = extractvolumepatches(x).transpose(0, 2, 3, 4, 1)
|
||||
expect_output = np.array([[[[[1., 1., 1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1., 1., 1., 1.],
|
||||
[1., 0., 1., 0., 1., 0., 1., 0.]],
|
||||
[[1., 1., 1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1., 1., 1., 1.],
|
||||
[1., 0., 1., 0., 1., 0., 1., 0.]],
|
||||
[[1., 1., 0., 0., 1., 1., 0., 0.],
|
||||
[1., 1., 0., 0., 1., 1., 0., 0.],
|
||||
[1., 0., 0., 0., 1., 0., 0., 0.]]],
|
||||
[[[1., 1., 1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1., 1., 1., 1.],
|
||||
[1., 0., 1., 0., 1., 0., 1., 0.]],
|
||||
[[1., 1., 1., 1., 1., 1., 1., 1.],
|
||||
[1., 1., 1., 1., 1., 1., 1., 1.],
|
||||
[1., 0., 1., 0., 1., 0., 1., 0.]],
|
||||
[[1., 1., 0., 0., 1., 1., 0., 0.],
|
||||
[1., 1., 0., 0., 1., 1., 0., 0.],
|
||||
[1., 0., 0., 0., 1., 0., 0., 0.]]],
|
||||
[[[1., 1., 1., 1., 0., 0., 0., 0.],
|
||||
[1., 1., 1., 1., 0., 0., 0., 0.],
|
||||
[1., 0., 1., 0., 0., 0., 0., 0.]],
|
||||
[[1., 1., 1., 1., 0., 0., 0., 0.],
|
||||
[1., 1., 1., 1., 0., 0., 0., 0.],
|
||||
[1., 0., 1., 0., 0., 0., 0., 0.]],
|
||||
[[1., 1., 0., 0., 0., 0., 0., 0.],
|
||||
[1., 1., 0., 0., 0., 0., 0., 0.],
|
||||
[1., 0., 0., 0., 0., 0., 0., 0.]]]]]).astype(type_i)
|
||||
assert np.allclose(output.asnumpy(), expect_output)
|
||||
assert output.shape == expect_output.shape
|
Loading…
Reference in New Issue