forked from mindspore-Ecosystem/mindspore
ops Channel_Shuffle supports dynamic shape feature
type: feature reason: add codes to supports dynamic shape for Channel_Shuffle. ------ Signed-off-by: wang_ziqi <wangziqi4@huawei.com>
This commit is contained in:
parent
883382e86c
commit
9c4a7c6345
|
@ -36,8 +36,6 @@ bool ChannelShuffleCpuKernelMod::Init(const BaseOperatorPtr &base_operator, cons
|
|||
const std::vector<KernelTensorPtr> &outputs) {
|
||||
kernel_name_ = base_operator->name();
|
||||
input_dtype_ = inputs[0]->GetDtype();
|
||||
input_shape_ = inputs[0]->GetShapeVector();
|
||||
outputs_ = outputs;
|
||||
group_ = GetValue<int64_t>(base_operator->GetAttr("group"));
|
||||
return true;
|
||||
}
|
||||
|
@ -66,6 +64,17 @@ bool ChannelShuffleCpuKernelMod::Launch(const std::vector<kernel::AddressPtr> &i
|
|||
return true;
|
||||
}
|
||||
|
||||
int ChannelShuffleCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs,
|
||||
const std::map<uint32_t, tensor::TensorPtr> &) {
|
||||
if (int ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) {
|
||||
return ret;
|
||||
}
|
||||
input_shape_ = inputs[0]->GetShapeVector();
|
||||
outputs_ = outputs;
|
||||
return KRET_OK;
|
||||
}
|
||||
|
||||
std::vector<KernelAttr> ChannelShuffleCpuKernelMod::GetOpSupport() {
|
||||
static std::vector<KernelAttr> support_list = {
|
||||
KernelAttr().AddInputAttr(kNumberTypeInt8).AddOutputAttr(kNumberTypeInt8),
|
||||
|
|
|
@ -19,6 +19,7 @@
|
|||
#include <vector>
|
||||
#include <memory>
|
||||
#include <utility>
|
||||
#include <map>
|
||||
|
||||
#include "plugin/device/cpu/kernel/cpu_kernel.h"
|
||||
#include "plugin/factory/ms_factory.h"
|
||||
|
@ -36,6 +37,9 @@ class ChannelShuffleCpuKernelMod : public NativeCpuKernelMod {
|
|||
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs) override;
|
||||
|
||||
int Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
|
||||
const std::vector<KernelTensorPtr> &outputs, const std::map<uint32_t, tensor::TensorPtr> &) override;
|
||||
|
||||
protected:
|
||||
std::vector<KernelAttr> GetOpSupport() override;
|
||||
|
||||
|
|
|
@ -36,6 +36,11 @@ abstract::ShapePtr ChannelShuffleInferShape(const PrimitivePtr &primitive,
|
|||
int64_t group = GetValue<int64_t>(primitive->GetAttr("group"));
|
||||
auto input_shape_ = shape_map[kShape];
|
||||
auto dims = input_shape_.size();
|
||||
|
||||
if (IsDynamic(input_shape_)) {
|
||||
return std::make_shared<abstract::Shape>(input_shape_);
|
||||
}
|
||||
|
||||
if (dims <= min_dims) {
|
||||
MS_EXCEPTION(TypeError) << "For ChannelShuffle, expect input with > 3 dims, "
|
||||
<< "but got " << input_shape_.size() << ".";
|
||||
|
|
Loading…
Reference in New Issue