Add ops test PR 1

This commit is contained in:
ZhengQihao3f3f3f 2021-04-27 16:46:04 +08:00
parent 9e650da6d7
commit e6a41b1938
7 changed files with 526 additions and 9 deletions

View File

@ -0,0 +1,168 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/cuda_impl/adaptive_avg_pool2d_impl.cuh"
__device__ inline uint start_index(uint a, uint b, uint c) {
return floorf(__uint2float_rn(a * c) / __uint2float_rn(b));
}
__device__ inline uint end_index(uint a, uint b, uint c) {
return ceilf(__uint2float_rn((a + 1) * c) / __uint2float_rn(b));
}
template <typename T>
__global__ void AdaptiveAvgPool2DKernel(const uint size, const uint input_height, const uint input_width,
const uint output_height, const uint output_width, T *input_data,
T *output_data) {
for (uint c = blockIdx.x * blockDim.x + threadIdx.x; c < size; c += gridDim.x * blockDim.x) {
T *input_ptr = input_data + c * input_height * input_width;
T *output_ptr = output_data + c * output_height * output_width;
for (uint oh = 0; oh < output_height; oh++) {
uint ih0 = start_index(oh, output_height, input_height);
uint ih1 = end_index(oh, output_height, input_height);
uint kh = ih1 - ih0;
for (uint ow = 0; ow < output_width; ow++) {
uint iw0 = start_index(ow, output_width, input_width);
uint iw1 = end_index(ow, output_width, input_width);
uint kw = iw1 - iw0;
// compute local average
T sum = 0;
for (uint ih = ih0; ih < ih1; ih++) {
for (uint iw = iw0; iw < iw1; iw++) {
sum += input_ptr[ih * input_width + iw];
}
}
output_ptr[oh * output_width + ow] = sum / kh / kw;
}
}
}
}
template <>
__global__ void AdaptiveAvgPool2DKernel(const uint size, const uint input_height, const uint input_width,
const uint output_height, const uint output_width, float *input_data,
float *output_data) {
for (uint c = blockIdx.x * blockDim.x + threadIdx.x; c < size; c += gridDim.x * blockDim.x) {
float *input_ptr = input_data + c * input_height * input_width;
float *output_ptr = output_data + c * output_height * output_width;
for (uint oh = 0; oh < output_height; oh++) {
uint ih0 = start_index(oh, output_height, input_height);
uint ih1 = end_index(oh, output_height, input_height);
uint kh = ih1 - ih0;
for (uint ow = 0; ow < output_width; ow++) {
uint iw0 = start_index(ow, output_width, input_width);
uint iw1 = end_index(ow, output_width, input_width);
uint kw = iw1 - iw0;
// compute local average
float sum = 0;
for (uint ih = ih0; ih < ih1; ih++) {
for (uint iw = iw0; iw < iw1; iw++) {
sum += input_ptr[ih * input_width + iw];
}
}
output_ptr[oh * output_width + ow] = sum / __uint2float_rn(kh * kw);
}
}
}
}
template <>
__global__ void AdaptiveAvgPool2DKernel(const uint size, const uint input_height, const uint input_width,
const uint output_height, const uint output_width, half *input_data,
half *output_data) {
for (uint c = blockIdx.x * blockDim.x + threadIdx.x; c < size; c += gridDim.x * blockDim.x) {
half *input_ptr = input_data + c * input_height * input_width;
half *output_ptr = output_data + c * output_height * output_width;
for (uint oh = 0; oh < output_height; oh++) {
uint ih0 = start_index(oh, output_height, input_height);
uint ih1 = end_index(oh, output_height, input_height);
uint kh = ih1 - ih0;
for (uint ow = 0; ow < output_width; ow++) {
uint iw0 = start_index(ow, output_width, input_width);
uint iw1 = end_index(ow, output_width, input_width);
uint kw = iw1 - iw0;
// compute local average
half sum = 0;
for (uint ih = ih0; ih < ih1; ih++) {
for (uint iw = iw0; iw < iw1; iw++) {
sum += input_ptr[ih * input_width + iw];
}
}
output_ptr[oh * output_width + ow] = sum / __uint2half_rn(kh * kw);
}
}
}
}
template <>
__global__ void AdaptiveAvgPool2DKernel(const uint size, const uint input_height, const uint input_width,
const uint output_height, const uint output_width, double *input_data,
double *output_data) {
for (uint c = blockIdx.x * blockDim.x + threadIdx.x; c < size; c += gridDim.x * blockDim.x) {
double *input_ptr = input_data + c * input_height * input_width;
double *output_ptr = output_data + c * output_height * output_width;
for (uint oh = 0; oh < output_height; oh++) {
uint ih0 = start_index(oh, output_height, input_height);
uint ih1 = end_index(oh, output_height, input_height);
uint kh = ih1 - ih0;
for (uint ow = 0; ow < output_width; ow++) {
uint iw0 = start_index(ow, output_width, input_width);
uint iw1 = end_index(ow, output_width, input_width);
uint kw = iw1 - iw0;
// compute local average
double sum = 0;
for (uint ih = ih0; ih < ih1; ih++) {
for (uint iw = iw0; iw < iw1; iw++) {
sum += input_ptr[ih * input_width + iw];
}
}
output_ptr[oh * output_width + ow] = sum / __uint2double_rn(kh * kw);
}
}
}
}
template <typename T>
void ApplyAdaptiveAvgPool2D(const uint size, const uint input_height, const uint input_width, const uint output_height,
const uint output_width, T *input_data, T *output_data, cudaStream_t cuda_stream) {
AdaptiveAvgPool2DKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
size, input_height, input_width, output_height, output_width, input_data, output_data);
}
template void ApplyAdaptiveAvgPool2D<float>(const uint size, const uint input_height, const uint input_width,
const uint output_height, const uint output_width, float *input_data,
float *output_data, cudaStream_t cuda_stream);
template void ApplyAdaptiveAvgPool2D<half>(const uint size, const uint input_height, const uint input_width,
const uint output_height, const uint output_width, half *input_data,
half *output_data, cudaStream_t cuda_stream);
template void ApplyAdaptiveAvgPool2D<double>(const uint size, const uint input_height, const uint input_width,
const uint output_height, const uint output_width, double *input_data,
double *output_data, cudaStream_t cuda_stream);

View File

@ -0,0 +1,25 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAPTIVEAVGPOOL2D_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAPTIVEAVGPOOL2D_IMPL_H_
#include "runtime/device/gpu/cuda_common.h"
template <typename T>
void ApplyAdaptiveAvgPool2D(const uint size, const uint input_height, const uint input_width, const uint output_height,
const uint output_width, T *input_data, T *output_data, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAPTIVEAVGPOOL2D_IMPL_H_

View File

@ -0,0 +1,31 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/adaptive_avg_pool2d_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(AdaptiveAvgPool2D,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
AdaptiveAvgPool2DKernel, half)
MS_REG_GPU_KERNEL_ONE(AdaptiveAvgPool2D,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
AdaptiveAvgPool2DKernel, float)
MS_REG_GPU_KERNEL_ONE(AdaptiveAvgPool2D,
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
AdaptiveAvgPool2DKernel, double)
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,120 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ADAPTIVEAVGPOOL2D_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ADAPTIVEAVGPOOL2D_GPU_KERNEL_H_
#include <vector>
#include <string>
#include <algorithm>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/adaptive_avg_pool2d_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class AdaptiveAvgPool2DKernel : public GpuKernel {
public:
AdaptiveAvgPool2DKernel()
: input_size_(0),
output_size_(0),
len(0),
input_height(0),
input_width(0),
output_height(0),
output_width(0),
size(0) {}
~AdaptiveAvgPool2DKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> & /*workspace*/,
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
ApplyAdaptiveAvgPool2D(size, input_height, input_width, output_height, output_width, input_addr, output_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
auto shape_addr = AnfAlgo::GetNodeAttr<std::vector<int64_t>>(kernel_node, "output_size");
if (shape_addr.size() == 1) {
output_height = shape_addr[0];
output_width = shape_addr[0];
} else if (shape_addr.size() == 2) {
output_height = static_cast<uint>(shape_addr[1]);
output_width = static_cast<uint>(shape_addr[0]);
} else {
MS_LOG(ERROR) << "Input Error.";
}
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but adaptive_avg_pool2d needs 1 inputs.";
return false;
}
input_size_ = sizeof(T);
output_size_ = sizeof(T);
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
len = static_cast<uint>(input_shape.size());
input_height = static_cast<uint>(input_shape[len - 1]);
input_width = static_cast<uint>(input_shape[len - 2]);
size = static_cast<uint>(len == 3 ? input_shape[0] : input_shape[0] * input_shape[1]);
for (uint i = 0; i < len; i++) {
input_size_ *= input_shape[i];
}
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < output_shape.size(); i++) {
output_size_ *= output_shape[i];
}
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
output_size_list_.push_back(output_size_);
}
private:
size_t input_size_;
size_t output_size_;
uint len;
uint input_height;
uint input_width;
uint output_height;
uint output_width;
uint size;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_ADAPTIVEAVGPOOL2D_GPU_KERNEL_H_

View File

@ -41,7 +41,8 @@ from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter,
from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary,
TensorSummary, HistogramSummary, Print, Assert)
from .control_ops import GeSwitch, Merge
from .inner_ops import (ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign, MakeRefKey,
from .inner_ops import (ScalarCast, Randperm, NoRepeatNGram, LambApplyOptimizerAssign, LambApplyWeightAssign,
MakeRefKey,
FusedWeightScaleApplyMomentum, AdamWeightDecay)
from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, AssignSub, Atan2, BatchMatMul,
@ -83,7 +84,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam
FusedSparseFtrl, FusedSparseProximalAdagrad,
ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2,
ApplyAddSign, ApplyPowerSign, ApplyGradientDescent, ApplyProximalGradientDescent,
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK)
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK, AdaptiveAvgPool2D)
from . import _quant_ops
from ._quant_ops import *
from .other_ops import (Assign, InplaceAssign, IOU, BoundingBoxDecode, BoundingBoxEncode,
@ -107,7 +108,6 @@ from .sponge_ops import (BondForce, BondEnergy, BondAtomEnergy, BondForceWithAto
GetCenterOfGeometry, MDTemperature, NeighborListUpdate, MDIterationLeapFrogLiujian,
CrdToUintCrd, MDIterationSetupRandState, TransferCrd)
__all__ = [
'Unique',
'ReverseSequence',
@ -469,6 +469,7 @@ __all__ = [
"CrdToUintCrd",
"MDIterationSetupRandState",
"TransferCrd",
"AdaptiveAvgPool2D"
]

View File

@ -129,6 +129,74 @@ class Flatten(PrimitiveWithInfer):
return input_x
class AdaptiveAvgPool2D(PrimitiveWithInfer):
r"""
AdaptiveAvgPool2D operation.
This operator applies a 2D adaptive average pooling to an input signal composed of multiple input planes.
That is, for any input size, the size of the specified output is H x W.
The number of output features is equal to the number of input planes.
Args:
output_size (Union[int, tuple]): The target output size is H x W.
ouput_size can be a tulpe, or a single H for H x H, and H x W can be int or None
which means the output size is the same as the input.
Inputs:
- **input_x** (Tensor) - The input of AdaptiveAvgPool2D, which is a 3D or 4D tensor,
with float16, float32, float64 data type.
Outputs:
Tensor, with the same type and same dimensions as the input_x.
Raises:
ValueError: if `output_size` is not a tuple and if `output_size` length is not 2.
TypeError: If `input_x` is not a tensor.
TypeError: If dtype of `input_x` is not float16, float32, float64.
ValueError: If `input_x` dimension is less than or more than output_size dimension.
Supported Platforms:
``GPU``
Examples:
>>> input_x = Tensor(np.array([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]],
>>> [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]],
>>> [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]), mindspore.float32)
>>> adaptive_avg_pool_2d = ops.AdaptiveAvgPool2D((2, 2))
>>> output = adaptive_avg_pool_2d(input_x)
>>> print(output)
[[[3.0, 4.0], [6.0, 7.0]],
[[3.0, 4.0], [6.0, 7.0]],
[[3.0, 4.0], [6.0, 7.0]]]
"""
@prim_attr_register
def __init__(self, output_size):
validator.check_value_type("output_size", output_size, [int, tuple], self.name)
if isinstance(output_size, tuple):
validator.check_int(len(output_size), 2, Rel.EQ, 'output_size', self.name)
self.output_size = (output_size, output_size) if isinstance(self.output_size, int) else output_size
def infer_shape(self, x_shape):
if len(x_shape) <= len(self.output_size):
raise ValueError("{} dimension should be larger than {} dimension".format(x_shape, self.output_size))
validator.check_int(len(x_shape), 5, Rel.LT, 'input_x_dimensions', self.name)
for input_x_dimension in x_shape:
validator.check_int(input_x_dimension, 0, Rel.GT, 'input_x dimension', self.name)
zipped = zip(self.output_size, x_shape[-len(self.output_size):])
out_size = [i if i else j for i, j in zipped]
for item in out_size:
validator.check_value_type("item of output_size", item, [int], self.name)
self.add_prim_attr('output_size', (out_size))
output_shape = x_shape[:len(x_shape) - len(out_size)] + out_size
return output_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_dtype_valid("x_dtype", x_dtype, [mstype.float16, mstype.float32, mstype.float64],
self.name)
return x_dtype
class Softmax(Primitive):
r"""
Softmax operation.
@ -3298,6 +3366,7 @@ class Gelu(PrimitiveWithInfer):
Same as operator GeLU. Gelu will be deprecated in the future.
Please use GeLU instead.
"""
@deprecated("1.1", "GeLU", True)
@prim_attr_register
def __init__(self):
@ -3354,13 +3423,12 @@ class GeLU(Primitive):
self.init_prim_io_names(inputs=['x'], outputs=['output'])
class FastGelu(PrimitiveWithInfer):
"""
Same as operator FastGeLU. FastGelu will be deprecated in the future.
Please use FastGeLU instead.
"""
@deprecated("1.1", "FastGeLU", True)
@prim_attr_register
def __init__(self):

View File

@ -0,0 +1,104 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor, ops
from mindspore.ops import operations as P
from mindspore.common.api import ms_function
context.set_context(device_target='GPU')
class Net(nn.Cell):
def __init__(self, output_size):
super(Net, self).__init__()
self.adaptive_avg_pool2d = P.AdaptiveAvgPool2D(output_size)
@ms_function
def construct(self, x):
return self.adaptive_avg_pool2d(x)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_net_normal():
x = np.random.randn(1, 32, 9, 9)
net = Net((3, 5))
output = net(Tensor(x, mindspore.float32))
expect_shape = (1, 32, 3, 5)
assert output.asnumpy().shape == expect_shape
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_net_single():
x = np.random.randn(1, 32, 7, 9)
net = Net(5)
output = net(Tensor(x, mindspore.float32))
expect_shape = (1, 32, 5, 5)
assert output.asnumpy().shape == expect_shape
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_net_none():
x = np.random.randn(1, 32, 7, 9)
net = Net((None, 5))
output = net(Tensor(x, mindspore.float32))
expect_shape = (1, 32, 7, 5)
assert output.asnumpy().shape == expect_shape
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_net_value():
x = np.array([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]],
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]],
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]])
net = Net((2, 2))
output = net(Tensor(x))
expect_shape = (3, 2, 2)
expect_output = np.array([[[3.0, 4.0], [6.0, 7.0]],
[[3.0, 4.0], [6.0, 7.0]],
[[3.0, 4.0], [6.0, 7.0]]])
assert output.asnumpy().shape == expect_shape
assert (output.asnumpy() == expect_output).all
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_net_pynative():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
x = np.array([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]],
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]],
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]])
adaptive_avg_pool_2d = ops.AdaptiveAvgPool2D((2, 2))
output = adaptive_avg_pool_2d(Tensor(x))
expect_shape = (3, 2, 2)
expect_output = np.array([[[3.0, 4.0], [6.0, 7.0]],
[[3.0, 4.0], [6.0, 7.0]],
[[3.0, 4.0], [6.0, 7.0]]])
assert output.asnumpy().shape == expect_shape
assert (output.asnumpy() == expect_output).all