forked from mindspore-Ecosystem/mindspore
rewrite gpu implementation of op maxpool_with_argmax and maxpool_with_argmax_grad
This commit is contained in:
parent
4477b97465
commit
09b68f9006
|
@ -18,130 +18,34 @@
|
|||
#include "maxpool_with_argmax_grad_impl.cuh"
|
||||
#include "runtime/device/gpu/cuda_common.h"
|
||||
#include "include/cuda_fp16.h"
|
||||
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
|
||||
|
||||
template <typename T, typename S>
|
||||
__global__ void MaxPoolWithArgmaxGrad(const T* x,
|
||||
const T* dy,
|
||||
__global__ void MaxPoolWithArgmaxGrad(const T* dy,
|
||||
const S* index,
|
||||
const int n,
|
||||
const int c,
|
||||
const int xHeight,
|
||||
const int xWidth,
|
||||
const int dyHeight,
|
||||
const int dyWidth,
|
||||
const int windowHeight,
|
||||
const int windowWidth,
|
||||
const int strideHeight,
|
||||
const int strideWidth,
|
||||
const int padTop,
|
||||
const int padLeft,
|
||||
const int xNCHW,
|
||||
const int xCHW,
|
||||
const int xHW,
|
||||
const int dyCHW,
|
||||
const int dyHW,
|
||||
const int dyNCHW,
|
||||
T* dx) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
pos < (xNCHW);
|
||||
pos += blockDim.x * gridDim.x) {
|
||||
const int posn = pos / xCHW;
|
||||
const int posc = pos / xHW % c;
|
||||
const int posh = pos / xHeight % xHeight;
|
||||
const int posw = pos % xWidth;
|
||||
const S posIdx = posh*xWidth + posw;
|
||||
int hstart = posh+padTop;
|
||||
if (hstart < windowHeight) {
|
||||
hstart = 0;
|
||||
} else {
|
||||
hstart = (hstart-windowHeight)/strideHeight + 1;
|
||||
}
|
||||
int wstart = posw+padLeft;
|
||||
if (wstart < windowWidth) {
|
||||
wstart = 0;
|
||||
} else {
|
||||
wstart = (wstart-windowWidth)/strideWidth + 1;
|
||||
}
|
||||
const int hend = min((posh+padTop)/strideHeight +1, dyHeight);
|
||||
const int wend = min((posw+padLeft)/strideWidth +1, dyWidth);
|
||||
const int channelStart = posn*dyCHW + posc*dyHW;
|
||||
T dySum = static_cast<T>(0.0);
|
||||
for (int hcur = hstart; hcur < hend; ++hcur) {
|
||||
for (int wcur = wstart; wcur < wend; ++wcur) {
|
||||
const int curIdx = hcur*dyWidth + wcur;
|
||||
S maxIdx = index[channelStart+curIdx];
|
||||
if (maxIdx == posIdx) {
|
||||
dySum += dy[channelStart+curIdx];
|
||||
}
|
||||
}
|
||||
}
|
||||
dx[pos] = dySum;
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (dyNCHW); pos += blockDim.x * gridDim.x) {
|
||||
const S idx = index[pos];
|
||||
const int posn = pos / dyCHW;
|
||||
MsAtomicAdd(dx + posn*xCHW + static_cast<int>(idx), dy[pos]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void MaxPoolWithArgmaxGrad(const half* x,
|
||||
const half* dy,
|
||||
const int* index,
|
||||
const int n,
|
||||
const int c,
|
||||
const int xHeight,
|
||||
const int xWidth,
|
||||
const int dyHeight,
|
||||
const int dyWidth,
|
||||
const int windowHeight,
|
||||
const int windowWidth,
|
||||
const int strideHeight,
|
||||
const int strideWidth,
|
||||
const int padTop,
|
||||
const int padLeft,
|
||||
const int xNCHW,
|
||||
const int xCHW,
|
||||
const int xHW,
|
||||
const int dyCHW,
|
||||
const int dyHW,
|
||||
half* dx) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
pos < (xNCHW);
|
||||
pos += blockDim.x * gridDim.x) {
|
||||
const int posn = pos / xCHW;
|
||||
const int posc = pos / xHW % c;
|
||||
const int posh = pos / xHeight % xHeight;
|
||||
const int posw = pos % xWidth;
|
||||
const int posIdx = posh*xWidth + posw;
|
||||
int hstart = posh+padTop;
|
||||
if (hstart < windowHeight) {
|
||||
hstart = 0;
|
||||
} else {
|
||||
hstart = (hstart-windowHeight)/strideHeight + 1;
|
||||
template <typename T>
|
||||
__global__ void InitOutput(const int size, T *output) {
|
||||
T zero = 0;
|
||||
for (size_t id = blockIdx.x * blockDim.x + threadIdx.x; id < size; id += blockDim.x * gridDim.x) {
|
||||
output[id] = zero;
|
||||
}
|
||||
int wstart = posw+padLeft;
|
||||
if (wstart < windowWidth) {
|
||||
wstart = 0;
|
||||
} else {
|
||||
wstart = (wstart-windowWidth)/strideWidth + 1;
|
||||
}
|
||||
const int hend = min((posh+padTop)/strideHeight +1, dyHeight);
|
||||
const int wend = min((posw+padLeft)/strideWidth +1, dyWidth);
|
||||
const int channelStart = posn*dyCHW + posc*dyHW;
|
||||
float dySum = 0.0f;
|
||||
for (int hcur = hstart; hcur < hend; ++hcur) {
|
||||
for (int wcur = wstart; wcur < wend; ++wcur) {
|
||||
const int curIdx = hcur*dyWidth + wcur;
|
||||
int maxIdx = index[channelStart+curIdx];
|
||||
if (maxIdx == posIdx) {
|
||||
dySum += __half2float(dy[channelStart+curIdx]);
|
||||
}
|
||||
}
|
||||
}
|
||||
dx[pos] = __float2half(dySum);
|
||||
}
|
||||
return;
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T, typename S>
|
||||
void CalMaxPoolWithArgmaxGrad(const T* x,
|
||||
const T* dy,
|
||||
void CalMaxPoolWithArgmaxGrad(const T* dy,
|
||||
const S* index,
|
||||
const int n,
|
||||
const int c,
|
||||
|
@ -149,12 +53,6 @@ void CalMaxPoolWithArgmaxGrad(const T* x,
|
|||
const int xWidth,
|
||||
const int dyHeight,
|
||||
const int dyWidth,
|
||||
const int windowHeight,
|
||||
const int windowWidth,
|
||||
const int strideHeight,
|
||||
const int strideWidth,
|
||||
const int padTop,
|
||||
const int padLeft,
|
||||
T* dx,
|
||||
cudaStream_t cuda_stream) {
|
||||
const int xHW = xHeight*xWidth;
|
||||
|
@ -162,36 +60,22 @@ void CalMaxPoolWithArgmaxGrad(const T* x,
|
|||
const int xNCHW = n*xCHW;
|
||||
const int dyHW = dyHeight*dyWidth;
|
||||
const int dyCHW = c*dyHW;
|
||||
MaxPoolWithArgmaxGrad<<<GET_BLOCKS(xNCHW),
|
||||
const int dyNCHW = n*dyCHW;
|
||||
InitOutput<<<GET_BLOCKS(xNCHW), GET_THREADS, 0, cuda_stream>>>(xNCHW, dx);
|
||||
MaxPoolWithArgmaxGrad<<<GET_BLOCKS(dyNCHW),
|
||||
GET_THREADS,
|
||||
0,
|
||||
cuda_stream>>>(
|
||||
x,
|
||||
dy,
|
||||
index,
|
||||
n,
|
||||
c,
|
||||
xHeight,
|
||||
xWidth,
|
||||
dyHeight,
|
||||
dyWidth,
|
||||
windowHeight,
|
||||
windowWidth,
|
||||
strideHeight,
|
||||
strideWidth,
|
||||
padTop,
|
||||
padLeft,
|
||||
xNCHW,
|
||||
xCHW,
|
||||
xHW,
|
||||
dyCHW,
|
||||
dyHW,
|
||||
dyNCHW,
|
||||
dx);
|
||||
return;
|
||||
}
|
||||
|
||||
template void CalMaxPoolWithArgmaxGrad<float, int>(const float* x,
|
||||
const float* dy,
|
||||
template void CalMaxPoolWithArgmaxGrad<float, int>(const float* dy,
|
||||
const int* index,
|
||||
const int n,
|
||||
const int c,
|
||||
|
@ -199,16 +83,9 @@ template void CalMaxPoolWithArgmaxGrad<float, int>(const float* x,
|
|||
const int xWidth,
|
||||
const int dyHeight,
|
||||
const int dyWidth,
|
||||
const int windowHeight,
|
||||
const int windowWidth,
|
||||
const int strideHeight,
|
||||
const int strideWidth,
|
||||
const int padTop,
|
||||
const int padLeft,
|
||||
float* dx,
|
||||
cudaStream_t cuda_stream);
|
||||
template void CalMaxPoolWithArgmaxGrad<half, int>(const half* x,
|
||||
const half* dy,
|
||||
template void CalMaxPoolWithArgmaxGrad<half, int>(const half* dy,
|
||||
const int* index,
|
||||
const int n,
|
||||
const int c,
|
||||
|
@ -216,11 +93,5 @@ template void CalMaxPoolWithArgmaxGrad<half, int>(const half* x,
|
|||
const int xWidth,
|
||||
const int dyHeight,
|
||||
const int dyWidth,
|
||||
const int windowHeight,
|
||||
const int windowWidth,
|
||||
const int strideHeight,
|
||||
const int strideWidth,
|
||||
const int padTop,
|
||||
const int padLeft,
|
||||
half* dx,
|
||||
cudaStream_t cuda_stream);
|
||||
|
|
|
@ -17,9 +17,7 @@
|
|||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_GRAD_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_GRAD_H_
|
||||
template <typename T, typename S>
|
||||
void CalMaxPoolWithArgmaxGrad(const T* x, const T* dy, const S* index, const int n, const int c, const int xHeight,
|
||||
const int xWidth, const int dyHeight, const int dyWidth, const int windowHeight,
|
||||
const int windowWidth, const int strideHeight, const int strideWidth, const int padTop,
|
||||
const int padLeft, T* dx, cudaStream_t cuda_stream);
|
||||
void CalMaxPoolWithArgmaxGrad(const T* dy, const S* index, const int n, const int c, const int xHeight,
|
||||
const int xWidth, const int dyHeight, const int dyWidth, T* dx, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_CUDA_IMPL_MAXPOOLWITHARGMAX_GRAD_H_
|
||||
|
|
|
@ -42,7 +42,7 @@ __global__ void MaxPoolWithArgmax(const T* input,
|
|||
pos += blockDim.x * gridDim.x) {
|
||||
const int posn = pos / outputCHW;
|
||||
const int posc = pos / outputHW % c;
|
||||
const int posh = pos / outputHeight % outputHeight;
|
||||
const int posh = pos / outputWidth % outputHeight;
|
||||
const int posw = pos % outputWidth;
|
||||
int hstart = posh * strideHeight - padTop;
|
||||
int wstart = posw * strideWidth - padLeft;
|
||||
|
@ -50,12 +50,12 @@ __global__ void MaxPoolWithArgmax(const T* input,
|
|||
const int wend = min(wstart + windowWidth, w);
|
||||
hstart = max(hstart, 0);
|
||||
wstart = max(wstart, 0);
|
||||
S inputStart = posn*c*h*w + posc*h*w;
|
||||
S maxIdx = hstart*w + wstart;
|
||||
S inputStart = posn*c*h*w;
|
||||
S maxIdx = posc*h*w + hstart*w + wstart;
|
||||
T maxData = input[inputStart+maxIdx];
|
||||
for (int hcur = hstart; hcur < hend; ++hcur) {
|
||||
for (int wcur = wstart; wcur < wend; ++wcur) {
|
||||
S inputIdx = hcur*w + wcur;
|
||||
S inputIdx = posc*h*w + hcur*w + wcur;
|
||||
T inputData = input[inputStart+inputIdx];
|
||||
if (inputData > maxData) {
|
||||
maxIdx = inputIdx;
|
||||
|
|
|
@ -48,12 +48,10 @@ class MaxPoolWithArgmaxGradGpuKernel : public GpuKernel {
|
|||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
T *x_addr = GetDeviceAddress<T>(inputs, 0);
|
||||
T *dy_addr = GetDeviceAddress<T>(inputs, 1);
|
||||
S *index_addr = GetDeviceAddress<S>(inputs, 2);
|
||||
T *dx_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
CalMaxPoolWithArgmaxGrad(x_addr, dy_addr, index_addr, n_, c_, x_height_, x_width_, dy_height_, dy_width_,
|
||||
window_height_, window_width_, stride_height_, stride_width_, pad_top_, pad_left_, dx_addr,
|
||||
CalMaxPoolWithArgmaxGrad(dy_addr, index_addr, n_, c_, x_height_, x_width_, dy_height_, dy_width_, dx_addr,
|
||||
reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
@ -95,57 +93,19 @@ class MaxPoolWithArgmaxGradGpuKernel : public GpuKernel {
|
|||
x_width_ = SizeToInt(x_shape[3]);
|
||||
dy_height_ = SizeToInt(dy_shape[2]);
|
||||
dy_width_ = SizeToInt(dy_shape[3]);
|
||||
std::vector<int> window;
|
||||
std::vector<int64_t> window_me =
|
||||
GetValue<std::vector<int64_t>>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("ksize"));
|
||||
(void)std::transform(window_me.begin(), window_me.end(), std::back_inserter(window),
|
||||
[](const int64_t &value) { return static_cast<int>(value); });
|
||||
window_height_ = window[1];
|
||||
window_width_ = window[2];
|
||||
std::vector<int> stride;
|
||||
std::vector<int64_t> stride_me =
|
||||
GetValue<std::vector<int64_t>>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("strides"));
|
||||
(void)std::transform(stride_me.begin(), stride_me.end(), std::back_inserter(stride),
|
||||
[](const int64_t &value) { return static_cast<int>(value); });
|
||||
stride_height_ = stride[1];
|
||||
stride_width_ = stride[2];
|
||||
pad_mode_ = GetValue<std::string>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("padding"));
|
||||
pad_top_ = 0;
|
||||
pad_left_ = 0;
|
||||
if (pad_mode_ == kSamePadModeUpperCase || pad_mode_ == kSamePadModeLowerCase) {
|
||||
SetPad();
|
||||
}
|
||||
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(x_size_);
|
||||
input_size_list_.push_back(dy_size_);
|
||||
input_size_list_.push_back(index_size_);
|
||||
output_size_list_.push_back(dx_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
void SetPad() {
|
||||
pad_height_ = std::max<int>(
|
||||
0, (((x_height_ / stride_height_) * stride_height_ == x_height_ ? (x_height_ / stride_height_)
|
||||
: (x_height_ / stride_height_) + 1) -
|
||||
1) *
|
||||
stride_height_ +
|
||||
window_height_ - x_height_);
|
||||
pad_width_ =
|
||||
std::max<int>(0, (((x_width_ / stride_width_) * stride_width_ == x_width_ ? (x_width_ / stride_width_)
|
||||
: (x_width_ / stride_width_) + 1) -
|
||||
1) *
|
||||
stride_width_ +
|
||||
window_width_ - x_width_);
|
||||
pad_top_ = pad_height_ / 2;
|
||||
pad_left_ = pad_width_ / 2;
|
||||
}
|
||||
|
||||
std::string pad_mode_;
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
@ -156,14 +116,6 @@ class MaxPoolWithArgmaxGradGpuKernel : public GpuKernel {
|
|||
int x_width_;
|
||||
int dy_height_;
|
||||
int dy_width_;
|
||||
int window_height_;
|
||||
int window_width_;
|
||||
int pad_height_;
|
||||
int pad_width_;
|
||||
int pad_top_;
|
||||
int pad_left_;
|
||||
int stride_height_;
|
||||
int stride_width_;
|
||||
|
||||
size_t x_size_;
|
||||
size_t dy_size_;
|
||||
|
|
|
@ -16,29 +16,84 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
class Net_Pool(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net_Pool, self).__init__()
|
||||
self.maxpool_fun = P.MaxPoolWithArgmax(ksize=2, strides=2, padding="VALID")
|
||||
|
||||
def construct(self, x):
|
||||
return self.maxpool_fun(x)
|
||||
import mindspore.ops.operations as P
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.nn import Cell
|
||||
from mindspore.ops import composite as C
|
||||
|
||||
|
||||
class Net_Pool2(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net_Pool2, self).__init__()
|
||||
self.maxpool_fun = P.MaxPoolWithArgmax(ksize=3, strides=2, padding="SAME")
|
||||
class MaxPoolWithArgMax_Net(Cell):
|
||||
def __init__(self, padding, ksize, strides):
|
||||
super(MaxPoolWithArgMax_Net, self).__init__()
|
||||
self.maxpool_with_argmax = P.MaxPoolWithArgmax(padding=padding, ksize=ksize, strides=strides)
|
||||
|
||||
def construct(self, x):
|
||||
return self.maxpool_fun(x)
|
||||
def construct(self, input_data):
|
||||
output, argmax = self.maxpool_with_argmax(input_data)
|
||||
return output, argmax
|
||||
|
||||
|
||||
class Grad(Cell):
|
||||
def __init__(self, network, argmax):
|
||||
super(Grad, self).__init__()
|
||||
self.grad = C.GradOperation(get_all=True, sens_param=True)
|
||||
self.network = network
|
||||
self.sens = (Tensor(np.ones(argmax.shape).astype(np.float32)),
|
||||
Tensor(np.ones(argmax.shape).astype(np.int32)))
|
||||
|
||||
def construct(self, input_data):
|
||||
gout = self.grad(self.network)(input_data, self.sens)
|
||||
return gout
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_train_forward_backward():
|
||||
x = np.arange(1 * 3 * 3 * 4).reshape(1, 3, 3, 4).astype(np.float32)
|
||||
expect_output = np.array([[[[5, 6, 7, 7],
|
||||
[9, 10, 11, 11],
|
||||
[9, 10, 11, 11]],
|
||||
[[17, 18, 19, 19],
|
||||
[21, 22, 23, 23],
|
||||
[21, 22, 23, 23]],
|
||||
[[29, 30, 31, 31],
|
||||
[33, 34, 35, 35],
|
||||
[33, 34, 35, 35]]]]).astype(np.float32)
|
||||
expect_argmax = np.array([[[[5, 6, 7, 7],
|
||||
[9, 10, 11, 11],
|
||||
[9, 10, 11, 11]],
|
||||
[[17, 18, 19, 19],
|
||||
[21, 22, 23, 23],
|
||||
[21, 22, 23, 23]],
|
||||
[[29, 30, 31, 31],
|
||||
[33, 34, 35, 35],
|
||||
[33, 34, 35, 35]]]]).astype(np.int32)
|
||||
expect_dx = np.array([[[[0, 0, 0, 0],
|
||||
[0, 1, 1, 2],
|
||||
[0, 2, 2, 4]],
|
||||
[[0, 0, 0, 0],
|
||||
[0, 1, 1, 2],
|
||||
[0, 2, 2, 4]],
|
||||
[[0, 0, 0, 0],
|
||||
[0, 1, 1, 2],
|
||||
[0, 2, 2, 4]]]]).astype(np.float32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net = MaxPoolWithArgMax_Net(padding="SAME", ksize=2, strides=1)
|
||||
output_tensor, argmax_tensor = net(Tensor(x))
|
||||
assert output_tensor.shape == expect_output.shape
|
||||
assert argmax_tensor.shape == expect_argmax.shape
|
||||
|
||||
error = np.ones(shape=expect_output.shape) * 1.0e-5
|
||||
diff_output = output_tensor.asnumpy() - expect_output
|
||||
assert np.all(diff_output < error)
|
||||
|
||||
net_grad = Grad(net, argmax_tensor)
|
||||
dx = net_grad(Tensor(x))[0].asnumpy()
|
||||
assert dx.shape == expect_dx.shape
|
||||
diff = dx - expect_dx
|
||||
assert np.all(diff < error)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -73,8 +128,8 @@ def test_maxpool_with_argmax_2d():
|
|||
]]]))
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
maxpool2d = Net_Pool()
|
||||
maxpool2d2 = Net_Pool2()
|
||||
maxpool2d = MaxPoolWithArgMax_Net(padding="VALID", ksize=2, strides=2)
|
||||
maxpool2d2 = MaxPoolWithArgMax_Net(padding="SAME", ksize=3, strides=2)
|
||||
output2, index2 = maxpool2d2(x)
|
||||
output, index = maxpool2d(x)
|
||||
assert (output.asnumpy() == expect_result).all()
|
||||
|
@ -83,8 +138,8 @@ def test_maxpool_with_argmax_2d():
|
|||
assert (index2.asnumpy() == expect__index_result2).all()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
maxpool2d = Net_Pool()
|
||||
maxpool2d2 = Net_Pool2()
|
||||
maxpool2d = MaxPoolWithArgMax_Net(padding="VALID", ksize=2, strides=2)
|
||||
maxpool2d2 = MaxPoolWithArgMax_Net(padding="SAME", ksize=3, strides=2)
|
||||
output2, index2 = maxpool2d2(x)
|
||||
output, index = maxpool2d(x)
|
||||
assert (output.asnumpy() == expect_result).all()
|
||||
|
@ -126,8 +181,8 @@ def test_maxpool_with_argmax_2d_fp16():
|
|||
]]]))
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
maxpool2d = Net_Pool()
|
||||
maxpool2d2 = Net_Pool2()
|
||||
maxpool2d = MaxPoolWithArgMax_Net(padding="VALID", ksize=2, strides=2)
|
||||
maxpool2d2 = MaxPoolWithArgMax_Net(padding="SAME", ksize=3, strides=2)
|
||||
output2, index2 = maxpool2d2(x)
|
||||
output, index = maxpool2d(x)
|
||||
assert (output.asnumpy() == expect_result).all()
|
||||
|
@ -136,12 +191,11 @@ def test_maxpool_with_argmax_2d_fp16():
|
|||
assert (index2.asnumpy() == expect__index_result2).all()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
maxpool2d = Net_Pool()
|
||||
maxpool2d2 = Net_Pool2()
|
||||
maxpool2d = MaxPoolWithArgMax_Net(padding="VALID", ksize=2, strides=2)
|
||||
maxpool2d2 = MaxPoolWithArgMax_Net(padding="SAME", ksize=3, strides=2)
|
||||
output2, index2 = maxpool2d2(x)
|
||||
output, index = maxpool2d(x)
|
||||
assert (output.asnumpy() == expect_result).all()
|
||||
assert (output2.asnumpy() == expect_result2).all()
|
||||
assert (index.asnumpy() == expect_index_result).all()
|
||||
assert (index2.asnumpy() == expect__index_result2).all()
|
||||
|
Loading…
Reference in New Issue