add ops :depthtospace,spacetodepth

Signed-off-by: kanghui <405194527@qq.com>
This commit is contained in:
kanghui 2021-05-14 05:18:07 +00:00
parent 6801ef61e0
commit cb8e1fa3f1
6 changed files with 573 additions and 0 deletions

View File

@ -0,0 +1,89 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/depthtospace_cpu_kernel.h"
#include <vector>
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
template <typename T>
void DepthToSpaceCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
CheckParam(kernel_node);
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
block_size_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "block_size");
}
template <typename T>
bool DepthToSpaceCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
auto input_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
size_t size = IntToSize(inputs[0]->size / sizeof(T));
std::vector<size_t> input_shape = input_shape_;
std::vector<size_t> output_shape = output_shape_;
size_t block_size = block_size_;
size_t input_dimension = input_shape.size();
size_t output_strides[3] = {1, 1, 1};
for (size_t i = input_dimension - 1; i >= 1; --i) {
for (size_t j = 0; j < i; ++j) {
output_strides[j] *= output_shape[i];
}
}
auto task = [&, input_addr, output_addr](size_t start, size_t end) {
std::vector<size_t> output_pos_array(input_dimension, 0);
for (size_t i = start; i < end; ++i) {
size_t tmp_pos = i;
for (size_t j = 0; j < input_dimension - 1; ++j) {
output_pos_array[j] = tmp_pos / output_strides[j];
tmp_pos %= output_strides[j];
}
output_pos_array.back() = tmp_pos;
size_t input_pos = output_pos_array[0];
input_pos =
(input_pos * input_shape[1]) +
(output_pos_array[1] +
(block_size * (output_pos_array[2] % block_size) + output_pos_array[3] % block_size) * output_shape[1]);
input_pos = (input_pos * input_shape[2]) + (output_pos_array[2] / block_size);
input_pos = (input_pos * input_shape[3]) + (output_pos_array[3] / block_size);
output_addr[i] = input_addr[input_pos];
}
};
CPUKernelUtils::ParallelFor(task, size);
return true;
}
template <typename T>
void DepthToSpaceCPUKernel<T>::CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but DepthToSpaceCPUKerrnel needs 1 input.";
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but DepthToSpaceCPUKernel needs 1 output.";
}
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,85 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DEPTHTOSPACE_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DEPTHTOSPACE_CPU_KERNEL_H_
#include <memory>
#include <string>
#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class DepthToSpaceCPUKernel : public CPUKernel {
public:
DepthToSpaceCPUKernel() = default;
~DepthToSpaceCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
void CheckParam(const CNodePtr &kernel_node);
std::vector<size_t> input_shape_;
std::vector<size_t> output_shape_;
size_t block_size_;
};
MS_REG_CPU_KERNEL_T(
DepthToSpace, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
DepthToSpaceCPUKernel, float);
MS_REG_CPU_KERNEL_T(
DepthToSpace, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
DepthToSpaceCPUKernel, float16);
MS_REG_CPU_KERNEL_T(DepthToSpace,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
DepthToSpaceCPUKernel, int8_t);
MS_REG_CPU_KERNEL_T(DepthToSpace,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
DepthToSpaceCPUKernel, int16_t);
MS_REG_CPU_KERNEL_T(DepthToSpace,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
DepthToSpaceCPUKernel, int);
MS_REG_CPU_KERNEL_T(DepthToSpace,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
DepthToSpaceCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(DepthToSpace,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
DepthToSpaceCPUKernel, uint8_t);
MS_REG_CPU_KERNEL_T(DepthToSpace,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
DepthToSpaceCPUKernel, uint16_t);
MS_REG_CPU_KERNEL_T(DepthToSpace,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
DepthToSpaceCPUKernel, uint32_t);
MS_REG_CPU_KERNEL_T(DepthToSpace,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
DepthToSpaceCPUKernel, uint64_t);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_DEPTHTOSPACE_CPU_KERNEL_H_

View File

@ -0,0 +1,91 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/cpu/spacetodepth_cpu_kernel.h"
#include <vector>
#include "runtime/device/cpu/cpu_device_address.h"
namespace mindspore {
namespace kernel {
template <typename T>
void SpaceToDepthCPUKernel<T>::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
CheckParam(kernel_node);
input_shape_ = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
output_shape_ = AnfAlgo::GetOutputDeviceShape(kernel_node, 0);
block_size_ = AnfAlgo::GetNodeAttr<int64_t>(kernel_node, "block_size");
}
template <typename T>
bool SpaceToDepthCPUKernel<T>::Launch(const std::vector<kernel::AddressPtr> &inputs,
const std::vector<kernel::AddressPtr> & /*workspace*/,
const std::vector<kernel::AddressPtr> &outputs) {
auto input_addr = reinterpret_cast<T *>(inputs[0]->addr);
auto output_addr = reinterpret_cast<T *>(outputs[0]->addr);
size_t size = IntToSize(inputs[0]->size / sizeof(T));
std::vector<size_t> input_shape = input_shape_;
std::vector<size_t> output_shape = output_shape_;
size_t block_size = block_size_;
size_t input_dimension = input_shape.size();
size_t input_strides[3] = {1, 1, 1};
for (size_t i = input_dimension - 1; i >= 1; --i) {
for (size_t j = 0; j < i; ++j) {
input_strides[j] *= input_shape[i];
}
}
auto task = [&, input_addr, output_addr](size_t start, size_t end) {
std::vector<size_t> input_pos_array(input_dimension, 0);
for (size_t i = start; i < end; ++i) {
size_t tmp_pos = i;
for (size_t j = 0; j < input_dimension - 1; ++j) {
input_pos_array[j] = tmp_pos / input_strides[j];
tmp_pos %= input_strides[j];
}
input_pos_array.back() = tmp_pos;
size_t output_pos = input_pos_array[0];
output_pos =
(output_pos * output_shape[1]) +
(input_pos_array[1] +
(block_size * (input_pos_array[2] % block_size) + input_pos_array[3] % block_size) * input_shape[1]);
output_pos = (output_pos * output_shape[2]) + (input_pos_array[2] / block_size);
output_pos = (output_pos * output_shape[3]) + (input_pos_array[3] / block_size);
output_addr[output_pos] = input_addr[i];
}
};
CPUKernelUtils::ParallelFor(task, size);
return true;
}
template <typename T>
void SpaceToDepthCPUKernel<T>::CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(EXCEPTION) << "Input number is " << input_num << ", but DepthToSpaceCPUKerrnel needs 1 input.";
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "Output number is " << output_num << ", but DepthToSpaceCPUKernel needs 1 output.";
}
}
} // namespace kernel
} // namespace mindspore

View File

@ -0,0 +1,84 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPACETODEPTH_CPU_KERNEL_H_
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPACETODEPTH_CPU_KERNEL_H_
#include <string>
#include <vector>
#include "backend/kernel_compiler/cpu/cpu_kernel.h"
#include "backend/kernel_compiler/cpu/cpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
template <typename T>
class SpaceToDepthCPUKernel : public CPUKernel {
public:
SpaceToDepthCPUKernel() = default;
~SpaceToDepthCPUKernel() override = default;
void InitKernel(const CNodePtr &kernel_node) override;
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
private:
void CheckParam(const CNodePtr &kernel_node);
std::vector<size_t> input_shape_;
std::vector<size_t> output_shape_;
size_t block_size_;
};
MS_REG_CPU_KERNEL_T(
SpaceToDepth, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
SpaceToDepthCPUKernel, float);
MS_REG_CPU_KERNEL_T(
SpaceToDepth, KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
SpaceToDepthCPUKernel, float16);
MS_REG_CPU_KERNEL_T(SpaceToDepth,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
SpaceToDepthCPUKernel, int8_t);
MS_REG_CPU_KERNEL_T(SpaceToDepth,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
SpaceToDepthCPUKernel, int16_t);
MS_REG_CPU_KERNEL_T(SpaceToDepth,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
SpaceToDepthCPUKernel, int);
MS_REG_CPU_KERNEL_T(SpaceToDepth,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
SpaceToDepthCPUKernel, int64_t);
MS_REG_CPU_KERNEL_T(SpaceToDepth,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
SpaceToDepthCPUKernel, uint8_t);
MS_REG_CPU_KERNEL_T(SpaceToDepth,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
SpaceToDepthCPUKernel, uint16_t);
MS_REG_CPU_KERNEL_T(SpaceToDepth,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
SpaceToDepthCPUKernel, uint32_t);
MS_REG_CPU_KERNEL_T(SpaceToDepth,
KernelAttr().SetAllSameAttr(true).AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
SpaceToDepthCPUKernel, uint64_t);
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPACETODEPTH_CPU_KERNEL_H_

View File

@ -0,0 +1,115 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops.operations.array_ops as P
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
class DepthToSpaceNet(nn.Cell):
def __init__(self, nptype, block_size=2, input_shape=(1, 12, 1, 1)):
super(DepthToSpaceNet, self).__init__()
self.DepthToSpace = P.DepthToSpace(block_size)
input_size = 1
for i in input_shape:
input_size = input_size*i
data_np = np.arange(input_size).reshape(input_shape).astype(nptype)
self.x1 = Parameter(initializer(Tensor(data_np), input_shape), name='x1')
@ms_function
def construct(self):
y1 = self.DepthToSpace(self.x1)
return y1
def DepthToSpace(nptype, block_size=2, input_shape=(1, 12, 1, 1)):
input_size = 1
for i in input_shape:
input_size = input_size*i
expect = np.array([[[[0, 3],
[6, 9]],
[[1, 4],
[7, 10]],
[[2, 5],
[8, 11]]]]).astype(nptype)
dts = DepthToSpaceNet(nptype, block_size, input_shape)
output = dts()
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_depthtospace_graph_float32():
DepthToSpace(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_depthtospace_graph_float16():
DepthToSpace(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_depthtospace_graph_int32():
DepthToSpace(np.int32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_depthtospace_graph_int64():
DepthToSpace(np.int64)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_depthtospace_graph_int8():
DepthToSpace(np.int8)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_depthtospace_graph_int16():
DepthToSpace(np.int16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_depthtospace_graph_uint8():
DepthToSpace(np.uint8)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_depthtospace_graph_uint16():
DepthToSpace(np.uint16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_depthtospace_graph_uint32():
DepthToSpace(np.uint32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_depthtospace_graph_uint64():
DepthToSpace(np.uint64)

View File

@ -0,0 +1,109 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
import mindspore.ops.operations.array_ops as P
from mindspore import Tensor
from mindspore.common.api import ms_function
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
class SpaceToDepthNet(nn.Cell):
def __init__(self, nptype):
super(SpaceToDepthNet, self).__init__()
self.SpaceToDepth = P.SpaceToDepth(2)
data_np = np.array([[[[0, 3],
[6, 9]],
[[1, 4],
[7, 10]],
[[2, 5],
[8, 11]]]]).astype(nptype)
self.data_np = data_np
self.x = Parameter(initializer(Tensor(self.data_np), (1, 3, 2, 2)), name='x')
@ms_function
def construct(self):
return self.SpaceToDepth(self.x)
def SpaceToDepth(nptype):
expect = np.arange(12).reshape((1, 12, 1, 1)).astype(nptype)
std = SpaceToDepthNet(nptype)
output = std()
assert (output.asnumpy() == expect).all()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_spacetodepth_graph_float32():
SpaceToDepth(np.float32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_spacetodepth_graph_float16():
SpaceToDepth(np.float16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_spacetodepth_graph_int32():
SpaceToDepth(np.int32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_spacetodepth_graph_int64():
SpaceToDepth(np.int64)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_spacetodepth_graph_int8():
SpaceToDepth(np.int8)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_spacetodepth_graph_int16():
SpaceToDepth(np.int16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_spacetodepth_graph_uint8():
SpaceToDepth(np.uint8)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_spacetodepth_graph_uint16():
SpaceToDepth(np.uint16)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_spacetodepth_graph_uint32():
SpaceToDepth(np.uint32)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_spacetodepth_graph_uint64():
SpaceToDepth(np.uint64)