forked from mindspore-Ecosystem/mindspore
modified commenting
This commit is contained in:
parent
0388f8bccd
commit
b518a4ed33
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
|
@ -25,6 +24,9 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t SHAPE_SIZE = 4;
|
||||
constexpr size_t CROPS_SHAPE_0 = 2;
|
||||
constexpr size_t CROPS_SHAPE_1 = 2;
|
||||
template <typename T>
|
||||
class BatchToSpaceGpuKernel : public GpuKernel {
|
||||
public:
|
||||
|
@ -32,7 +34,6 @@ class BatchToSpaceGpuKernel : public GpuKernel {
|
|||
~BatchToSpaceGpuKernel() = default;
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
|
@ -79,7 +80,6 @@ class BatchToSpaceGpuKernel : public GpuKernel {
|
|||
ow_ = 0;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
crops_.clear();
|
||||
input_shape_.clear();
|
||||
}
|
||||
|
@ -91,9 +91,9 @@ class BatchToSpaceGpuKernel : public GpuKernel {
|
|||
}
|
||||
|
||||
bool CheckParam(const CNodePtr &kernel_node) {
|
||||
block_size_ = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "block_size"));
|
||||
if (block_size_ < 2) {
|
||||
MS_LOG(ERROR) << "block_size can not be less than 2.";
|
||||
block_size_ = GetAttr<int64_t>(kernel_node, "block_size");
|
||||
if (block_size_ < 1) {
|
||||
MS_LOG(ERROR) << "block_size can not be less than 1.";
|
||||
return false;
|
||||
}
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
|
@ -109,7 +109,7 @@ class BatchToSpaceGpuKernel : public GpuKernel {
|
|||
|
||||
// check input_shape
|
||||
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
|
||||
if (input_shape.size() != 4) {
|
||||
if (input_shape.size() != SHAPE_SIZE) {
|
||||
MS_LOG(ERROR) << "Input is " << input_shape.size() << "-D, but BatchToSpace supports 4-D tensor.";
|
||||
return false;
|
||||
}
|
||||
|
@ -117,30 +117,35 @@ class BatchToSpaceGpuKernel : public GpuKernel {
|
|||
MS_LOG(ERROR) << "input_shape[0] must be divisible by product of block_shape";
|
||||
return false;
|
||||
}
|
||||
for (int idx = 0; idx < SHAPE_SIZE; ++idx) {
|
||||
if (input_shape[idx] < 1) {
|
||||
MS_LOG(ERROR) << "input_shape[" << idx << "] can not less than 1";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
input_shape_.assign(input_shape.begin(), input_shape.end());
|
||||
|
||||
// check crops
|
||||
crops_ =
|
||||
static_cast<std::vector<std::vector<int64_t>>>(GetAttr<std::vector<std::vector<int64_t>>>(kernel_node, "crops"));
|
||||
crops_ = (GetAttr<std::vector<std::vector<int64_t>>>(kernel_node, "crops"));
|
||||
|
||||
if (crops_.size() != 2) {
|
||||
if (crops_.size() != CROPS_SHAPE_0) {
|
||||
MS_LOG(ERROR) << "crops.size() in BatchToSpace needs 2.";
|
||||
return false;
|
||||
}
|
||||
if (crops_[0].size() != 2 || crops_[1].size() != 2) {
|
||||
if (crops_[0].size() != CROPS_SHAPE_1 || crops_[1].size() != CROPS_SHAPE_1) {
|
||||
MS_LOG(ERROR) << "crops[i].size() in BatchToSpace needs 2.";
|
||||
return false;
|
||||
} else {
|
||||
for (size_t idx_i = 0; idx_i < 2; ++idx_i) {
|
||||
for (size_t idx_j = 0; idx_j < 2; ++idx_j) {
|
||||
for (size_t idx_i = 0; idx_i < CROPS_SHAPE_0; ++idx_i) {
|
||||
for (size_t idx_j = 0; idx_j < CROPS_SHAPE_1; ++idx_j) {
|
||||
if (crops_[idx_i][idx_j] < 0) {
|
||||
MS_LOG(ERROR) << "the number in crops can not be less than 0.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
auto tmp_shape = input_shape[idx_i + 2] * block_size_ - crops_[idx_i][0] - crops_[idx_i][1];
|
||||
if (tmp_shape < 0) {
|
||||
MS_LOG(ERROR) << "out_shape can not be less 0.";
|
||||
auto tmp_shape = input_shape[idx_i + CROPS_SHAPE_1] * block_size_ - crops_[idx_i][0] - crops_[idx_i][1];
|
||||
if (tmp_shape <= 0) {
|
||||
MS_LOG(ERROR) << "out_shape can not be less 1.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -151,7 +156,6 @@ class BatchToSpaceGpuKernel : public GpuKernel {
|
|||
private:
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
||||
std::vector<std::vector<int64_t>> crops_;
|
||||
std::vector<size_t> input_shape_;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
|
@ -25,6 +24,10 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
constexpr size_t SHAPE_SIZE = 4;
|
||||
constexpr size_t PADDING_SHAPE_0 = 2;
|
||||
constexpr size_t PADDING_SHAPE_1 = 2;
|
||||
|
||||
template <typename T>
|
||||
class SpaceToBatchGpuKernel : public GpuKernel {
|
||||
public:
|
||||
|
@ -32,7 +35,6 @@ class SpaceToBatchGpuKernel : public GpuKernel {
|
|||
~SpaceToBatchGpuKernel() {}
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
|
||||
|
||||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
|
||||
|
@ -79,7 +81,6 @@ class SpaceToBatchGpuKernel : public GpuKernel {
|
|||
ow_ = 0;
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
workspace_size_list_.clear();
|
||||
paddings_.clear();
|
||||
input_shape_.clear();
|
||||
}
|
||||
|
@ -93,8 +94,8 @@ class SpaceToBatchGpuKernel : public GpuKernel {
|
|||
private:
|
||||
bool CheckParam(const CNodePtr &kernel_node) {
|
||||
block_size_ = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "block_size"));
|
||||
if (block_size_ < 2) {
|
||||
MS_LOG(ERROR) << "block_size can not be less than 2.";
|
||||
if (block_size_ < 1) {
|
||||
MS_LOG(ERROR) << "block_size can not be less than 1.";
|
||||
return false;
|
||||
}
|
||||
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
|
@ -110,34 +111,37 @@ class SpaceToBatchGpuKernel : public GpuKernel {
|
|||
|
||||
// check input_shape
|
||||
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
|
||||
if (input_shape.size() != 4) {
|
||||
if (input_shape.size() != SHAPE_SIZE) {
|
||||
MS_LOG(ERROR) << "Input is " << input_shape.size() << "-D, but BatchToSpace supports 4-D tensor.";
|
||||
return false;
|
||||
}
|
||||
input_shape_.assign(input_shape.begin(), input_shape.end());
|
||||
// check paddings_
|
||||
paddings_ = static_cast<std::vector<std::vector<int64_t>>>(
|
||||
GetAttr<std::vector<std::vector<int64_t>>>(kernel_node, "paddings"));
|
||||
if (paddings_.size() != 2) {
|
||||
paddings_ = GetAttr<std::vector<std::vector<int64_t>>>(kernel_node, "paddings");
|
||||
if (paddings_.size() != PADDING_SHAPE_0) {
|
||||
MS_LOG(ERROR) << "paddings.size() in BatchToSpace needs 2.";
|
||||
return false;
|
||||
}
|
||||
if (paddings_[0].size() != 2 || paddings_[1].size() != 2) {
|
||||
if (paddings_[0].size() != PADDING_SHAPE_1 || paddings_[1].size() != PADDING_SHAPE_1) {
|
||||
MS_LOG(ERROR) << "paddings[i].size() in BatchToSpace needs 2.";
|
||||
return false;
|
||||
} else {
|
||||
for (size_t idx_i = 0; idx_i < 2; ++idx_i) {
|
||||
for (size_t idx_j = 0; idx_j < 2; ++idx_j) {
|
||||
for (size_t idx_i = 0; idx_i < PADDING_SHAPE_0; ++idx_i) {
|
||||
for (size_t idx_j = 0; idx_j < PADDING_SHAPE_1; ++idx_j) {
|
||||
if (paddings_[idx_i][idx_j] < 0) {
|
||||
MS_LOG(ERROR) << "the number in paddings can not be less than 0.";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
auto tmp_shape = input_shape[idx_i + 2] + paddings_[idx_i][0] + paddings_[idx_i][1];
|
||||
auto tmp_shape = input_shape[idx_i + PADDING_SHAPE_1] + paddings_[idx_i][0] + paddings_[idx_i][1];
|
||||
if ((tmp_shape % block_size_) != 0) {
|
||||
MS_LOG(ERROR) << "padded shape must be divisible by block_size";
|
||||
return false;
|
||||
}
|
||||
if ((tmp_shape / block_size_) == 0) {
|
||||
MS_LOG(ERROR) << "padded shape can not be less than block_size";
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
|
@ -145,8 +149,6 @@ class SpaceToBatchGpuKernel : public GpuKernel {
|
|||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
||||
std::vector<std::vector<int64_t>> paddings_;
|
||||
std::vector<size_t> input_shape_;
|
||||
size_t block_size_;
|
||||
|
|
Loading…
Reference in New Issue