spacetobatch modified cpplint

This commit is contained in:
zong_shuai 2021-08-24 22:14:49 +08:00
parent 81cf9f5ccf
commit 5a861816fd
6 changed files with 33 additions and 33 deletions

View File

@ -48,11 +48,11 @@ class BatchToSpaceGpuKernel : public GpuKernel {
}
bool Init(const CNodePtr &kernel_node) override {
if(!CheckParam(kernel_node)) {
if (!CheckParam(kernel_node)) {
return false;
}
input_size_ = sizeof(T);
for(size_t idx = 0; idx < input_shape_.size(); ++idx){
for (size_t idx = 0; idx < input_shape_.size(); ++idx) {
input_size_ *= input_shape_[idx];
}
@ -93,54 +93,54 @@ class BatchToSpaceGpuKernel : public GpuKernel {
bool CheckParam(const CNodePtr &kernel_node) {
block_size_ = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "block_size"));
if(block_size_ < 2) {
if (block_size_ < 2) {
MS_LOG(ERROR) << "block_size can not be less than 2.";
return false;
}
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if(input_num != 1) {
if (input_num != 1) {
MS_LOG(ERROR) << "input_num is " << input_num << ", but BatchToSpace needs 1 input.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if(output_num != 1) {
if (output_num != 1) {
MS_LOG(ERROR) << "output_num is " << output_num << ", but BatchToSpace needs 1 output.";
return false;
}
// check input_shape
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
if(input_shape.size() != 4) {
if (input_shape.size() != 4) {
MS_LOG(ERROR) << "Input is " << input_shape.size() << "-D, but BatchToSpace supports 4-D tensor.";
return false;
}
if((input_shape[0] % (block_size_ * block_size_)) != 0) {
if ((input_shape[0] % (block_size_ * block_size_)) != 0) {
MS_LOG(ERROR) << "input_shape[0] must be divisible by product of block_shape";
return false;
}
input_shape_.assign(input_shape.begin(),input_shape.end());
input_shape_.assign(input_shape.begin(), input_shape.end());
// check crops
crops_ =
static_cast<std::vector<std::vector<int64_t>>>(GetAttr<std::vector<std::vector<int64_t>>>(kernel_node, "crops"));
if(crops_.size() != 2) {
if (crops_.size() != 2) {
MS_LOG(ERROR) << "crops.size() in BatchToSpace needs 2.";
return false;
}
if(crops_[0].size() != 2 || crops_[1].size() != 2) {
if (crops_[0].size() != 2 || crops_[1].size() != 2) {
MS_LOG(ERROR) << "crops[i].size() in BatchToSpace needs 2.";
return false;
}else {
for(size_t idx_i = 0; idx_i < 2; ++idx_i) {
for(size_t idx_j = 0; idx_j < 2; ++idx_j) {
if(crops_[idx_i][idx_j] < 0) {
} else {
for (size_t idx_i = 0; idx_i < 2; ++idx_i) {
for (size_t idx_j = 0; idx_j < 2; ++idx_j) {
if (crops_[idx_i][idx_j] < 0) {
MS_LOG(ERROR) << "the number in crops can not be less than 0.";
return false;
}
}
auto tmp_shape = input_shape[idx_i + 2] * block_size_ - crops_[idx_i][0] - crops_[idx_i][1];
if(tmp_shape < 0) {
if (tmp_shape < 0) {
MS_LOG(ERROR) << "out_shape can not be less 0.";
}
}

View File

@ -48,11 +48,11 @@ class SpaceToBatchGpuKernel : public GpuKernel {
}
bool Init(const CNodePtr &kernel_node) override {
if(!CheckParam(kernel_node)) {
if (!CheckParam(kernel_node)) {
return false;
}
input_size_ = sizeof(T);
for(size_t idx = 0; idx < input_shape_.size(); ++idx){
for (size_t idx = 0; idx < input_shape_.size(); ++idx) {
input_size_ *= input_shape_[idx];
}
in_ = input_shape_[0];
@ -93,48 +93,48 @@ class SpaceToBatchGpuKernel : public GpuKernel {
private:
bool CheckParam(const CNodePtr &kernel_node) {
block_size_ = static_cast<int64_t>(GetAttr<int64_t>(kernel_node, "block_size"));
if(block_size_ < 2) {
if (block_size_ < 2) {
MS_LOG(ERROR) << "block_size can not be less than 2.";
return false;
}
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if(input_num != 1) {
if (input_num != 1) {
MS_LOG(ERROR) << "input_num is " << input_num << ", but BatchToSpace needs 1 input.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if(output_num != 1) {
if (output_num != 1) {
MS_LOG(ERROR) << "output_num is " << output_num << ", but BatchToSpace needs 1 output.";
return false;
}
// check input_shape
auto input_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, 0);
if(input_shape.size() != 4) {
if (input_shape.size() != 4) {
MS_LOG(ERROR) << "Input is " << input_shape.size() << "-D, but BatchToSpace supports 4-D tensor.";
return false;
}
input_shape_.assign(input_shape.begin(),input_shape.end());
input_shape_.assign(input_shape.begin(), input_shape.end());
// check paddings_
paddings_ =
static_cast<std::vector<std::vector<int64_t>>>(GetAttr<std::vector<std::vector<int64_t>>>(kernel_node, "paddings"));
if(paddings_.size() != 2) {
if (paddings_.size() != 2) {
MS_LOG(ERROR) << "paddings.size() in BatchToSpace needs 2.";
return false;
}
if(paddings_[0].size() != 2 || paddings_[1].size() != 2) {
if (paddings_[0].size() != 2 || paddings_[1].size() != 2) {
MS_LOG(ERROR) << "paddings[i].size() in BatchToSpace needs 2.";
return false;
}else {
for(size_t idx_i = 0; idx_i < 2; ++idx_i) {
for(size_t idx_j = 0; idx_j < 2; ++idx_j) {
if(paddings_[idx_i][idx_j] < 0) {
} else {
for (size_t idx_i = 0; idx_i < 2; ++idx_i) {
for (size_t idx_j = 0; idx_j < 2; ++idx_j) {
if (paddings_[idx_i][idx_j] < 0) {
MS_LOG(ERROR) << "the number in paddings can not be less than 0.";
return false;
}
}
auto tmp_shape = input_shape[idx_i + 2] + paddings_[idx_i][0] + paddings_[idx_i][1];
if((tmp_shape % block_size_) != 0) {
if ((tmp_shape % block_size_) != 0) {
MS_LOG(ERROR) << "padded shape must be divisible by block_size";
}
}

View File

@ -130,4 +130,4 @@ template void CalBatchToSpace<uint64_t>(const size_t size, const uint64_t *input
const size_t on, const size_t oh, const size_t ow,
const size_t oc, const size_t crop_up, const size_t crop_dn,
const size_t crop_lft, const size_t crop_rht, const size_t block_num,
uint64_t *output, cudaStream_t cuda_stream);
uint64_t *output, cudaStream_t cuda_stream);

View File

@ -24,4 +24,4 @@ void CalBatchToSpace(const size_t size, const T *input, const size_t in,
const size_t crop_lft, const size_t crop_rht, const size_t block_num,
T *output, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHTOSPACE_H_
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BATCHTOSPACE_H_

View File

@ -131,4 +131,4 @@ template void CalSpaceToBatch<uint64_t>(const size_t size, const uint64_t *input
const size_t on, const size_t oh, const size_t ow,
const size_t oc, const size_t pad_up, const size_t pad_dn,
const size_t pad_lft, const size_t pad_rht, const size_t block_num,
uint64_t *output, cudaStream_t cuda_stream);
uint64_t *output, cudaStream_t cuda_stream);

View File

@ -24,4 +24,4 @@ void CalSpaceToBatch(const size_t size, const T *input, const size_t in,
const size_t pad_lft, const size_t pad_rht, const size_t block_num,
T *output, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPACETOBATCH_H_
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPACETOBATCH_H_