diff --git a/mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_helper.cc b/mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_helper.cc index b83d05f9675..6eeedeb13ed 100644 --- a/mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_helper.cc +++ b/mindspore/ccsrc/backend/common/optimizer/dynamic_shape/dynamic_shape_helper.cc @@ -74,30 +74,6 @@ void InferShapeForNopNode(const AnfNodePtr &input_node) { } } -bool InferShapeForDefiniteOutputNode(const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(cnode); - if (!common::AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimShape)) { - return false; - } - auto input_size = common::AnfAlgo::GetInputTensorNum(cnode); - if (input_size != 1) { - MS_LOG(EXCEPTION) << "Node only has one input: " << cnode->fullname_with_scope(); - } - auto cur_shape = dynamic_cast(cnode->Shape().get())->shape(); - if (std::any_of(cur_shape.begin(), cur_shape.end(), [](int64_t x) { return x == kInvalidShape; })) { - return false; - } - std::vector output_shape = {static_cast(cur_shape.size())}; - mindspore::abstract::BaseShapePtr shape = std::make_shared(output_shape); - - // cppcheck-suppress unreadVariable - auto lock = AnfUtils::GetAbstractLock(cnode.get()); - auto abstract = cnode->abstract(); - MS_EXCEPTION_IF_NULL(abstract); - abstract->set_shape(shape); - return true; -} - TypeId GetSequenceType(const abstract::AbstractSequencePtr &seq_abs) { auto elems = seq_abs->elements(); if (!elems[0]->isa()) { @@ -265,10 +241,6 @@ void InferShape(const CNodePtr &cnode, std::map *de MS_EXCEPTION_IF_NULL(depend_tensor_map); MS_LOG(DEBUG) << "InferShape start, node:" << cnode->fullname_with_scope(); std::set depend_list = abstract::GetValueDependArgIndices(cnode); - auto ret = InferShapeForDefiniteOutputNode(cnode); - if (ret) { - return; - } depend_tensor_map->clear(); auto &inputs = cnode->inputs(); diff --git a/mindspore/ccsrc/kernel/common_utils.h b/mindspore/ccsrc/kernel/common_utils.h index f8c1105c177..9ca363487fb 100644 --- a/mindspore/ccsrc/kernel/common_utils.h +++ b/mindspore/ccsrc/kernel/common_utils.h @@ -460,14 +460,14 @@ BACKEND_EXPORT int64_t CalOutputTupleSize(const AnfNodePtr &node); BACKEND_EXPORT void SetDynamicInputSizeAttr(const CNodePtr &cnode); BACKEND_EXPORT bool IsDynamicParamKernel(const std::string &op_name); -template +template class MatchKernelHelper { public: MatchKernelHelper() = default; virtual ~MatchKernelHelper() = default; - using KernelRunFunc = std::function &, const std::vector &, - const std::vector &)>; + using KernelRunFunc = std::function &, + const std::vector &, const std::vector &)>; virtual const std::vector> &GetFuncList() const = 0; protected: diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/shape_cpu_kernel.cc b/mindspore/ccsrc/plugin/device/cpu/kernel/shape_cpu_kernel.cc new file mode 100644 index 00000000000..f9eb15167ba --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/shape_cpu_kernel.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2019-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "plugin/device/cpu/kernel/shape_cpu_kernel.h" +#include "plugin/device/cpu/hal/device/cpu_device_address.h" + +namespace mindspore { +namespace kernel { +namespace { +constexpr size_t kShapeInputsNum = 1; +constexpr size_t kShapeOutputsNum = 1; +} // namespace + +bool ShapeCpuKernelMod::Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) { + MS_EXCEPTION_IF_NULL(base_operator); + kernel_name_ = base_operator->name(); + CHECK_KERNEL_INPUTS_NUM(inputs.size(), kShapeInputsNum, kernel_name_); + CHECK_KERNEL_OUTPUTS_NUM(outputs.size(), kShapeOutputsNum, kernel_name_); + return MatchKernelFunc(base_operator, inputs, outputs); +} + +int ShapeCpuKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, + const std::map &) { + if (auto ret = KernelMod::Resize(base_operator, inputs, outputs); ret != KRET_OK) { + return ret; + } + input_shape_ = inputs.at(kIndex0)->GetShapeVector(); + output_shape_ = outputs.at(kIndex0)->GetShapeVector(); + if (output_shape_.size() != 1) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', the dimension of output must be 1-D, but got: " << output_shape_.size(); + } + if (output_shape_[0] != SizeToLong(input_shape_.size())) { + MS_LOG(EXCEPTION) << "For '" << kernel_name_ + << "', 'output_shape[0]' must be equal to the dimension of input, but got 'output_shape[0]': " + << output_shape_[0] << " and the dimension of input: " << input_shape_.size(); + } + return KRET_OK; +} + +bool ShapeCpuKernelMod::LaunchKernel(const std::vector &inputs, const std::vector &, + const std::vector &outputs) { + auto output_addr = reinterpret_cast(outputs[0]->addr); + for (size_t i = 0; i < LongToSize(output_shape_[0]); ++i) { + output_addr[i] = input_shape_[i]; + } + return true; +} + +const std::vector> &ShapeCpuKernelMod::GetFuncList() const { + static const std::vector> func_list = { + {KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64), + &ShapeCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64), + &ShapeCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64), + &ShapeCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt16).AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64), + &ShapeCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64), + &ShapeCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeInt64).AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64), + &ShapeCpuKernelMod::LaunchKernel}, + {KernelAttr().AddInputAttr(kNumberTypeBool).AddOutputAttr(kObjectTypeTuple, kNumberTypeInt64), + &ShapeCpuKernelMod::LaunchKernel}}; + return func_list; +} + +MS_KERNEL_FACTORY_REG(NativeCpuKernelMod, Shape, ShapeCpuKernelMod); +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/plugin/device/cpu/kernel/shape_cpu_kernel.h b/mindspore/ccsrc/plugin/device/cpu/kernel/shape_cpu_kernel.h new file mode 100644 index 00000000000..60445a2779e --- /dev/null +++ b/mindspore/ccsrc/plugin/device/cpu/kernel/shape_cpu_kernel.h @@ -0,0 +1,60 @@ +/** + * Copyright 2019-2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SHAPE_CPU_KERNEL_H_ +#define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SHAPE_CPU_KERNEL_H_ + +#include +#include +#include +#include +#include "plugin/device/cpu/kernel/cpu_kernel.h" +#include "plugin/factory/ms_factory.h" + +namespace mindspore { +namespace kernel { +class ShapeCpuKernelMod : public NativeCpuKernelMod, public MatchKernelHelper { + public: + ShapeCpuKernelMod() = default; + ~ShapeCpuKernelMod() override = default; + + bool Launch(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs) override { + MS_EXCEPTION_IF_NULL(kernel_func_); + return kernel_func_(this, inputs, workspace, outputs); + } + int Resize(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs, const std::map &) override; + + bool Init(const BaseOperatorPtr &base_operator, const std::vector &inputs, + const std::vector &outputs) override; + + const std::vector> &GetFuncList() const override; + + protected: + std::vector GetOpSupport() override { return OpSupport(); } + + bool LaunchKernel(const std::vector &inputs, const std::vector &workspace, + const std::vector &outputs); + + private: + ShapeVector input_shape_; + ShapeVector output_shape_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_KERNEL_SHAPE_CPU_KERNEL_H_