forked from mindspore-Ecosystem/mindspore
add gpu Col2Im impl and tensor & function op
This commit is contained in:
parent
0f27520ddb
commit
404816a5f4
|
@ -301,6 +301,7 @@ Array操作
|
|||
|
||||
mindspore.ops.broadcast_to
|
||||
mindspore.ops.adaptive_max_pool2d
|
||||
mindspore.ops.col2im
|
||||
mindspore.ops.diag
|
||||
mindspore.ops.expand_dims
|
||||
mindspore.ops.gather
|
||||
|
|
|
@ -76,6 +76,28 @@ mindspore.Tensor
|
|||
|
||||
Tensor。如果在指定轴方向上所有Tensor元素都为True,则其值为True,否则其值为False。如果轴为None或空元组,则默认降维。
|
||||
|
||||
.. py:method:: col2im(output_size, kernel_size, dilation, padding_value, stride)
|
||||
|
||||
将一组滑动的局部块组合成一个大张量。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **output_size** (Tensor) - 输出张量的后两维的shape。
|
||||
- **kernel_size** (Union[int, tuple[int], list[int]]) - 滑动窗口的大小。
|
||||
- **dilation** (Union[int, tuple[int], list[int]]) - 滑动窗口扩张的大小。
|
||||
- **padding_value** (Union[int, tuple[int], list[int]]) - 填充的大小。
|
||||
- **stride** (Union[int, tuple[int], list[int]]) - 步长的大小。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,输出的张量,维度和类型和输入一致。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - 如果 `kernel_size`,`dilation`,`padding_value`,`stride` 不属于 Union[int, tuple[int], list[int]]。
|
||||
- **ValueError** - 如果 `kernel_size`,`dilation`,`stride` 值小于等于0或者个数大于2。
|
||||
- **ValueError** - 如果 `padding_value` 值小于0或者个数大于2。
|
||||
|
||||
.. py:method:: argmax(axis=None)
|
||||
|
||||
返回指定轴上最大值的索引。
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
mindspore.ops.col2im
|
||||
====================
|
||||
|
||||
.. py:function:: mindspore.ops.col2im(input_x, output_size, kernel_size, dilation, padding_value, stride)
|
||||
|
||||
将一组滑动的局部块组合成一个大张量。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **input_x** (Tensor) - 输入的批量的滑动局部块,
|
||||
- **output_size** (Tensor) - 输出张量的后两维的shape。
|
||||
- **kernel_size** (Union[int, tuple[int], list[int]]) - 滑动窗口的大小。
|
||||
- **dilation** (Union[int, tuple[int], list[int]]) - 滑动窗口扩张的大小。
|
||||
- **padding_value** (Union[int, tuple[int], list[int]]) - 填充的大小。
|
||||
- **stride** (Union[int, tuple[int], list[int]]) - 步长的大小。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,输出的张量,维度和类型和输入一致。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - 如果 `kernel_size`,`dilation`,`padding_value`,`stride` 不属于 Union[int, tuple[int], list[int]]。
|
||||
- **ValueError** - 如果 `kernel_size`,`dilation`,`stride` 值小于等于0或者个数大于2。
|
||||
- **TypeError** - 如果 `padding_value` 值小于0或者个数大于2。
|
||||
- **TypeError** - 如果 input_x.dims(2) 不等于 kernel_size[0] * kernel_size[1]。
|
||||
- **TypeError** - 如果 input_x.dims(3) 与计算出的滑动块数量不匹配。
|
|
@ -300,6 +300,7 @@ Array Operation
|
|||
|
||||
mindspore.ops.adaptive_max_pool2d
|
||||
mindspore.ops.broadcast_to
|
||||
mindspore.ops.col2im
|
||||
mindspore.ops.diag
|
||||
mindspore.ops.expand_dims
|
||||
mindspore.ops.gather
|
||||
|
|
|
@ -255,6 +255,7 @@ BuiltInTypeMap &GetMethodMap() {
|
|||
{"pdist", std::string("pdist")}, // F.pdist()
|
||||
{"adaptive_avgpool2d", std::string("adaptive_avgpool2d")}, // P.AdaptiveAvgPool2D
|
||||
{"adaptive_max_pool2d", std::string("adaptive_max_pool2d")}, // P.AdaptiveMaxPool2D
|
||||
{"col2im", std::string("col2im")}, // P.Col2Im
|
||||
}},
|
||||
{kObjectTypeRowTensorType,
|
||||
{
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/col2im_impl.cuh"
|
||||
#include <stdio.h>
|
||||
#include <stdint.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <algorithm>
|
||||
#include "include/cuda_fp16.h"
|
||||
|
||||
template <typename T>
|
||||
__global__ void Col2ImKernel(const T *input, T *output, const size_t num_kernels, const size_t per_batch_size,
|
||||
const size_t per_channel_size, const size_t per_col_batch_size, const size_t out_height,
|
||||
const size_t out_width, const size_t in_height, const size_t in_width,
|
||||
const size_t kernel_height, const size_t kernel_width, const size_t pad_height,
|
||||
const size_t pad_width, const size_t stride_height, const size_t stride_width,
|
||||
const size_t dilation_height, const size_t dilation_width) {
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < num_kernels; i += blockDim.x * gridDim.x) {
|
||||
T val = static_cast<T>(0);
|
||||
size_t w_id = i % out_height + pad_width;
|
||||
size_t h_id = i % per_batch_size / out_width % out_height + pad_height;
|
||||
size_t c_id = i % per_batch_size / per_channel_size;
|
||||
size_t n_col_offset = i / per_batch_size * per_col_batch_size;
|
||||
size_t kernel_expand_h = (kernel_height - 1) * dilation_height + 1;
|
||||
size_t kernel_expand_w = (kernel_width - 1) * dilation_width + 1;
|
||||
// range coordinates
|
||||
size_t out_height_start = h_id < kernel_expand_h ? 0 : (h_id - kernel_expand_h) / stride_height + 1;
|
||||
size_t out_width_start = w_id < kernel_expand_w ? 0 : (w_id - kernel_expand_w) / stride_width + 1;
|
||||
size_t out_height_end = min(h_id / stride_height + 1, in_height);
|
||||
size_t out_width_end = min(w_id / stride_width + 1, in_width);
|
||||
|
||||
for (size_t height = out_height_start; height < out_height_end; ++height) {
|
||||
for (size_t width = out_width_start; width < out_width_end; ++width) {
|
||||
size_t kernel_h = (h_id - height * stride_height);
|
||||
size_t kernel_w = (w_id - width * stride_width);
|
||||
if (kernel_h % dilation_height == 0 && kernel_w % dilation_width == 0) {
|
||||
kernel_h /= dilation_height;
|
||||
kernel_w /= dilation_width;
|
||||
size_t data_index =
|
||||
n_col_offset +
|
||||
(((c_id * kernel_height + kernel_h) * kernel_width + kernel_w) * in_height + height) * in_width + width;
|
||||
val += input[data_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
output[i] = val;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void Col2Im(const T *input, const size_t batch_size, const size_t channels, const size_t out_height,
|
||||
const size_t out_width, const size_t in_height, const size_t in_width, const size_t kernel_height,
|
||||
const size_t kernel_width, const size_t pad_height, const size_t pad_width, const size_t stride_height,
|
||||
const size_t stride_width, const size_t dilation_height, const size_t dilation_width, T *output,
|
||||
cudaStream_t cuda_stream) {
|
||||
size_t per_channel_size = out_height * out_width;
|
||||
size_t per_batch_size = channels * per_channel_size;
|
||||
size_t num_kernels = batch_size * per_batch_size;
|
||||
size_t per_col_batch_size = channels * in_height * in_width * kernel_width * kernel_height;
|
||||
Col2ImKernel<<<GET_BLOCKS(num_kernels), GET_THREADS, 0, cuda_stream>>>(
|
||||
input, output, num_kernels, per_batch_size, per_channel_size, per_col_batch_size, out_height, out_width, in_height,
|
||||
in_width, kernel_height, kernel_width, pad_height, pad_width, stride_height, stride_width, dilation_height,
|
||||
dilation_width);
|
||||
return;
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void Col2Im<float>(const float *input, const size_t batch_size, const size_t channels,
|
||||
const size_t out_height, const size_t out_width, const size_t in_height,
|
||||
const size_t in_width, const size_t kernel_height,
|
||||
const size_t kernel_width, const size_t pad_height, const size_t pad_width,
|
||||
const size_t stride_height, const size_t stride_width,
|
||||
const size_t dilation_height, const size_t dilation_width, float *output,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void Col2Im<half>(const half *input, const size_t batch_size, const size_t channels,
|
||||
const size_t out_height, const size_t out_width, const size_t in_height,
|
||||
const size_t in_width, const size_t kernel_height, const size_t kernel_width,
|
||||
const size_t pad_height, const size_t pad_width, const size_t stride_height,
|
||||
const size_t stride_width, const size_t dilation_height,
|
||||
const size_t dilation_width, half *output, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,28 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_COL2IM_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_COL2IM_IMPL_CUH_
|
||||
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_common.h"
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void Col2Im(const T *input, const size_t batch_size, const size_t channels, const size_t out_height,
|
||||
const size_t out_width, const size_t in_height, const size_t in_width,
|
||||
const size_t kernel_height, const size_t kernel_width, const size_t pad_height,
|
||||
const size_t pad_width, const size_t stride_height, const size_t stride_width,
|
||||
const size_t dilation_height, const size_t dilation_width, T *output,
|
||||
cudaStream_t cuda_stream);
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_COL2IM_IMPL_CUH_
|
|
@ -0,0 +1,131 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/nn/col2im_gpu_kernel.h"
|
||||
#include "mindspore/core/abstract/utils.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/col2im_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr int kCol2ImInputsNum = 2;
|
||||
constexpr int kPaddingDirection = 2;
|
||||
} // namespace
|
||||
|
||||
void Col2ImFwdGpuKernelMod::ResetResource() noexcept {
|
||||
batch_size_ = 0;
|
||||
channels_ = 0;
|
||||
out_height_ = 0;
|
||||
out_width_ = 0;
|
||||
in_height_ = 0;
|
||||
in_width_ = 0;
|
||||
pad_height_ = 0;
|
||||
pad_width_ = 0;
|
||||
kernel_height_ = 0;
|
||||
kernel_width_ = 0;
|
||||
stride_height_ = 0;
|
||||
stride_width_ = 0;
|
||||
dilation_height_ = 0;
|
||||
dilation_width_ = 0;
|
||||
is_null_input_ = false;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool Col2ImFwdGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs,
|
||||
const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) {
|
||||
T *input_addr = GetDeviceAddress<T>(inputs, kIndex0);
|
||||
T *output_addr = GetDeviceAddress<T>(outputs, kIndex0);
|
||||
Col2Im(input_addr, batch_size_, channels_, out_height_, out_width_, in_height_, in_width_, kernel_height_,
|
||||
kernel_width_, pad_height_, pad_width_, stride_height_, stride_width_, dilation_height_, dilation_width_,
|
||||
output_addr, reinterpret_cast<cudaStream_t>(cuda_stream_));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Col2ImFwdGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
if (inputs.empty() || outputs.empty()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
|
||||
return false;
|
||||
}
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' does not support this kernel type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
return true;
|
||||
}
|
||||
|
||||
int Col2ImFwdGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
ResetResource();
|
||||
int ret = KernelMod::Resize(base_operator, inputs, outputs, inputsOnHost);
|
||||
if (ret != 0) {
|
||||
return ret;
|
||||
}
|
||||
if (input_size_list_.size() != kCol2ImInputsNum) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' input size must be equal 2.";
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
auto input_shape = inputs[kIndex0]->GetShapeVector();
|
||||
auto output_shape = outputs[kIndex0]->GetShapeVector();
|
||||
batch_size_ = input_shape[kIndex0];
|
||||
channels_ = input_shape[kIndex1];
|
||||
out_height_ = output_shape[kIndex2];
|
||||
out_width_ = output_shape[kIndex3];
|
||||
auto kernel_size = GetValue<std::vector<int64_t>>(base_operator->GetAttr("kernel_size"));
|
||||
auto dilation = GetValue<std::vector<int64_t>>(base_operator->GetAttr("dilation"));
|
||||
auto padding = GetValue<std::vector<int64_t>>(base_operator->GetAttr("padding"));
|
||||
auto stride = GetValue<std::vector<int64_t>>(base_operator->GetAttr("stride"));
|
||||
pad_height_ = padding[kIndex0];
|
||||
pad_width_ = padding[kIndex1];
|
||||
kernel_height_ = kernel_size[kIndex0];
|
||||
kernel_width_ = kernel_size[kIndex1];
|
||||
stride_height_ = stride[kIndex0];
|
||||
stride_width_ = stride[kIndex1];
|
||||
dilation_height_ = dilation[kIndex0];
|
||||
dilation_width_ = dilation[kIndex1];
|
||||
in_height_ =
|
||||
(out_height_ + kPaddingDirection * pad_height_ - (dilation_height_ * (kernel_height_ - 1) + 1)) / stride_height_ +
|
||||
1;
|
||||
in_width_ =
|
||||
(out_width_ + kPaddingDirection * pad_width_ - (dilation_width_ * (kernel_width_ - 1) + 1)) / stride_width_ + 1;
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, Col2ImFwdGpuKernelMod::Col2ImFunc>> Col2ImFwdGpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&Col2ImFwdGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat16),
|
||||
&Col2ImFwdGpuKernelMod::LaunchKernel<half>}};
|
||||
|
||||
std::vector<KernelAttr> Col2ImFwdGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, Col2ImFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Col2Im, Col2ImFwdGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,86 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_COL2IM_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_COL2IM_GPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/pad_impl.cuh"
|
||||
#include "plugin/device/gpu/kernel/kernel_constants.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class Col2ImFwdGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
Col2ImFwdGpuKernelMod() { ResetResource(); }
|
||||
~Col2ImFwdGpuKernelMod() = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
cuda_stream_ = stream_ptr;
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
protected:
|
||||
void ResetResource() noexcept;
|
||||
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs);
|
||||
using Col2ImFunc =
|
||||
std::function<bool(Col2ImFwdGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
|
||||
private:
|
||||
size_t batch_size_{0};
|
||||
size_t channels_{0};
|
||||
size_t out_height_{0};
|
||||
size_t out_width_{0};
|
||||
size_t in_height_{0};
|
||||
size_t in_width_{0};
|
||||
size_t pad_width_{0};
|
||||
size_t pad_height_{0};
|
||||
size_t kernel_height_{0};
|
||||
size_t kernel_width_{0};
|
||||
size_t stride_height_{0};
|
||||
size_t stride_width_{0};
|
||||
size_t dilation_height_{0};
|
||||
size_t dilation_width_{0};
|
||||
bool is_null_input_{false};
|
||||
void *cuda_stream_{nullptr};
|
||||
Col2ImFunc kernel_func_{};
|
||||
static std::vector<std::pair<KernelAttr, Col2ImFunc>> func_list_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_NN_COL2IM_GPU_KERNEL_H_
|
|
@ -1667,6 +1667,13 @@ def intopk(x1, x2, k):
|
|||
return F.intopk(x1, x2, k)
|
||||
|
||||
|
||||
def col2im(input_x, output_size, kernel_size, dilation, padding_value, stride):
|
||||
"""
|
||||
Combines an array of sliding local blocks into a large containing tensor.
|
||||
"""
|
||||
return F.col2im(input_x, output_size, kernel_size, dilation, padding_value, stride)
|
||||
|
||||
|
||||
def narrow(x, axis, start, length):
|
||||
"""
|
||||
Returns a narrowed tensor from input tensor.
|
||||
|
|
|
@ -1572,6 +1572,46 @@ class Tensor(Tensor_):
|
|||
perm = validator.check_transpose_axis(axes, self.ndim)
|
||||
return tensor_operator_registry.get('transpose')()(self, perm)
|
||||
|
||||
def col2im(self, output_size, kernel_size, dilation, padding_value, stride):
|
||||
"""
|
||||
Combines an array of sliding local blocks into a large containing tensor.
|
||||
|
||||
Args:
|
||||
output_size (Tensor): 1D tensor with 2 elements of data type int.
|
||||
kernel_size (Union[int, tuple[int], list[int]]): The size of the kernel, should be two int
|
||||
for height and width. If type is int, it means that height equal with width. Must be specified.
|
||||
dilation (Union[int, tuple[int], list[int]]): The size of the dilation, should be two int
|
||||
for height and width. If type is int, it means that height equal with width. Default: 1.
|
||||
padding_value (Union[int, tuple[int], list[int]]): The size of the padding, should be two int
|
||||
for height and width. If type is int, it means that height equal with width. Default: 1.
|
||||
stride (Union[int, tuple[int], list[int]]): The size of the stride, should be two int
|
||||
for height and width. If type is int, it means that height equal with width. Default: 0.
|
||||
|
||||
Returns:
|
||||
A 4D Tensor, with same type as input 'x'.
|
||||
|
||||
Raises:
|
||||
TypeError: If :attr:`kernel_size`, `dilation`, `padding_value`, `stride` data type is not in
|
||||
Union[int, tuple[int], list[int]].
|
||||
ValueError: If :attr:`kernel_size`, `dilation`, `stride` value is less than zero or elements
|
||||
number more than 2.
|
||||
ValueError: If :attr:`padding_value` value is not greater than zero or elements number more than 2.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(input_data=np.random.rand(16, 16, 4, 25), dtype=mstype.float32)
|
||||
>>> output_size = Tensor(input_data=[8, 8], dtype=mstype.int32)
|
||||
>>> y = x.col2im(output_size, kernel_size=[2, 2], dilation=[2, 2], padding_value=[2, 2], stride=[2, 2])
|
||||
>>> print(y.shape)
|
||||
(16, 16, 8, 8)
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('col2im')(self, output_size, kernel_size, dilation, padding_value, stride)
|
||||
|
||||
|
||||
|
||||
def reshape(self, *shape):
|
||||
"""
|
||||
Give a new shape to a tensor without changing its data.
|
||||
|
|
|
@ -27,6 +27,7 @@ from .._vmap.vmap_base import vmap_rules_getters, vmap_general_preprocess, _bdim
|
|||
from ..operations.array_ops import Fills
|
||||
from ..operations.array_ops import ScatterNdMax
|
||||
from ..operations.array_ops import UniqueConsecutive
|
||||
from ..operations.array_ops import Col2Im
|
||||
|
||||
|
||||
@vmap_rules_getters.register("Cast")
|
||||
|
@ -947,3 +948,4 @@ def get_gather_vmap_rule(prim, axis_size):
|
|||
get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(P.Unique)(get_unsupported_dynamic_vmap_rule)
|
||||
get_unsupported_dynamic_vmap_rule =\
|
||||
vmap_rules_getters.register(UniqueConsecutive)(get_unsupported_dynamic_vmap_rule)
|
||||
get_unsupported_dynamic_vmap_rule = vmap_rules_getters.register(Col2Im)(get_unsupported_dynamic_vmap_rule)
|
||||
|
|
|
@ -87,6 +87,7 @@ from .array_func import (
|
|||
fills,
|
||||
broadcast_to,
|
||||
adaptive_max_pool2d,
|
||||
col2im,
|
||||
)
|
||||
from .parameter_func import (
|
||||
assign,
|
||||
|
|
|
@ -22,6 +22,7 @@ from mindspore.ops.primitive import constexpr
|
|||
from ..operations.array_ops import UniqueConsecutive
|
||||
from ..operations.array_ops import NonZero, MatrixDiagV3
|
||||
from ..operations.array_ops import Fills
|
||||
from ..operations.array_ops import Col2Im
|
||||
from ..operations.array_ops import ScatterNdMax
|
||||
from ..operations.array_ops import ScatterNdMul
|
||||
from ..operations.nn_ops import AdaptiveMaxPool2D
|
||||
|
@ -2909,6 +2910,51 @@ def diag(input_x):
|
|||
return diag_(input_x)
|
||||
|
||||
|
||||
def col2im(input_x, output_size, kernel_size, dilation, padding_value, stride):
|
||||
"""
|
||||
Combines an array of sliding local blocks into a large containing tensor.
|
||||
|
||||
Args:
|
||||
input_x (Tensor): 4D tensor with data type float16 or float.
|
||||
output_size (Tensor): 1D tensor with 2 elements of data type int.
|
||||
kernel_size (Union[int, tuple[int], list[int]]): The size of the kernel, should be two int
|
||||
for height and width. If type is int, it means that height equal with width. Must be specified.
|
||||
dilation (Union[int, tuple[int], list[int]]): The size of the dilation, should be two int
|
||||
for height and width. If type is int, it means that height equal with width. Default: 1.
|
||||
padding_value (Union[int, tuple[int], list[int]]): The size of the padding, should be two int
|
||||
for height and width. If type is int, it means that height equal with width. Default: 1.
|
||||
stride (Union[int, tuple[int], list[int]]): The size of the stride, should be two int
|
||||
for height and width. If type is int, it means that height equal with width. Default: 0.
|
||||
|
||||
Returns:
|
||||
A 4D Tensor, with same type as 'input_x'.
|
||||
|
||||
Raises:
|
||||
TypeError: If :attr:`kernel_size`, `dilation`, `padding_value`, `stride` data type is not in
|
||||
Union[int, tuple[int], list[int]].
|
||||
ValueError: If :attr:`kernel_size`, `dilation`, `padding_value`, `stride` value is not
|
||||
greater than zero or elements number more than 2.
|
||||
ValueError: If :attr:`padding_value` value is less than zero or elements number more than 2.
|
||||
ValueError: If input_x.shape[2] != kernel_size[0] * kernel_size[1].
|
||||
ValueError: If input_x.shape[3] does not match the calculated number of sliding blocks.
|
||||
|
||||
Supported Platforms:
|
||||
``GPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> import mindspore.ops as ops
|
||||
>>> from mindspore import Tensor
|
||||
>>> x = Tensor(input_data=np.random.rand(16, 16, 4, 25), dtype=mstype.float32)
|
||||
>>> output_size = Tensor(input_data=[8, 8], dtype=mstype.int32)
|
||||
>>> output = ops.col2im(x, output_size, [2, 2], [2, 2], [2, 2], [2, 2])
|
||||
>>> print(output.shape)
|
||||
(16, 16, 8, 8)
|
||||
"""
|
||||
c2i = Col2Im(kernel_size, dilation, padding_value, stride)
|
||||
return c2i(input_x, output_size)
|
||||
|
||||
|
||||
__all__ = [
|
||||
'unique',
|
||||
'unique_consecutive',
|
||||
|
@ -2971,5 +3017,6 @@ __all__ = [
|
|||
'adaptive_max_pool2d',
|
||||
'meshgrid',
|
||||
'broadcast_to',
|
||||
'col2im'
|
||||
]
|
||||
__all__.sort()
|
||||
|
|
|
@ -970,6 +970,7 @@ tensor_operator_registry.register('inplace_update', P.InplaceUpdate)
|
|||
tensor_operator_registry.register('inplace_add', P.InplaceAdd)
|
||||
tensor_operator_registry.register('inplace_sub', P.InplaceSub)
|
||||
tensor_operator_registry.register('adaptive_avgpool2d', adaptive_avgpool2d)
|
||||
tensor_operator_registry.register('col2im', col2im)
|
||||
# ms cannot support Tensor(True) compare
|
||||
tensor_operator_registry.register('__eq__', equal)
|
||||
tensor_operator_registry.register('__ne__', not_equal)
|
||||
|
|
|
@ -530,9 +530,13 @@ class Col2Im(Primitive):
|
|||
ValueError: If x.shape[3] does not match the calculated number of sliding blocks.
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore import dtype as mstype
|
||||
>>> from mindspore.ops.operations.array_ops import Col2Im
|
||||
>>> x = Tensor(input_data=np.random.rand(16, 16, 4, 25), dtype=mstype.float32)
|
||||
>>> output_size = Tensor(input_data=[8, 8], dtype=mstype.int32)
|
||||
>>> col2im = P.Col2Im(kernel_size=[2, 2], dilation=[2, 2], padding=[2, 2], stride=[2, 2])
|
||||
>>> col2im = Col2Im(kernel_size=[2, 2], dilation=[2, 2], padding=[2, 2], stride=[2, 2])
|
||||
>>> y = col2im(x, output_size)
|
||||
>>> print(y.shape)
|
||||
(16, 16, 8, 8)
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations.array_ops import Col2Im
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
np.random.seed(1)
|
||||
|
||||
|
||||
class Col2ImTest(nn.Cell):
|
||||
def __init__(self, kernel_size, dilation, padding, stride):
|
||||
super(Col2ImTest, self).__init__()
|
||||
self.c2i = Col2Im(kernel_size, dilation, padding, stride)
|
||||
|
||||
def construct(self, x, output_size):
|
||||
return self.c2i(x, output_size)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
@pytest.mark.parametrize("mode", [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_col2im_op(mode):
|
||||
"""
|
||||
Feature: Celu cpu kernel
|
||||
Description: test the celu alpha = 1.0.
|
||||
Expectation: match to np benchmark.
|
||||
"""
|
||||
context.set_context(mode=mode, device_target='GPU')
|
||||
x = Tensor(np.random.rand(16, 16, 4, 25).astype(np.float32))
|
||||
output_size = Tensor([8, 8], dtype=mstype.int32)
|
||||
kernel_size = [2, 2]
|
||||
dilation = [2, 2]
|
||||
padding = [2, 2]
|
||||
stride = [2, 2]
|
||||
expect_shape = (16, 16, 8, 8)
|
||||
col2im = Col2ImTest(kernel_size=kernel_size, dilation=dilation, padding=padding, stride=stride)
|
||||
output = col2im(x, output_size)
|
||||
assert output.shape == expect_shape
|
||||
|
||||
output_func = F.col2im(x, output_size, kernel_size, dilation, padding, stride)
|
||||
assert output_func.shape == expect_shape
|
||||
|
||||
assert x.col2im(output_size, kernel_size, dilation, padding, stride).shape == expect_shape
|
Loading…
Reference in New Issue