add space to batch nd cpu kernel
udpate include files update pylint fix doc problems fix code check problem fix code check problem fix docs udpate ops func
This commit is contained in:
parent
dd79e3cc13
commit
87a7fb3d4d
|
@ -252,6 +252,7 @@ Array操作
|
|||
mindspore.ops.shape
|
||||
mindspore.ops.size
|
||||
mindspore.ops.tensor_scatter_add
|
||||
mindspore.ops.space_to_batch_nd
|
||||
mindspore.ops.tile
|
||||
mindspore.ops.transpose
|
||||
mindspore.ops.unique
|
||||
|
|
|
@ -571,6 +571,29 @@ mindspore.Tensor
|
|||
|
||||
- **ValueError** - `side` 或 `sorter` 的参数无效。
|
||||
|
||||
.. py:method:: space_to_batch_nd(block_shape, paddings)
|
||||
|
||||
将空间维度划分为对应大小的块,然后在批次维度重排张量。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **block_shape** (list[int], tuple[int], int) - 块形状描述空间维度为分割的个数。
|
||||
- **paddings** (tuple, list) - 空间维度的填充大小。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,经过划分排列之后的结果。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - 如果 `block_shape` 不是 list, tuple 或者 int。
|
||||
- **TypeError** - 如果 `paddings` 不是 list 或者 tuple。
|
||||
- **ValueError** - 如果当 `block_shape` 为 list 或 tuple, `block_shape` 不是一维。
|
||||
- **ValueError** - 如果 Ascend 平台上 `block_shape` 长度不是2。
|
||||
- **ValueError** - 如果 `paddings` 的形状不是 (2, M), 其中 M 为 `block_shape` 的长度。
|
||||
- **ValueError** - 如果 `block_shape` 的元素不是大于一的整数。
|
||||
- **ValueError** - 如果 `paddings` 的元素不是非负的整数。
|
||||
|
||||
.. py:method:: shape
|
||||
:property:
|
||||
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
mindspore.ops.SpaceToBatchND
|
||||
=============================
|
||||
|
||||
.. py:class:: mindspore.ops.SpaceToBatchND(block_shape, paddings)
|
||||
|
||||
将空间维度划分为对应大小的块,并在批次维度重排张量。
|
||||
|
||||
此操作将输入的空间维度(Space) [1, ..., M] 划分为形状为 `block_shape` 的块网格,
|
||||
并将这些块在批次维度 (Batch,默认是第0维) 中交错排列。
|
||||
如此,输出在空间维度上的截面是输入在对应空间维度上截面的一个网格,
|
||||
而输出的批次维度的大小为输入的批次维度的大小乘以空间维度分解成块网格的数量。
|
||||
在划分成块之前,输入的空间维度会根据 `paddings` 填充零。
|
||||
|
||||
假设输入的形状为 :math:`(n, c_1, ... c_k, w_1, ..., w_M)`,
|
||||
那么输出的形状为 :math:`(n', c_1, ... c_k, w'_1, ..., w'_M)`,
|
||||
其中
|
||||
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
n' = n*(block\_shape[0]*...*block\_shape[M]) \\
|
||||
w'_i = (w_i+paddings[i][0]+paddings[i][1])//block\_shape[i]
|
||||
\end{array}
|
||||
|
||||
.. note::
|
||||
|
||||
Ascend只支持四维的输入。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **block_shape** (list[int], tuple[int], int) - 块形状描述空间维度为分割的个数。如果 `block_shape` 为list或者tuple,其长度 `M` 为空间维度的长度。如果 `block_shape` 为整数,那么所有空间维度分割的个数均为 `block_shape` 。在Ascend后端 `M` 必须为2。
|
||||
- **paddings** (tuple, list) - 空间维度的填充大小。
|
||||
|
||||
**输入:**
|
||||
|
||||
- **x** (Tensor) - SpaceToBatchND 的输入,Ascend平台必须为四维。
|
||||
|
||||
**输出:**
|
||||
|
||||
- **y** (Tensor) - Tensor,经过划分排列之后的结果。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - 如果 `block_shape` 不是 list, tuple 或者 int。
|
||||
- **TypeError** - 如果 `paddings` 不是 list 或者 tuple。
|
||||
- **ValueError** - 如果当 `block_shape` 为 list 或 tuple, `block_shape` 不是一维。
|
||||
- **ValueError** - 如果 Ascend 平台上 `block_shape` 长度不是2。
|
||||
- **ValueError** - 如果 `paddings` 的形状不是 (2, M), 其中 M 为 `block_shape` 的长度。
|
||||
- **ValueError** - 如果 `block_shape` 的元素不是大于一的整数。
|
||||
- **ValueError** - 如果 `paddings` 的元素不是非负的整数。
|
|
@ -0,0 +1,42 @@
|
|||
mindspore.ops.space_to_batch_nd
|
||||
================================
|
||||
|
||||
.. py:function:: mindspore.ops.space_to_batch_nd(input_x, block_shape, paddings)
|
||||
|
||||
将空间维度划分为对应大小的块,然后在批次维度重排张量。
|
||||
|
||||
此函数将输入的空间维度(Space) `[1, ..., M]` 划分为形状为 `block_shape` 的块网格,并将这些块在批次维度上 (Batch,默认是第0维) 中交错排列。
|
||||
输出的张量在空间维度上的截面是输入在对应空间维度上截面的一个网格,而输出的批次维度的大小为空间维度分解成块网格的数量乘以输入的批次维度的大小。
|
||||
在划分成块之前,输入的空间维度会根据 `paddings` 填充零。
|
||||
如此,假设输入的形状为 :math:`(n, c_1, ... c_k, w_1, ..., w_M)`,则输出的形状为 :math:`(n', c_1, ... c_k, w'_1, ..., w'_M)`,
|
||||
其中
|
||||
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
n' = n*(block\_shape[0] * ... * block\_shape[M]) \\
|
||||
w'_i = (w_i + paddings[i][0] + paddings[i][1])//block\_shape[i]
|
||||
\end{array}
|
||||
|
||||
.. note::
|
||||
|
||||
Ascend只支持四维张量的输入。
|
||||
|
||||
**参数:**
|
||||
|
||||
- **input_x** (Tensor) - 输入张量,Ascend平台必须为四维。
|
||||
- **block_shape** (list[int], tuple[int], int) - 块形状描述空间维度为分割的个数。如果 `block_shape` 为list或者tuple,其长度 `M` 为空间维度的长度。如果 `block_shape` 为整数,那么所有空间维度分割的个数均为 `block_shape` 。在Ascend后端 `M` 必须为2。
|
||||
- **paddings** (tuple, list) - 空间维度的填充大小。
|
||||
|
||||
**返回:**
|
||||
|
||||
Tensor,经过划分排列之后的结果。
|
||||
|
||||
**异常:**
|
||||
|
||||
- **TypeError** - 如果 `block_shape` 不是 list, tuple 或者 int。
|
||||
- **TypeError** - 如果 `paddings` 不是 list 或者 tuple。
|
||||
- **ValueError** - 如果当 `block_shape` 为 list 或 tuple, `block_shape` 不是一维。
|
||||
- **ValueError** - 如果 Ascend 平台上 `block_shape` 长度不是2。
|
||||
- **ValueError** - 如果 `paddings` 的形状不是 (2, M), 其中 M 为 `block_shape` 的长度。
|
||||
- **ValueError** - 如果 `block_shape` 的元素不是大于一的整数。
|
||||
- **ValueError** - 如果 `paddings` 的元素不是非负的整数。
|
|
@ -253,6 +253,7 @@ Array Operation
|
|||
mindspore.ops.size
|
||||
mindspore.ops.tensor_scatter_add
|
||||
mindspore.ops.tensor_scatter_div
|
||||
mindspore.ops.space_to_batch_nd
|
||||
mindspore.ops.tile
|
||||
mindspore.ops.transpose
|
||||
mindspore.ops.unique
|
||||
|
@ -308,6 +309,7 @@ Parameter Operation Oprators
|
|||
mindspore.ops.assign
|
||||
mindspore.ops.assign_add
|
||||
mindspore.ops.assign_sub
|
||||
mindspore.ops.index_add
|
||||
|
||||
.. list-table::
|
||||
:widths: 50 50
|
||||
|
|
|
@ -0,0 +1,199 @@
|
|||
/**
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "mindspore/core/ops/space_to_batch_nd.h"
|
||||
#include "plugin/device/cpu/kernel/space_to_batch_nd_cpu_kernel.h"
|
||||
#include "plugin/device/cpu/hal/device/cpu_device_address.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t PADDING_SHAPE_1 = 2;
|
||||
constexpr size_t kSpaceToBatchNDInputsNum = 1;
|
||||
constexpr size_t kSpaceToBatchNDOutputsNum = 1;
|
||||
constexpr char kKernelName[] = "SpaceToBatchND";
|
||||
|
||||
void SpaceToBatchNDCpuKernelMod::CheckParam() {
|
||||
for (size_t i = 0; i < block_rank_; i++) {
|
||||
if (block_size_[i] < 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the elements of 'block_size' should be both larger than 1, but got " << i
|
||||
<< "'th block size " << block_size_[i] << ")\n";
|
||||
}
|
||||
}
|
||||
|
||||
// check paddings_
|
||||
if (paddings_.size() != block_rank_) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the size of 'paddings' should be equal to the length of 'block_size': " << block_rank_
|
||||
<< ", but got " << paddings_.size();
|
||||
}
|
||||
|
||||
for (size_t idx_i = 0; idx_i < block_rank_; ++idx_i) {
|
||||
if (paddings_[idx_i].size() != PADDING_SHAPE_1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the size of each vector of 'paddings' should be equal to the length of 'block_size': "
|
||||
<< PADDING_SHAPE_1 << ", but got " << idx_i << "'th element: " << paddings_[idx_i].size();
|
||||
}
|
||||
for (size_t idx_j = 0; idx_j < PADDING_SHAPE_1; ++idx_j) {
|
||||
if (paddings_[idx_i][idx_j] < 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the element of 'paddings' cannot be less than 0, "
|
||||
<< "but got paddings[" << idx_i << "][ " << idx_j << "]: " << paddings_[idx_i][idx_j];
|
||||
}
|
||||
}
|
||||
auto tmp_shape = input_shape_[idx_i + off_set_] + paddings_[idx_i][0] + paddings_[idx_i][1];
|
||||
if ((tmp_shape % block_size_[idx_i]) != 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', padded shape should be divisible by block_size, , but got padded shape: " << tmp_shape
|
||||
<< ", block_size: " << block_size_[idx_i];
|
||||
}
|
||||
if ((tmp_shape / block_size_[idx_i]) == 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', padded shape cannot be less than block_size"
|
||||
<< ", but got padded shape: " << tmp_shape << ", block_size: " << block_size_[idx_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool SpaceToBatchNDCpuKernelMod::LaunchKernel(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &outputs) {
|
||||
// check all shapes, blocks and paddings are valid
|
||||
CheckParam();
|
||||
|
||||
const auto *input = reinterpret_cast<T *>(inputs[0]->addr);
|
||||
auto *output = reinterpret_cast<T *>(outputs[0]->addr);
|
||||
int ret = memset_s(output, outputs[0]->size, 0, sizeof(T) * output_size_);
|
||||
if (ret != 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', memset_s error. Error no: " << ret;
|
||||
}
|
||||
|
||||
for (int64_t pos = 0; pos < input_size_; pos += 1) {
|
||||
std::vector<int64_t> input_index(input_shape_.size(), 0);
|
||||
int64_t cur_pos = pos;
|
||||
for (int rev_i = input_shape_.size() - 1; rev_i >= 0; rev_i -= 1) {
|
||||
input_index[rev_i] = cur_pos % input_shape_[rev_i];
|
||||
cur_pos = cur_pos / input_shape_[rev_i];
|
||||
}
|
||||
|
||||
std::vector<int64_t> output_index(input_index);
|
||||
int64_t idx_on = 0;
|
||||
for (size_t i = off_set_; i < input_shape_.size(); i += 1) {
|
||||
output_index[i] = (input_index[i] + paddings_[i - off_set_][0]) / block_size_[i - off_set_];
|
||||
idx_on =
|
||||
idx_on * block_size_[i - off_set_] + (input_index[i] + paddings_[i - off_set_][0]) % block_size_[i - off_set_];
|
||||
}
|
||||
|
||||
output_index[0] = idx_on * input_shape_[0] + input_index[0];
|
||||
|
||||
int64_t out_pos = 0;
|
||||
|
||||
for (size_t i = 0; i < output_shape_.size(); i += 1) {
|
||||
out_pos = out_pos * output_shape_[i] + output_index[i];
|
||||
}
|
||||
|
||||
output[out_pos] = input[pos];
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SpaceToBatchNDCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::SpaceToBatchND>(base_operator);
|
||||
if (!kernel_ptr) {
|
||||
MS_LOG(ERROR) << "cast SpaceToBatchND ops failed!";
|
||||
return false;
|
||||
}
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
|
||||
if (inputs.size() != kSpaceToBatchNDInputsNum || outputs.size() != kSpaceToBatchNDOutputsNum) {
|
||||
MS_LOG(ERROR) << kernel_name_ << ": input and output size should be " << kSpaceToBatchNDInputsNum << " and "
|
||||
<< kSpaceToBatchNDOutputsNum << ", but get " << inputs.size() << " and " << outputs.size();
|
||||
return false;
|
||||
}
|
||||
|
||||
block_size_ = kernel_ptr->get_block_shape();
|
||||
paddings_ = kernel_ptr->get_paddings();
|
||||
block_rank_ = block_size_.size();
|
||||
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, SpaceToBatchNDFunc> &pair) { return pair.first; });
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, support_list);
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "Maximum does not support this kernel data type: " << kernel_attr;
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SpaceToBatchNDCpuKernelMod::Resize(const BaseOperatorPtr &base_operator,
|
||||
const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &others) {
|
||||
if (!NativeCpuKernelMod::Resize(base_operator, inputs, outputs, others)) {
|
||||
MS_LOG(WARNING) << kernel_name_ << " reinit failed.";
|
||||
return false;
|
||||
}
|
||||
// get input_shape
|
||||
input_shape_ = inputs.at(kIndex0)->GetShapeVector();
|
||||
output_shape_ = outputs.at(kIndex0)->GetShapeVector();
|
||||
|
||||
input_size_ = 1;
|
||||
output_size_ = 1;
|
||||
for (size_t i = 0; i < input_shape_.size(); ++i) {
|
||||
input_size_ = input_shape_[i] * input_size_;
|
||||
}
|
||||
for (size_t i = 0; i < output_shape_.size(); ++i) {
|
||||
output_size_ = output_shape_[i] * output_size_;
|
||||
}
|
||||
|
||||
off_set_ = input_shape_.size() - block_size_.size();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, SpaceToBatchNDCpuKernelMod::SpaceToBatchNDFunc>>
|
||||
SpaceToBatchNDCpuKernelMod::func_list_ = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
&SpaceToBatchNDCpuKernelMod::LaunchKernel<int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16),
|
||||
&SpaceToBatchNDCpuKernelMod::LaunchKernel<int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
|
||||
&SpaceToBatchNDCpuKernelMod::LaunchKernel<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64),
|
||||
&SpaceToBatchNDCpuKernelMod::LaunchKernel<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8),
|
||||
&SpaceToBatchNDCpuKernelMod::LaunchKernel<uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
&SpaceToBatchNDCpuKernelMod::LaunchKernel<uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
&SpaceToBatchNDCpuKernelMod::LaunchKernel<uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
&SpaceToBatchNDCpuKernelMod::LaunchKernel<uint64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
&SpaceToBatchNDCpuKernelMod::LaunchKernel<float16>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
&SpaceToBatchNDCpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
&SpaceToBatchNDCpuKernelMod::LaunchKernel<double>}};
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, SpaceToBatchND, SpaceToBatchNDCpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,72 @@
|
|||
/**
|
||||
* Copyright 2019-2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPACE_TO_BATCH_ND_CPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPACE_TO_BATCH_ND_CPU_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <utility>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class SpaceToBatchNDCpuKernelMod : public NativeCpuKernelMod {
|
||||
public:
|
||||
SpaceToBatchNDCpuKernelMod() = default;
|
||||
~SpaceToBatchNDCpuKernelMod() override = default;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
bool Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &others = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override {
|
||||
return kernel_func_(this, inputs, workspace, outputs);
|
||||
}
|
||||
|
||||
private:
|
||||
void CheckParam();
|
||||
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<kernel::AddressPtr> &inputs, const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs);
|
||||
using SpaceToBatchNDFunc =
|
||||
std::function<bool(SpaceToBatchNDCpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &)>;
|
||||
static std::vector<std::pair<KernelAttr, SpaceToBatchNDFunc>> func_list_;
|
||||
SpaceToBatchNDFunc kernel_func_;
|
||||
|
||||
std::vector<std::vector<int64_t>> paddings_;
|
||||
std::vector<int64_t> block_size_;
|
||||
std::vector<int64_t> input_shape_;
|
||||
std::vector<int64_t> output_shape_;
|
||||
size_t block_rank_;
|
||||
size_t off_set_;
|
||||
int64_t input_size_;
|
||||
int64_t output_size_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_CPU_SPACE_TO_BATCH_ND_CPU_KERNEL_H_
|
|
@ -26,6 +26,109 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr size_t PADDING_SHAPE_1 = 2;
|
||||
|
||||
ShapeVector SpaceToBatchNDInferShapeImpl(const string &kernel_name_, const std::vector<int64_t> &block_size_,
|
||||
const std::vector<std::vector<int64_t>> &paddings_,
|
||||
const ShapeVector &input_shape_) {
|
||||
auto block_rank_ = block_size_.size();
|
||||
auto off_set_ = input_shape_.size() - block_size_.size();
|
||||
|
||||
ShapeVector output_shape_ = input_shape_;
|
||||
for (size_t i = 0; i < block_rank_; i++) {
|
||||
if (block_size_[i] < 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the elements of 'block_size' should be both larger than 1, but got " << i
|
||||
<< "'th block size " << block_size_[i] << ")\n";
|
||||
}
|
||||
}
|
||||
|
||||
// check paddings_
|
||||
if (paddings_.size() != block_rank_) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the size of 'paddings' should be equal to the length of 'block_size': " << block_rank_
|
||||
<< ", but got " << paddings_.size();
|
||||
}
|
||||
|
||||
for (size_t idx_i = 0; idx_i < block_rank_; ++idx_i) {
|
||||
if (paddings_[idx_i].size() != PADDING_SHAPE_1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the size of each vector of 'paddings' should be equal to the length of 'block_size': "
|
||||
<< PADDING_SHAPE_1 << ", but got " << idx_i << "'th element: " << paddings_[idx_i].size();
|
||||
}
|
||||
for (size_t idx_j = 0; idx_j < PADDING_SHAPE_1; ++idx_j) {
|
||||
if (paddings_[idx_i][idx_j] < 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the element of 'paddings' cannot be less than 0, "
|
||||
<< "but got paddings[" << idx_i << "][ " << idx_j << "]: " << paddings_[idx_i][idx_j];
|
||||
}
|
||||
}
|
||||
|
||||
// check the paddings and block_sizes are valid
|
||||
auto tmp_shape = input_shape_[idx_i + off_set_] + paddings_[idx_i][0] + paddings_[idx_i][1];
|
||||
if ((tmp_shape % block_size_[idx_i]) != 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', padded shape should be divisible by block_size, , but got padded shape: " << tmp_shape
|
||||
<< ", block_size: " << block_size_[idx_i];
|
||||
}
|
||||
if ((tmp_shape / block_size_[idx_i]) == 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', padded shape cannot be less than block_size"
|
||||
<< ", but got padded shape: " << tmp_shape << ", block_size: " << block_size_[idx_i];
|
||||
}
|
||||
output_shape_[idx_i + off_set_] = tmp_shape / block_size_[idx_i];
|
||||
output_shape_[0] = output_shape_[0] * block_size_[idx_i];
|
||||
}
|
||||
|
||||
return output_shape_;
|
||||
}
|
||||
|
||||
abstract::ShapePtr SpaceToBatchNDInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 1, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto input_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, kInputIndex0);
|
||||
|
||||
auto paddings_value_ptr = primitive->GetAttr(kPaddings);
|
||||
MS_EXCEPTION_IF_NULL(paddings_value_ptr);
|
||||
auto paddings = GetValue<std::vector<std::vector<int64_t>>>(paddings_value_ptr);
|
||||
|
||||
auto block_shapes_value_ptr = primitive->GetAttr(kBlockShape);
|
||||
MS_EXCEPTION_IF_NULL(block_shapes_value_ptr);
|
||||
auto block_shapes = GetValue<std::vector<int64_t>>(block_shapes_value_ptr);
|
||||
|
||||
ShapeVector out_shape = SpaceToBatchNDInferShapeImpl(prim_name, block_shapes, paddings, input_shape_ptr->shape());
|
||||
|
||||
if (!input_shape_ptr->IsDynamic()) {
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
ShapeVector max_out_shape =
|
||||
SpaceToBatchNDInferShapeImpl(prim_name, block_shapes, paddings, input_shape_ptr->max_shape());
|
||||
ShapeVector min_out_shape =
|
||||
SpaceToBatchNDInferShapeImpl(prim_name, block_shapes, paddings, input_shape_ptr->min_shape());
|
||||
|
||||
return std::make_shared<abstract::Shape>(out_shape, min_out_shape, max_out_shape);
|
||||
}
|
||||
|
||||
TypePtr SpaceToBatchNDInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = prim->name();
|
||||
const int64_t input_num = 1;
|
||||
(void)CheckAndConvertUtils::CheckInteger("input number", SizeToLong(input_args.size()), kEqual, input_num, op_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16,
|
||||
kUInt32, kUInt64, kFloat16, kFloat32, kFloat64};
|
||||
auto var_type = input_args[kInputIndex0]->BuildType();
|
||||
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("input type", var_type, valid_types, prim->name());
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void SpaceToBatchND::set_paddings(std::vector<std::vector<int64_t>> paddings) {
|
||||
const int64_t pad_size = 2;
|
||||
(void)CheckAndConvertUtils::CheckInteger(kPaddings, SizeToLong(paddings.size()), kEqual, pad_size, this->name());
|
||||
|
@ -59,12 +162,20 @@ std::vector<int64_t> SpaceToBatchND::get_block_shape() const {
|
|||
return GetValue<std::vector<int64_t>>(GetAttr(kBlockShape));
|
||||
}
|
||||
|
||||
abstract::AbstractBasePtr SpaceToBatchNDInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args) {
|
||||
{
|
||||
return abstract::MakeAbstract(SpaceToBatchNDInferShape(primitive, input_args),
|
||||
SpaceToBatchNDInferType(primitive, input_args));
|
||||
}
|
||||
}
|
||||
|
||||
void SpaceToBatchND::Init(const std::vector<int64_t> block_shape, const std::vector<std::vector<int64_t>> paddings) {
|
||||
this->set_paddings(paddings);
|
||||
this->set_block_shape(block_shape);
|
||||
}
|
||||
|
||||
MIND_API_OPERATOR_IMPL(SpaceToBatchND, BaseOperator);
|
||||
REGISTER_PRIMITIVE_C(kNameSpaceToBatchND, SpaceToBatchND);
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(SpaceToBatchND, prim::kPrimSpaceToBatchND, SpaceToBatchNDInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -1058,10 +1058,10 @@ class Tensor(Tensor_):
|
|||
perm = tuple(range(0, self.ndim))
|
||||
if axis2 + 1 < self.ndim:
|
||||
new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \
|
||||
perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1] + perm[axis2 + 1:]
|
||||
perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1] + perm[axis2 + 1:]
|
||||
else:
|
||||
new_perm = perm[0:axis1] + perm[axis2:axis2 + 1] + \
|
||||
perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1]
|
||||
perm[axis1 + 1:axis2] + perm[axis1:axis1 + 1]
|
||||
|
||||
return tensor_operator_registry.get('transpose')()(self, new_perm)
|
||||
|
||||
|
@ -2199,6 +2199,48 @@ class Tensor(Tensor_):
|
|||
j = tensor_operator_registry.get('select')(mask, mid, j)
|
||||
return j
|
||||
|
||||
def space_to_batch_nd(self, block_shape, paddings):
|
||||
r"""
|
||||
Divides spatial dimensions into blocks and combines the block size with the original batch.
|
||||
|
||||
Args:
|
||||
block_shape (Union[list(int), tuple(int), int]): The block shape of dividing block with all value greater
|
||||
than 1.
|
||||
paddings (Union[tuple, list]): The padding values for spatial dimensions, containing 2 subtraction list.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input tensor. It must be a 4-D tensor on Ascend.
|
||||
|
||||
Outputs:
|
||||
Tensor, the output tensor with the same data type as input.
|
||||
|
||||
Raises:
|
||||
TypeError: If `block_shape` is not one of list, tuple, int.
|
||||
TypeError: If `paddings` is neither list nor tuple.
|
||||
ValueError: If `block_shape` is not one dimensional when `block_shape` is a list or tuple.
|
||||
ValueError: If the length of `block_shape` is not 2 on Ascend.
|
||||
ValueError: If shape of `paddings` is not (2, M), where M is the length of `block_shape`.
|
||||
ValueError: If the element of `block_shape` is not an integer larger than 1.
|
||||
ValueError: If the element of `paddings` is not an integer larger than 0.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Tensor
|
||||
>>> block_shape = [2, 2]
|
||||
>>> paddings = [[0, 0], [0, 0]]
|
||||
>>> input_x = Tensor(np.array([[[[1, 2], [3, 4]]]]), mindspore.float32)
|
||||
>>> output = input_x.space_to_batch_nd(block_shape, paddings)
|
||||
>>> print(output)
|
||||
[[[[1.]]]
|
||||
[[[2.]]]
|
||||
[[[3.]]]
|
||||
[[[4.]]]]
|
||||
"""
|
||||
return tensor_operator_registry.get('space_to_batch_nd')(block_shape, paddings)(self)
|
||||
|
||||
def var(self, axis=None, ddof=0, keepdims=False):
|
||||
"""
|
||||
Compute the variance along the specified axis.
|
||||
|
|
|
@ -76,3 +76,4 @@ from .buffer_get import _buffer_get_cpu
|
|||
from .buffer_sample import _buffer_sample_cpu
|
||||
from .priority_replay_buffer import _prb_push_op_cpu
|
||||
from .priority_replay_buffer import _prb_sample_op_cpu
|
||||
from .space_to_batch_nd import _space_to_batch_nd_cpu
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""ReduceSum op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, CpuRegOp, DataType
|
||||
|
||||
space_to_batch_nd_op_info = CpuRegOp("SpaceToBatchND") \
|
||||
.input(0, "x", "required") \
|
||||
.output(0, "y", "required") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.I8_Default, DataType.I8_Default) \
|
||||
.dtype_format(DataType.I16_Default, DataType.I16_Default) \
|
||||
.dtype_format(DataType.I32_Default, DataType.I32_Default) \
|
||||
.dtype_format(DataType.I64_Default, DataType.I64_Default) \
|
||||
.dtype_format(DataType.U8_Default, DataType.U8_Default) \
|
||||
.dtype_format(DataType.U16_Default, DataType.U16_Default) \
|
||||
.dtype_format(DataType.U32_Default, DataType.U32_Default) \
|
||||
.dtype_format(DataType.U64_Default, DataType.U64_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(space_to_batch_nd_op_info)
|
||||
def _space_to_batch_nd_cpu():
|
||||
"""SpaceToBatchND cpu register"""
|
||||
return
|
|
@ -23,7 +23,7 @@ from . import array_func, parameter_func, math_func
|
|||
from .array_func import (unique, eye, fill, fill_, tile, size, ones, ones_like, shape, shape_, dyn_shape, rank,
|
||||
reshape, reshape_, tensor_slice, slice, scalar_to_array, scalar_to_tensor, tuple_to_array,
|
||||
expand_dims, transpose, scatter_nd, gather, gather_d, gather_nd, scalar_cast, masked_fill,
|
||||
tensor_scatter_add, tensor_scatter_div, scatter_max, scatter_min)
|
||||
tensor_scatter_add, tensor_scatter_div, scatter_max, scatter_min, space_to_batch_nd)
|
||||
from .parameter_func import assign, assign_add, assign_sub, index_add
|
||||
from .math_func import (addn, absolute, abs, tensor_add, add, neg_tensor, neg, tensor_lt, less, tensor_le, le,
|
||||
tensor_gt, gt, tensor_ge, ge, tensor_sub, sub, tensor_mul, mul, tensor_div, div,
|
||||
|
|
|
@ -999,6 +999,63 @@ def tensor_scatter_add(input_x, indices, updates):
|
|||
return tensor_scatter_add_(input_x, indices, updates)
|
||||
|
||||
|
||||
def space_to_batch_nd(input_x, block_size, paddings):
|
||||
r"""
|
||||
Divides a tensor's spatial dimensions into blocks and combines the block sizes with the original batch.
|
||||
|
||||
This operation will divide spatial dimensions into blocks with `block_shape`,
|
||||
and after division, the output tensor's spatial dimension is the corresponding number of blocks.
|
||||
The output tensor's batch dimension is the product of the original batch and the product of `block_shape`.
|
||||
Before division, the spatial dimensions of the input are zero padded according to paddings if necessary.
|
||||
Assume input shape is :math:`(n, c_1, ... c_k, w_1, ..., w_M)` with
|
||||
:math:`block\_shape` and :math:`paddings`. Then the shape of the output tensor will be
|
||||
:math:`(n', c_1, ... c_k, w'_1, ..., w'_M)`, where
|
||||
|
||||
:math:`n' = n*(block\_shape[0]*...*block\_shape[M])`
|
||||
|
||||
:math:`w'_i = (w_i+paddings[i][0]+paddings[i][1])//block\_shape[i]`
|
||||
|
||||
Args:
|
||||
input_x (Tensor): The input tensor. It must be a 4-D tensor on Ascend.
|
||||
block_shape (Union[list(int), tuple(int), int]): The block shape of dividing block with all value greater
|
||||
than 1. If `block_shape` is a tuple or list, the length of `block_shape` is M corresponding to the
|
||||
number of spatial dimensions. If `block_shape` is an int, the block size of M dimensions are the same,
|
||||
equal to `block_shape`. M must be 2 on Ascend.
|
||||
paddings (Union[tuple, list]): The padding values for spatial dimensions, containing 2 subtraction list.
|
||||
Each contains M integer values. All values must be greater than 0.
|
||||
`paddings[i]` specifies the paddings for the spatial dimension i,
|
||||
which corresponds to the input dimension i + offset.
|
||||
It is required that input_shape[i+offset]+paddings[i][0]+paddings[i][1] is divisible by block_shape[i].
|
||||
M must be 2 on Ascend.
|
||||
|
||||
Returns:
|
||||
Tensor, the output tensor with the same data type as input.
|
||||
|
||||
Raises:
|
||||
ValueError: If `block_shape` is not one dimensional when `block_shape` is a list or tuple.
|
||||
ValueError: If the length of `block_shape` is not 2 on Ascend.
|
||||
ValueError: If the element of `block_shape` is not an integer larger than 1.
|
||||
ValueError: If shape of `paddings` is not (2, M), where M is the length of `block_shape`.
|
||||
ValueError: If the element of `paddings` is not an integer larger than 0.
|
||||
TypeError: If `block_shape` is not one of list, tuple, int.
|
||||
TypeError: If `paddings` is neither list nor tuple.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> block_shape = [2, 2]
|
||||
>>> paddings = [[0, 0], [0, 0]]
|
||||
>>> input_x = Tensor(np.array([[[[1, 2], [3, 4]]]]), mindspore.float32)
|
||||
>>> output = ops.space_to_batch_nd(input_x, block_shape, paddings)
|
||||
>>> print(output)
|
||||
[[[[1.]]]
|
||||
[[[2.]]]
|
||||
[[[3.]]]
|
||||
[[[4.]]]]`
|
||||
"""
|
||||
return P.SpaceToBatchND(block_size, paddings)(input_x)
|
||||
|
||||
##############################
|
||||
# Type Conversion Functions.
|
||||
##############################
|
||||
|
@ -1227,6 +1284,7 @@ __all__ = [
|
|||
'scalar_cast',
|
||||
'scalar_to_array',
|
||||
'scalar_to_tensor',
|
||||
'space_to_batch_nd',
|
||||
'tuple_to_array',
|
||||
'expand_dims',
|
||||
'transpose',
|
||||
|
|
|
@ -946,7 +946,7 @@ def print_info(info):
|
|||
|
||||
def make_sparse_tensor(indices, values, dense_shape):
|
||||
"""Call make_coo_tensor in this function."""
|
||||
print_info("WARNING: 'SparseTensor' is deprecated from version 1.7 and will be removed in a future version. " + \
|
||||
print_info("WARNING: 'SparseTensor' is deprecated from version 1.7 and will be removed in a future version. " +
|
||||
"Please use 'COOTensor' instead.")
|
||||
return make_coo_tensor(indices, values, dense_shape)
|
||||
|
||||
|
@ -982,6 +982,7 @@ tensor_operator_registry.register('split', P.Split)
|
|||
tensor_operator_registry.register('index_add', P.IndexAdd)
|
||||
tensor_operator_registry.register('scatter_max', P.ScatterMax)
|
||||
tensor_operator_registry.register('scatter_min', P.ScatterMin)
|
||||
tensor_operator_registry.register('space_to_batch_nd', P.SpaceToBatchND)
|
||||
# ms cannot support Tensor(True) compare
|
||||
tensor_operator_registry.register('__eq__', equal)
|
||||
tensor_operator_registry.register('__ne__', not_equal)
|
||||
|
|
|
@ -24,6 +24,7 @@ from collections import Counter
|
|||
import numpy as np
|
||||
|
||||
from mindspore import log as logger
|
||||
from mindspore import context
|
||||
from mindspore.common.initializer import Zero
|
||||
from .. import signature as sig
|
||||
from .._utils import get_broadcast_shape, is_shape_unknown
|
||||
|
@ -3718,7 +3719,7 @@ class StridedSlice(PrimitiveWithInfer):
|
|||
if has_ellipsis:
|
||||
# When there is ellipsis, handle the second half of the ellipsis split.
|
||||
ellipsis_occupied_dims = x_rank - i - (slice_len - (j + 1)) + \
|
||||
len(tuple(filter(lambda x: x == '1', new_axis_pos[j + 1:slice_len])))
|
||||
len(tuple(filter(lambda x: x == '1', new_axis_pos[j + 1:slice_len])))
|
||||
ret_shape.extend(x_shape[i:i + ellipsis_occupied_dims])
|
||||
j += 1
|
||||
i += ellipsis_occupied_dims
|
||||
|
@ -5636,46 +5637,49 @@ class SpaceToBatchND(PrimitiveWithInfer):
|
|||
r"""
|
||||
Divides spatial dimensions into blocks and combines the block size with the original batch.
|
||||
|
||||
This operation will divide spatial dimensions (H, W) into blocks with block_shape, the output tensor's H and W
|
||||
This operation will divide spatial dimensions into blocks with `block_shape`, and then the output tensor's spatial
|
||||
dimension is the corresponding number of blocks after division. The output tensor's batch dimension is the
|
||||
product of the original batch and the product of `block_shape`. Before division,
|
||||
the spatial dimensions of the input are zero padded according to paddings if necessary.
|
||||
product of the original batch and all elements in `block_shape`.
|
||||
Before division, the spatial dimensions of the input are zero padded according to paddings if necessary.
|
||||
|
||||
Args:
|
||||
block_shape (Union[list(int), tuple(int), int]): The block shape of dividing block with all value greater
|
||||
than 1. If `block_shape` is a tuple or list, the length of `block_shape` is M corresponding to the
|
||||
number of spatial dimensions. If `block_shape` is an int, the block size of M dimensions are the same,
|
||||
equal to `block_shape`. M must be 2.
|
||||
paddings (Union[tuple, list]): The padding values for H and W dimension, containing 2 subtraction list.
|
||||
Each contains 2 integer value. All values must be greater than 0.
|
||||
block_shape (Union[list(int), tuple(int), int]): The block shape of dividing block
|
||||
with all elements greater than 1. If `block_shape` is a list or tuple,
|
||||
the length of `block_shape` is the number of spatial dimensions, called M later.
|
||||
If `block_shape` is an int, the block size of M dimensions are the same, equal to `block_shape`.
|
||||
In this case of Ascend, M must be 2.
|
||||
paddings (Union[tuple, list]): The padding values for spatial dimensions, containing 2 subtraction list.
|
||||
Each contains M integer values. All values must be greater than 0.
|
||||
`paddings[i]` specifies the paddings for the spatial dimension i,
|
||||
which corresponds to the input dimension i+2.
|
||||
It is required that input_shape[i+2]+paddings[i][0]+paddings[i][1] is divisible by block_shape[i].
|
||||
which corresponds to the input dimension i + offset.
|
||||
For each i, input_shape[i + offset]+paddings[i][0]+paddings[i][1]
|
||||
should be divisible by block_shape[i].
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The input tensor. It must be a 4-D tensor.
|
||||
- **input_x** (Tensor) - The input tensor. The input tensor must be a 4-D tensor on Ascend.
|
||||
|
||||
Outputs:
|
||||
Tensor, the output tensor with the same data type as input. Assume input shape is :math:`(n, c, h, w)` with
|
||||
:math:`block\_shape` and :math:`paddings`. The shape of the output tensor will be :math:`(n', c', h', w')`,
|
||||
Tensor, the output tensor with the same data type as input.
|
||||
Assume input shape is :math:`(n, c_1, ... c_k, w_1, ..., w_M)` with
|
||||
:math:`block\_shape` and :math:`paddings`.
|
||||
The shape of the output tensor will be :math:`(n', c_1, ... c_k, w'_1, ..., w'_M)`,
|
||||
where
|
||||
|
||||
:math:`n' = n*(block\_shape[0]*block\_shape[1])`
|
||||
:math:`n' = n*(block\_shape[0]*...*block\_shape[M])`
|
||||
|
||||
:math:`c' = c`
|
||||
|
||||
:math:`h' = (h+paddings[0][0]+paddings[0][1])//block\_shape[0]`
|
||||
|
||||
:math:`w' = (w+paddings[1][0]+paddings[1][1])//block\_shape[1]`
|
||||
:math:`w'_i = (w_i+paddings[i][0]+paddings[i][1])//block\_shape[i]`
|
||||
|
||||
Raises:
|
||||
TypeError: If `block_shape` is not one of list, tuple, int.
|
||||
TypeError: If `paddings` is neither list nor tuple.
|
||||
ValueError: If length of shape of `block_shape` is not equal to 1.
|
||||
ValueError: If length of `block_shape` or `paddings` is not equal to 2.
|
||||
ValueError: If `block_shape` is not one dimensional when `block_shape` is a list or tuple.
|
||||
ValueError: If the length of `block_shape` is not 2 on Ascend.
|
||||
ValueError: If shape of `paddings` is not (2, M), where M is the length of `block_shape`.
|
||||
ValueError: If the element of `block_shape` is not an integer larger than 1.
|
||||
ValueError: If the element of `paddings` is not an integer larger than 0.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
``Ascend`` ``GPU``
|
||||
|
||||
Examples:
|
||||
>>> block_shape = [2, 2]
|
||||
|
@ -5693,20 +5697,23 @@ class SpaceToBatchND(PrimitiveWithInfer):
|
|||
@prim_attr_register
|
||||
def __init__(self, block_shape, paddings):
|
||||
"""Initialize SpaceToBatchND"""
|
||||
validator.check_value_type('paddings type', paddings, [list, tuple], self.name)
|
||||
validator.check('paddings length', len(paddings), '', 2, Rel.EQ, self.name)
|
||||
|
||||
if isinstance(block_shape, int):
|
||||
block_shape = (block_shape,) * 2
|
||||
block_shape = (block_shape,) * np.array(paddings).shape[0]
|
||||
|
||||
self.add_prim_attr("block_shape", block_shape)
|
||||
validator.check_value_type('block_shape type', block_shape, [list, tuple], self.name)
|
||||
validator.check('block_shape shape', len(np.array(block_shape).shape), '', 1, Rel.EQ, self.name)
|
||||
block_rank = len(block_shape)
|
||||
validator.check('block_shape length', block_rank, '', 2, Rel.EQ, self.name)
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
validator.check('block_shape length', block_rank, '', 2, Rel.EQ, self.name)
|
||||
for elem in block_shape:
|
||||
validator.check('block_shape element', elem, '', 1, Rel.GE, self.name)
|
||||
validator.check_value_type('block_shape element', elem, [int], self.name)
|
||||
self.block_shape = block_shape
|
||||
|
||||
validator.check_value_type('paddings type', paddings, [list, tuple], self.name)
|
||||
validator.check('paddings length', len(paddings), '', 2, Rel.EQ, self.name)
|
||||
validator.check('paddings shape', np.array(paddings).shape, '', (block_rank, 2), Rel.EQ, self.name)
|
||||
for elem in itertools.chain(*paddings):
|
||||
validator.check_non_negative_int(elem, 'paddings element', self.name)
|
||||
|
@ -5719,18 +5726,23 @@ class SpaceToBatchND(PrimitiveWithInfer):
|
|||
|
||||
def infer_shape(self, x_shape):
|
||||
x_rank = len(x_shape)
|
||||
validator.check_equal_int(x_rank, 4, 'x_shape rank', self.name)
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
validator.check_equal_int(x_rank, 4, 'x_shape rank', self.name)
|
||||
out_shape = copy.deepcopy(x_shape)
|
||||
|
||||
block_shape_prod = 1
|
||||
offset = 2
|
||||
offset = len(x_shape) - len(self.block_shape)
|
||||
if offset <= 0:
|
||||
raise ValueError(f"For '{self.name}', the dim of the input should be larger than that of the blocks, "
|
||||
f"but the shape of the inputs is {x_shape} "
|
||||
f"while the shape of blocks is {self.block_shape}.")
|
||||
for i in range(len(self.block_shape)):
|
||||
padded = out_shape[i + offset] + self.paddings[i][0] + \
|
||||
self.paddings[i][1]
|
||||
self.paddings[i][1]
|
||||
if padded % self.block_shape[i] != 0:
|
||||
raise ValueError(f"For '{self.name}', the padded should be divisible by 'block_shape', "
|
||||
f"where padded = input_x_shape[i + 2] + paddings[i][0] + paddings[i][1], "
|
||||
f"but got input_x_shape[{i + 2}]: {out_shape[i + offset]}, "
|
||||
f"but got input_x_shape[{i + offset}]: {out_shape[i + offset]}, "
|
||||
f"paddings[{i}][0]: {self.paddings[i][0]} and paddings[{i}][1]: {self.paddings[i][1]}."
|
||||
f" Please check the official api documents for "
|
||||
f"more information about the output tensor.")
|
||||
|
|
|
@ -0,0 +1,164 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops as ops
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
|
||||
class SpaceToBatchNDNet(nn.Cell):
|
||||
def __init__(self, nptype, block_size=2, input_shape=(1, 1, 4, 4)):
|
||||
super(SpaceToBatchNDNet, self).__init__()
|
||||
self.space_to_batch_nd = ops.SpaceToBatchND(block_shape=block_size, paddings=[[0, 0], [0, 0]])
|
||||
input_size = np.prod(input_shape)
|
||||
data_np = np.arange(input_size).reshape(input_shape).astype(nptype)
|
||||
self.x1 = Parameter(initializer(Tensor(data_np), input_shape), name='x1')
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
y1 = self.space_to_batch_nd(self.x1)
|
||||
return y1
|
||||
|
||||
|
||||
def space_to_batch_nd_test_case(nptype, block_size=2, input_shape=(1, 1, 4, 4)):
|
||||
expect = np.array([[[[0, 2],
|
||||
[8, 10]]],
|
||||
[[[1, 3],
|
||||
[9, 11]]],
|
||||
[[[4, 6],
|
||||
[12, 14]]],
|
||||
[[[5, 7],
|
||||
[13, 15]]]]).astype(nptype)
|
||||
|
||||
dts = SpaceToBatchNDNet(nptype, block_size, input_shape)
|
||||
output = dts()
|
||||
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
def space_to_batch_nd_all_dtype():
|
||||
space_to_batch_nd_test_case(np.float32)
|
||||
space_to_batch_nd_test_case(np.float16)
|
||||
space_to_batch_nd_test_case(np.int8)
|
||||
space_to_batch_nd_test_case(np.int16)
|
||||
space_to_batch_nd_test_case(np.int32)
|
||||
space_to_batch_nd_test_case(np.int64)
|
||||
space_to_batch_nd_test_case(np.uint8)
|
||||
space_to_batch_nd_test_case(np.uint16)
|
||||
space_to_batch_nd_test_case(np.uint32)
|
||||
space_to_batch_nd_test_case(np.uint64)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_space_to_batch_nd_graph():
|
||||
"""
|
||||
Feature: test SpaceToBatchND function interface.
|
||||
Description: test interface.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='CPU')
|
||||
space_to_batch_nd_all_dtype()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_space_to_batch_nd_pynative():
|
||||
"""
|
||||
Feature: test SpaceToBatchND function interface.
|
||||
Description: test interface.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
|
||||
space_to_batch_nd_all_dtype()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_space_to_batch_nd_function():
|
||||
"""
|
||||
Feature: test SpaceToBatchND function interface.
|
||||
Description: test interface.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(device_target="CPU")
|
||||
x = Tensor(np.arange(16).reshape((1, 1, 4, 4)).astype(np.float32), mindspore.float32)
|
||||
output = ops.space_to_batch_nd(x, 2, [[0, 0], [0, 0]])
|
||||
expect = np.array([[[[0, 2],
|
||||
[8, 10]]],
|
||||
[[[1, 3],
|
||||
[9, 11]]],
|
||||
[[[4, 6],
|
||||
[12, 14]]],
|
||||
[[[5, 7],
|
||||
[13, 15]]]]).astype(np.float32)
|
||||
np.testing.assert_array_equal(output.asnumpy(), expect)
|
||||
|
||||
|
||||
class SpaceToBatchNDDynamicShapeNetMS(nn.Cell):
|
||||
def __init__(self, block_size, paddings, axis=0):
|
||||
super().__init__()
|
||||
self.unique = ops.Unique()
|
||||
self.gather = ops.Gather()
|
||||
self.relu = ops.SpaceToBatchND(block_size, paddings)
|
||||
self.axis = axis
|
||||
|
||||
def construct(self, x, indices):
|
||||
unique_indices, _ = self.unique(indices)
|
||||
x = self.gather(x, unique_indices, self.axis)
|
||||
return self.relu(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_space_to_batch_nd_dynamic():
|
||||
"""
|
||||
Feature: test SpaceToBatchND dynamic shape.
|
||||
Description: the input to SpaceToBatchND is dynamic.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
x = np.array([[[[1, 2, 3, 4], [5, 6, 7, 8]]], [[[1, 2, 3, 4], [5, 6, 7, 8]]],
|
||||
[[[1, 2, 3, 4], [5, 6, 7, 8]]], [[[1, 2, 3, 4], [5, 6, 7, 8]]]]).astype(np.float32)
|
||||
block_size = [2, 2]
|
||||
paddings = [[0, 0], [0, 0]]
|
||||
|
||||
input_x = Tensor(x, mindspore.float32)
|
||||
input_y = Tensor(np.array([0, 0, 1, 0]), mindspore.int32)
|
||||
expect = np.array([[[[1., 3.]]],
|
||||
[[[1., 3.]]],
|
||||
[[[2., 4.]]],
|
||||
[[[2., 4.]]],
|
||||
[[[5., 7.]]],
|
||||
[[[5., 7.]]],
|
||||
[[[6., 8.]]],
|
||||
[[[6., 8.]]]]).astype(np.float32)
|
||||
dyn_net = SpaceToBatchNDDynamicShapeNetMS(block_size, paddings)
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
output = dyn_net(input_x, input_y)
|
||||
assert (output.asnumpy() == expect).all()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
output = dyn_net(input_x, input_y)
|
||||
assert (output.asnumpy() == expect).all()
|
Loading…
Reference in New Issue