!41716 support dynamic shape and rank for batch_to_shape

Merge pull request !41716 from hangq/wood
This commit is contained in:
i-robot 2022-09-13 07:29:25 +00:00 committed by Gitee
commit a4b4fcec76
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 134 additions and 90 deletions

View File

@ -20,6 +20,7 @@
#include <vector>
#include <string>
#include <memory>
#include <map>
#include "plugin/device/gpu/kernel/gpu_kernel.h"
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/batchtospace_impl.cuh"
@ -30,9 +31,23 @@ constexpr size_t SHAPE_SIZE = 4;
constexpr size_t CROPS_SHAPE_0 = 2;
constexpr size_t CROPS_SHAPE_1 = 2;
template <typename T>
class BatchToSpaceGpuKernelMod : public DeprecatedNativeGpuKernelMod {
class BatchToSpaceGpuKernelMod : public NativeGpuKernelMod {
public:
BatchToSpaceGpuKernelMod() { ResetResource(); }
BatchToSpaceGpuKernelMod() {
in_ = 0;
ic_ = 0;
ih_ = 0;
iw_ = 0;
on_ = 0;
oc_ = 0;
oh_ = 0;
ow_ = 0;
kernel_name_ = "BatchToSpace";
crops_.clear();
input_size_list_.clear();
output_size_list_.clear();
input_shape_.clear();
}
~BatchToSpaceGpuKernelMod() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
@ -43,75 +58,52 @@ class BatchToSpaceGpuKernelMod : public DeprecatedNativeGpuKernelMod {
T *input = GetDeviceAddress<T>(inputs, 0);
T *output = GetDeviceAddress<T>(outputs, 0);
size_t size = output_size_ / sizeof(T);
size_t size = output_size_list_[0] / sizeof(T);
CalBatchToSpace<T>(size, input, in_, ih_, iw_, ic_, on_, oh_, ow_, oc_, crops_[0][0], crops_[0][1], crops_[1][0],
crops_[1][1], block_size_, output, device_id_, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs) override {
MS_EXCEPTION_IF_NULL(base_operator);
PrimitivePtr prim = base_operator->GetPrim();
MS_EXCEPTION_IF_NULL(prim);
kernel_name_ = prim->name();
device_id_ = MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_DEVICE_ID);
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
(void)CheckParam(kernel_node);
input_size_ = sizeof(T);
for (size_t idx = 0; idx < input_shape_.size(); ++idx) {
input_size_ *= static_cast<size_t>(input_shape_[idx]);
}
constexpr int IDX_2 = 2;
constexpr int IDX_3 = 3;
in_ = static_cast<size_t>(input_shape_[0]);
ic_ = static_cast<size_t>(input_shape_[1]);
ih_ = static_cast<size_t>(input_shape_[IDX_2]);
iw_ = static_cast<size_t>(input_shape_[IDX_3]);
on_ = in_ / (block_size_ * block_size_);
oc_ = ic_;
oh_ = ih_ * block_size_ - crops_[0][0] - crops_[0][1];
ow_ = iw_ * block_size_ - crops_[1][0] - crops_[1][1];
output_size_ = on_ * oc_ * oh_ * ow_ * sizeof(T);
InitSizeLists();
return true;
}
void ResetResource() noexcept override {
in_ = 0;
ic_ = 0;
ih_ = 0;
iw_ = 0;
on_ = 0;
oc_ = 0;
oh_ = 0;
ow_ = 0;
kernel_name_ = "BatchToSpace";
input_size_list_.clear();
output_size_list_.clear();
crops_.clear();
input_shape_.clear();
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_);
output_size_list_.push_back(output_size_);
}
void CheckParam(const CNodePtr &kernel_node) {
block_size_ = GetAttr<int64_t>(kernel_node, "block_size");
// wait for primitive unified between lite and cloud.
block_size_ = GetValue<int64_t>(prim->GetAttr("block_size"));
if (block_size_ < 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'block_size' cannot be less than 1, but got "
<< block_size_;
}
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 1, but got " << input_num;
// check crops
crops_ = GetValue<std::vector<std::vector<int64_t>>>(prim->GetAttr("crops"));
if (crops_.size() != CROPS_SHAPE_0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the size of 'crops' must be " << CROPS_SHAPE_0 << ", but got "
<< crops_.size();
}
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of outputs must be 1, but got " << output_num;
if (crops_[0].size() != CROPS_SHAPE_1 || crops_[1].size() != CROPS_SHAPE_1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the size of element of 'crops' must be " << CROPS_SHAPE_1
<< ", but got the size of crops[0]: " << crops_[0].size()
<< ", the size of crops[1]: " << crops_[1].size();
}
CHECK_KERNEL_INPUTS_NUM(inputs.size(), 1, kernel_name_);
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), 1, kernel_name_);
return true;
}
int Resize(
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override {
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
// check input_shape
auto input_shape = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 0);
auto input_shape = inputs[0]->GetShapeVector();
if (input_shape.size() != SHAPE_SIZE) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input must be 4, but got "
<< input_shape.size();
@ -129,46 +121,38 @@ class BatchToSpaceGpuKernelMod : public DeprecatedNativeGpuKernelMod {
}
}
input_shape_.assign(input_shape.begin(), input_shape.end());
// check crops
crops_ = (GetAttr<std::vector<std::vector<int64_t>>>(kernel_node, "crops"));
if (crops_.size() != CROPS_SHAPE_0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the size of 'crops' must be " << CROPS_SHAPE_0 << ", but got "
<< crops_.size();
}
if (crops_[0].size() != CROPS_SHAPE_1 || crops_[1].size() != CROPS_SHAPE_1) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the size of element of 'crops' must be " << CROPS_SHAPE_1
<< ", but got the size of crops[0]: " << crops_[0].size()
<< ", the size of crops[1]: " << crops_[1].size();
} else {
for (size_t idx_i = 0; idx_i < CROPS_SHAPE_0; ++idx_i) {
for (size_t idx_j = 0; idx_j < CROPS_SHAPE_1; ++idx_j) {
if (crops_[idx_i][idx_j] < 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the element of 'crops' must be greater than or equal to 0, but got crops[" << idx_i
<< "][" << idx_j << "]: " << crops_[idx_i][idx_j];
}
}
auto tmp_shape = input_shape[idx_i + CROPS_SHAPE_1] * block_size_ - crops_[idx_i][0] - crops_[idx_i][1];
if (tmp_shape <= 0) {
for (size_t idx_i = 0; idx_i < CROPS_SHAPE_0; ++idx_i) {
for (size_t idx_j = 0; idx_j < CROPS_SHAPE_1; ++idx_j) {
if (crops_[idx_i][idx_j] < 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the element of shape of output must be greater than 0, but got " << tmp_shape;
<< "', the element of 'crops' must be greater than or equal to 0, but got crops[" << idx_i
<< "][" << idx_j << "]: " << crops_[idx_i][idx_j];
}
}
auto tmp_shape = input_shape[idx_i + CROPS_SHAPE_1] * block_size_ - crops_[idx_i][0] - crops_[idx_i][1];
if (tmp_shape <= 0) {
MS_LOG(EXCEPTION) << "For '" << kernel_name_
<< "', the element of shape of output must be greater than 0, but got " << tmp_shape;
}
}
constexpr int IDX_2 = 2;
constexpr int IDX_3 = 3;
in_ = static_cast<size_t>(input_shape_[0]);
ic_ = static_cast<size_t>(input_shape_[1]);
ih_ = static_cast<size_t>(input_shape_[IDX_2]);
iw_ = static_cast<size_t>(input_shape_[IDX_3]);
on_ = in_ / (block_size_ * block_size_);
oc_ = ic_;
oh_ = ih_ * block_size_ - crops_[0][0] - crops_[0][1];
ow_ = iw_ * block_size_ - crops_[1][0] - crops_[1][1];
return static_cast<int>(KRET_OK);
}
private:
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
std::vector<std::vector<int64_t>> crops_;
std::vector<int64_t> input_shape_;
size_t block_size_;
size_t input_size_;
size_t output_size_;
size_t in_;
size_t ic_;
size_t ih_;

View File

@ -15,14 +15,17 @@
*/
#include "ops/batch_to_space.h"
#include <set>
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/shape_utils.h"
#include "abstract/ops/primitive_infer_map.h"
#include "mindapi/src/helper.h"
namespace mindspore {
namespace ops {
MIND_API_OPERATOR_IMPL(BatchToSpace, BaseOperator);
void BatchToSpace::Init(const std::vector<int64_t> &block_size, const std::vector<std::vector<int64_t>> &crops) {
this->set_block_size(block_size);
this->set_crops(crops);
@ -46,6 +49,66 @@ std::vector<std::vector<int64_t>> BatchToSpace::get_crops() const {
return GetValue<std::vector<std::vector<int64_t>>>(value_ptr);
}
REGISTER_PRIMITIVE_C(kNameBatchToSpace, BatchToSpace);
class BatchToSpaceInfer : public abstract::OpInferBase {
public:
BaseShapePtr InferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
auto x = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
auto x_shape = x->BuildShape();
MS_EXCEPTION_IF_NULL(x_shape);
auto shape_element = x_shape->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_element);
auto input_shape = shape_element->shape();
const size_t input_rank = 4;
if (input_shape.size() != input_rank) {
MS_EXCEPTION(ValueError) << "Rank of input should be 4, got " << shape_element->shape().size();
}
if (mindspore::IsDynamicRank(shape_element->shape())) {
return std::make_shared<abstract::Shape>(std::vector<int64_t>{UNKNOWN_RANK});
}
auto block_size = GetValue<int64_t>(primitive->GetAttr(kBlockSize));
auto crops = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kCrops));
const size_t height_dim_index = 2;
ShapeVector output_shape(input_rank);
for (size_t i = 0; i < height_dim_index; i++) {
output_shape[i] = input_shape[i];
}
for (size_t i = height_dim_index; i < input_rank; i++) {
auto x_block_prod = input_shape[i] * block_size;
auto crop_sum = crops[i - height_dim_index][0] + crops[i - height_dim_index][1];
if (x_block_prod < crop_sum) {
MS_EXCEPTION(ValueError) << "x block shape prod should be greater or equal to crops sum, got x_block_prod: "
<< x_block_prod << ", crop_sum: " << crop_sum;
}
output_shape[i] = x_block_prod - crop_sum;
}
auto block_size_prod = block_size * block_size;
if (output_shape[0] % block_size_prod != 0) {
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the shape of output with index 0 must be divided exactly "
<< "by block_size_prod, but got the shape of output: " << output_shape << " and "
<< "block_size_prod: " << block_size_prod << ".";
}
output_shape[0] = output_shape[0] / block_size_prod;
return std::make_shared<abstract::Shape>(output_shape);
}
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t input_num = 1;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32,
kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
auto x_type = input_args[kInputIndex0]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name);
return input_args[kInputIndex0]->BuildType();
}
};
REGISTER_PRIMITIVE_OP_INFER_IMPL(BatchToSpace, prim::kPrimBatchToSpace, BatchToSpaceInfer, false);
} // namespace ops
} // namespace mindspore

View File

@ -47,9 +47,6 @@ class MIND_API BatchToSpace : public BaseOperator {
/// \return crops.
std::vector<std::vector<int64_t>> get_crops() const;
};
abstract::AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<abstract::AbstractBasePtr> &input_args);
} // namespace ops
} // namespace mindspore