forked from mindspore-Ecosystem/mindspore
!10985 Refactor cpu ops Concat to support morn than 4D inputs
From: @yuanwei66 Reviewed-by: @c_34,@wuxuejian Signed-off-by: @wuxuejian
This commit is contained in:
commit
abb86381ac
|
@ -19,84 +19,50 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
void ConcatCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
template <typename T>
|
||||
void ConcatCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
|
||||
CheckParam(kernel_node);
|
||||
|
||||
axis_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS);
|
||||
axis_ = LongToInt(AnfAlgo::GetNodeAttr<int64_t>(kernel_node, AXIS));
|
||||
auto input_1_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (axis_ < 0) {
|
||||
axis_ = axis_ + SizeToLong(input_1_shape.size());
|
||||
}
|
||||
axis_ += 4 - SizeToLong(input_1_shape.size());
|
||||
|
||||
auto input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t i = 0; i < input_num; i++) {
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i);
|
||||
CPUKernelUtils::ExpandDimsTo4(&input_shape);
|
||||
input_shape_list_.push_back(input_shape);
|
||||
axis_ = axis_ + SizeToInt(input_1_shape.size());
|
||||
}
|
||||
|
||||
output_shape_ = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
CPUKernelUtils::ExpandDimsTo4(&output_shape_);
|
||||
input_num_ = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
for (size_t i = 0; i < input_num_; i++) {
|
||||
auto input_shape_i = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i);
|
||||
auto flat_shape = CPUKernelUtils::FlatShapeByAxis(input_shape_i, axis_);
|
||||
input_flat_shape_list_.push_back(flat_shape);
|
||||
}
|
||||
}
|
||||
|
||||
bool ConcatCPUKernel::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto output_addr = reinterpret_cast<float *>(outputs[0]->addr);
|
||||
template <typename T>
|
||||
bool ConcatCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> & /*workspace*/,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
auto buff_size = outputs[0]->size;
|
||||
size_t dim0 = output_shape_[0];
|
||||
size_t dim1 = output_shape_[1];
|
||||
size_t dim2 = output_shape_[2];
|
||||
|
||||
if (axis_ == 3) {
|
||||
for (size_t i = 0; i < dim0; ++i) {
|
||||
for (size_t j = 0; j < dim1; ++j) {
|
||||
for (size_t k = 0; k < dim2; ++k) {
|
||||
CopyDataToOutput(inputs, i, j, k, &output_addr, &buff_size);
|
||||
}
|
||||
// each input's row of shape after flat are same
|
||||
auto before_axis = input_flat_shape_list_[0][0];
|
||||
for (size_t i = 0; i < before_axis; ++i) {
|
||||
for (size_t j = 0; j < input_num_; ++j) {
|
||||
auto input_j_addr = reinterpret_cast<T *>(inputs[j]->addr);
|
||||
auto copy_num = input_flat_shape_list_[j][1];
|
||||
auto offset = copy_num * i;
|
||||
auto ret = memcpy_s(output_addr, buff_size, input_j_addr + offset, copy_num * sizeof(T));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "memcpy failed.";
|
||||
}
|
||||
output_addr += copy_num;
|
||||
buff_size -= copy_num * sizeof(T);
|
||||
}
|
||||
} else if (axis_ == 2) {
|
||||
for (size_t i = 0; i < dim0; ++i) {
|
||||
for (size_t j = 0; j < dim1; ++j) {
|
||||
CopyDataToOutput(inputs, i, j, 0, &output_addr, &buff_size);
|
||||
}
|
||||
}
|
||||
} else if (axis_ == 1) {
|
||||
for (size_t i = 0; i < dim0; ++i) {
|
||||
CopyDataToOutput(inputs, i, 0, 0, &output_addr, &buff_size);
|
||||
}
|
||||
} else if (axis_ == 0) {
|
||||
CopyDataToOutput(inputs, 0, 0, 0, &output_addr, &buff_size);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void ConcatCPUKernel::CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t dim0, size_t dim1,
|
||||
size_t dim2, float **output_addr, size_t *buff_size) {
|
||||
for (size_t i = 0; i < input_shape_list_.size(); ++i) {
|
||||
auto input_i_shape = input_shape_list_[i];
|
||||
auto input_i_addr = reinterpret_cast<float *>(inputs[i]->addr);
|
||||
|
||||
size_t num = CPUKernelUtils::GetElementNumOnAxis(input_i_shape, axis_);
|
||||
num *= input_i_shape[axis_];
|
||||
auto pos = CPUKernelUtils::CalcOffset(input_i_shape, dim0, dim1, dim2, 0);
|
||||
auto ret = memcpy_s(*output_addr, *buff_size, input_i_addr + pos, num * sizeof(float));
|
||||
if (ret != EOK) {
|
||||
MS_LOG(EXCEPTION) << "memcpy failed.";
|
||||
}
|
||||
*output_addr += num;
|
||||
*buff_size -= num * sizeof(float);
|
||||
}
|
||||
}
|
||||
|
||||
void ConcatCPUKernel::CheckParam(const CNodePtr &kernel_node) {
|
||||
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
|
||||
if (input_shape.size() > 4) {
|
||||
MS_LOG(EXCEPTION) << "Input dims is " << input_shape.size() << ", but ConcatCPUKernel olny support 4d or lower.";
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ConcatCPUKernel<T>::CheckParam(const CNodePtr &kernel_node) {
|
||||
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but ConcatCPUKernel needs 1 output.";
|
||||
|
|
|
@ -22,9 +22,10 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
template <typename T>
|
||||
class ConcatCPUKernel : public CPUKernel {
|
||||
public:
|
||||
ConcatCPUKernel() : axis_(0) {}
|
||||
ConcatCPUKernel() = default;
|
||||
~ConcatCPUKernel() override = default;
|
||||
|
||||
void InitKernel(const CNodePtr &kernel_node) override;
|
||||
|
@ -34,16 +35,20 @@ class ConcatCPUKernel : public CPUKernel {
|
|||
|
||||
private:
|
||||
void CheckParam(const CNodePtr &kernel_node);
|
||||
void CopyDataToOutput(const std::vector<kernel::AddressPtr> &inputs, size_t dim0, size_t dim1, size_t dim2,
|
||||
float **output_addr, size_t *buff_size);
|
||||
int64_t axis_;
|
||||
std::vector<std::vector<size_t>> input_shape_list_;
|
||||
std::vector<size_t> output_shape_;
|
||||
int axis_ = 0;
|
||||
size_t input_num_ = 1;
|
||||
std::vector<std::vector<size_t>> input_flat_shape_list_;
|
||||
};
|
||||
|
||||
MS_REG_CPU_KERNEL(Concat,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ConcatCPUKernel);
|
||||
MS_REG_CPU_KERNEL_T(
|
||||
Concat, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
ConcatCPUKernel, float);
|
||||
MS_REG_CPU_KERNEL_T(Concat,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
ConcatCPUKernel, int)
|
||||
MS_REG_CPU_KERNEL_T(Concat,
|
||||
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeBool).AddOutputAttr(kNumberTypeBool),
|
||||
ConcatCPUKernel, bool)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -98,5 +98,24 @@ void CPUKernelUtils::ParallelFor(const CTask &task, size_t count) {
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<size_t> CPUKernelUtils::FlatShapeByAxis(const std::vector<size_t> &shape, int axis) {
|
||||
if (axis < 0) {
|
||||
axis = axis + SizeToInt(shape.size());
|
||||
}
|
||||
size_t dim_row = 1;
|
||||
size_t dim_col = 1;
|
||||
std::vector<size_t> flat_shape;
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
if (SizeToInt(i) < axis) {
|
||||
dim_row *= shape[i];
|
||||
} else {
|
||||
dim_col *= shape[i];
|
||||
}
|
||||
}
|
||||
flat_shape.push_back(dim_row);
|
||||
flat_shape.push_back(dim_col);
|
||||
return flat_shape;
|
||||
}
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -120,6 +120,7 @@ class CPUKernelUtils {
|
|||
static size_t GetElementNumOnAxis(const std::vector<size_t> &shape, int axis);
|
||||
static void GetElementNumEveryDim(const std::vector<size_t> &shape, std::vector<size_t> *element_num);
|
||||
static void ParallelFor(const CTask &task, size_t count);
|
||||
static std::vector<size_t> FlatShapeByAxis(const std::vector<size_t> &shape, int axis);
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -19,84 +19,290 @@ from mindspore import Tensor
|
|||
from mindspore.ops import operations as P
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
|
||||
class Concat_Axis0(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Concat_Axis0, self).__init__()
|
||||
self.cat = P.Concat(axis=0)
|
||||
class ConcatV10(nn.Cell):
|
||||
def __init__(self, nptype):
|
||||
super(ConcatV10, self).__init__()
|
||||
|
||||
def construct(self, x1, x2):
|
||||
return self.cat((x1, x2))
|
||||
self.cat = P.Concat(axis=2)
|
||||
self.x1 = Tensor(np.array([[[0., 0., 1.],
|
||||
[1., 2., 3.]],
|
||||
[[2., 4., 5.],
|
||||
[3., 6., 7.]]]).astype(nptype))
|
||||
|
||||
def construct(self):
|
||||
return self.cat((self.x1,))
|
||||
|
||||
|
||||
def axis10(nptype):
|
||||
cat = ConcatV10(nptype)
|
||||
output = cat()
|
||||
expect = np.array([[[0., 0., 1.],
|
||||
[1., 2., 3.]],
|
||||
[[2., 4., 5.],
|
||||
[3., 6., 7.]]]).astype(nptype)
|
||||
print(output)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_in2_axis0():
|
||||
x1 = Tensor(np.arange(2 * 2 * 2).reshape(2, 2, 2), mstype.float32)
|
||||
x2 = Tensor(np.arange(3 * 2 * 2).reshape(3, 2, 2), mstype.float32)
|
||||
cat = Concat_Axis0()
|
||||
output_ms = cat(x1, x2)
|
||||
print("output:\n", output_ms)
|
||||
output_np = np.concatenate((x1.asnumpy(), x2.asnumpy()), axis=0)
|
||||
def test_axis10_float32():
|
||||
axis10(np.float32)
|
||||
|
||||
error = np.ones(shape=output_np.shape) * 10e-6
|
||||
diff = output_ms.asnumpy() - output_np
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_axis10_int32():
|
||||
axis10(np.int32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_axis10_bool():
|
||||
axis10(np.bool)
|
||||
|
||||
class ConcatV32(nn.Cell):
|
||||
def __init__(self, nptype):
|
||||
super(ConcatV32, self).__init__()
|
||||
|
||||
self.cat = P.Concat(axis=2)
|
||||
self.x1 = Tensor(np.arange(2 * 2 * 1).reshape(2, 2, 1).astype(nptype))
|
||||
self.x2 = Tensor(np.arange(2 * 2 * 2).reshape(2, 2, 2).astype(nptype))
|
||||
|
||||
def construct(self):
|
||||
return self.cat((self.x1, self.x2))
|
||||
|
||||
|
||||
def axis32(nptype):
|
||||
cat = ConcatV32(nptype)
|
||||
output = cat()
|
||||
expect = np.array([[[0., 0., 1.],
|
||||
[1., 2., 3.]],
|
||||
[[2., 4., 5.],
|
||||
[3., 6., 7.]]]).astype(nptype)
|
||||
print(output)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_axis32_float32():
|
||||
axis32(np.float32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_axis32_int32():
|
||||
axis32(np.int32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_axis32_bool():
|
||||
axis32(np.bool)
|
||||
|
||||
|
||||
class ConcatV43(nn.Cell):
|
||||
def __init__(self, nptype):
|
||||
super(ConcatV43, self).__init__()
|
||||
|
||||
self.cat = P.Concat(axis=3)
|
||||
self.x1 = Tensor(np.arange(2 * 2 * 2 * 2).reshape(2, 2, 2, 2).astype(nptype))
|
||||
self.x2 = Tensor(np.arange(2 * 2 * 2 * 3).reshape(2, 2, 2, 3).astype(nptype))
|
||||
|
||||
def construct(self):
|
||||
return self.cat((self.x1, self.x2))
|
||||
|
||||
|
||||
def axis43(nptype):
|
||||
cat = ConcatV43(nptype)
|
||||
output = cat()
|
||||
expect = np.array([[[[0., 1., 0., 1., 2.],
|
||||
[2., 3., 3., 4., 5.]],
|
||||
[[4., 5., 6., 7., 8.],
|
||||
[6., 7., 9., 10., 11.]]],
|
||||
[[[8., 9., 12., 13., 14.],
|
||||
[10., 11., 15., 16., 17.]],
|
||||
[[12., 13., 18., 19., 20.],
|
||||
[14., 15., 21., 22., 23.]]]]).astype(nptype)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
print(output)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_axis43_float32():
|
||||
axis43(np.float32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_axis43_int32():
|
||||
axis43(np.int32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_axis43_bool():
|
||||
axis43(np.bool)
|
||||
|
||||
|
||||
class ConcatV21(nn.Cell):
|
||||
def __init__(self, nptype):
|
||||
super(ConcatV21, self).__init__()
|
||||
|
||||
class Concat_Axis1(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Concat_Axis1, self).__init__()
|
||||
self.cat = P.Concat(axis=1)
|
||||
self.x1 = Tensor(np.arange(2 * 2).reshape(2, 2).astype(nptype))
|
||||
self.x2 = Tensor(np.arange(2 * 3).reshape(2, 3).astype(nptype))
|
||||
|
||||
def construct(self, x1, x2):
|
||||
return self.cat((x1, x2))
|
||||
def construct(self):
|
||||
return self.cat((self.x1, self.x2))
|
||||
|
||||
|
||||
def axis21(nptype):
|
||||
cat = ConcatV21(nptype)
|
||||
output = cat()
|
||||
expect = np.array([[0., 1., 0., 1., 2.],
|
||||
[2., 3., 3., 4., 5.]]).astype(nptype)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
print(output)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_in2_axis1():
|
||||
x1 = Tensor(np.arange(2 * 2 * 2).reshape(2, 2, 2), mstype.float32)
|
||||
x2 = Tensor(np.arange(2 * 3 * 2).reshape(2, 3, 2), mstype.float32)
|
||||
cat = Concat_Axis1()
|
||||
output_ms = cat(x1, x2)
|
||||
print("output:\n", output_ms)
|
||||
output_np = np.concatenate((x1.asnumpy(), x2.asnumpy()), axis=1)
|
||||
def test_axis21_float32():
|
||||
axis21(np.float32)
|
||||
|
||||
error = np.ones(shape=output_np.shape) * 10e-6
|
||||
diff = output_ms.asnumpy() - output_np
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_axis21_int32():
|
||||
axis21(np.int32)
|
||||
|
||||
class Concat_in3_Axis2(nn.Cell):
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_axis21_bool():
|
||||
axis21(np.bool)
|
||||
|
||||
|
||||
class Concat3INet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Concat_in3_Axis2, self).__init__()
|
||||
self.cat = P.Concat(axis=-1)
|
||||
super(Concat3INet, self).__init__()
|
||||
self.cat = P.Concat(axis=1)
|
||||
|
||||
def construct(self, x1, x2, x3):
|
||||
return self.cat((x1, x2, x3))
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_in3_axis2():
|
||||
x1 = Tensor(np.arange(2 * 2 * 1).reshape(2, 2, 1), mstype.float32)
|
||||
x2 = Tensor(np.arange(2 * 2 * 2).reshape(2, 2, 2), mstype.float32)
|
||||
x3 = Tensor(np.arange(2 * 2 * 3).reshape(2, 2, 3), mstype.float32)
|
||||
cat = Concat_in3_Axis2()
|
||||
output_ms = cat(x1, x2, x3)
|
||||
print("output:\n", output_ms)
|
||||
output_np = np.concatenate((x1.asnumpy(), x2.asnumpy(), x3.asnumpy()), axis=-1)
|
||||
|
||||
def concat_3i(nptype):
|
||||
cat = Concat3INet()
|
||||
|
||||
x1_np = np.random.randn(32, 4, 224, 224).astype(nptype)
|
||||
x2_np = np.random.randn(32, 8, 224, 224).astype(nptype)
|
||||
x3_np = np.random.randn(32, 10, 224, 224).astype(nptype)
|
||||
output_np = np.concatenate((x1_np, x2_np, x3_np), axis=1)
|
||||
|
||||
x1_ms = Tensor(x1_np)
|
||||
x2_ms = Tensor(x2_np)
|
||||
x3_ms = Tensor(x3_np)
|
||||
output_ms = cat(x1_ms, x2_ms, x3_ms)
|
||||
|
||||
error = np.ones(shape=output_np.shape) * 10e-6
|
||||
diff = output_ms.asnumpy() - output_np
|
||||
assert np.all(diff < error)
|
||||
assert np.all(-diff < error)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_in2_axis0()
|
||||
test_in2_axis1()
|
||||
test_in3_axis2()
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_concat_3i_float32():
|
||||
concat_3i(np.float32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_concat_3i_int32():
|
||||
concat_3i(np.int32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_concat_3i_bool():
|
||||
cat = Concat3INet()
|
||||
|
||||
x1_np = np.random.choice([True, False], (32, 4, 224, 224)).astype(np.bool)
|
||||
x2_np = np.random.choice([True, False], (32, 8, 224, 224)).astype(np.bool)
|
||||
x3_np = np.random.choice([True, False], (32, 10, 224, 224)).astype(np.bool)
|
||||
output_np = np.concatenate((x1_np, x2_np, x3_np), axis=1)
|
||||
|
||||
x1_ms = Tensor(x1_np)
|
||||
x2_ms = Tensor(x2_np)
|
||||
x3_ms = Tensor(x3_np)
|
||||
output_ms = cat(x1_ms, x2_ms, x3_ms)
|
||||
|
||||
assert (output_ms.asnumpy() == output_np).all()
|
||||
|
||||
|
||||
class Concat4INet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Concat4INet, self).__init__()
|
||||
self.cat = P.Concat(axis=1)
|
||||
|
||||
def construct(self, x1, x2, x3, x4):
|
||||
return self.cat((x1, x2, x3, x4))
|
||||
|
||||
|
||||
def concat_4i(nptype):
|
||||
cat = Concat4INet()
|
||||
|
||||
x1_np = np.random.randn(32, 4, 224, 224).astype(nptype)
|
||||
x2_np = np.random.randn(32, 8, 224, 224).astype(nptype)
|
||||
x3_np = np.random.randn(32, 10, 224, 224).astype(nptype)
|
||||
x4_np = np.random.randn(32, 5, 224, 224).astype(nptype)
|
||||
output_np = np.concatenate((x1_np, x2_np, x3_np, x4_np), axis=1)
|
||||
|
||||
x1_ms = Tensor(x1_np)
|
||||
x2_ms = Tensor(x2_np)
|
||||
x3_ms = Tensor(x3_np)
|
||||
x4_ms = Tensor(x4_np)
|
||||
output_ms = cat(x1_ms, x2_ms, x3_ms, x4_ms)
|
||||
|
||||
error = np.ones(shape=output_np.shape) * 10e-6
|
||||
diff = output_ms.asnumpy() - output_np
|
||||
assert np.all(diff < error)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_concat_4i_float32():
|
||||
concat_4i(np.float32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_concat_4i_int32():
|
||||
concat_4i(np.int32)
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_concat_4i_bool():
|
||||
cat = Concat4INet()
|
||||
|
||||
x1_np = np.random.choice([True, False], (32, 4, 224, 224)).astype(np.bool)
|
||||
x2_np = np.random.choice([True, False], (32, 8, 224, 224)).astype(np.bool)
|
||||
x3_np = np.random.choice([True, False], (32, 10, 224, 224)).astype(np.bool)
|
||||
x4_np = np.random.choice([True, False], (32, 5, 224, 224)).astype(np.bool)
|
||||
output_np = np.concatenate((x1_np, x2_np, x3_np, x4_np), axis=1)
|
||||
|
||||
x1_ms = Tensor(x1_np)
|
||||
x2_ms = Tensor(x2_np)
|
||||
x3_ms = Tensor(x3_np)
|
||||
x4_ms = Tensor(x4_np)
|
||||
output_ms = cat(x1_ms, x2_ms, x3_ms, x4_ms)
|
||||
|
||||
assert (output_ms.asnumpy() == output_np).all()
|
||||
|
|
Loading…
Reference in New Issue