ops Channel_Shuffle supports dynamic shape feature

type: feature
reason: add codes to supports dynamic shape for Channel_Shuffle.

------

Signed-off-by: wang_ziqi <wangziqi4@huawei.com>
This commit is contained in:
wang_ziqi 2023-02-09 19:04:24 +08:00
parent 883382e86c
commit 9c4a7c6345
3 changed files with 20 additions and 2 deletions

View File

@ -36,8 +36,6 @@ bool ChannelShuffleCpuKernelMod::Init(const BaseOperatorPtr &base_operator, cons
const std::vector<KernelTensorPtr> &outputs) {
kernel_name_ = base_operator->name();
input_dtype_ = inputs[0]->GetDtype();
input_shape_ = inputs[0]->GetShapeVector();
outputs_ = outputs;
group_ = GetValue<int64_t>(base_operator->GetAttr("group"));
return true;
}
@ -66,6 +64,17 @@ bool ChannelShuffleCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &i
return true;
}
int ChannelShuffleCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs,
const std::map<uint32_t, tensor::TensorPtr> &) {
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
return ret;
}
input_shape_ = inputs[0]->GetShapeVector();
outputs_ = outputs;
return KRET_OK;
}
std::vector<KernelAttr> ChannelShuffleCpuKernelMod::GetOpSupport() {
static std::vector<KernelAttr> support_list = {
KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),

View File

@ -19,6 +19,7 @@
#include <vector>
#include <memory>
#include <utility>
#include <map>
#include "plugin/device/cpu/kernel/cpu_kernel.h"
#include "plugin/factory/ms_factory.h"
@ -36,6 +37,9 @@ class ChannelShuffleCpuKernelMod : public NativeCpuKernelMod {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs) override;
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
protected:
std::vector<KernelAttr> GetOpSupport() override;

View File

@ -36,6 +36,11 @@ abstract::ShapePtr ChannelShuffleInferShape(const PrimitivePtr &primitive,
int64_t group = GetValue<int64_t>(primitive->GetAttr("group"));
auto input_shape_ = shape_map[kShape];
auto dims = input_shape_.size();
if (IsDynamic(input_shape_)) {
return std::make_shared<abstract::Shape>(input_shape_);
}
if (dims <= min_dims) {
MS_EXCEPTION(TypeError) << "For ChannelShuffle, expect input with > 3 dims, "
<< "but got " << input_shape_.size() << ".";