From 6aa5cab829f8016356c98f5c2ba581bf045eae1e Mon Sep 17 00:00:00 2001 From: hangangqiang Date: Thu, 8 Sep 2022 21:01:30 +0800 Subject: [PATCH] support dynamic shape and rank for batch_to_shape --- .../kernel/arrays/batchtospace_gpu_kernel.h | 156 ++++++++---------- mindspore/core/ops/batch_to_space.cc | 65 +++++++- mindspore/core/ops/batch_to_space.h | 3 - 3 files changed, 134 insertions(+), 90 deletions(-) diff --git a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/batchtospace_gpu_kernel.h b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/batchtospace_gpu_kernel.h index 204c16814ed..e02b1af4fae 100644 --- a/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/batchtospace_gpu_kernel.h +++ b/mindspore/ccsrc/plugin/device/gpu/kernel/arrays/batchtospace_gpu_kernel.h @@ -20,6 +20,7 @@ #include #include #include +#include #include "plugin/device/gpu/kernel/gpu_kernel.h" #include "plugin/device/gpu/kernel/gpu_kernel_factory.h" #include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/batchtospace_impl.cuh" @@ -30,9 +31,23 @@ constexpr size_t SHAPE_SIZE = 4; constexpr size_t CROPS_SHAPE_0 = 2; constexpr size_t CROPS_SHAPE_1 = 2; template -class BatchToSpaceGpuKernelMod : public DeprecatedNativeGpuKernelMod { +class BatchToSpaceGpuKernelMod : public NativeGpuKernelMod { public: - BatchToSpaceGpuKernelMod() { ResetResource(); } + BatchToSpaceGpuKernelMod() { + in_ = 0; + ic_ = 0; + ih_ = 0; + iw_ = 0; + on_ = 0; + oc_ = 0; + oh_ = 0; + ow_ = 0; + kernel_name_ = "BatchToSpace"; + crops_.clear(); + input_size_list_.clear(); + output_size_list_.clear(); + input_shape_.clear(); + } ~BatchToSpaceGpuKernelMod() = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } const std::vector &GetOutputSizeList() const override { return output_size_list_; } @@ -43,75 +58,52 @@ class BatchToSpaceGpuKernelMod : public DeprecatedNativeGpuKernelMod { T *input = GetDeviceAddress(inputs, 0); T *output = GetDeviceAddress(outputs, 0); - size_t size = output_size_ / sizeof(T); + size_t size = output_size_list_[0] / sizeof(T); CalBatchToSpace(size, input, in_, ih_, iw_, ic_, on_, oh_, ow_, oc_, crops_[0][0], crops_[0][1], crops_[1][0], crops_[1][1], block_size_, output, device_id_, reinterpret_cast(stream_ptr)); return true; } - bool Init(const CNodePtr &kernel_node) override { + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override { + MS_EXCEPTION_IF_NULL(base_operator); + PrimitivePtr prim = base_operator->GetPrim(); + MS_EXCEPTION_IF_NULL(prim); + kernel_name_ = prim->name(); + device_id_ = MsContext::GetInstance()->get_param(MS_CTX_DEVICE_ID); - kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node); - (void)CheckParam(kernel_node); - input_size_ = sizeof(T); - for (size_t idx = 0; idx < input_shape_.size(); ++idx) { - input_size_ *= static_cast(input_shape_[idx]); - } - constexpr int IDX_2 = 2; - constexpr int IDX_3 = 3; - in_ = static_cast(input_shape_[0]); - ic_ = static_cast(input_shape_[1]); - ih_ = static_cast(input_shape_[IDX_2]); - iw_ = static_cast(input_shape_[IDX_3]); - - on_ = in_ / (block_size_ * block_size_); - oc_ = ic_; - oh_ = ih_ * block_size_ - crops_[0][0] - crops_[0][1]; - ow_ = iw_ * block_size_ - crops_[1][0] - crops_[1][1]; - output_size_ = on_ * oc_ * oh_ * ow_ * sizeof(T); - InitSizeLists(); - return true; - } - void ResetResource() noexcept override { - in_ = 0; - ic_ = 0; - ih_ = 0; - iw_ = 0; - on_ = 0; - oc_ = 0; - oh_ = 0; - ow_ = 0; - kernel_name_ = "BatchToSpace"; - input_size_list_.clear(); - output_size_list_.clear(); - crops_.clear(); - input_shape_.clear(); - } - - protected: - void InitSizeLists() override { - input_size_list_.push_back(input_size_); - output_size_list_.push_back(output_size_); - } - - void CheckParam(const CNodePtr &kernel_node) { - block_size_ = GetAttr(kernel_node, "block_size"); + // wait for primitive unified between lite and cloud. + block_size_ = GetValue(prim->GetAttr("block_size")); if (block_size_ < 1) { MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'block_size' cannot be less than 1, but got " << block_size_; } - size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 1, but got " << input_num; + // check crops + crops_ = GetValue>>(prim->GetAttr("crops")); + if (crops_.size() != CROPS_SHAPE_0) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the size of 'crops' must be " << CROPS_SHAPE_0 << ", but got " + << crops_.size(); } - size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node); - if (output_num != 1) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of outputs must be 1, but got " << output_num; + if (crops_[0].size() != CROPS_SHAPE_1 || crops_[1].size() != CROPS_SHAPE_1) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the size of element of 'crops' must be " << CROPS_SHAPE_1 + << ", but got the size of crops[0]: " << crops_[0].size() + << ", the size of crops[1]: " << crops_[1].size(); } + CHECK_KERNEL_INPUTS_NUM(inputs.size(), 1, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), 1, kernel_name_); + return true; + } + int Resize( + const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &inputsOnHost = std::map()) override { + if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) { + return ret; + } // check input_shape - auto input_shape = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 0); + auto input_shape = inputs[0]->GetShapeVector(); if (input_shape.size() != SHAPE_SIZE) { MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input must be 4, but got " << input_shape.size(); @@ -129,46 +121,38 @@ class BatchToSpaceGpuKernelMod : public DeprecatedNativeGpuKernelMod { } } input_shape_.assign(input_shape.begin(), input_shape.end()); - - // check crops - crops_ = (GetAttr>>(kernel_node, "crops")); - - if (crops_.size() != CROPS_SHAPE_0) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the size of 'crops' must be " << CROPS_SHAPE_0 << ", but got " - << crops_.size(); - } - if (crops_[0].size() != CROPS_SHAPE_1 || crops_[1].size() != CROPS_SHAPE_1) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the size of element of 'crops' must be " << CROPS_SHAPE_1 - << ", but got the size of crops[0]: " << crops_[0].size() - << ", the size of crops[1]: " << crops_[1].size(); - } else { - for (size_t idx_i = 0; idx_i < CROPS_SHAPE_0; ++idx_i) { - for (size_t idx_j = 0; idx_j < CROPS_SHAPE_1; ++idx_j) { - if (crops_[idx_i][idx_j] < 0) { - MS_LOG(EXCEPTION) << "For '" << kernel_name_ - << "', the element of 'crops' must be greater than or equal to 0, but got crops[" << idx_i - << "][" << idx_j << "]: " << crops_[idx_i][idx_j]; - } - } - auto tmp_shape = input_shape[idx_i + CROPS_SHAPE_1] * block_size_ - crops_[idx_i][0] - crops_[idx_i][1]; - if (tmp_shape <= 0) { + for (size_t idx_i = 0; idx_i < CROPS_SHAPE_0; ++idx_i) { + for (size_t idx_j = 0; idx_j < CROPS_SHAPE_1; ++idx_j) { + if (crops_[idx_i][idx_j] < 0) { MS_LOG(EXCEPTION) << "For '" << kernel_name_ - << "', the element of shape of output must be greater than 0, but got " << tmp_shape; + << "', the element of 'crops' must be greater than or equal to 0, but got crops[" << idx_i + << "][" << idx_j << "]: " << crops_[idx_i][idx_j]; } } + auto tmp_shape = input_shape[idx_i + CROPS_SHAPE_1] * block_size_ - crops_[idx_i][0] - crops_[idx_i][1]; + if (tmp_shape <= 0) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', the element of shape of output must be greater than 0, but got " << tmp_shape; + } } + constexpr int IDX_2 = 2; + constexpr int IDX_3 = 3; + in_ = static_cast(input_shape_[0]); + ic_ = static_cast(input_shape_[1]); + ih_ = static_cast(input_shape_[IDX_2]); + iw_ = static_cast(input_shape_[IDX_3]); + + on_ = in_ / (block_size_ * block_size_); + oc_ = ic_; + oh_ = ih_ * block_size_ - crops_[0][0] - crops_[0][1]; + ow_ = iw_ * block_size_ - crops_[1][0] - crops_[1][1]; + return static_cast(KRET_OK); } private: - std::vector input_size_list_; - std::vector output_size_list_; - std::vector workspace_size_list_; - std::vector> crops_; std::vector input_shape_; size_t block_size_; - size_t input_size_; - size_t output_size_; size_t in_; size_t ic_; size_t ih_; diff --git a/mindspore/core/ops/batch_to_space.cc b/mindspore/core/ops/batch_to_space.cc index d794c44bed2..f965ab19d20 100644 --- a/mindspore/core/ops/batch_to_space.cc +++ b/mindspore/core/ops/batch_to_space.cc @@ -15,14 +15,17 @@ */ #include "ops/batch_to_space.h" +#include #include "ops/op_utils.h" #include "utils/check_convert_utils.h" +#include "utils/shape_utils.h" #include "abstract/ops/primitive_infer_map.h" #include "mindapi/src/helper.h" namespace mindspore { namespace ops { MIND_API_OPERATOR_IMPL(BatchToSpace, BaseOperator); + void BatchToSpace::Init(const std::vector &block_size, const std::vector> &crops) { this->set_block_size(block_size); this->set_crops(crops); @@ -46,6 +49,66 @@ std::vector> BatchToSpace::get_crops() const { return GetValue>>(value_ptr); } -REGISTER_PRIMITIVE_C(kNameBatchToSpace, BatchToSpace); +class BatchToSpaceInfer : public abstract::OpInferBase { + public: + BaseShapePtr InferShape(const PrimitivePtr &primitive, + const std::vector &input_args) const override { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + const int64_t input_num = 1; + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name); + auto x = CheckAndConvertUtils::CheckArgs(prim_name, input_args, 0); + auto x_shape = x->BuildShape(); + MS_EXCEPTION_IF_NULL(x_shape); + auto shape_element = x_shape->cast(); + MS_EXCEPTION_IF_NULL(shape_element); + auto input_shape = shape_element->shape(); + const size_t input_rank = 4; + if (input_shape.size() != input_rank) { + MS_EXCEPTION(ValueError) << "Rank of input should be 4, got " << shape_element->shape().size(); + } + if (mindspore::IsDynamicRank(shape_element->shape())) { + return std::make_shared(std::vector{UNKNOWN_RANK}); + } + auto block_size = GetValue(primitive->GetAttr(kBlockSize)); + auto crops = GetValue>>(primitive->GetAttr(kCrops)); + const size_t height_dim_index = 2; + ShapeVector output_shape(input_rank); + for (size_t i = 0; i < height_dim_index; i++) { + output_shape[i] = input_shape[i]; + } + for (size_t i = height_dim_index; i < input_rank; i++) { + auto x_block_prod = input_shape[i] * block_size; + auto crop_sum = crops[i - height_dim_index][0] + crops[i - height_dim_index][1]; + if (x_block_prod < crop_sum) { + MS_EXCEPTION(ValueError) << "x block shape prod should be greater or equal to crops sum, got x_block_prod: " + << x_block_prod << ", crop_sum: " << crop_sum; + } + output_shape[i] = x_block_prod - crop_sum; + } + auto block_size_prod = block_size * block_size; + if (output_shape[0] % block_size_prod != 0) { + MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the shape of output with index 0 must be divided exactly " + << "by block_size_prod, but got the shape of output: " << output_shape << " and " + << "block_size_prod: " << block_size_prod << "."; + } + output_shape[0] = output_shape[0] / block_size_prod; + return std::make_shared(output_shape); + } + + TypePtr InferType(const PrimitivePtr &primitive, const std::vector &input_args) const override { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + const int64_t input_num = 1; + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name); + const std::set valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32, + kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128}; + auto x_type = input_args[kInputIndex0]->BuildType(); + (void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name); + return input_args[kInputIndex0]->BuildType(); + } +}; + +REGISTER_PRIMITIVE_OP_INFER_IMPL(BatchToSpace, prim::kPrimBatchToSpace, BatchToSpaceInfer, false); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/batch_to_space.h b/mindspore/core/ops/batch_to_space.h index 0342fd70bd9..a279ad08e8d 100644 --- a/mindspore/core/ops/batch_to_space.h +++ b/mindspore/core/ops/batch_to_space.h @@ -47,9 +47,6 @@ class MIND_API BatchToSpace : public BaseOperator { /// \return crops. std::vector> get_crops() const; }; - -abstract::AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args); } // namespace ops } // namespace mindspore