!43683 [assistant][ops] Add Betainc, SpaceToBatchND and BatchToSpaceND
Merge pull request !43683 from zwy9901/merge
This commit is contained in:
commit
28fdfb512a
|
@ -12,8 +12,8 @@ mindspore.ops.SpaceToBatchND
|
|||
在划分成块之前,输入的空间维度会根据 `paddings` 填充零。
|
||||
|
||||
参数:
|
||||
- **block_shape** (Union[list(int), tuple(int), int]) - 块形状描述空间维度为分割的个数,取值需大于1。如果 `block_shape` 为list或者tuple,其长度 `M` 为空间维度的长度。如果 `block_shape` 为整数,那么所有空间维度分割的个数均为 `block_shape` 。在Ascend后端 `M` 必须为2。
|
||||
- **paddings** (Union[tuple, list]) - 空间维度的填充大小。包含M个List,每一个List包含2个整形值,且各值须大于0。 `paddings[i]` 为对空间维度 `i` 的填充,对应输入Tensor的维度 `i+offset` , `offset` 为空间维度在输入Tensor维度中的偏移量。
|
||||
- **block_shape** (Union[list(int), tuple(int), int]) - 块形状描述空间维度为分割的个数,取值需大于或者等于1。如果 `block_shape` 为list或者tuple,其长度 `M` 为空间维度的长度。如果 `block_shape` 为整数,那么所有空间维度分割的个数均为 `block_shape` 。在Ascend后端 `M` 必须为2。
|
||||
- **paddings** (Union[tuple, list]) - 空间维度的填充大小。包含M个List,每一个List包含2个整形值,且各值须大于或者等于0。 `paddings[i]` 为对空间维度 `i` 的填充,对应输入Tensor的维度 `i+offset` , `offset` 为空间维度在输入Tensor维度中的偏移量,其中 `offset=N-M` , `N` 是输入维度数。
|
||||
对空间维度i, `input_shape[i+offset]+paddings[i][0]+paddings[i][1]` 必须能被 `block_shape[i]` 整除。
|
||||
|
||||
输入:
|
||||
|
@ -26,8 +26,8 @@ mindspore.ops.SpaceToBatchND
|
|||
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
n' = n*(block\_shape[0]*...*block\_shape[M]) \\
|
||||
w'_i = (w_i+paddings[i][0]+paddings[i][1])//block\_shape[i]
|
||||
n' = n*(block\_shape[0]*...*block\_shape[M-1]) \\
|
||||
w'_i = (w_i+paddings[i-1][0]+paddings[i-1][1])//block\_shape[i-1]
|
||||
\end{array}
|
||||
|
||||
异常:
|
||||
|
@ -36,5 +36,5 @@ mindspore.ops.SpaceToBatchND
|
|||
- **ValueError** - 如果当 `block_shape` 为 list 或 tuple, `block_shape` 不是一维。
|
||||
- **ValueError** - 如果 Ascend 平台上 `block_shape` 长度不是2。
|
||||
- **ValueError** - 如果 `paddings` 的形状不是 (M, 2), 其中 M 为 `block_shape` 的长度。
|
||||
- **ValueError** - 如果 `block_shape` 的元素不是大于一的整数。
|
||||
- **ValueError** - 如果 `block_shape` 的元素不是大于或者等于一的整数。
|
||||
- **ValueError** - 如果 `paddings` 的元素不是非负的整数。
|
||||
|
|
|
@ -6,22 +6,20 @@ mindspore.ops.batch_to_space_nd
|
|||
用块划分批次维度,并将这些块交错回空间维度。
|
||||
|
||||
此函数会将批次维度 `N` 划分为具有 `block_shape` 的块,即输出张量的 `N` 维度是划分后对应的块数。
|
||||
输出张量的 `H` 、`W` 维度是原始的 `H` 、`W` 维度和 `block_shape` 的乘积从维度裁剪给定。
|
||||
如此,若输入的shape为 :math:`(n, c, h, w)` ,则输出的shape为 :math:`(n', c', h', w')` 。
|
||||
输出张量的 :math:`w_1, ..., w_M` 维度是原始的 :math:`w_1, ..., w_M` 维度和 `block_shape` 的乘积从维度裁剪给定。
|
||||
如此,若输入的shape为 :math:`(n, c_1, ... c_k, w_1, ..., w_M)`,则输出的shape为 :math:`(n', c_1, ... c_k, w'_1, ..., w'_M)` 。
|
||||
其中
|
||||
|
||||
:math:`n' = n//(block\_shape[0]*block\_shape[1])`
|
||||
|
||||
:math:`c' = c`
|
||||
|
||||
:math:`h' = h*block\_shape[0]-crops[0][0]-crops[0][1]`
|
||||
|
||||
:math:`w' = w*block\_shape[1]-crops[1][0]-crops[1][1]`
|
||||
.. math::
|
||||
\begin{array}{ll} \\
|
||||
n' = n//(block\_shape[0]*...*block\_shape[M-1]) \\
|
||||
w'_i = w_i*block\_shape[i-1]-crops[i-1][0]-crops[i-1][1]
|
||||
\end{array}
|
||||
|
||||
参数:
|
||||
- **input_x** (Tensor) - 输入张量,必须大于或者等于四维(Ascend平台必须为4维)。批次维度需能被 `block_shape` 整除。支持数据类型float16和float32。
|
||||
- **block_shape** (Union[list(int), tuple(int), int]) - 分割批次维度的块的数量,取值需大于1。如果 `block_shape` 为list或者tuple,其长度 `M` 为空间维度的长度。如果 `block_shape` 为整数,那么所有空间维度分割的个数均为 `block_shape` 。 `M` 必须为2。
|
||||
- **crops** (Union[list(int), tuple(int)]) - 空间维度的裁剪大小,包含两个长度为2的list,分别对应空间维度H和W。取值需大于或等于0,同时要求 `input_shape[i+2] * block_shape[i] > crops[i][0] + crops[i][1]` 。
|
||||
- **input_x** (Tensor) - 输入张量,必须大于或者等于二维(Ascend平台必须为4维)。批次维度需能被 `block_shape` 整除。
|
||||
- **block_shape** (Union[list(int), tuple(int), int]) - 分割批次维度的块的数量,取值需大于或者等于1。如果 `block_shape` 为list或者tuple,其长度 `M` 为空间维度的长度。如果 `block_shape` 为list或者tuple,其长度 `M` 为空间维度的长度。如果 `block_shape` 为整数,那么所有空间维度分割的个数均为 `block_shape` 。在Ascend后端 `M` 必须为2。
|
||||
- **crops** (Union[list(int), tuple(int)]) - 空间维度的裁剪大小,包含 `M` 个长度为2的list,取值需大于或等于0。`crops[i]` 为对空间维度 `i` 的填充,对应输入Tensor的维度 `i+offset` , `offset` 为空间维度在输入Tensor维度中的偏移量,其中 `offset=N-M` , `N` 是输入维度数。同时要求 `input_shape[i+offset] * block_shape[i] > crops[i][0] + crops[i][1]` 。
|
||||
|
||||
返回:
|
||||
Tensor,经过划分排列之后的结果。
|
||||
|
@ -30,7 +28,7 @@ mindspore.ops.batch_to_space_nd
|
|||
- **TypeError** - 如果 `block_shape` 不是 list、tuple 或者 int。
|
||||
- **TypeError** - 如果 `crops` 不是 list 或者 tuple。
|
||||
- **ValueError** - 如果当 `block_shape` 为 list 或 tuple, `block_shape` 不是一维。
|
||||
- **ValueError** - 如果 `block_shape` 或 `crops` 长度不是2。
|
||||
- **ValueError** - 如果 `block_shape` 的元素不是大于一的整数。
|
||||
- **ValueError** - 如果 Ascend 平台上 `block_shape` 长度不是2。
|
||||
- **ValueError** - 如果 `block_shape` 的元素不是大于或者等于一的整数。
|
||||
- **ValueError** - 如果 `crops` 的形状不是 (M, 2),其中 M 为 `block_shape` 的长度。
|
||||
- **ValueError** - 如果 `crops` 的元素不是非负的整数。
|
||||
|
|
|
@ -0,0 +1,113 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/arrays/batch_to_space_nd_gpu_kernel.h"
|
||||
#include <utility>
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
template <typename T>
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> CreateBatchToSpaceNDKernelPtr(const std::string &kernel_name,
|
||||
const uint32_t &device_id) {
|
||||
return std::make_unique<cukernel::BatchToSpaceNDHelperGpuKernel<T>>(kernel_name, device_id);
|
||||
}
|
||||
using BatchToSpaceNDPtrCreatorFunc =
|
||||
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
|
||||
|
||||
const std::vector<std::pair<KernelAttr, BatchToSpaceNDPtrCreatorFunc>> kernel_attr = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), CreateBatchToSpaceNDKernelPtr<int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), CreateBatchToSpaceNDKernelPtr<int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateBatchToSpaceNDKernelPtr<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CreateBatchToSpaceNDKernelPtr<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), CreateBatchToSpaceNDKernelPtr<uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
CreateBatchToSpaceNDKernelPtr<uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
CreateBatchToSpaceNDKernelPtr<uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
CreateBatchToSpaceNDKernelPtr<uint64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateBatchToSpaceNDKernelPtr<half>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateBatchToSpaceNDKernelPtr<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateBatchToSpaceNDKernelPtr<double>}};
|
||||
} // namespace
|
||||
|
||||
bool BatchToSpaceNDGpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
std::vector<void *> input_ptrs = ConvertPtrs(inputs);
|
||||
std::vector<void *> work_ptrs = ConvertPtrs(workspace);
|
||||
std::vector<void *> output_ptrs = ConvertPtrs(outputs);
|
||||
if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool BatchToSpaceNDGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::BatchToSpaceND>(base_operator);
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
return false;
|
||||
}
|
||||
attr_ptr_->block_shape = kernel_ptr->get_block_shape();
|
||||
attr_ptr_->crops = kernel_ptr->get_crops();
|
||||
helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_));
|
||||
helper_ptr_->SetKernelParam(attr_ptr_);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int BatchToSpaceNDGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
for (const auto &input : inputs) {
|
||||
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
|
||||
auto input_shape = input->GetShapeVector();
|
||||
if (!IsValidShape(input_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
std::vector<std::vector<int64_t>> input_shapes;
|
||||
std::vector<std::vector<int64_t>> output_shapes;
|
||||
std::vector<int64_t> inp_shape = inputs[kIndex0]->GetShapeVector();
|
||||
std::vector<int64_t> out_shape = outputs[kIndex0]->GetShapeVector();
|
||||
input_shapes.emplace_back(inp_shape);
|
||||
output_shapes.emplace_back(out_shape);
|
||||
if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
input_size_list_ = helper_ptr_->GetInputSizeList();
|
||||
output_size_list_ = helper_ptr_->GetOutputSizeList();
|
||||
workspace_size_list_ = helper_ptr_->GetWorkSizeList();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> BatchToSpaceNDGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, BatchToSpaceNDPtrCreatorFunc> &item) { return item.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, BatchToSpaceND, BatchToSpaceNDGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BATCH_TO_SPACE_ND_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BATCH_TO_SPACE_ND_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include "mindspore/core/ops/batch_to_space_nd.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/batch_to_space_nd_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class BatchToSpaceNDGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
BatchToSpaceNDGpuKernelMod() { attr_ptr_ = std::make_shared<cukernel::BatchToSpaceNDAttr>(); }
|
||||
~BatchToSpaceNDGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(
|
||||
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> helper_ptr_{nullptr};
|
||||
std::shared_ptr<cukernel::BatchToSpaceNDAttr> attr_ptr_{nullptr};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_BATCH_TO_SPACE_ND_KERNEL_H_
|
|
@ -0,0 +1,113 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/arrays/space_to_batch_nd_gpu_kernel.h"
|
||||
#include <utility>
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
template <typename T>
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> CreateSpaceToBatchNDKernelPtr(const std::string &kernel_name,
|
||||
const uint32_t &device_id) {
|
||||
return std::make_unique<cukernel::SpaceToBatchNDHelperGpuKernel<T>>(kernel_name, device_id);
|
||||
}
|
||||
using SpaceToBatchNDPtrCreatorFunc =
|
||||
std::function<std::unique_ptr<cukernel::GpuKernelHelperBase>(const std::string &, const uint32_t &)>;
|
||||
|
||||
const std::vector<std::pair<KernelAttr, SpaceToBatchNDPtrCreatorFunc>> kernel_attr = {
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8), CreateSpaceToBatchNDKernelPtr<int8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kNumberTypeInt16), CreateSpaceToBatchNDKernelPtr<int16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), CreateSpaceToBatchNDKernelPtr<int32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kNumberTypeInt64), CreateSpaceToBatchNDKernelPtr<int64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt8).AddOutputAttr(kNumberTypeUInt8), CreateSpaceToBatchNDKernelPtr<uint8_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt16).AddOutputAttr(kNumberTypeUInt16),
|
||||
CreateSpaceToBatchNDKernelPtr<uint16_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt32).AddOutputAttr(kNumberTypeUInt32),
|
||||
CreateSpaceToBatchNDKernelPtr<uint32_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeUInt64).AddOutputAttr(kNumberTypeUInt64),
|
||||
CreateSpaceToBatchNDKernelPtr<uint64_t>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
CreateSpaceToBatchNDKernelPtr<half>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
CreateSpaceToBatchNDKernelPtr<float>},
|
||||
{KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
CreateSpaceToBatchNDKernelPtr<double>}};
|
||||
} // namespace
|
||||
|
||||
bool SpaceToBatchNDGpuKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
std::vector<void *> input_ptrs = ConvertPtrs(inputs);
|
||||
std::vector<void *> work_ptrs = ConvertPtrs(workspace);
|
||||
std::vector<void *> output_ptrs = ConvertPtrs(outputs);
|
||||
if (helper_ptr_->Process(input_ptrs, output_ptrs, work_ptrs, stream_ptr) != 0) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool SpaceToBatchNDGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
auto kernel_ptr = std::dynamic_pointer_cast<ops::SpaceToBatchND>(base_operator);
|
||||
kernel_name_ = kernel_ptr->name();
|
||||
auto tensor_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(tensor_attr, GetOpSupport());
|
||||
if (!is_match) {
|
||||
return false;
|
||||
}
|
||||
attr_ptr_->block_shape = kernel_ptr->get_block_shape();
|
||||
attr_ptr_->paddings = kernel_ptr->get_paddings();
|
||||
helper_ptr_ = std::move(kernel_attr[index].second(kernel_name_, device_id_));
|
||||
helper_ptr_->SetKernelParam(attr_ptr_);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
int SpaceToBatchNDGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
for (const auto &input : inputs) {
|
||||
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
|
||||
auto input_shape = input->GetShapeVector();
|
||||
if (!IsValidShape(input_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
std::vector<std::vector<int64_t>> input_shapes;
|
||||
std::vector<std::vector<int64_t>> output_shapes;
|
||||
std::vector<int64_t> inp_shape = inputs[kIndex0]->GetShapeVector();
|
||||
std::vector<int64_t> out_shape = outputs[kIndex0]->GetShapeVector();
|
||||
input_shapes.emplace_back(inp_shape);
|
||||
output_shapes.emplace_back(out_shape);
|
||||
if (helper_ptr_->CalMemSize(input_shapes, output_shapes) == -1) {
|
||||
return KRET_RESIZE_FAILED;
|
||||
}
|
||||
input_size_list_ = helper_ptr_->GetInputSizeList();
|
||||
output_size_list_ = helper_ptr_->GetOutputSizeList();
|
||||
workspace_size_list_ = helper_ptr_->GetWorkSizeList();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> SpaceToBatchNDGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(kernel_attr.begin(), kernel_attr.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, SpaceToBatchNDPtrCreatorFunc> &item) { return item.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, SpaceToBatchND, SpaceToBatchNDGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,56 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPACE_TO_BATCH_ND_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPACE_TO_BATCH_ND_KERNEL_H_
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include "mindspore/core/ops/space_to_batch_nd.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/space_to_batch_nd_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class SpaceToBatchNDGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
SpaceToBatchNDGpuKernelMod() { attr_ptr_ = std::make_shared<cukernel::SpaceToBatchNDAttr>(); }
|
||||
~SpaceToBatchNDGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(
|
||||
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<cukernel::GpuKernelHelperBase> helper_ptr_{nullptr};
|
||||
std::shared_ptr<cukernel::SpaceToBatchNDAttr> attr_ptr_{nullptr};
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_SPACE_TO_BATCH_ND_KERNEL_H_
|
|
@ -0,0 +1,198 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_BATCH_TO_SPACE_ND_HELPER_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_BATCH_TO_SPACE_ND_HELPER_H_
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/batch_to_space_nd_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace cukernel {
|
||||
class BatchToSpaceNDAttr : public GpuKernelAttrBase {
|
||||
public:
|
||||
BatchToSpaceNDAttr() = default;
|
||||
~BatchToSpaceNDAttr() override = default;
|
||||
std::vector<std::vector<int64_t>> crops;
|
||||
std::vector<int64_t> block_shape;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class BatchToSpaceNDHelperGpuKernel : public GpuKernelHelperBase {
|
||||
public:
|
||||
explicit BatchToSpaceNDHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id)
|
||||
: GpuKernelHelperBase(kernel_name, device_id) {
|
||||
is_null_input_ = false;
|
||||
}
|
||||
|
||||
virtual ~BatchToSpaceNDHelperGpuKernel() = default;
|
||||
int CalMemSize(const std::vector<std::vector<int64_t>> &input_shapes,
|
||||
const std::vector<std::vector<int64_t>> &output_shapes) override {
|
||||
constexpr size_t INPUT_NUM = 1;
|
||||
constexpr size_t OUTPUT_NUM = 1;
|
||||
ResetResource();
|
||||
int inp_flag = CalShapesSizeInBytes<T>(input_shapes, INPUT_NUM, kernel_name_, "input_shapes", &input_size_list_);
|
||||
if (inp_flag == -1) {
|
||||
return inp_flag;
|
||||
}
|
||||
input_shape_ = input_shapes[0];
|
||||
|
||||
int out_flag =
|
||||
CalShapesSizeInBytes<T>(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_);
|
||||
if (out_flag == -1) {
|
||||
return out_flag;
|
||||
}
|
||||
output_shape_ = output_shapes[0];
|
||||
is_null_input_ = (inp_flag == 1 || out_flag == 1);
|
||||
return CheckKernelParam();
|
||||
}
|
||||
|
||||
int Process(const std::vector<void *> &input_ptrs, const std::vector<void *> &output_ptrs,
|
||||
const std::vector<void *> &work_ptrs, void *cuda_stream) override {
|
||||
if (is_null_input_) {
|
||||
return 0;
|
||||
}
|
||||
input_size_ = 1;
|
||||
for (size_t i = 0; i < (size_t) static_cast<int64_t>(input_shape_.size()); ++i) {
|
||||
input_size_ = input_shape_[i] * input_size_;
|
||||
}
|
||||
|
||||
output_size_ = 1;
|
||||
for (size_t i = 0; i < (size_t) static_cast<int64_t>(output_shape_.size()); ++i) {
|
||||
output_size_ = output_shape_[i] * output_size_;
|
||||
}
|
||||
|
||||
T *input_ptr = nullptr;
|
||||
T *output_ptr = nullptr;
|
||||
input_shape_size = input_shape_.size();
|
||||
output_shape_size = output_shape_.size();
|
||||
|
||||
int flag = GetDeviceAddress<T>(input_ptrs, 0, kernel_name_, &input_ptr);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
|
||||
flag = GetDeviceAddress<T>(output_ptrs, 0, kernel_name_, &output_ptr);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
|
||||
std::vector<int64_t> crops_start_(crops_.size(), 0);
|
||||
for (int i = 0; i < static_cast<int>(crops_.size()); i++) crops_start_[i] = crops_[i][0];
|
||||
std::vector<int64_t> stride_(input_shape_size, 1);
|
||||
for (int i = static_cast<int>(input_shape_size) - 2; i >= 0; i--) stride_[i] = stride_[i + 1] * input_shape_[i + 1];
|
||||
std::vector<int64_t> on_stride_(block_rank_, 1);
|
||||
if (block_rank_ > 1) {
|
||||
for (int i = static_cast<int>(block_rank_) - 2; i >= 0; i--)
|
||||
on_stride_[i] = on_stride_[i + 1] * block_shape_[i + 1];
|
||||
}
|
||||
// call cuda kernel
|
||||
CalBatchToSpaceND(input_ptr, crops_start_.data(), block_shape_.data(), output_shape_.data(), output_shape_size,
|
||||
stride_.data(), on_stride_.data(), off_set_, output_size_, output_ptr, device_id_,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream));
|
||||
return 0;
|
||||
}
|
||||
|
||||
void SetKernelParam(const GpuKernelAttrBasePtr &kernel_attr) override {
|
||||
attr_ptr_ = std::dynamic_pointer_cast<BatchToSpaceNDAttr>(kernel_attr);
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
block_rank_ = 0;
|
||||
off_set_ = 0;
|
||||
input_size_ = 0;
|
||||
output_size_ = 0;
|
||||
input_shape_size = 0;
|
||||
output_shape_size = 0;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
work_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
int CheckKernelParam() override {
|
||||
constexpr size_t CROP_SHAPE_1 = 2;
|
||||
crops_ = attr_ptr_->crops;
|
||||
block_shape_ = attr_ptr_->block_shape;
|
||||
block_rank_ = block_shape_.size();
|
||||
off_set_ = input_shape_.size() - block_shape_.size();
|
||||
|
||||
if (static_cast<int>(block_shape_.size()) - static_cast<int>(input_shape_.size()) >= 0) {
|
||||
MS_LOG(ERROR) << kernel_name_ << " resize failed because input shape should be greater than block shape, "
|
||||
<< "but input shape is " << input_shape_ << " and block shape is " << block_shape_;
|
||||
return -1;
|
||||
}
|
||||
|
||||
// check crops_
|
||||
if (crops_.size() != block_rank_) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', the size of 'crops' should be equal to the length of 'block_shape': " << block_rank_
|
||||
<< ", but got " << crops_.size();
|
||||
return -1;
|
||||
}
|
||||
int64_t block_shape_prod = 1;
|
||||
for (size_t idx_i = 0; idx_i < block_rank_; ++idx_i) {
|
||||
if (block_shape_[idx_i] < 1) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', the elements of 'block_shape' should be both larger than 1, but got " << idx_i
|
||||
<< "'th block size " << block_shape_[idx_i] << ")\n";
|
||||
return -1;
|
||||
}
|
||||
block_shape_prod = block_shape_prod * block_shape_[idx_i];
|
||||
if (crops_[idx_i].size() != CROP_SHAPE_1) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', the size of each vector of 'crops' should be equal to the length of 'block_shape': "
|
||||
<< CROP_SHAPE_1 << ", but got " << idx_i << "'th element: " << crops_[idx_i].size();
|
||||
return -1;
|
||||
}
|
||||
for (size_t idx_j = 0; idx_j < CROP_SHAPE_1; ++idx_j) {
|
||||
if (crops_[idx_i][idx_j] < 0) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the element of 'crops' cannot be less than 0, "
|
||||
<< "but got crops[" << idx_i << "][ " << idx_j << "]: " << crops_[idx_i][idx_j];
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (input_shape_[0] % block_shape_prod != 0) {
|
||||
MS_LOG(ERROR)
|
||||
<< "For '" << kernel_name_
|
||||
<< "', the first dim of 'input_x' must be divisible by 'block_shape_prod'. But got first dim of 'input_x': "
|
||||
<< input_shape_[0] << ", 'block_shape_prod' with value: " << block_shape_prod << ".";
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<BatchToSpaceNDAttr> attr_ptr_;
|
||||
std::vector<std::vector<int64_t>> crops_;
|
||||
std::vector<int64_t> block_shape_;
|
||||
std::vector<int64_t> input_shape_;
|
||||
std::vector<int64_t> output_shape_;
|
||||
size_t block_rank_;
|
||||
size_t off_set_;
|
||||
size_t input_shape_size;
|
||||
size_t output_shape_size;
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
bool is_null_input_;
|
||||
};
|
||||
} // namespace cukernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_BATCH_TO_SPACE_ND_HELPER_H_
|
|
@ -0,0 +1,190 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_SPACE_TO_BATCH_ND_HELPER_H_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_SPACE_TO_BATCH_ND_HELPER_H_
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_class/helper_base.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/space_to_batch_nd_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace cukernel {
|
||||
class SpaceToBatchNDAttr : public GpuKernelAttrBase {
|
||||
public:
|
||||
SpaceToBatchNDAttr() = default;
|
||||
~SpaceToBatchNDAttr() override = default;
|
||||
std::vector<std::vector<int64_t>> paddings;
|
||||
std::vector<int64_t> block_shape;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class SpaceToBatchNDHelperGpuKernel : public GpuKernelHelperBase {
|
||||
public:
|
||||
explicit SpaceToBatchNDHelperGpuKernel(const std::string &kernel_name, const uint32_t &device_id)
|
||||
: GpuKernelHelperBase(kernel_name, device_id) {
|
||||
is_null_input_ = false;
|
||||
}
|
||||
|
||||
virtual ~SpaceToBatchNDHelperGpuKernel() = default;
|
||||
int CalMemSize(const std::vector<std::vector<int64_t>> &input_shapes,
|
||||
const std::vector<std::vector<int64_t>> &output_shapes) override {
|
||||
constexpr size_t INPUT_NUM = 1;
|
||||
constexpr size_t OUTPUT_NUM = 1;
|
||||
ResetResource();
|
||||
int inp_flag = CalShapesSizeInBytes<T>(input_shapes, INPUT_NUM, kernel_name_, "input_shapes", &input_size_list_);
|
||||
if (inp_flag == -1) {
|
||||
return inp_flag;
|
||||
}
|
||||
input_shape_ = input_shapes[0];
|
||||
|
||||
int out_flag =
|
||||
CalShapesSizeInBytes<T>(output_shapes, OUTPUT_NUM, kernel_name_, "output_shapes", &output_size_list_);
|
||||
if (out_flag == -1) {
|
||||
return out_flag;
|
||||
}
|
||||
output_shape_ = output_shapes[0];
|
||||
is_null_input_ = (inp_flag == 1 || out_flag == 1);
|
||||
return CheckKernelParam();
|
||||
}
|
||||
|
||||
int Process(const std::vector<void *> &input_ptrs, const std::vector<void *> &output_ptrs,
|
||||
const std::vector<void *> &work_ptrs, void *cuda_stream) override {
|
||||
if (is_null_input_) {
|
||||
return 0;
|
||||
}
|
||||
input_size_ = 1;
|
||||
for (size_t i = 0; i < (size_t) static_cast<int64_t>(input_shape_.size()); ++i) {
|
||||
input_size_ = input_shape_[i] * input_size_;
|
||||
}
|
||||
|
||||
output_size_ = 1;
|
||||
for (size_t i = 0; i < (size_t) static_cast<int64_t>(output_shape_.size()); ++i) {
|
||||
output_size_ = output_shape_[i] * output_size_;
|
||||
}
|
||||
|
||||
T *input_ptr = nullptr;
|
||||
T *output_ptr = nullptr;
|
||||
input_shape_size = input_shape_.size();
|
||||
output_shape_size = output_shape_.size();
|
||||
|
||||
int flag = GetDeviceAddress<T>(input_ptrs, 0, kernel_name_, &input_ptr);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
|
||||
flag = GetDeviceAddress<T>(output_ptrs, 0, kernel_name_, &output_ptr);
|
||||
if (flag != 0) {
|
||||
return flag;
|
||||
}
|
||||
|
||||
std::vector<int64_t> paddings_start(paddings_.size(), 0);
|
||||
for (int i = 0; i < static_cast<int>(paddings_.size()); i++) paddings_start[i] = paddings_[i][0];
|
||||
std::vector<int64_t> stride(output_shape_size, 1);
|
||||
for (int i = static_cast<int>(output_shape_size) - 2; i >= 0; i--) stride[i] = stride[i + 1] * output_shape_[i + 1];
|
||||
std::vector<int64_t> on_stride(block_rank, 1);
|
||||
if (block_rank > 1) {
|
||||
for (int i = static_cast<int>(block_rank) - 2; i >= 0; i--) on_stride[i] = on_stride[i + 1] * block_shape_[i + 1];
|
||||
}
|
||||
|
||||
// call cuda kernel
|
||||
CalSpaceToBatchND(input_ptr, paddings_start.data(), block_shape_.data(), input_shape_.data(), input_shape_size,
|
||||
stride.data(), on_stride.data(), off_set, input_size_, output_size_, output_ptr, device_id_,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream));
|
||||
return 0;
|
||||
}
|
||||
|
||||
void SetKernelParam(const GpuKernelAttrBasePtr &kernel_attr) override {
|
||||
attr_ptr_ = std::dynamic_pointer_cast<SpaceToBatchNDAttr>(kernel_attr);
|
||||
}
|
||||
|
||||
void ResetResource() noexcept override {
|
||||
block_rank = 0;
|
||||
off_set = 0;
|
||||
input_size_ = 0;
|
||||
output_size_ = 0;
|
||||
input_shape_size = 0;
|
||||
output_shape_size = 0;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
work_size_list_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
int CheckKernelParam() override {
|
||||
constexpr size_t PADDING_SHAPE_1 = 2;
|
||||
paddings_ = attr_ptr_->paddings;
|
||||
block_shape_ = attr_ptr_->block_shape;
|
||||
block_rank = block_shape_.size();
|
||||
off_set = input_shape_.size() - block_shape_.size();
|
||||
|
||||
// check paddings_
|
||||
if (paddings_.size() != block_rank) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', the size of 'paddings' should be equal to the length of 'block_size': " << block_rank
|
||||
<< ", but got " << paddings_.size();
|
||||
return -1;
|
||||
}
|
||||
|
||||
for (size_t idx_i = 0; idx_i < block_rank; ++idx_i) {
|
||||
if (paddings_[idx_i].size() != PADDING_SHAPE_1) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', the size of each vector of 'paddings' should be equal to the length of 'block_size': "
|
||||
<< PADDING_SHAPE_1 << ", but got " << idx_i << "'th element: " << paddings_[idx_i].size();
|
||||
return -1;
|
||||
}
|
||||
for (size_t idx_j = 0; idx_j < PADDING_SHAPE_1; ++idx_j) {
|
||||
if (paddings_[idx_i][idx_j] < 0) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the element of 'paddings' cannot be less than 0, "
|
||||
<< "but got paddings[" << idx_i << "][ " << idx_j << "]: " << paddings_[idx_i][idx_j];
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
auto tmp_shape = input_shape_[idx_i + off_set] + paddings_[idx_i][0] + paddings_[idx_i][1];
|
||||
if ((tmp_shape % block_shape_[idx_i]) != 0) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', padded shape should be divisible by block_size, , but got padded shape: " << tmp_shape
|
||||
<< ", block_size: " << block_shape_[idx_i];
|
||||
return -1;
|
||||
}
|
||||
if ((tmp_shape / block_shape_[idx_i]) == 0) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', padded shape cannot be less than block_size"
|
||||
<< ", but got padded shape: " << tmp_shape << ", block_size: " << block_shape_[idx_i];
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<SpaceToBatchNDAttr> attr_ptr_;
|
||||
std::vector<std::vector<int64_t>> paddings_;
|
||||
std::vector<int64_t> block_shape_;
|
||||
std::vector<int64_t> input_shape_;
|
||||
std::vector<int64_t> output_shape_;
|
||||
size_t block_rank;
|
||||
size_t off_set;
|
||||
size_t input_shape_size;
|
||||
size_t output_shape_size;
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
bool is_null_input_;
|
||||
};
|
||||
} // namespace cukernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_CLASS_SPACE_TO_BATCH_ND_HELPER_H_
|
|
@ -0,0 +1,141 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <cuda_runtime.h>
|
||||
#include "batch_to_space_nd_impl.cuh"
|
||||
|
||||
__constant__ int64_t con_crop_start[8];
|
||||
__constant__ int64_t con_block_shape[8];
|
||||
__constant__ int64_t con_output_shape[8];
|
||||
__constant__ int64_t stride[8];
|
||||
__constant__ int64_t on_stride[8];
|
||||
|
||||
template <typename T>
|
||||
__global__ void BatchToSpaceND(const T *__restrict__ input, const int64_t *crops_start, const int64_t *block_shape,
|
||||
const int64_t *output_shape, const size_t output_shape_size, const size_t off_set_,
|
||||
const size_t output_size_, T *__restrict__ output) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < output_size_; pos += gridDim.x * blockDim.x) {
|
||||
int64_t cur_pos = pos;
|
||||
int idx = output_shape_size - 1;
|
||||
int64_t in_pos = 0;
|
||||
int64_t idx_on = 0;
|
||||
int64_t cur_out_idx = 0;
|
||||
int64_t cur_inp_idx = 0;
|
||||
int64_t temp_idx = 0;
|
||||
int64_t offset_idx = 0;
|
||||
for (; idx >= off_set_; --idx) {
|
||||
cur_out_idx = cur_pos % con_output_shape[idx];
|
||||
cur_pos /= con_output_shape[idx];
|
||||
offset_idx = idx - off_set_;
|
||||
temp_idx = (cur_out_idx + con_crop_start[offset_idx]);
|
||||
cur_inp_idx = temp_idx / con_block_shape[offset_idx];
|
||||
in_pos += cur_inp_idx * stride[idx];
|
||||
idx_on += (temp_idx % con_block_shape[offset_idx]) * on_stride[offset_idx];
|
||||
}
|
||||
for (; idx > 0; --idx) {
|
||||
cur_out_idx = cur_pos % con_output_shape[idx];
|
||||
cur_pos /= con_output_shape[idx];
|
||||
in_pos += cur_out_idx * stride[idx];
|
||||
}
|
||||
|
||||
cur_inp_idx = idx_on * con_output_shape[0] + cur_pos % con_output_shape[0];
|
||||
|
||||
in_pos += cur_inp_idx * stride[0];
|
||||
output[pos] = input[in_pos];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalBatchToSpaceND(const T *input, const int64_t *crops_start, const int64_t *block_shape,
|
||||
const int64_t *output_shape, const size_t output_shape_size, const int64_t *stride_,
|
||||
const int64_t *on_stride_, const size_t off_set_, const size_t output_size_, T *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
cudaMemcpyToSymbol(con_output_shape, output_shape, sizeof(int64_t) * output_shape_size);
|
||||
cudaMemcpyToSymbol(con_block_shape, block_shape, sizeof(int64_t) * (output_shape_size - off_set_));
|
||||
cudaMemcpyToSymbol(con_crop_start, crops_start, sizeof(int64_t) * (output_shape_size - off_set_));
|
||||
cudaMemcpyToSymbol(stride, stride_, sizeof(int64_t) * output_shape_size);
|
||||
cudaMemcpyToSymbol(on_stride, on_stride_, sizeof(int64_t) * (output_shape_size - off_set_));
|
||||
BatchToSpaceND<<<CUDA_BLOCKS(device_id, output_size_), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
input, crops_start, block_shape, output_shape, output_shape_size, off_set_, output_size_, output);
|
||||
return;
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void CalBatchToSpaceND<int8_t>(const int8_t *input, const int64_t *crops_start,
|
||||
const int64_t *block_shape, const int64_t *output_shape,
|
||||
const size_t output_shape_size, const int64_t *stride_,
|
||||
const int64_t *on_stride_, const size_t off_set_,
|
||||
const size_t output_size_, int8_t *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalBatchToSpaceND<int16_t>(const int16_t *input, const int64_t *crops_start,
|
||||
const int64_t *block_shape, const int64_t *output_shape,
|
||||
const size_t output_shape_size, const int64_t *stride_,
|
||||
const int64_t *on_stride_, const size_t off_set_,
|
||||
const size_t output_size_, int16_t *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalBatchToSpaceND<int32_t>(const int32_t *input, const int64_t *crops_start,
|
||||
const int64_t *block_shape, const int64_t *output_shape,
|
||||
const size_t output_shape_size, const int64_t *stride_,
|
||||
const int64_t *on_stride_, const size_t off_set_,
|
||||
const size_t output_size_, int32_t *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalBatchToSpaceND<int64_t>(const int64_t *input, const int64_t *crops_start,
|
||||
const int64_t *block_shape, const int64_t *output_shape,
|
||||
const size_t output_shape_size, const int64_t *stride_,
|
||||
const int64_t *on_stride_, const size_t off_set_,
|
||||
const size_t output_size_, int64_t *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalBatchToSpaceND<uint8_t>(const uint8_t *input, const int64_t *crops_start,
|
||||
const int64_t *block_shape, const int64_t *output_shape,
|
||||
const size_t output_shape_size, const int64_t *stride_,
|
||||
const int64_t *on_stride_, const size_t off_set_,
|
||||
const size_t output_size_, uint8_t *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalBatchToSpaceND<uint16_t>(const uint16_t *input, const int64_t *crops_start,
|
||||
const int64_t *block_shape, const int64_t *output_shape,
|
||||
const size_t output_shape_size, const int64_t *stride_,
|
||||
const int64_t *on_stride_, const size_t off_set_,
|
||||
const size_t output_size_, uint16_t *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalBatchToSpaceND<uint32_t>(const uint32_t *input, const int64_t *crops_start,
|
||||
const int64_t *block_shape, const int64_t *output_shape,
|
||||
const size_t output_shape_size, const int64_t *stride_,
|
||||
const int64_t *on_stride_, const size_t off_set_,
|
||||
const size_t output_size_, uint32_t *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalBatchToSpaceND<uint64_t>(const uint64_t *input, const int64_t *crops_start,
|
||||
const int64_t *block_shape, const int64_t *output_shape,
|
||||
const size_t output_shape_size, const int64_t *stride_,
|
||||
const int64_t *on_stride_, const size_t off_set_,
|
||||
const size_t output_size_, uint64_t *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalBatchToSpaceND<half>(const half *input, const int64_t *crops_start,
|
||||
const int64_t *block_shape, const int64_t *output_shape,
|
||||
const size_t output_shape_size, const int64_t *stride_,
|
||||
const int64_t *on_stride_, const size_t off_set_,
|
||||
const size_t output_size_, half *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalBatchToSpaceND<float>(const float *input, const int64_t *crops_start,
|
||||
const int64_t *block_shape, const int64_t *output_shape,
|
||||
const size_t output_shape_size, const int64_t *stride_,
|
||||
const int64_t *on_stride_, const size_t off_set_,
|
||||
const size_t output_size_, float *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalBatchToSpaceND<double>(const double *input, const int64_t *crops_start,
|
||||
const int64_t *block_shape, const int64_t *output_shape,
|
||||
const size_t output_shape_size, const int64_t *stride_,
|
||||
const int64_t *on_stride_, const size_t off_set_,
|
||||
const size_t output_size_, double *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPACE_TO_BATCH_ND_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPACE_TO_BATCH_ND_IMPL_CUH_
|
||||
#include <vector>
|
||||
#include "include/cuda_fp16.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void CalBatchToSpaceND(const T *input, const int64_t *crops_start, const int64_t *block_shape,
|
||||
const int64_t *output_shape, const size_t output_shape_size,
|
||||
const int64_t *stride_, const int64_t *on_stride_, const size_t off_set_,
|
||||
const size_t output_size_, T *output, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BATCH_TO_SPACE_ND_IMPL_CUH_
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
#include "betainc_impl.cuh"
|
||||
#include <math.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include "unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
template <typename T>
|
||||
void CalBetainc(const size_t size, T *input_a, T *input_b, T *input_x, T *output, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream) {
|
||||
int num = static_cast<int>(size);
|
||||
T *agpu = input_a, *bgpu = input_b, *xgpu = input_x;
|
||||
int gpudevice = device_id;
|
||||
Eigen::GpuStreamDevice stream(&cuda_stream, gpudevice);
|
||||
Eigen::GpuDevice gpu_device(&stream);
|
||||
Eigen::TensorMap<Eigen::Tensor<T, 1>> Eigen_a(agpu, num);
|
||||
Eigen::TensorMap<Eigen::Tensor<T, 1>> Eigen_b(bgpu, num);
|
||||
Eigen::TensorMap<Eigen::Tensor<T, 1>> Eigen_x(xgpu, num);
|
||||
Eigen::TensorMap<Eigen::Tensor<T, 1>> Eigen_z(output, num);
|
||||
Eigen_z.device(gpu_device) = Eigen::betainc(Eigen_a, Eigen_b, Eigen_x);
|
||||
return;
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void CalBetainc<float>(const size_t size, float *input_a, float *input_b, float *input_x,
|
||||
float *output, const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
template CUDA_LIB_EXPORT void CalBetainc<double>(const size_t size, double *input_a, double *input_b, double *input_x,
|
||||
double *output, const uint32_t &device_id, cudaStream_t cuda_stream);
|
|
@ -0,0 +1,27 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BETAINC_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BETAINC_IMPL_CUH_
|
||||
|
||||
#include <vector>
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void CalBetainc(const size_t size, T *a, T *b, T *x, T *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_BETAINC_IMPL_CUH_
|
|
@ -0,0 +1,152 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include <cuda_runtime.h>
|
||||
#include "space_to_batch_nd_impl.cuh"
|
||||
|
||||
__constant__ int64_t con_paddings_start[8];
|
||||
__constant__ int64_t con_block_shape[8];
|
||||
__constant__ int64_t con_input_shape[8];
|
||||
__constant__ int64_t con_stride[8];
|
||||
__constant__ int64_t con_on_stride[8];
|
||||
|
||||
template <typename T>
|
||||
__global__ void SpaceToBatchND(const T *__restrict__ input, const int64_t *paddings_start, const int64_t *block_shape,
|
||||
const int64_t *input_shape, const size_t input_shape_size, const size_t off_set,
|
||||
const size_t input_size_, T *__restrict__ output) {
|
||||
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < input_size_; pos += gridDim.x * blockDim.x) {
|
||||
int64_t cur_pos = pos;
|
||||
int idx = input_shape_size - 1;
|
||||
int64_t out_pos = 0;
|
||||
int64_t idx_on = 0;
|
||||
int64_t cur_out_idx = 0;
|
||||
int64_t cur_inp_idx = 0;
|
||||
int64_t temp_idx = 0;
|
||||
int64_t offset_idx = 0;
|
||||
for (; idx >= off_set; --idx) {
|
||||
cur_out_idx = cur_pos % con_input_shape[idx];
|
||||
cur_pos /= con_input_shape[idx];
|
||||
offset_idx = idx - off_set;
|
||||
temp_idx = (cur_out_idx + con_paddings_start[offset_idx]);
|
||||
cur_inp_idx = temp_idx / con_block_shape[offset_idx];
|
||||
out_pos += cur_inp_idx * con_stride[idx];
|
||||
idx_on += (temp_idx % con_block_shape[offset_idx]) * con_on_stride[offset_idx];
|
||||
}
|
||||
for (; idx > 0; --idx) {
|
||||
cur_out_idx = cur_pos % con_input_shape[idx];
|
||||
cur_pos /= con_input_shape[idx];
|
||||
out_pos += cur_out_idx * con_stride[idx];
|
||||
}
|
||||
|
||||
cur_inp_idx = idx_on * con_input_shape[0] + cur_pos % con_input_shape[0];
|
||||
|
||||
out_pos += cur_inp_idx * con_stride[0];
|
||||
output[out_pos] = input[pos];
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void CalSpaceToBatchND(const T *input, const int64_t *paddings_start, const int64_t *block_shape,
|
||||
const int64_t *input_shape, const size_t input_shape_size, const int64_t *stride,
|
||||
const int64_t *on_stride, const size_t off_set, const size_t input_size_,
|
||||
const size_t output_size_, T *output, const uint32_t &device_id, cudaStream_t cuda_stream) {
|
||||
cudaMemset(output, 0, output_size_ * sizeof(T));
|
||||
cudaMemcpyToSymbol(con_input_shape, input_shape, sizeof(int64_t) * input_shape_size);
|
||||
cudaMemcpyToSymbol(con_block_shape, block_shape, sizeof(int64_t) * (input_shape_size - off_set));
|
||||
cudaMemcpyToSymbol(con_paddings_start, paddings_start, sizeof(int64_t) * (input_shape_size - off_set));
|
||||
cudaMemcpyToSymbol(con_on_stride, on_stride, sizeof(int64_t) * (input_shape_size - off_set));
|
||||
cudaMemcpyToSymbol(con_stride, stride, sizeof(int64_t) * input_shape_size);
|
||||
SpaceToBatchND<<<CUDA_BLOCKS(device_id, input_size_), CUDA_THREADS(device_id), 0, cuda_stream>>>(
|
||||
input, paddings_start, block_shape, input_shape, input_shape_size, off_set, input_size_, output);
|
||||
return;
|
||||
}
|
||||
|
||||
template CUDA_LIB_EXPORT void CalSpaceToBatchND<uint8_t>(const uint8_t *input, const int64_t *paddings_start,
|
||||
const int64_t *block_shape, const int64_t *input_shape,
|
||||
const size_t input_shape_size, const int64_t *stride,
|
||||
const int64_t *on_stride, const size_t off_set,
|
||||
const size_t input_size_, const size_t output_size_,
|
||||
uint8_t *output, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSpaceToBatchND<uint16_t>(const uint16_t *input, const int64_t *paddings_start,
|
||||
const int64_t *block_shape, const int64_t *input_shape,
|
||||
const size_t input_shape_size, const int64_t *stride,
|
||||
const int64_t *on_stride, const size_t off_set,
|
||||
const size_t input_size_, const size_t output_size_,
|
||||
uint16_t *output, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSpaceToBatchND<uint32_t>(const uint32_t *input, const int64_t *paddings_start,
|
||||
const int64_t *block_shape, const int64_t *input_shape,
|
||||
const size_t input_shape_size, const int64_t *stride,
|
||||
const int64_t *on_stride, const size_t off_set,
|
||||
const size_t input_size_, const size_t output_size_,
|
||||
uint32_t *output, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSpaceToBatchND<uint64_t>(const uint64_t *input, const int64_t *paddings_start,
|
||||
const int64_t *block_shape, const int64_t *input_shape,
|
||||
const size_t input_shape_size, const int64_t *stride,
|
||||
const int64_t *on_stride, const size_t off_set,
|
||||
const size_t input_size_, const size_t output_size_,
|
||||
uint64_t *output, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSpaceToBatchND<int8_t>(const int8_t *input, const int64_t *paddings_start,
|
||||
const int64_t *block_shape, const int64_t *input_shape,
|
||||
const size_t input_shape_size, const int64_t *stride,
|
||||
const int64_t *on_stride, const size_t off_set,
|
||||
const size_t input_size_, const size_t output_size_,
|
||||
int8_t *output, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSpaceToBatchND<int16_t>(const int16_t *input, const int64_t *paddings_start,
|
||||
const int64_t *block_shape, const int64_t *input_shape,
|
||||
const size_t input_shape_size, const int64_t *stride,
|
||||
const int64_t *on_stride, const size_t off_set,
|
||||
const size_t input_size_, const size_t output_size_,
|
||||
int16_t *output, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSpaceToBatchND<int32_t>(const int32_t *input, const int64_t *paddings_start,
|
||||
const int64_t *block_shape, const int64_t *input_shape,
|
||||
const size_t input_shape_size, const int64_t *stride,
|
||||
const int64_t *on_stride, const size_t off_set,
|
||||
const size_t input_size_, const size_t output_size_,
|
||||
int32_t *output, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSpaceToBatchND<int64_t>(const int64_t *input, const int64_t *paddings_start,
|
||||
const int64_t *block_shape, const int64_t *input_shape,
|
||||
const size_t input_shape_size, const int64_t *stride,
|
||||
const int64_t *on_stride, const size_t off_set,
|
||||
const size_t input_size_, const size_t output_size_,
|
||||
int64_t *output, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSpaceToBatchND<half>(const half *input, const int64_t *paddings_start,
|
||||
const int64_t *block_shape, const int64_t *input_shape,
|
||||
const size_t input_shape_size, const int64_t *stride,
|
||||
const int64_t *on_stride, const size_t off_set,
|
||||
const size_t input_size_, const size_t output_size_, half *output,
|
||||
const uint32_t &device_id, cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSpaceToBatchND<float>(const float *input, const int64_t *paddings_start,
|
||||
const int64_t *block_shape, const int64_t *input_shape,
|
||||
const size_t input_shape_size, const int64_t *stride,
|
||||
const int64_t *on_stride, const size_t off_set,
|
||||
const size_t input_size_, const size_t output_size_,
|
||||
float *output, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
template CUDA_LIB_EXPORT void CalSpaceToBatchND<double>(const double *input, const int64_t *paddings_start,
|
||||
const int64_t *block_shape, const int64_t *input_shape,
|
||||
const size_t input_shape_size, const int64_t *stride,
|
||||
const int64_t *on_stride, const size_t off_set,
|
||||
const size_t input_size_, const size_t output_size_,
|
||||
double *output, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPACE_TO_BATCH_ND_IMPL_CUH_
|
||||
#define MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPACE_TO_BATCH_ND_IMPL_CUH_
|
||||
#include <vector>
|
||||
#include "include/cuda_fp16.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/cuda_device_info.h"
|
||||
template <typename T>
|
||||
CUDA_LIB_EXPORT void CalSpaceToBatchND(const T *input, const int64_t *paddings_start, const int64_t *block_shape,
|
||||
const int64_t *input_shape, const size_t input_shape_size, const int64_t *stride,
|
||||
const int64_t *on_stride, const size_t off_set, const size_t input_size_,
|
||||
const size_t output_size_, T *output, const uint32_t &device_id,
|
||||
cudaStream_t cuda_stream);
|
||||
|
||||
#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_GPU_KERNEL_CUDA_IMPL_CUDA_OPS_SPACE_TO_BATCH_ND_IMPL_CUH_
|
|
@ -0,0 +1,126 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "plugin/device/gpu/kernel/math/betainc_gpu_kernel.h"
|
||||
#include <utility>
|
||||
#include <algorithm>
|
||||
#include "kernel/common_utils.h"
|
||||
#include "abstract/utils.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/betainc_impl.cuh"
|
||||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
namespace {
|
||||
constexpr size_t kBetaincInputsNum = 3;
|
||||
constexpr size_t kAIndex = 0;
|
||||
constexpr size_t kBIndex = 1;
|
||||
constexpr size_t kXIndex = 2;
|
||||
} // namespace
|
||||
|
||||
bool BetaincGpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
auto kernel_attr = GetKernelAttrFromTensors(inputs, outputs);
|
||||
auto [is_match, index] = MatchKernelAttr(kernel_attr, GetOpSupport());
|
||||
if (inputs.empty() || outputs.empty()) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "' got empty inputs or outputs, which is invalid.";
|
||||
return false;
|
||||
}
|
||||
if (inputs.size() != kBetaincInputsNum) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the number of inputs should be 3, but got " << inputs.size();
|
||||
return false;
|
||||
}
|
||||
constexpr int OUTPUT_NUM = 1;
|
||||
if (outputs.size() != OUTPUT_NUM) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_ << "', the number of outputs should be 1, but got " << outputs.size();
|
||||
return false;
|
||||
}
|
||||
if (!is_match) {
|
||||
MS_LOG(ERROR) << "For '" << kernel_name_
|
||||
<< "', the kernel type should be in [float32, float64], but got: " << kernel_attr << ".";
|
||||
return false;
|
||||
}
|
||||
kernel_func_ = func_list_[index].second;
|
||||
input_size_ = abstract::TypeIdSize(kernel_attr.GetInputAttr(kIndex0).first);
|
||||
return true;
|
||||
}
|
||||
int BetaincGpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost) {
|
||||
for (const auto &input : inputs) {
|
||||
// If any input shape contains -1, means input shape is dynamic, so just return do nothing.
|
||||
auto input_shape = input->GetShapeVector();
|
||||
if (!IsValidShape(input_shape)) {
|
||||
return KRET_UNKNOWN_SHAPE;
|
||||
}
|
||||
}
|
||||
ResetResource();
|
||||
std::vector<int64_t> input_shape_ = std::vector<int64_t>(inputs.at(kIndex0)->GetDeviceShapeAdaptively().begin(),
|
||||
inputs.at(kIndex0)->GetDeviceShapeAdaptively().end());
|
||||
for (int64_t i = 0; i < static_cast<int64_t>(input_shape_.size()); i++) {
|
||||
input_size_ *= input_shape_[i];
|
||||
}
|
||||
InitSizeLists();
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
bool BetaincGpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &inputs,
|
||||
const std::vector<kernel::AddressPtr> &workspace,
|
||||
const std::vector<kernel::AddressPtr> &outputs, void *cuda_stream) {
|
||||
if (is_null_input_) {
|
||||
return true;
|
||||
}
|
||||
return kernel_func_(this, inputs, outputs, workspace, cuda_stream);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool BetaincGpuKernelMod::LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs,
|
||||
const std::vector<AddressPtr> &workspace, void *cuda_stream) {
|
||||
T *input_a = GetDeviceAddress<T>(inputs, kAIndex);
|
||||
T *input_b = GetDeviceAddress<T>(inputs, kBIndex);
|
||||
T *input_x = GetDeviceAddress<T>(inputs, kXIndex);
|
||||
T *output = GetDeviceAddress<T>(outputs, kAIndex);
|
||||
|
||||
CalBetainc(input_size_ / sizeof(T), input_a, input_b, input_x, output, device_id_,
|
||||
reinterpret_cast<cudaStream_t>(cuda_stream));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<std::pair<KernelAttr, BetaincGpuKernelMod::KernelFunc>> BetaincGpuKernelMod::func_list_ = {
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddInputAttr(kNumberTypeFloat32)
|
||||
.AddOutputAttr(kNumberTypeFloat32),
|
||||
&BetaincGpuKernelMod::LaunchKernel<float>},
|
||||
{KernelAttr()
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddInputAttr(kNumberTypeFloat64)
|
||||
.AddOutputAttr(kNumberTypeFloat64),
|
||||
&BetaincGpuKernelMod::LaunchKernel<double>}};
|
||||
|
||||
std::vector<KernelAttr> BetaincGpuKernelMod::GetOpSupport() {
|
||||
std::vector<KernelAttr> support_list;
|
||||
(void)std::transform(func_list_.begin(), func_list_.end(), std::back_inserter(support_list),
|
||||
[](const std::pair<KernelAttr, KernelFunc> &pair) { return pair.first; });
|
||||
return support_list;
|
||||
}
|
||||
|
||||
MS_KERNEL_FACTORY_REG(NativeGpuKernelMod, Betainc, BetaincGpuKernelMod);
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,96 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_BETAINC_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_BETAINC_GPU_KERNEL_H_
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
#include <utility>
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/betainc_impl.cuh"
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
class BetaincGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
BetaincGpuKernelMod() { ResetResource(); }
|
||||
~BetaincGpuKernelMod() override = default;
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override;
|
||||
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override;
|
||||
|
||||
int Resize(
|
||||
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override;
|
||||
|
||||
void ResetResource() noexcept {
|
||||
is_null_input_ = false;
|
||||
input_elements_ = 0;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
protected:
|
||||
void InitSizeLists() {
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
input_size_list_.push_back(input_size_);
|
||||
input_size_list_.push_back(input_size_);
|
||||
input_size_list_.push_back(input_size_);
|
||||
output_size_list_.push_back(input_size_);
|
||||
workspace_size_list_.push_back(input_size_);
|
||||
workspace_size_list_.push_back(input_size_);
|
||||
workspace_size_list_.push_back(input_size_);
|
||||
workspace_size_list_.push_back(input_size_);
|
||||
workspace_size_list_.push_back(input_size_);
|
||||
workspace_size_list_.push_back(input_size_);
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
bool LaunchKernel(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs,
|
||||
const std::vector<AddressPtr> &workspace, void *cuda_stream);
|
||||
using KernelFunc =
|
||||
std::function<bool(BetaincGpuKernelMod *, const std::vector<kernel::AddressPtr> &,
|
||||
const std::vector<kernel::AddressPtr> &, const std::vector<kernel::AddressPtr> &, void *)>;
|
||||
KernelFunc kernel_func_{};
|
||||
static std::vector<std::pair<KernelAttr, KernelFunc>> func_list_;
|
||||
bool is_null_input_;
|
||||
size_t input_elements_;
|
||||
size_t input_size_;
|
||||
size_t output_size;
|
||||
std::vector<size_t> a_shape_;
|
||||
std::vector<size_t> b_shape_;
|
||||
std::vector<size_t> x_shape_;
|
||||
std::vector<size_t> output_shape_;
|
||||
};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_BETAINC_GPU_KERNEL_H_
|
|
@ -31,6 +31,10 @@ abstract::ShapePtr AngleInferShape(const PrimitivePtr &primitive, const std::vec
|
|||
(void)CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(primitive_name, input_args, 0);
|
||||
auto x = input_args[0]->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(x);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->GetShapeTrack())[kShape];
|
||||
if (IsDynamicRank(x_shape)) {
|
||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{-2});
|
||||
}
|
||||
auto shape_element = x->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_element);
|
||||
return shape_element;
|
||||
|
|
|
@ -36,19 +36,19 @@ abstract::ShapePtr BatchToSpaceNDInferShape(const PrimitivePtr &primitive,
|
|||
if (IsDynamicRank(x_shape)) {
|
||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeRankAny});
|
||||
}
|
||||
constexpr int64_t len = 4;
|
||||
constexpr int64_t ascend_len = 4;
|
||||
constexpr int64_t len = 2;
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
bool is_ascend = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice);
|
||||
(void)CheckAndConvertUtils::CheckInteger("input_x rank", SizeToLong(x_shape.size()),
|
||||
(is_ascend ? kEqual : kGreaterEqual), len, prim_name);
|
||||
(is_ascend ? kEqual : kGreaterEqual), (is_ascend ? ascend_len : len),
|
||||
prim_name);
|
||||
if (IsDynamicShape(x_shape)) {
|
||||
std::vector<int64_t> res(x_shape.size(), abstract::Shape::kShapeDimAny);
|
||||
return std::make_shared<abstract::Shape>(res);
|
||||
}
|
||||
|
||||
auto out_shape = x_shape;
|
||||
|
||||
int64_t block_shape_prod = 1;
|
||||
auto block_shape = GetValue<std::vector<int64_t>>(primitive->GetAttr(kBlockShape));
|
||||
auto crops = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kCrops));
|
||||
|
@ -58,14 +58,17 @@ abstract::ShapePtr BatchToSpaceNDInferShape(const PrimitivePtr &primitive,
|
|||
block_shape_prod = block_shape_prod * block_shape[i];
|
||||
auto x_block_prod = out_shape[i + offset] * block_shape[i];
|
||||
auto crops_sum = crops[i][0] + crops[i][1];
|
||||
CheckAndConvertUtils::Check("x block shape prod", x_block_prod, kGreaterThan, crops_sum, prim_name);
|
||||
if (out_shape[i + offset] >= 0)
|
||||
CheckAndConvertUtils::Check("x block shape prod", x_block_prod, kGreaterThan, crops_sum, prim_name);
|
||||
out_shape[i + offset] = x_block_prod - crops_sum;
|
||||
}
|
||||
if (out_shape[0] % block_shape_prod != 0) {
|
||||
MS_EXCEPTION(ValueError)
|
||||
<< "For '" << prim_name
|
||||
<< "', the first dim of 'input_x' must be divisible by 'block_shape_prod'. But got first dim of 'input_x': "
|
||||
<< out_shape[0] << ", 'block_shape_prod' with value: " << block_shape_prod << ".";
|
||||
if (out_shape[0] >= 0) {
|
||||
if (out_shape[0] % block_shape_prod != 0) {
|
||||
MS_EXCEPTION(ValueError)
|
||||
<< "For '" << prim_name
|
||||
<< "', the first dim of 'input_x' must be divisible by 'block_shape_prod'. But got first dim of 'input_x': "
|
||||
<< out_shape[0] << ", 'block_shape_prod' with value: " << block_shape_prod << ".";
|
||||
}
|
||||
}
|
||||
out_shape[0] = int64_t(floor(out_shape[0] / static_cast<float>(block_shape_prod)));
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
|
|
|
@ -34,6 +34,12 @@ abstract::ShapePtr BetaincInferShape(const PrimitivePtr &primitive, const std::v
|
|||
auto b_shape_ptr = b_shape->cast<abstract::ShapePtr>();
|
||||
auto x_shape = input_args[kInputIndex2]->BuildShape();
|
||||
auto x_shape_ptr = x_shape->cast<abstract::ShapePtr>();
|
||||
auto a_rank_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto b_rank_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto x_rank_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
if (IsDynamic(a_rank_shape) || IsDynamic(b_rank_shape) || IsDynamic(x_rank_shape)) {
|
||||
return a_shape_ptr;
|
||||
}
|
||||
if (*a_shape_ptr != *b_shape_ptr) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << ", shape of b " << b_shape->ToString()
|
||||
<< " are not consistent with the shape a " << a_shape->ToString() << " .";
|
||||
|
|
|
@ -94,13 +94,16 @@ abstract::ShapePtr SpaceToBatchNDInferShape(const PrimitivePtr &primitive,
|
|||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
(void)CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, 1, prim_name);
|
||||
auto x_shape = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
if (x_shape.size() != 0 && (IsDynamicRank(x_shape) || x_shape[0] == -1)) {
|
||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{-2});
|
||||
}
|
||||
constexpr size_t x_min_len = 2;
|
||||
CheckAndConvertUtils::CheckInteger("input_x rank", SizeToLong(x_shape.size()), kGreaterEqual, x_min_len, prim_name);
|
||||
for (const auto &item : input_args) {
|
||||
MS_EXCEPTION_IF_NULL(item);
|
||||
}
|
||||
auto input_shape_ptr = CheckAndConvertUtils::GetTensorInputShape(prim_name, input_args, 0);
|
||||
if (IsDynamicRank(input_shape_ptr->shape())) {
|
||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{abstract::Shape::kShapeRankAny});
|
||||
}
|
||||
|
||||
auto paddings_value_ptr = primitive->GetAttr(kPaddings);
|
||||
MS_EXCEPTION_IF_NULL(paddings_value_ptr);
|
||||
|
|
|
@ -3091,56 +3091,51 @@ def batch_to_space_nd(input_x, block_shape, crops):
|
|||
Divides batch dimension with blocks and interleaves these blocks back into spatial dimensions.
|
||||
|
||||
This operation will divide batch dimension N into blocks with block_shape, the output tensor's N dimension
|
||||
is the corresponding number of blocks after division. The output tensor's H, W dimension is the product of
|
||||
original H, W dimension and block_shape with given amount to crop from dimension, respectively.
|
||||
is the corresponding number of blocks after division. The output tensor's :math:`w_1, ..., w_M` dimension is
|
||||
the product of original :math:`w_1, ..., w_M` dimension and block_shape with given amount to crop from dimension,
|
||||
respectively.
|
||||
|
||||
If the input shape is :math:`(n, c, h, w)`, the output shape is :math:`(n', c', h', w')`.
|
||||
If the input shape is :math:`(n, c_1, ... c_k, w_1, ..., w_M)`, the output shape is
|
||||
:math:`(n, c_1, ... c_k, w_1, ..., w_M)`.
|
||||
|
||||
:math:`n' = n//(block\_shape[0]*block\_shape[1])`
|
||||
:math:`n' = n//(block\_shape[0]*...*block\_shape[M-1])`
|
||||
|
||||
:math:`c' = c`
|
||||
|
||||
:math:`h' = h*block\_shape[0]-crops[0][0]-crops[0][1]`
|
||||
|
||||
:math:`w' = w*block\_shape[1]-crops[1][0]-crops[1][1]`
|
||||
:math:`w'_i = w_i*block\_shape[i-1]-crops[i-1][0]-crops[i-1][1]`
|
||||
|
||||
Args:
|
||||
input_x (Tensor): The input tensor. It must be greater or equal to 4-D tensor(equal to 4-D tensor on Ascend),
|
||||
batch dimension must be divisible by product of `block_shape`. The data type is float16 or float32.
|
||||
input_x (Tensor): The input tensor. It must be greater or equal to 2-D tensor(equal to 4-D tensor on Ascend),
|
||||
batch dimension must be divisible by product of `block_shape`.
|
||||
block_shape (Union[list(int), tuple(int), int]): The block shape of dividing block with all value greater
|
||||
than 1. If `block_shape` is a tuple or list, the length of `block_shape` is M corresponding to the
|
||||
number of spatial dimensions. If `block_shape` is an int, the block size of M dimensions are the same,
|
||||
equal to `block_shape`. M must be 2.
|
||||
crops (Union[list(int), tuple(int)]): The crop value for H and W dimension, containing 2 subtraction list,
|
||||
each containing 2 int value.
|
||||
All values must be >= 0. crops[i] specifies the crop values for spatial dimension i, which corresponds to
|
||||
input dimension i+2. It is required that
|
||||
than or equal to 1. If `block_shape` is a tuple or list, the length of `block_shape` is M corresponding
|
||||
to the number of spatial dimensions. If `block_shape` is an int, the block size of M dimensions are the
|
||||
same, equal to `block_shape`. In this case of Ascend, M must be 2.
|
||||
crops (Union[list(int), tuple(int)]): The crops values for spatial dimensions, containing M subtraction list.
|
||||
Each contains 2 integer values. All values must be >= 0. crops[i] specifies the crops values for spatial
|
||||
dimension i, which corresponds to input dimension i + offset,where offset = N-M, and N is the number of
|
||||
input dimensions. It is required that
|
||||
|
||||
:math:`input\_shape[i+2]*block\_shape[i] > crops[i][0]+crops[i][1]`
|
||||
:math:`input\_shape[i+offset]*block\_shape[i] > crops[i][0]+crops[i][1]`
|
||||
|
||||
Returns:
|
||||
Tensor, the output tensor with the same type as input. Assume input shape is (n, c, h, w) with block_shape
|
||||
and crops. The output shape will be (n', c', h', w'), where
|
||||
Tensor, the output tensor with the same type as input. Assume input shape is
|
||||
:math:`(n, c_1, ... c_k, w_1, ..., w_M)` with block_shape and crops. The output shape will be
|
||||
:math:`(n', c_1, ... c_k, w'_1, ..., w'_M)`, where
|
||||
|
||||
:math:`n' = n//(block\_shape[0]*block\_shape[1])`
|
||||
:math:`n' = n//(block\_shape[0]*...*block\_shape[M-1])`
|
||||
|
||||
:math:`c' = c`
|
||||
|
||||
:math:`h' = h*block\_shape[0]-crops[0][0]-crops[0][1]`
|
||||
|
||||
:math:`w' = w*block\_shape[1]-crops[1][0]-crops[1][1]`
|
||||
:math:`w'_i = w_i*block\_shape[i-1]-crops[i-1][0]-crops[i-1][1]`
|
||||
|
||||
Raises:
|
||||
TypeError: If `block_shape` is not one of list, tuple, int.
|
||||
TypeError: If `crops` is neither list nor tuple.
|
||||
ValueError: If `block_shape` is not one dimensional when `block_shape` is a list or tuple.
|
||||
ValueError: If length of `block_shape` or `crops` is not equal to 2.
|
||||
ValueError: If the element of `block_shape` is not an integer larger than 1.
|
||||
ValueError: If the length of `block_shape` is not 2 on Ascend.
|
||||
ValueError: If the element of `block_shape` is not an integer larger than or euqal to 1.
|
||||
ValueError: If shape of `crops` is not (M, 2), where M is the length of `block_shape`.
|
||||
ValueError: If the element of `crops` is not an integer larger than 0.
|
||||
ValueError: If the element of `crops` is not an integer larger than or euqal to 0.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> block_shape = [2, 2]
|
||||
|
|
|
@ -5264,7 +5264,7 @@ class BatchToSpace(PrimitiveWithInfer):
|
|||
return out_shape
|
||||
|
||||
|
||||
class SpaceToBatchND(PrimitiveWithInfer):
|
||||
class SpaceToBatchND(Primitive):
|
||||
r"""
|
||||
Divides spatial dimensions into blocks and combines the block size with the original batch.
|
||||
|
||||
|
@ -5275,14 +5275,15 @@ class SpaceToBatchND(PrimitiveWithInfer):
|
|||
|
||||
Args:
|
||||
block_shape (Union[list(int), tuple(int), int]): The block shape of dividing block
|
||||
with all elements greater than 1. If `block_shape` is a list or tuple,
|
||||
with all elements greater than or euqal to 1. If `block_shape` is a list or tuple,
|
||||
the length of `block_shape` is the number of spatial dimensions, called M later.
|
||||
If `block_shape` is an int, the block size of M dimensions are the same, equal to `block_shape`.
|
||||
In this case of Ascend, M must be 2.
|
||||
paddings (Union[tuple, list]): The padding values for spatial dimensions, containing M subtraction list.
|
||||
Each contains 2 integer values. All values must be greater than 0.
|
||||
Each contains 2 integer values. All values must be greater than or equal to 0.
|
||||
`paddings[i]` specifies the paddings for the spatial dimension i,
|
||||
which corresponds to the input dimension i + offset.
|
||||
which corresponds to the input dimension i + offset,where offset = N-M,
|
||||
and N is the number of input dimensions.
|
||||
For each i, input_shape[i + offset]+paddings[i][0]+paddings[i][1]
|
||||
should be divisible by block_shape[i].
|
||||
|
||||
|
@ -5296,9 +5297,9 @@ class SpaceToBatchND(PrimitiveWithInfer):
|
|||
The shape of the output tensor will be :math:`(n', c_1, ... c_k, w'_1, ..., w'_M)`,
|
||||
where
|
||||
|
||||
:math:`n' = n*(block\_shape[0]*...*block\_shape[M])`
|
||||
:math:`n' = n*(block\_shape[0]*...*block\_shape[M-1])`
|
||||
|
||||
:math:`w'_i = (w_i+paddings[i][0]+paddings[i][1])//block\_shape[i]`
|
||||
:math:`w'_i = (w_i+paddings[i-1][0]+paddings[i-1][1])//block\_shape[i-1]`
|
||||
|
||||
Raises:
|
||||
TypeError: If `block_shape` is not one of list, tuple, int.
|
||||
|
@ -5306,11 +5307,11 @@ class SpaceToBatchND(PrimitiveWithInfer):
|
|||
ValueError: If `block_shape` is not one dimensional when `block_shape` is a list or tuple.
|
||||
ValueError: If the length of `block_shape` is not 2 on Ascend.
|
||||
ValueError: If shape of `paddings` is not (M, 2), where M is the length of `block_shape`.
|
||||
ValueError: If the element of `block_shape` is not an integer larger than 1.
|
||||
ValueError: If the element of `paddings` is not an integer larger than 0.
|
||||
ValueError: If the element of `block_shape` is not an integer larger than or equal to 1.
|
||||
ValueError: If the element of `paddings` is not an integer larger than or euqal to 0.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> block_shape = [2, 2]
|
||||
|
@ -5329,6 +5330,7 @@ class SpaceToBatchND(PrimitiveWithInfer):
|
|||
def __init__(self, block_shape, paddings):
|
||||
"""Initialize SpaceToBatchND"""
|
||||
validator.check_value_type('paddings type', paddings, [list, tuple], self.name)
|
||||
validator.check('paddings length', len(paddings), '', 1, Rel.GE, self.name)
|
||||
|
||||
if isinstance(block_shape, int):
|
||||
block_shape = (block_shape,) * np.array(paddings).shape[0]
|
||||
|
@ -5351,37 +5353,6 @@ class SpaceToBatchND(PrimitiveWithInfer):
|
|||
validator.check_value_type('paddings element', elem, [int], self.name)
|
||||
self.paddings = paddings
|
||||
|
||||
def infer_dtype(self, x_dtype):
|
||||
validator.check_tensor_dtype_valid('input_x', x_dtype, mstype.number_type, self.name)
|
||||
return x_dtype
|
||||
|
||||
def infer_shape(self, x_shape):
|
||||
x_rank = len(x_shape)
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
validator.check_equal_int(x_rank, 4, 'x_shape rank', self.name)
|
||||
out_shape = copy.deepcopy(x_shape)
|
||||
|
||||
block_shape_prod = 1
|
||||
offset = len(x_shape) - len(self.block_shape)
|
||||
if offset <= 0:
|
||||
raise ValueError(f"For '{self.name}', the dim of the input should be larger than that of the blocks, "
|
||||
f"but the shape of the inputs is {x_shape} "
|
||||
f"while the shape of blocks is {self.block_shape}.")
|
||||
for i in range(len(self.block_shape)):
|
||||
padded = out_shape[i + offset] + self.paddings[i][0] + \
|
||||
self.paddings[i][1]
|
||||
if padded % self.block_shape[i] != 0:
|
||||
raise ValueError(f"For '{self.name}', the padded must be divisible by 'block_shape', "
|
||||
f"where padded = input_x_shape[i + 2] + paddings[i][0] + paddings[i][1], "
|
||||
f"but got input_x_shape[{i + offset}]: {out_shape[i + offset]}, "
|
||||
f"paddings[{i}][0]: {self.paddings[i][0]} and paddings[{i}][1]: {self.paddings[i][1]}."
|
||||
f" Please check the official api documents for "
|
||||
f"more information about the output tensor.")
|
||||
out_shape[i + offset] = padded // self.block_shape[i]
|
||||
block_shape_prod = block_shape_prod * self.block_shape[i]
|
||||
out_shape[0] *= block_shape_prod
|
||||
return out_shape
|
||||
|
||||
|
||||
class BatchToSpaceND(Primitive):
|
||||
r"""
|
||||
|
@ -5390,26 +5361,37 @@ class BatchToSpaceND(Primitive):
|
|||
Refer to :func:`mindspore.ops.batch_to_space_nd` for more detail.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Examples:
|
||||
>>> block_size = 2
|
||||
>>> crops = [[0, 0], [0, 0]]
|
||||
>>> batch_to_space = ops.BatchToSpace(block_size, crops)
|
||||
>>> input_x = Tensor(np.array([[[[1]]], [[[2]]], [[[3]]], [[[4]]]]), mindspore.float32)
|
||||
>>> output = batch_to_space(input_x)
|
||||
>>> print(output)
|
||||
[[[[1. 2.]
|
||||
[3. 4.]]]]
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, block_shape, crops):
|
||||
"""Initialize BatchToSpaceND"""
|
||||
if isinstance(block_shape, int):
|
||||
block_shape = (block_shape,) * 2
|
||||
block_shape = (block_shape,) * np.array(crops).shape[0]
|
||||
self.add_prim_attr("block_shape", block_shape)
|
||||
validator.check_value_type('block_shape type', block_shape, [list, tuple], self.name)
|
||||
validator.check('block_shape shape', len(np.array(block_shape).shape), '', 1, Rel.EQ, self.name)
|
||||
block_rank = len(block_shape)
|
||||
validator.check('block_shape length', block_rank, '', 2, Rel.EQ, self.name)
|
||||
if context.get_context("device_target") == "Ascend":
|
||||
validator.check('block_shape length', block_rank, '', 2, Rel.EQ, self.name)
|
||||
for elem in block_shape:
|
||||
validator.check('block_shape element', elem, '', 1, Rel.GE, self.name)
|
||||
validator.check_value_type('block_shape element', elem, [int], self.name)
|
||||
self.block_shape = block_shape
|
||||
|
||||
validator.check_value_type('crops type', crops, [list, tuple], self.name)
|
||||
validator.check('crops length', len(crops), '', 2, Rel.EQ, self.name)
|
||||
validator.check('crops length', len(crops), '', 1, Rel.GE, self.name)
|
||||
validator.check('crops shape', np.array(crops).shape, '', (block_rank, 2), Rel.EQ, self.name)
|
||||
for elem in itertools.chain(*crops):
|
||||
validator.check_non_negative_int(elem, 'crops element', self.name)
|
||||
|
|
|
@ -1549,7 +1549,7 @@ class Betainc(Primitive):
|
|||
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend`` ``CPU``
|
||||
``Ascend`` ``GPU`` ``CPU``
|
||||
|
||||
Example:
|
||||
>>> a = Tensor(np.array([1, 1, 1]), mindspore.float32)
|
||||
|
@ -1563,6 +1563,7 @@ class Betainc(Primitive):
|
|||
@prim_attr_register
|
||||
def __init__(self):
|
||||
"""Initialize Betainc"""
|
||||
self.init_prim_io_names(inputs=['a', 'b', 'x'], outputs=['output'])
|
||||
|
||||
|
||||
class CumSum(Primitive):
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations.array_ops as ops
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
|
||||
class BatchToSpaceNDNet(nn.Cell):
|
||||
def __init__(self, nptype, block_shape=2, input_shape=(4, 1, 1, 1)):
|
||||
super(BatchToSpaceNDNet, self).__init__()
|
||||
self.batch_to_space_nd = ops.BatchToSpaceND(block_shape=block_shape, crops=[[0, 0], [0, 0]])
|
||||
input_size = np.prod(input_shape)
|
||||
data_np = np.arange(input_size).reshape(input_shape).astype(nptype)
|
||||
self.x1 = Parameter(initializer(Tensor(data_np), input_shape), name='x1')
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
y1 = self.batch_to_space_nd(self.x1)
|
||||
return y1
|
||||
|
||||
|
||||
def batch_to_space_nd_test_case(nptype, block_shape=2, input_shape=(4, 1, 1, 1)):
|
||||
expect = np.array([[[[0, 1],
|
||||
[2, 3]]]]).astype(nptype)
|
||||
|
||||
dts = BatchToSpaceNDNet(nptype, block_shape, input_shape)
|
||||
output = dts()
|
||||
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_batch_to_space_nd_graph():
|
||||
"""
|
||||
Feature: test BatchToSpaceND function interface.
|
||||
Description: test interface.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
batch_to_space_nd_test_case(np.float32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_batch_to_space_nd_pynative():
|
||||
"""
|
||||
Feature: test BatchToSpaceND function interface.
|
||||
Description: test interface.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
batch_to_space_nd_test_case(np.float32)
|
|
@ -0,0 +1,76 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops.operations import math_ops as P
|
||||
|
||||
|
||||
|
||||
class NetBetainc(nn.Cell):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.betainc = P.Betainc()
|
||||
|
||||
def construct(self, a, b, x):
|
||||
return self.betainc(a, b, x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_betainc_graph():
|
||||
"""
|
||||
Feature: Betainc
|
||||
Description: Test of input fp64 graph
|
||||
Expectation: match to tf.raw_ops.Betainc
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
a_np = np.array([[1, 2], [3, 4]]).astype(np.float64)
|
||||
b_np = np.array([[2, 3], [4, 5]]).astype(np.float64)
|
||||
x_np = np.array([[0.5, 0.5], [0.4, 0.3]]).astype(np.float64)
|
||||
a = Tensor(a_np)
|
||||
b = Tensor(b_np)
|
||||
x = Tensor(x_np)
|
||||
net = NetBetainc()
|
||||
output_ms = net(a, b, x)
|
||||
expect = np.array([[0.75, 0.6875], [0.45568, 0.19410435]], dtype=np.float64)
|
||||
assert np.allclose(output_ms.asnumpy(), expect, 1e-4, 1e-4)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_betainc_pynative():
|
||||
"""
|
||||
Feature: Betainc
|
||||
Description: Test of input fp32 pynative
|
||||
Expectation: match to tf.raw_ops.Betainc
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
a_np = np.array([[1, 2], [3, 4]]).astype(np.float32)
|
||||
b_np = np.array([[2, 3], [4, 5]]).astype(np.float32)
|
||||
x_np = np.array([[0.5, 0.5], [0.4, 0.3]]).astype(np.float32)
|
||||
a = Tensor(a_np)
|
||||
b = Tensor(b_np)
|
||||
x = Tensor(x_np)
|
||||
cast = P.Betainc()
|
||||
output_ms = cast(a, b, x)
|
||||
expect = np.array([[0.75, 0.6875], [0.45568, 0.19410435]], dtype=np.float32)
|
||||
assert np.allclose(output_ms.asnumpy(), expect, 1e-4, 1e-4)
|
|
@ -0,0 +1,79 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations.array_ops as ops
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
|
||||
|
||||
class SpaceToBatchNDNet(nn.Cell):
|
||||
def __init__(self, nptype, block_size=2, input_shape=(1, 1, 4, 4)):
|
||||
super(SpaceToBatchNDNet, self).__init__()
|
||||
self.space_to_batch_nd = ops.SpaceToBatchND(block_shape=block_size, paddings=[[0, 0], [0, 0]])
|
||||
input_size = np.prod(input_shape)
|
||||
data_np = np.arange(input_size).reshape(input_shape).astype(nptype)
|
||||
self.x1 = Parameter(initializer(Tensor(data_np), input_shape), name='x1')
|
||||
|
||||
@ms_function
|
||||
def construct(self):
|
||||
y1 = self.space_to_batch_nd(self.x1)
|
||||
return y1
|
||||
|
||||
|
||||
def space_to_batch_nd_test_case(nptype, block_size=2, input_shape=(1, 1, 4, 4)):
|
||||
expect = np.array([[[[0, 2],
|
||||
[8, 10]]],
|
||||
[[[1, 3],
|
||||
[9, 11]]],
|
||||
[[[4, 6],
|
||||
[12, 14]]],
|
||||
[[[5, 7],
|
||||
[13, 15]]]]).astype(nptype)
|
||||
|
||||
dts = SpaceToBatchNDNet(nptype, block_size, input_shape)
|
||||
output = dts()
|
||||
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_space_to_batch_nd_graph():
|
||||
"""
|
||||
Feature: test SpaceToBatchND function interface.
|
||||
Description: test interface.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
|
||||
space_to_batch_nd_test_case(np.float32)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_space_to_batch_nd_pynative():
|
||||
"""
|
||||
Feature: test SpaceToBatchND function interface.
|
||||
Description: test interface.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
|
||||
space_to_batch_nd_test_case(np.float32)
|
Loading…
Reference in New Issue