forked from mindspore-Ecosystem/mindspore
add ExtractImagePatches GPU
This commit is contained in:
parent
f22e0522fe
commit
b9fa6641e7
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_EMBEDDING_LOOKUP_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_EMBEDDING_LOOKUP_GPU_KERNEL_H_
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_EMBEDDING_LOOKUP_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_EMBEDDING_LOOKUP_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
|
@ -145,4 +145,4 @@ class EmbeddingLookupKernel : public GpuKernel {
|
|||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_EMBEDDING_LOOKUP_GPU_KERNEL_H_
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_EMBEDDING_LOOKUP_GPU_KERNEL_H_
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/arrays/extract_image_patches_gpu_kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(ExtractImagePatches,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
ExtractImagePatchesKernel, double)
|
||||
MS_REG_GPU_KERNEL_ONE(ExtractImagePatches,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ExtractImagePatchesKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(ExtractImagePatches,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
ExtractImagePatchesKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(ExtractImagePatches, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ExtractImagePatchesKernel, int)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,233 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_EXTRACT_IMAGE_PATCHES_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_EXTRACT_IMAGE_PATCHES_GPU_KERNEL_H_
|
||||
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
|
||||
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/transpose_impl_opt.cuh"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/extract_image_patches_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class ExtractImagePatchesKernel : public GpuKernel {
|
||||
public:
|
||||
ExtractImagePatchesKernel() { ResetResource(); }
|
||||
~ExtractImagePatchesKernel() override = default;
|
||||
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
T *t_input = GetDeviceAddress<T>(workspace, 0);
|
||||
T *t_output = GetDeviceAddress<T>(workspace, 1);
|
||||
size_t *input_shape = GetDeviceAddress<size_t>(workspace, 2);
|
||||
size_t *input_to_nhwc_axis = GetDeviceAddress<size_t>(workspace, 3);
|
||||
size_t *t_output_shape = GetDeviceAddress<size_t>(workspace, 4);
|
||||
size_t *t_output_to_nchw_axis = GetDeviceAddress<size_t>(workspace, 5);
|
||||
|
||||
size_t shape_size = 4 * sizeof(size_t);
|
||||
std::vector<size_t> to_nhwc_axis = {0, 2, 3, 1};
|
||||
std::vector<size_t> to_nchw_axis = {0, 3, 1, 2};
|
||||
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(input_shape, &input_shape_[0], shape_size, cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync input_shape_ failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(input_to_nhwc_axis, &to_nhwc_axis[0], shape_size, cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync to_nhwc_axis failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(t_output_shape, &t_output_shape_[0], shape_size, cudaMemcpyHostToDevice,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync t_output_shape_ failed");
|
||||
CHECK_CUDA_RET_WITH_EXCEPT(kernel_node_,
|
||||
cudaMemcpyAsync(t_output_to_nchw_axis, &to_nchw_axis[0], shape_size,
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream_ptr)),
|
||||
"cudaMemcpyAsync to_nchw_axis failed");
|
||||
CalNCHW2NHWCInterface(input_size_, shape_size / sizeof(size_t), input, &input_shape_[0], &to_nhwc_axis[0],
|
||||
input_shape, input_to_nhwc_axis, t_input, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalExtractImagePatchesNHWC(output_size_, stride_row_, stride_col_, rate_row_, rate_col_, output_cols_, need_batch_,
|
||||
row_stride_, patch_stride_, other_stride_, input_row_size_, input_col_size_,
|
||||
row_padding_top_, col_padding_left_, col_input_stride_, row_input_stride_,
|
||||
patch_input_stride_, output_depth_, t_input, t_output,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
CalNHWC2NCHWInterface(output_size_, shape_size / sizeof(size_t), t_output, &t_output_shape_[0], &to_nchw_axis[0],
|
||||
t_output_shape, t_output_to_nchw_axis, output, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
kernel_node_ = kernel_node;
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but ExtractImagePatches needs 1 inputs.";
|
||||
return false;
|
||||
}
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but ExtractImagePatches has 1 output.";
|
||||
return false;
|
||||
}
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
input_size_ = 1;
|
||||
for (size_t i = 0; i < input_shape.size(); i++) {
|
||||
input_size_ *= input_shape[i];
|
||||
input_shape_.push_back(input_shape[i]);
|
||||
}
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
output_size_ = 1;
|
||||
for (size_t i = 0; i < output_shape.size(); i++) {
|
||||
output_size_ *= output_shape[i];
|
||||
}
|
||||
// transposed NHWC shape
|
||||
t_output_shape_ = {output_shape[0], output_shape[2], output_shape[3], output_shape[1]};
|
||||
|
||||
auto padding = GetAttr<std::string>(kernel_node, "padding");
|
||||
auto ksizes = GetAttr<std::vector<int64_t>>(kernel_node, "ksizes");
|
||||
auto strides = GetAttr<std::vector<int64_t>>(kernel_node, "strides");
|
||||
auto rates = GetAttr<std::vector<int64_t>>(kernel_node, "rates");
|
||||
|
||||
ksize_row_ = ksizes[2];
|
||||
ksize_col_ = ksizes[3];
|
||||
stride_row_ = strides[2];
|
||||
stride_col_ = strides[3];
|
||||
rate_row_ = rates[2];
|
||||
rate_col_ = rates[3];
|
||||
|
||||
// transposed NHWC shape
|
||||
std::vector<size_t> t_input_shape = {input_shape_[0], input_shape_[2], input_shape_[3], input_shape_[1]};
|
||||
|
||||
int64_t input_depth = static_cast<int64_t>(t_input_shape[3]);
|
||||
input_col_size_ = static_cast<int64_t>(t_input_shape[2]);
|
||||
input_row_size_ = static_cast<int64_t>(t_input_shape[1]);
|
||||
|
||||
int64_t patch_rows_eff = ksize_row_ + (ksize_row_ - 1) * (rate_row_ - 1);
|
||||
int64_t patch_cols_eff = ksize_col_ + (ksize_col_ - 1) * (rate_col_ - 1);
|
||||
|
||||
if (padding == "VALID") {
|
||||
output_rows_ = std::ceil((input_row_size_ - patch_rows_eff + 1.f) / static_cast<float>(stride_row_));
|
||||
output_cols_ = std::ceil((input_col_size_ - patch_cols_eff + 1.f) / static_cast<float>(stride_col_));
|
||||
row_padding_top_ = std::max(0l, ((output_rows_ - 1) * stride_row_ + patch_rows_eff - input_row_size_) / 2);
|
||||
col_padding_left_ = std::max(0l, ((output_cols_ - 1) * stride_col_ + patch_cols_eff - input_col_size_) / 2);
|
||||
} else if (padding == "SAME") {
|
||||
output_rows_ = std::ceil(input_row_size_ / static_cast<float>(stride_row_));
|
||||
output_cols_ = std::ceil(input_col_size_ / static_cast<float>(stride_col_));
|
||||
row_padding_top_ = ((output_rows_ - 1) * stride_row_ + patch_rows_eff - input_row_size_) / 2;
|
||||
col_padding_left_ = ((output_cols_ - 1) * stride_col_ + patch_cols_eff - input_col_size_) / 2;
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid padding value: " << padding << ".";
|
||||
}
|
||||
|
||||
row_stride_ = ksize_col_;
|
||||
patch_stride_ = row_stride_ * ksize_row_ * input_depth;
|
||||
other_stride_ = patch_stride_ * output_rows_ * output_cols_;
|
||||
col_input_stride_ = input_depth;
|
||||
row_input_stride_ = input_depth * input_col_size_;
|
||||
patch_input_stride_ = input_depth * input_col_size_ * input_row_size_;
|
||||
output_depth_ = input_depth;
|
||||
need_batch_ = (output_size_ - 1) / other_stride_;
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
input_size_ = 1;
|
||||
output_size_ = 1;
|
||||
ksize_row_ = 1;
|
||||
ksize_col_ = 1;
|
||||
stride_row_ = 1;
|
||||
stride_col_ = 1;
|
||||
rate_row_ = 1;
|
||||
rate_col_ = 1;
|
||||
output_rows_ = 1;
|
||||
output_cols_ = 1;
|
||||
need_batch_ = 1;
|
||||
row_stride_ = 1;
|
||||
patch_stride_ = 1;
|
||||
other_stride_ = 1;
|
||||
input_row_size_ = 1;
|
||||
input_col_size_ = 1;
|
||||
row_padding_top_ = 1;
|
||||
col_padding_left_ = 1;
|
||||
col_input_stride_ = 1;
|
||||
row_input_stride_ = 1;
|
||||
patch_input_stride_ = 1;
|
||||
output_depth_ = 1;
|
||||
input_shape_.clear();
|
||||
t_output_shape_.clear();
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_ * sizeof(T));
|
||||
output_size_list_.push_back(output_size_ * sizeof(T));
|
||||
workspace_size_list_.push_back(input_size_ * sizeof(T));
|
||||
workspace_size_list_.push_back(output_size_ * sizeof(T));
|
||||
workspace_size_list_.push_back(4 * sizeof(size_t));
|
||||
workspace_size_list_.push_back(4 * sizeof(size_t));
|
||||
workspace_size_list_.push_back(4 * sizeof(size_t));
|
||||
workspace_size_list_.push_back(4 * sizeof(size_t));
|
||||
}
|
||||
|
||||
private:
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
int64_t ksize_row_;
|
||||
int64_t ksize_col_;
|
||||
int64_t stride_row_;
|
||||
int64_t stride_col_;
|
||||
int64_t rate_row_;
|
||||
int64_t rate_col_;
|
||||
int64_t output_rows_;
|
||||
int64_t output_cols_;
|
||||
bool need_batch_;
|
||||
int64_t row_stride_;
|
||||
int64_t patch_stride_;
|
||||
int64_t other_stride_;
|
||||
int64_t input_row_size_;
|
||||
int64_t input_col_size_;
|
||||
int64_t row_padding_top_;
|
||||
int64_t col_padding_left_;
|
||||
int64_t col_input_stride_;
|
||||
int64_t row_input_stride_;
|
||||
int64_t patch_input_stride_;
|
||||
int64_t output_depth_;
|
||||
std::vector<size_t> input_shape_;
|
||||
std::vector<size_t> t_output_shape_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_ARRAYS_EXTRACT_IMAGE_PATCHES_GPU_KERNEL_H_
|
|
@ -0,0 +1,102 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/extract_image_patches_impl.cuh"
|
||||
|
||||
template <typename T>
|
||||
__global__ void ExtractImagePatches(size_t output_size, int64_t stride_row, int64_t stride_col, int64_t rate_row,
|
||||
int64_t rate_col, int64_t output_cols, bool need_batch, int64_t row_stride,
|
||||
int64_t patch_stride, int64_t other_stride, int64_t input_row_size,
|
||||
int64_t input_col_size, int64_t row_padding_top, int64_t col_padding_left,
|
||||
int64_t col_input_stride, int64_t row_input_stride, int64_t patch_input_stride,
|
||||
int64_t output_depth, const T *input, T *output) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_size; pos += blockDim.x * gridDim.x) {
|
||||
const int64_t batch_index = need_batch ? (static_cast<int64_t>(pos) / other_stride) : 0;
|
||||
const int64_t inner_index =
|
||||
need_batch ? (static_cast<int64_t>(pos) - batch_index * other_stride) : static_cast<int64_t>(pos);
|
||||
// inner index
|
||||
const int64_t patch_index = inner_index / patch_stride;
|
||||
const int64_t patch_offset = (inner_index - patch_index * patch_stride) / output_depth;
|
||||
// row
|
||||
const int64_t row_index = patch_index / output_cols;
|
||||
const int64_t row_offset = patch_offset / row_stride;
|
||||
const int64_t input_row = row_index * stride_row + row_offset * rate_row - row_padding_top;
|
||||
if (input_row < 0 || input_row >= input_row_size) {
|
||||
output[pos] = static_cast<T>(0);
|
||||
return;
|
||||
}
|
||||
// col
|
||||
const int64_t col_index = patch_index - row_index * output_cols;
|
||||
const int64_t col_offset = patch_offset - row_offset * row_stride;
|
||||
const int64_t input_col = col_index * stride_col + col_offset * rate_col - col_padding_left;
|
||||
if (input_col < 0 || input_col >= input_col_size) {
|
||||
output[pos] = static_cast<T>(0);
|
||||
return;
|
||||
}
|
||||
// depth
|
||||
const int64_t depth = inner_index - (inner_index / output_depth) * output_depth;
|
||||
// input index
|
||||
const int64_t input_index =
|
||||
depth + input_col * col_input_stride + input_row * row_input_stride + batch_index * patch_input_stride;
|
||||
output[pos] = input[static_cast<size_t>(input_index)];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalExtractImagePatchesNHWC(size_t output_size, int64_t stride_row, int64_t stride_col, int64_t rate_row,
|
||||
int64_t rate_col, int64_t output_cols, bool need_batch, int64_t row_stride,
|
||||
int64_t patch_stride, int64_t other_stride, int64_t input_row_size,
|
||||
int64_t input_col_size, int64_t row_padding_top, int64_t col_padding_left,
|
||||
int64_t col_input_stride, int64_t row_input_stride, int64_t patch_input_stride,
|
||||
int64_t output_depth, const T *input, T *output, cudaStream_t stream) {
|
||||
ExtractImagePatches<<<GET_BLOCKS(output_size), GET_THREADS, 0, stream>>>(
|
||||
output_size, stride_row, stride_col, rate_row, rate_col, output_cols, need_batch, row_stride, patch_stride,
|
||||
other_stride, input_row_size, input_col_size, row_padding_top, col_padding_left, col_input_stride, row_input_stride,
|
||||
patch_input_stride, output_depth, input, output);
|
||||
}
|
||||
|
||||
template void CalExtractImagePatchesNHWC<int>(size_t output_size, int64_t stride_row, int64_t stride_col,
|
||||
int64_t rate_row, int64_t rate_col, int64_t output_cols, bool need_batch,
|
||||
int64_t row_stride, int64_t patch_stride, int64_t other_stride,
|
||||
int64_t input_row_size, int64_t input_col_size, int64_t row_padding_top,
|
||||
int64_t col_padding_left, int64_t col_input_stride,
|
||||
int64_t row_input_stride, int64_t patch_input_stride,
|
||||
int64_t output_depth, const int *input, int *output, cudaStream_t stream);
|
||||
template void CalExtractImagePatchesNHWC<float>(size_t output_size, int64_t stride_row, int64_t stride_col,
|
||||
int64_t rate_row, int64_t rate_col, int64_t output_cols,
|
||||
bool need_batch, int64_t row_stride, int64_t patch_stride,
|
||||
int64_t other_stride, int64_t input_row_size, int64_t input_col_size,
|
||||
int64_t row_padding_top, int64_t col_padding_left,
|
||||
int64_t col_input_stride, int64_t row_input_stride,
|
||||
int64_t patch_input_stride, int64_t output_depth, const float *input,
|
||||
float *output, cudaStream_t stream);
|
||||
template void CalExtractImagePatchesNHWC<half>(size_t output_size, int64_t stride_row, int64_t stride_col,
|
||||
int64_t rate_row, int64_t rate_col, int64_t output_cols, bool need_batch,
|
||||
int64_t row_stride, int64_t patch_stride, int64_t other_stride,
|
||||
int64_t input_row_size, int64_t input_col_size, int64_t row_padding_top,
|
||||
int64_t col_padding_left, int64_t col_input_stride,
|
||||
int64_t row_input_stride, int64_t patch_input_stride,
|
||||
int64_t output_depth, const half *input, half *output,
|
||||
cudaStream_t stream);
|
||||
template void CalExtractImagePatchesNHWC<double>(size_t output_size, int64_t stride_row, int64_t stride_col,
|
||||
int64_t rate_row, int64_t rate_col, int64_t output_cols,
|
||||
bool need_batch, int64_t row_stride, int64_t patch_stride,
|
||||
int64_t other_stride, int64_t input_row_size, int64_t input_col_size,
|
||||
int64_t row_padding_top, int64_t col_padding_left,
|
||||
int64_t col_input_stride, int64_t row_input_stride,
|
||||
int64_t patch_input_stride, int64_t output_depth, const double *input,
|
||||
double *output, cudaStream_t stream);
|
|
@ -0,0 +1,31 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_EXTRACT_IMAGE_PATCHES_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_EXTRACT_IMAGE_PATCHES_IMPL_CUH_
|
||||
|
||||
#include <vector>
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
|
||||
template <typename T>
|
||||
void CalExtractImagePatchesNHWC(size_t output_size, int64_t stride_row, int64_t stride_col, int64_t rate_row,
|
||||
int64_t rate_col, int64_t output_cols, bool need_batch, int64_t row_stride,
|
||||
int64_t patch_stride, int64_t other_stride, int64_t input_row_size,
|
||||
int64_t input_col_size, int64_t row_padding_top, int64_t col_padding_left,
|
||||
int64_t col_input_stride, int64_t row_input_stride, int64_t patch_input_stride,
|
||||
int64_t output_depth, const T *input, T *output, cudaStream_t stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_EXTRACT_IMAGE_PATCHES_IMPL_CUH_
|
|
@ -60,7 +60,7 @@ class ExtractImagePatches(PrimitiveWithInfer):
|
|||
"""init"""
|
||||
|
||||
def _check_tuple_or_list(arg_name, arg_val, prim_name):
|
||||
validator.check_value_type(f"{arg_name}s", ksizes, [tuple, list], self.name)
|
||||
validator.check_value_type(f"{arg_name}s", arg_val, [tuple, list], self.name)
|
||||
if len(arg_val) != 4 or arg_val[0] != 1 or arg_val[1] != 1:
|
||||
raise ValueError(f"For \'{prim_name}\' the format of {arg_name}s should be [1, {arg_name}_row, "
|
||||
f"{arg_name}_col, 1], but got {arg_val}.")
|
||||
|
@ -99,6 +99,11 @@ class ExtractImagePatches(PrimitiveWithInfer):
|
|||
out_col = (in_col - 1) // stride_col + 1
|
||||
|
||||
out_shape = [out_batch, out_depth, out_row, out_col]
|
||||
# avoiding empty outputs
|
||||
validator.check("out_batch", out_batch, "", 0, Rel.GT, self.name)
|
||||
validator.check("out_depth", out_depth, "", 0, Rel.GT, self.name)
|
||||
validator.check("out_row", out_row, "", 0, Rel.GT, self.name)
|
||||
validator.check("out_col", out_col, "", 0, Rel.GT, self.name)
|
||||
return out_shape
|
||||
|
||||
def infer_dtype(self, input_x):
|
||||
|
|
|
@ -0,0 +1,83 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, ksizes, strides, rates, padding="valid"):
|
||||
super(Net, self).__init__()
|
||||
self.extractimagepatches = inner.ExtractImagePatches(ksizes, strides, rates, padding)
|
||||
|
||||
def construct(self, input_tensor):
|
||||
return self.extractimagepatches(input_tensor)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_extract_image_patches_valid():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net = Net([1, 1, 2, 4], [1, 1, 7, 5], [1, 1, 2, 1], "valid")
|
||||
input_tensor = Tensor(np.arange(360).reshape(3, 2, 6, 10).astype(np.float32))
|
||||
output = net(input_tensor)
|
||||
expect = np.array([0., 5., 60., 65., 1., 6., 61., 66., 2., 7., 62., 67., 3., 8.,
|
||||
63., 68., 20., 25., 80., 85., 21., 26., 81., 86., 22., 27., 82., 87.,
|
||||
23., 28., 83., 88., 120., 125., 180., 185., 121., 126., 181., 186., 122., 127.,
|
||||
182., 187., 123., 128., 183., 188., 140., 145., 200., 205., 141., 146., 201., 206.,
|
||||
142., 147., 202., 207., 143., 148., 203., 208., 240., 245., 300., 305., 241., 246.,
|
||||
301., 306., 242., 247., 302., 307., 243., 248., 303., 308., 260., 265., 320., 325.,
|
||||
261., 266., 321., 326., 262., 267., 322., 327., 263., 268., 323., 328.]).reshape((3, 16, 1, 2))
|
||||
assert np.all(output.asnumpy() == expect)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
net = Net([1, 1, 2, 4], [1, 1, 7, 5], [1, 1, 2, 1], "valid")
|
||||
input_tensor = Tensor(np.arange(360).reshape(3, 2, 6, 10).astype(np.float32))
|
||||
output = net(input_tensor)
|
||||
expect = np.array([0., 5., 60., 65., 1., 6., 61., 66., 2., 7., 62., 67., 3., 8.,
|
||||
63., 68., 20., 25., 80., 85., 21., 26., 81., 86., 22., 27., 82., 87.,
|
||||
23., 28., 83., 88., 120., 125., 180., 185., 121., 126., 181., 186., 122., 127.,
|
||||
182., 187., 123., 128., 183., 188., 140., 145., 200., 205., 141., 146., 201., 206.,
|
||||
142., 147., 202., 207., 143., 148., 203., 208., 240., 245., 300., 305., 241., 246.,
|
||||
301., 306., 242., 247., 302., 307., 243., 248., 303., 308., 260., 265., 320., 325.,
|
||||
261., 266., 321., 326., 262., 267., 322., 327., 263., 268., 323., 328.]).reshape((3, 16, 1, 2))
|
||||
assert np.all(output.asnumpy() == expect)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_extract_image_patches_same():
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net = Net([1, 1, 5, 2], [1, 1, 8, 2], [1, 1, 3, 3], "same")
|
||||
input_tensor = Tensor(np.arange(6).reshape(1, 1, 2, 3).astype(np.float32))
|
||||
output = net(input_tensor)
|
||||
expect = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 4., 5., 0., 0.,
|
||||
0., 0., 0., 0., 0., 0., 0.]).reshape((1, 10, 1, 2))
|
||||
assert np.all(output.asnumpy() == expect)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
net = Net([1, 1, 5, 2], [1, 1, 8, 2], [1, 1, 3, 3], "same")
|
||||
input_tensor = Tensor(np.arange(6).reshape(1, 1, 2, 3).astype(np.float32))
|
||||
output = net(input_tensor)
|
||||
expect = np.array([0., 0., 0., 0., 0., 0., 0., 0., 0., 4., 5., 0., 0., 0.,
|
||||
0., 0., 0., 0., 0., 0.]).reshape((1, 10, 1, 2))
|
||||
assert np.all(output.asnumpy() == expect)
|
Loading…
Reference in New Issue