From e9c5192d273bc08f577e376ea74cf7970958a7c8 Mon Sep 17 00:00:00 2001 From: Yang Jiao Date: Fri, 19 Aug 2022 09:29:27 +0800 Subject: [PATCH] para to tensor --- .../common/graph_kernel/model/op_node.cc | 129 ++++++++++++++++++ .../ccsrc/common/graph_kernel/model/op_node.h | 16 +++ .../common/graph_kernel/model/op_register.cc | 1 + .../converter/expanders/strided_slice.cc | 38 ++++++ .../converter/graph_kernel_expander_lite.cc | 60 +++----- .../converter/graph_kernel_expander_lite.h | 26 +--- .../converter/graph_kernel_optimization.cc | 2 + .../converter/parameter_to_tensor.cc | 50 +++++++ .../converter/parameter_to_tensor.h | 29 ++++ 9 files changed, 287 insertions(+), 64 deletions(-) create mode 100644 mindspore/lite/tools/graph_kernel/converter/expanders/strided_slice.cc create mode 100644 mindspore/lite/tools/graph_kernel/converter/parameter_to_tensor.cc create mode 100644 mindspore/lite/tools/graph_kernel/converter/parameter_to_tensor.h diff --git a/mindspore/ccsrc/common/graph_kernel/model/op_node.cc b/mindspore/ccsrc/common/graph_kernel/model/op_node.cc index 4497e77ce9d..385f4bb4b50 100644 --- a/mindspore/ccsrc/common/graph_kernel/model/op_node.cc +++ b/mindspore/ccsrc/common/graph_kernel/model/op_node.cc @@ -18,6 +18,8 @@ #include #include #include +#include +#include #include #include #include @@ -1002,6 +1004,133 @@ void StridedSliceOp::RectifyAbstract(const PrimitivePtr &primitive, AbstractBase SetAbastractsFromAttrs(primitive, convert_input_list, inputs_abstract, input_names_vec); } +template +tensor::TensorPtr StridedSliceOnnxOp::CalcStridedSliceOnnx(const NodePtrList &inputs, const DAttrs &attrs) { + constexpr size_t input_index = 0; + constexpr size_t begin_index = 1; + constexpr size_t end_index = 2; + constexpr size_t axes_index = 3; + constexpr size_t stride_index = 4; + + ShapeVector input_shape = inputs[input_index]->shape; + std::vector begin = ChangeDataToVec(inputs[begin_index]); + std::vector end = ChangeDataToVec(inputs[end_index]); + std::vector axes = ChangeDataToVec(inputs[axes_index]); + std::vector stride = ChangeDataToVec(inputs[stride_index]); + + std::unordered_map> info; + for (size_t i = 0; i < axes.size(); i++) { + int axis = axes[i] < 0 ? axes[i] + SizeToInt(input_shape.size()) : axes[i]; + if (begin[i] < 0 || end[i] < 0 || stride[i] < 0) { + MS_LOG(INFO) << "Only do infervalue for StridedSliceOnnx when begin, end and stride are non-negative."; + return nullptr; + } + std::unordered_set pos; + int index = begin[i]; + while (index < end[i]) { + (void)pos.insert(IntToSize(index)); + index += stride[i]; + } + (void)info.emplace(axis, pos); + } + + TM *input_x = + static_cast(std::static_pointer_cast(inputs[input_index])->data()->data_c()); + + std::vector res; + + std::function func; + func = [&func, &input_x, &res, &info, &input_shape](size_t dim, size_t offset) { + if ((dim + 1) == input_shape.size()) { + for (size_t i = 0; i < LongToSize(input_shape[dim]); i++) { + if (info.count(SizeToInt(dim)) > 0) { + if (info[SizeToInt(dim)].count(i) > 0) { + (void)res.emplace_back(input_x[offset + i]); + } + } else { + (void)res.emplace_back(input_x[offset + i]); + } + } + } else if ((dim + 1) < input_shape.size()) { + size_t accu = 1; + for (size_t j = dim + 1; j < input_shape.size(); j++) { + accu *= LongToSize(input_shape[j]); + } + for (size_t i = 0; i < LongToSize(input_shape[dim]); i++) { + if (info.count(SizeToInt(dim)) > 0) { + if (info[SizeToInt(dim)].count(i) > 0) { + func(dim + 1, offset + i * accu); + } + } else { + func(dim + 1, offset + i * accu); + } + } + } + return; + }; + func(0, 0); + return std::make_shared(this->type, this->shape, &res[0], this->type); +} + +NodePtr StridedSliceOnnxOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs) { + for (auto i : inputs) { + if (i->NodeType() != NType::Value) { + return nullptr; + } + } + TypeId output_type = this->type; + tensor::TensorPtr res = nullptr; + switch (static_cast(output_type)) { + case TypeId::kNumberTypeUInt8: { + res = CalcStridedSliceOnnx(inputs, attrs); + break; + } + case TypeId::kNumberTypeInt8: { + res = CalcStridedSliceOnnx(inputs, attrs); + break; + } + case TypeId::kNumberTypeInt16: { + res = CalcStridedSliceOnnx(inputs, attrs); + break; + } + case TypeId::kNumberTypeInt32: { + res = CalcStridedSliceOnnx(inputs, attrs); + break; + } + case TypeId::kNumberTypeInt64: { + res = CalcStridedSliceOnnx(inputs, attrs); + break; + } + case TypeId::kNumberTypeUInt16: { + res = CalcStridedSliceOnnx(inputs, attrs); + break; + } + case TypeId::kNumberTypeUInt32: { + res = CalcStridedSliceOnnx(inputs, attrs); + break; + } + case TypeId::kNumberTypeUInt64: { + res = CalcStridedSliceOnnx(inputs, attrs); + break; + } + case TypeId::kNumberTypeFloat16: { + res = CalcStridedSliceOnnx(inputs, attrs); + break; + } + case TypeId::kNumberTypeFloat32: { + res = CalcStridedSliceOnnx(inputs, attrs); + break; + } + case TypeId::kNumberTypeFloat64: { + res = CalcStridedSliceOnnx(inputs, attrs); + break; + } + default: + return nullptr; + } + return res == nullptr ? nullptr : std::make_shared(res); +} + void MatMulOp::RectifyAbstract(const PrimitivePtr &primitive, AbstractBasePtrList *inputs_abstract) { if (primitive->HasAttr("dst_type")) { auto out_type = primitive->GetAttr("dst_type"); diff --git a/mindspore/ccsrc/common/graph_kernel/model/op_node.h b/mindspore/ccsrc/common/graph_kernel/model/op_node.h index c90c9e2033f..cfaf3ec7d19 100644 --- a/mindspore/ccsrc/common/graph_kernel/model/op_node.h +++ b/mindspore/ccsrc/common/graph_kernel/model/op_node.h @@ -326,6 +326,22 @@ class StridedSliceOp : public OpaqueOp { void RectifyAbstract(const PrimitivePtr &primitive, AbstractBasePtrList *inputs_abstract) override; }; +class StridedSliceOnnxOp : public OpaqueOp { + public: + explicit StridedSliceOnnxOp(const std::string &op) : OpaqueOp(op) {} + ~StridedSliceOnnxOp() = default; + NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs) override; + + protected: + template + tensor::TensorPtr CalcStridedSliceOnnx(const NodePtrList &inputs, const DAttrs &attrs); + std::vector InferShape(const NodePtrList &, const DAttrs &attrs) override { + return GetValue>(attrs.find("output_shape")->second); + } + std::vector InferType(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->type}; } + DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; } +}; + class MatMulOp : public OpaqueOp { public: explicit MatMulOp(const std::string &op) : OpaqueOp(op) {} diff --git a/mindspore/ccsrc/common/graph_kernel/model/op_register.cc b/mindspore/ccsrc/common/graph_kernel/model/op_register.cc index 251e84a896d..9468e3aab23 100644 --- a/mindspore/ccsrc/common/graph_kernel/model/op_register.cc +++ b/mindspore/ccsrc/common/graph_kernel/model/op_register.cc @@ -113,6 +113,7 @@ OP_REGISTER("BatchMatMul", OpaqueOp); OP_REGISTER("CumSum", OpaqueOp); OP_REGISTER("OneHot", OpaqueOp); OP_REGISTER("StridedSlice", StridedSliceOp); +OP_REGISTER("StridedSliceOnnx", StridedSliceOnnxOp); OP_REGISTER("Concat", ConcatOp); OP_REGISTER("Gather", GatherOp); OP_REGISTER("Shape", ShapeOp); diff --git a/mindspore/lite/tools/graph_kernel/converter/expanders/strided_slice.cc b/mindspore/lite/tools/graph_kernel/converter/expanders/strided_slice.cc new file mode 100644 index 00000000000..9fc209939f3 --- /dev/null +++ b/mindspore/lite/tools/graph_kernel/converter/expanders/strided_slice.cc @@ -0,0 +1,38 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "common/graph_kernel/expanders/op_desc_registry.h" + +namespace mindspore::graphkernel::expanders { +class StridedSlice : public OpDesc { + public: + StridedSlice() {} + ~StridedSlice() = default; + + protected: + NodePtrList Expand(const NodePtrList &inputs) override { + const size_t onnx_slice_input_num = 5; + if (inputs.size() != onnx_slice_input_num) return {}; + std::vector shp; + (void)shp.emplace_back(outputs_info_[0].shape); + auto result = gb.Emit("StridedSliceOnnx", inputs, {{"output_shape", MakeValue(shp)}}); + return {result}; + } +}; +EXPANDER_OP_DESC_REGISTER("StridedSlice", StridedSlice); +} // namespace mindspore::graphkernel::expanders diff --git a/mindspore/lite/tools/graph_kernel/converter/graph_kernel_expander_lite.cc b/mindspore/lite/tools/graph_kernel/converter/graph_kernel_expander_lite.cc index 55f8cd953f5..3e42818606a 100644 --- a/mindspore/lite/tools/graph_kernel/converter/graph_kernel_expander_lite.cc +++ b/mindspore/lite/tools/graph_kernel/converter/graph_kernel_expander_lite.cc @@ -35,44 +35,19 @@ #include "tools/graph_kernel/converter/preprocess_weight.h" namespace mindspore::graphkernel { -AnfNodePtr ParaToValueDeco::Run(const AnfNodePtr &node) { +AnfNodePtr TensorToValueDeco::Run(const AnfNodePtr &node) { auto cnode = QuickCloneCNode(node); for (const auto &idx : input_idx_) { - if (cnode->input(idx + 1)->isa()) { - auto param_value = cnode->input(idx + 1)->cast()->default_param()->cast(); - auto int_value = static_cast(param_value->data_ptr()->data()); - ShapeVector out_list; - std::transform(int_value, int_value + param_value->data_ptr()->size(), std::back_inserter(out_list), IntToLong); - auto value = std::make_shared(MakeValue(out_list)); - cnode->set_input(idx + 1, value); - } - } - return decorated_->Run(cnode); -} - -AnfNodePtr ParaToTensorDeco::Run(const AnfNodePtr &node) { - auto cnode = QuickCloneCNode(node); - HashSet ids; - if (convert_all_) { - for (size_t i = 1; i < cnode->inputs().size(); i++) { - (void)ids.insert(i - 1); - } - } else { - ids = input_idx_; - } - for (const auto &idx : ids) { - if (cnode->input(idx + 1)->isa()) { - auto default_param = cnode->input(idx + 1)->cast()->default_param(); - if (default_param == nullptr) { - continue; + if (cnode->input(idx + 1)->isa()) { + auto value = cnode->input(idx + 1)->cast()->value(); + if (value->isa()) { + auto param_value = value->cast(); + auto int_value = static_cast(param_value->data_ptr()->data()); + ShapeVector out_list; + std::transform(int_value, int_value + param_value->data_ptr()->size(), std::back_inserter(out_list), IntToLong); + auto new_value = std::make_shared(MakeValue(out_list)); + cnode->set_input(idx + 1, new_value); } - auto param_value = default_param->cast(); - if (param_value == nullptr) { - continue; - } - auto value = NewValueNode(param_value); - value->set_abstract(param_value->ToAbstract()); - cnode->set_input(idx + 1, value); } } return decorated_->Run(cnode); @@ -180,7 +155,8 @@ std::vector GraphKernelExpanderLite::InitOpList() { {kCPUDevice, OpLevel_1, prim::kPrimUnsqueeze}, {kCPUDevice, OpLevel_1, prim::kPrimGather}, {kCPUDevice, OpLevel_1, prim::kPrimShape}, {kCPUDevice, OpLevel_1, prim::kPrimConcat}, {kCPUDevice, OpLevel_1, prim::kPrimConstantOfShape}, {kCPUDevice, OpLevel_1, prim::kPrimConv2DFusion}, - {kCPUDevice, OpLevel_1, prim::kPrimAvgPoolFusion}, {kCPUDevice, OpLevel_1, prim::kPrimMaxPoolFusion}}; + {kCPUDevice, OpLevel_1, prim::kPrimAvgPoolFusion}, {kCPUDevice, OpLevel_1, prim::kPrimMaxPoolFusion}, + {kCPUDevice, OpLevel_1, prim::kPrimStridedSlice}}; const auto &flags = GraphKernelFlags::GetInstance(); return GkUtils::GetValidOps(expand_ops_with_level, flags.fusion_ops_level, flags.enable_expand_ops_only, flags.enable_expand_ops, flags.disable_expand_ops); @@ -213,13 +189,13 @@ ExpanderPtr GraphKernelExpanderLite::InitExpander(const AnfNodePtr &node) { {prim::kPrimShape->name(), {FixFormatDeco::Creator}}, {prim::kPrimReshape->name(), {InputToAttrDeco::GetCreator({1}), FixFormatDeco::Creator}}, {prim::kPrimConstantOfShape->name(), {InputToAttrDeco::GetCreator({0}), FixFormatDeco::Creator}}, - {prim::kPrimTranspose->name(), {ParaToValueDeco::GetCreator({1}), InputToAttrDeco::GetCreator({1})}}, + {prim::kPrimTranspose->name(), {TensorToValueDeco::GetCreator({1}), InputToAttrDeco::GetCreator({1})}}, {prim::kPrimGather->name(), - {ParaToTensorDeco::GetCreator({1}), ParaToValueDeco::GetCreator({2}), InputToAttrDeco::GetCreator({2}), - FixFormatDeco::Creator}}, - {prim::kPrimConcat->name(), {ParaToTensorDeco::GetCreator({}, true), FixFormatDeco::Creator}}, - {prim::kPrimConv2DFusion->name(), {ParaToTensorDeco::GetCreator({1}), SubstituteConv2D::Creator}}, - {prim::kPrimMatMulFusion->name(), {ParaToTensorDeco::GetCreator({1}), MatmulPackB::Creator}}, + {TensorToValueDeco::GetCreator({2}), InputToAttrDeco::GetCreator({2}), FixFormatDeco::Creator}}, + {prim::kPrimConcat->name(), {FixFormatDeco::Creator}}, + {prim::kPrimStridedSlice->name(), {FixFormatDeco::Creator}}, + {prim::kPrimConv2DFusion->name(), {SubstituteConv2D::Creator}}, + {prim::kPrimMatMulFusion->name(), {MatmulPackB::Creator}}, {prim::kPrimAvgPoolFusion->name(), {PoolLayoutDeco::Creator}}, {prim::kPrimMaxPoolFusion->name(), {PoolLayoutDeco::Creator}}, }; diff --git a/mindspore/lite/tools/graph_kernel/converter/graph_kernel_expander_lite.h b/mindspore/lite/tools/graph_kernel/converter/graph_kernel_expander_lite.h index 06cb41a81a3..c113310922c 100644 --- a/mindspore/lite/tools/graph_kernel/converter/graph_kernel_expander_lite.h +++ b/mindspore/lite/tools/graph_kernel/converter/graph_kernel_expander_lite.h @@ -23,15 +23,15 @@ #include "utils/hash_set.h" namespace mindspore::graphkernel { -class ParaToValueDeco : public ExpanderDecorator { +class TensorToValueDeco : public ExpanderDecorator { public: - ParaToValueDeco(const ExpanderPtr &decorated, const HashSet &input_idx) + TensorToValueDeco(const ExpanderPtr &decorated, const HashSet &input_idx) : ExpanderDecorator(decorated), input_idx_(input_idx) {} - ~ParaToValueDeco() = default; + ~TensorToValueDeco() = default; static ExpanderCreatorFunc GetCreator(const HashSet &input_idx) { return [input_idx](const ExpanderPtr &decorated) { - return std::static_pointer_cast(std::make_shared(decorated, input_idx)); + return std::static_pointer_cast(std::make_shared(decorated, input_idx)); }; } AnfNodePtr Run(const AnfNodePtr &node) override; @@ -40,24 +40,6 @@ class ParaToValueDeco : public ExpanderDecorator { HashSet input_idx_; }; -class ParaToTensorDeco : public ExpanderDecorator { - public: - ParaToTensorDeco(const ExpanderPtr &decorated, const HashSet &input_idx, bool convert_all = false) - : ExpanderDecorator(decorated), input_idx_(input_idx), convert_all_(convert_all) {} - ~ParaToTensorDeco() = default; - - static ExpanderCreatorFunc GetCreator(const HashSet &input_idx, bool convert_all = false) { - return [input_idx, convert_all](const ExpanderPtr &decorated) { - return std::static_pointer_cast(std::make_shared(decorated, input_idx, convert_all)); - }; - } - AnfNodePtr Run(const AnfNodePtr &node) override; - - protected: - HashSet input_idx_; - bool convert_all_; -}; - class FixFormatDeco : public ExpanderDecorator { public: explicit FixFormatDeco(const ExpanderPtr &decorated) : ExpanderDecorator(decorated) {} diff --git a/mindspore/lite/tools/graph_kernel/converter/graph_kernel_optimization.cc b/mindspore/lite/tools/graph_kernel/converter/graph_kernel_optimization.cc index aa29c6df6ff..7d74ae2925f 100644 --- a/mindspore/lite/tools/graph_kernel/converter/graph_kernel_optimization.cc +++ b/mindspore/lite/tools/graph_kernel/converter/graph_kernel_optimization.cc @@ -36,6 +36,7 @@ #include "tools/graph_kernel/converter/graph_kernel_expander_lite.h" #include "tools/graph_kernel/converter/insert_abstract.h" #include "tools/graph_kernel/converter/graph_kernel_splitter_lite.h" +#include "tools/graph_kernel/converter/parameter_to_tensor.h" namespace mindspore::graphkernel { using opt::GetitemTuple; @@ -51,6 +52,7 @@ GkPassManagerPtr GraphKernelOptimizer::PreProcess() const { // Some ops may lose abstract in converter pm->Add(std::make_shared(), OptLevel_1); pm->Add(std::make_shared(), OptLevel_1); + pm->Add(std::make_shared(), OptLevel_1); return pm; } diff --git a/mindspore/lite/tools/graph_kernel/converter/parameter_to_tensor.cc b/mindspore/lite/tools/graph_kernel/converter/parameter_to_tensor.cc new file mode 100644 index 00000000000..970131bcd14 --- /dev/null +++ b/mindspore/lite/tools/graph_kernel/converter/parameter_to_tensor.cc @@ -0,0 +1,50 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tools/graph_kernel/converter/parameter_to_tensor.h" + +namespace mindspore::graphkernel { +bool ParameterToTensor::Run(const FuncGraphPtr &func_graph) { + auto todos = TopoSort(func_graph->output()); + auto mng = func_graph->manager(); + if (mng == nullptr) { + mng = Manage(func_graph, true); + func_graph->set_manager(mng); + } + for (auto &node : todos) { + auto cnode = node->cast(); + if (cnode == nullptr) { + continue; + } + + for (size_t idx = 1; i < cnode->inputs().size(); idx++) { + if (cnode->input(idx)->isa()) { + auto default_param = cnode->input(idx)->cast()->default_param(); + if (default_param == nullptr) { + continue; + } + auto param_value = default_param->cast(); + if (param_value == nullptr) { + continue; + } + auto value = NewValueNode(param_value); + value->set_abstract(param_value->ToAbstract()); + (void)mng->Replace(cnode->input(idx), value); + } + } + } + return true; +} +} // namespace mindspore::graphkernel diff --git a/mindspore/lite/tools/graph_kernel/converter/parameter_to_tensor.h b/mindspore/lite/tools/graph_kernel/converter/parameter_to_tensor.h new file mode 100644 index 00000000000..58bb03fb791 --- /dev/null +++ b/mindspore/lite/tools/graph_kernel/converter/parameter_to_tensor.h @@ -0,0 +1,29 @@ +/** + * Copyright 2022 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_PARAMETER_TO_TENSOR_H_ +#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_PARAMETER_TO_TENSOR_H_ +#include "ir/func_graph.h" +#include "backend/common/optimizer/pass.h" + +namespace mindspore::graphkernel { +class ParameterToTensor : public opt::Pass { + public: + ParameterToTensor() : Pass("parameter_to_tensor") {} + ~ParameterToTensor() override = default; + bool Run(const FuncGraphPtr &func_graph) override; +}; +} // namespace mindspore::graphkernel +#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_PARAMETER_TO_TENSOR_H_