forked from mindspore-Ecosystem/mindspore
!41716 support dynamic shape and rank for batch_to_shape
Merge pull request !41716 from hangq/wood
This commit is contained in:
commit
a4b4fcec76
|
@ -20,6 +20,7 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel.h"
|
||||
#include "plugin/device/gpu/kernel/gpu_kernel_factory.h"
|
||||
#include "plugin/device/gpu/kernel/cuda_impl/cuda_ops/batchtospace_impl.cuh"
|
||||
|
@ -30,9 +31,23 @@ constexpr size_t SHAPE_SIZE = 4;
|
|||
constexpr size_t CROPS_SHAPE_0 = 2;
|
||||
constexpr size_t CROPS_SHAPE_1 = 2;
|
||||
template <typename T>
|
||||
class BatchToSpaceGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
||||
class BatchToSpaceGpuKernelMod : public NativeGpuKernelMod {
|
||||
public:
|
||||
BatchToSpaceGpuKernelMod() { ResetResource(); }
|
||||
BatchToSpaceGpuKernelMod() {
|
||||
in_ = 0;
|
||||
ic_ = 0;
|
||||
ih_ = 0;
|
||||
iw_ = 0;
|
||||
on_ = 0;
|
||||
oc_ = 0;
|
||||
oh_ = 0;
|
||||
ow_ = 0;
|
||||
kernel_name_ = "BatchToSpace";
|
||||
crops_.clear();
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
input_shape_.clear();
|
||||
}
|
||||
~BatchToSpaceGpuKernelMod() = default;
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
|
@ -43,75 +58,52 @@ class BatchToSpaceGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
T *input = GetDeviceAddress<T>(inputs, 0);
|
||||
T *output = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
size_t size = output_size_ / sizeof(T);
|
||||
size_t size = output_size_list_[0] / sizeof(T);
|
||||
|
||||
CalBatchToSpace<T>(size, input, in_, ih_, iw_, ic_, on_, oh_, ow_, oc_, crops_[0][0], crops_[0][1], crops_[1][0],
|
||||
crops_[1][1], block_size_, output, device_id_, reinterpret_cast<cudaStream_t>(stream_ptr));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
bool Init(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs) override {
|
||||
MS_EXCEPTION_IF_NULL(base_operator);
|
||||
PrimitivePtr prim = base_operator->GetPrim();
|
||||
MS_EXCEPTION_IF_NULL(prim);
|
||||
kernel_name_ = prim->name();
|
||||
|
||||
device_id_ = MsContext::GetInstance()->get_param<uint32_t>(MS_CTX_DEVICE_ID);
|
||||
kernel_name_ = common::AnfAlgo::GetCNodeName(kernel_node);
|
||||
(void)CheckParam(kernel_node);
|
||||
input_size_ = sizeof(T);
|
||||
for (size_t idx = 0; idx < input_shape_.size(); ++idx) {
|
||||
input_size_ *= static_cast<size_t>(input_shape_[idx]);
|
||||
}
|
||||
constexpr int IDX_2 = 2;
|
||||
constexpr int IDX_3 = 3;
|
||||
in_ = static_cast<size_t>(input_shape_[0]);
|
||||
ic_ = static_cast<size_t>(input_shape_[1]);
|
||||
ih_ = static_cast<size_t>(input_shape_[IDX_2]);
|
||||
iw_ = static_cast<size_t>(input_shape_[IDX_3]);
|
||||
|
||||
on_ = in_ / (block_size_ * block_size_);
|
||||
oc_ = ic_;
|
||||
oh_ = ih_ * block_size_ - crops_[0][0] - crops_[0][1];
|
||||
ow_ = iw_ * block_size_ - crops_[1][0] - crops_[1][1];
|
||||
output_size_ = on_ * oc_ * oh_ * ow_ * sizeof(T);
|
||||
InitSizeLists();
|
||||
return true;
|
||||
}
|
||||
void ResetResource() noexcept override {
|
||||
in_ = 0;
|
||||
ic_ = 0;
|
||||
ih_ = 0;
|
||||
iw_ = 0;
|
||||
on_ = 0;
|
||||
oc_ = 0;
|
||||
oh_ = 0;
|
||||
ow_ = 0;
|
||||
kernel_name_ = "BatchToSpace";
|
||||
input_size_list_.clear();
|
||||
output_size_list_.clear();
|
||||
crops_.clear();
|
||||
input_shape_.clear();
|
||||
}
|
||||
|
||||
protected:
|
||||
void InitSizeLists() override {
|
||||
input_size_list_.push_back(input_size_);
|
||||
output_size_list_.push_back(output_size_);
|
||||
}
|
||||
|
||||
void CheckParam(const CNodePtr &kernel_node) {
|
||||
block_size_ = GetAttr<int64_t>(kernel_node, "block_size");
|
||||
// wait for primitive unified between lite and cloud.
|
||||
block_size_ = GetValue<int64_t>(prim->GetAttr("block_size"));
|
||||
if (block_size_ < 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the 'block_size' cannot be less than 1, but got "
|
||||
<< block_size_;
|
||||
}
|
||||
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
|
||||
if (input_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of inputs must be 1, but got " << input_num;
|
||||
// check crops
|
||||
crops_ = GetValue<std::vector<std::vector<int64_t>>>(prim->GetAttr("crops"));
|
||||
if (crops_.size() != CROPS_SHAPE_0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the size of 'crops' must be " << CROPS_SHAPE_0 << ", but got "
|
||||
<< crops_.size();
|
||||
}
|
||||
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
|
||||
if (output_num != 1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the number of outputs must be 1, but got " << output_num;
|
||||
if (crops_[0].size() != CROPS_SHAPE_1 || crops_[1].size() != CROPS_SHAPE_1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the size of element of 'crops' must be " << CROPS_SHAPE_1
|
||||
<< ", but got the size of crops[0]: " << crops_[0].size()
|
||||
<< ", the size of crops[1]: " << crops_[1].size();
|
||||
}
|
||||
CHECK_KERNEL_INPUTS_NUM(inputs.size(), 1, kernel_name_);
|
||||
CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), 1, kernel_name_);
|
||||
return true;
|
||||
}
|
||||
|
||||
int Resize(
|
||||
const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &inputsOnHost = std::map<uint32_t, tensor::TensorPtr>()) override {
|
||||
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
// check input_shape
|
||||
auto input_shape = AnfAlgo::GetInputDeviceShapeAdaptively(kernel_node, 0);
|
||||
auto input_shape = inputs[0]->GetShapeVector();
|
||||
if (input_shape.size() != SHAPE_SIZE) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the dimension of input must be 4, but got "
|
||||
<< input_shape.size();
|
||||
|
@ -129,46 +121,38 @@ class BatchToSpaceGpuKernelMod : public DeprecatedNativeGpuKernelMod {
|
|||
}
|
||||
}
|
||||
input_shape_.assign(input_shape.begin(), input_shape.end());
|
||||
|
||||
// check crops
|
||||
crops_ = (GetAttr<std::vector<std::vector<int64_t>>>(kernel_node, "crops"));
|
||||
|
||||
if (crops_.size() != CROPS_SHAPE_0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the size of 'crops' must be " << CROPS_SHAPE_0 << ", but got "
|
||||
<< crops_.size();
|
||||
}
|
||||
if (crops_[0].size() != CROPS_SHAPE_1 || crops_[1].size() != CROPS_SHAPE_1) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_ << "', the size of element of 'crops' must be " << CROPS_SHAPE_1
|
||||
<< ", but got the size of crops[0]: " << crops_[0].size()
|
||||
<< ", the size of crops[1]: " << crops_[1].size();
|
||||
} else {
|
||||
for (size_t idx_i = 0; idx_i < CROPS_SHAPE_0; ++idx_i) {
|
||||
for (size_t idx_j = 0; idx_j < CROPS_SHAPE_1; ++idx_j) {
|
||||
if (crops_[idx_i][idx_j] < 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the element of 'crops' must be greater than or equal to 0, but got crops[" << idx_i
|
||||
<< "][" << idx_j << "]: " << crops_[idx_i][idx_j];
|
||||
}
|
||||
}
|
||||
auto tmp_shape = input_shape[idx_i + CROPS_SHAPE_1] * block_size_ - crops_[idx_i][0] - crops_[idx_i][1];
|
||||
if (tmp_shape <= 0) {
|
||||
for (size_t idx_i = 0; idx_i < CROPS_SHAPE_0; ++idx_i) {
|
||||
for (size_t idx_j = 0; idx_j < CROPS_SHAPE_1; ++idx_j) {
|
||||
if (crops_[idx_i][idx_j] < 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the element of shape of output must be greater than 0, but got " << tmp_shape;
|
||||
<< "', the element of 'crops' must be greater than or equal to 0, but got crops[" << idx_i
|
||||
<< "][" << idx_j << "]: " << crops_[idx_i][idx_j];
|
||||
}
|
||||
}
|
||||
auto tmp_shape = input_shape[idx_i + CROPS_SHAPE_1] * block_size_ - crops_[idx_i][0] - crops_[idx_i][1];
|
||||
if (tmp_shape <= 0) {
|
||||
MS_LOG(EXCEPTION) << "For '" << kernel_name_
|
||||
<< "', the element of shape of output must be greater than 0, but got " << tmp_shape;
|
||||
}
|
||||
}
|
||||
constexpr int IDX_2 = 2;
|
||||
constexpr int IDX_3 = 3;
|
||||
in_ = static_cast<size_t>(input_shape_[0]);
|
||||
ic_ = static_cast<size_t>(input_shape_[1]);
|
||||
ih_ = static_cast<size_t>(input_shape_[IDX_2]);
|
||||
iw_ = static_cast<size_t>(input_shape_[IDX_3]);
|
||||
|
||||
on_ = in_ / (block_size_ * block_size_);
|
||||
oc_ = ic_;
|
||||
oh_ = ih_ * block_size_ - crops_[0][0] - crops_[0][1];
|
||||
ow_ = iw_ * block_size_ - crops_[1][0] - crops_[1][1];
|
||||
return static_cast<int>(KRET_OK);
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
||||
std::vector<std::vector<int64_t>> crops_;
|
||||
std::vector<int64_t> input_shape_;
|
||||
size_t block_size_;
|
||||
size_t input_size_;
|
||||
size_t output_size_;
|
||||
size_t in_;
|
||||
size_t ic_;
|
||||
size_t ih_;
|
||||
|
|
|
@ -15,14 +15,17 @@
|
|||
*/
|
||||
|
||||
#include "ops/batch_to_space.h"
|
||||
#include <set>
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/shape_utils.h"
|
||||
#include "abstract/ops/primitive_infer_map.h"
|
||||
#include "mindapi/src/helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
MIND_API_OPERATOR_IMPL(BatchToSpace, BaseOperator);
|
||||
|
||||
void BatchToSpace::Init(const std::vector<int64_t> &block_size, const std::vector<std::vector<int64_t>> &crops) {
|
||||
this->set_block_size(block_size);
|
||||
this->set_crops(crops);
|
||||
|
@ -46,6 +49,66 @@ std::vector<std::vector<int64_t>> BatchToSpace::get_crops() const {
|
|||
return GetValue<std::vector<std::vector<int64_t>>>(value_ptr);
|
||||
}
|
||||
|
||||
REGISTER_PRIMITIVE_C(kNameBatchToSpace, BatchToSpace);
|
||||
class BatchToSpaceInfer : public abstract::OpInferBase {
|
||||
public:
|
||||
BaseShapePtr InferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||
auto x = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, 0);
|
||||
auto x_shape = x->BuildShape();
|
||||
MS_EXCEPTION_IF_NULL(x_shape);
|
||||
auto shape_element = x_shape->cast<abstract::ShapePtr>();
|
||||
MS_EXCEPTION_IF_NULL(shape_element);
|
||||
auto input_shape = shape_element->shape();
|
||||
const size_t input_rank = 4;
|
||||
if (input_shape.size() != input_rank) {
|
||||
MS_EXCEPTION(ValueError) << "Rank of input should be 4, got " << shape_element->shape().size();
|
||||
}
|
||||
if (mindspore::IsDynamicRank(shape_element->shape())) {
|
||||
return std::make_shared<abstract::Shape>(std::vector<int64_t>{UNKNOWN_RANK});
|
||||
}
|
||||
auto block_size = GetValue<int64_t>(primitive->GetAttr(kBlockSize));
|
||||
auto crops = GetValue<std::vector<std::vector<int64_t>>>(primitive->GetAttr(kCrops));
|
||||
const size_t height_dim_index = 2;
|
||||
ShapeVector output_shape(input_rank);
|
||||
for (size_t i = 0; i < height_dim_index; i++) {
|
||||
output_shape[i] = input_shape[i];
|
||||
}
|
||||
for (size_t i = height_dim_index; i < input_rank; i++) {
|
||||
auto x_block_prod = input_shape[i] * block_size;
|
||||
auto crop_sum = crops[i - height_dim_index][0] + crops[i - height_dim_index][1];
|
||||
if (x_block_prod < crop_sum) {
|
||||
MS_EXCEPTION(ValueError) << "x block shape prod should be greater or equal to crops sum, got x_block_prod: "
|
||||
<< x_block_prod << ", crop_sum: " << crop_sum;
|
||||
}
|
||||
output_shape[i] = x_block_prod - crop_sum;
|
||||
}
|
||||
auto block_size_prod = block_size * block_size;
|
||||
if (output_shape[0] % block_size_prod != 0) {
|
||||
MS_EXCEPTION(ValueError) << "For '" << prim_name << "', the shape of output with index 0 must be divided exactly "
|
||||
<< "by block_size_prod, but got the shape of output: " << output_shape << " and "
|
||||
<< "block_size_prod: " << block_size_prod << ".";
|
||||
}
|
||||
output_shape[0] = output_shape[0] / block_size_prod;
|
||||
return std::make_shared<abstract::Shape>(output_shape);
|
||||
}
|
||||
|
||||
TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const override {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto prim_name = primitive->name();
|
||||
const int64_t input_num = 1;
|
||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||
const std::set<TypePtr> valid_types = {kInt8, kInt16, kInt32, kInt64, kUInt8, kUInt16, kUInt32,
|
||||
kUInt64, kFloat16, kFloat32, kFloat64, kComplex64, kComplex128};
|
||||
auto x_type = input_args[kInputIndex0]->BuildType();
|
||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("x", x_type, valid_types, prim_name);
|
||||
return input_args[kInputIndex0]->BuildType();
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_PRIMITIVE_OP_INFER_IMPL(BatchToSpace, prim::kPrimBatchToSpace, BatchToSpaceInfer, false);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -47,9 +47,6 @@ class MIND_API BatchToSpace : public BaseOperator {
|
|||
/// \return crops.
|
||||
std::vector<std::vector<int64_t>> get_crops() const;
|
||||
};
|
||||
|
||||
abstract::AbstractBasePtr BatchToSpaceInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<abstract::AbstractBasePtr> &input_args);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
Loading…
Reference in New Issue