para to tensor
This commit is contained in:
parent
4a00560a8a
commit
e9c5192d27
|
@ -18,6 +18,8 @@
|
|||
#include <cmath>
|
||||
#include <algorithm>
|
||||
#include <set>
|
||||
#include <unordered_set>
|
||||
#include <unordered_map>
|
||||
#include <sstream>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
|
@ -1002,6 +1004,133 @@ void StridedSliceOp::RectifyAbstract(const PrimitivePtr &primitive, AbstractBase
|
|||
SetAbastractsFromAttrs(primitive, convert_input_list, inputs_abstract, input_names_vec);
|
||||
}
|
||||
|
||||
template <typename TM>
|
||||
tensor::TensorPtr StridedSliceOnnxOp::CalcStridedSliceOnnx(const NodePtrList &inputs, const DAttrs &attrs) {
|
||||
constexpr size_t input_index = 0;
|
||||
constexpr size_t begin_index = 1;
|
||||
constexpr size_t end_index = 2;
|
||||
constexpr size_t axes_index = 3;
|
||||
constexpr size_t stride_index = 4;
|
||||
|
||||
ShapeVector input_shape = inputs[input_index]->shape;
|
||||
std::vector<int> begin = ChangeDataToVec<int, int>(inputs[begin_index]);
|
||||
std::vector<int> end = ChangeDataToVec<int, int>(inputs[end_index]);
|
||||
std::vector<int> axes = ChangeDataToVec<int, int>(inputs[axes_index]);
|
||||
std::vector<int> stride = ChangeDataToVec<int, int>(inputs[stride_index]);
|
||||
|
||||
std::unordered_map<int, std::unordered_set<size_t>> info;
|
||||
for (size_t i = 0; i < axes.size(); i++) {
|
||||
int axis = axes[i] < 0 ? axes[i] + SizeToInt(input_shape.size()) : axes[i];
|
||||
if (begin[i] < 0 || end[i] < 0 || stride[i] < 0) {
|
||||
MS_LOG(INFO) << "Only do infervalue for StridedSliceOnnx when begin, end and stride are non-negative.";
|
||||
return nullptr;
|
||||
}
|
||||
std::unordered_set<size_t> pos;
|
||||
int index = begin[i];
|
||||
while (index < end[i]) {
|
||||
(void)pos.insert(IntToSize(index));
|
||||
index += stride[i];
|
||||
}
|
||||
(void)info.emplace(axis, pos);
|
||||
}
|
||||
|
||||
TM *input_x =
|
||||
static_cast<TM *>(std::static_pointer_cast<inner::ConstTensorNode>(inputs[input_index])->data()->data_c());
|
||||
|
||||
std::vector<TM> res;
|
||||
|
||||
std::function<void(size_t, size_t)> func;
|
||||
func = [&func, &input_x, &res, &info, &input_shape](size_t dim, size_t offset) {
|
||||
if ((dim + 1) == input_shape.size()) {
|
||||
for (size_t i = 0; i < LongToSize(input_shape[dim]); i++) {
|
||||
if (info.count(SizeToInt(dim)) > 0) {
|
||||
if (info[SizeToInt(dim)].count(i) > 0) {
|
||||
(void)res.emplace_back(input_x[offset + i]);
|
||||
}
|
||||
} else {
|
||||
(void)res.emplace_back(input_x[offset + i]);
|
||||
}
|
||||
}
|
||||
} else if ((dim + 1) < input_shape.size()) {
|
||||
size_t accu = 1;
|
||||
for (size_t j = dim + 1; j < input_shape.size(); j++) {
|
||||
accu *= LongToSize(input_shape[j]);
|
||||
}
|
||||
for (size_t i = 0; i < LongToSize(input_shape[dim]); i++) {
|
||||
if (info.count(SizeToInt(dim)) > 0) {
|
||||
if (info[SizeToInt(dim)].count(i) > 0) {
|
||||
func(dim + 1, offset + i * accu);
|
||||
}
|
||||
} else {
|
||||
func(dim + 1, offset + i * accu);
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
};
|
||||
func(0, 0);
|
||||
return std::make_shared<tensor::Tensor>(this->type, this->shape, &res[0], this->type);
|
||||
}
|
||||
|
||||
NodePtr StridedSliceOnnxOp::InferValue(const NodePtrList &inputs, const DAttrs &attrs) {
|
||||
for (auto i : inputs) {
|
||||
if (i->NodeType() != NType::Value) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
TypeId output_type = this->type;
|
||||
tensor::TensorPtr res = nullptr;
|
||||
switch (static_cast<int>(output_type)) {
|
||||
case TypeId::kNumberTypeUInt8: {
|
||||
res = CalcStridedSliceOnnx<uint8_t>(inputs, attrs);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeInt8: {
|
||||
res = CalcStridedSliceOnnx<int8_t>(inputs, attrs);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeInt16: {
|
||||
res = CalcStridedSliceOnnx<int16_t>(inputs, attrs);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeInt32: {
|
||||
res = CalcStridedSliceOnnx<int32_t>(inputs, attrs);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeInt64: {
|
||||
res = CalcStridedSliceOnnx<int64_t>(inputs, attrs);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeUInt16: {
|
||||
res = CalcStridedSliceOnnx<uint16_t>(inputs, attrs);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeUInt32: {
|
||||
res = CalcStridedSliceOnnx<uint32_t>(inputs, attrs);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeUInt64: {
|
||||
res = CalcStridedSliceOnnx<uint64_t>(inputs, attrs);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeFloat16: {
|
||||
res = CalcStridedSliceOnnx<float16>(inputs, attrs);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeFloat32: {
|
||||
res = CalcStridedSliceOnnx<float>(inputs, attrs);
|
||||
break;
|
||||
}
|
||||
case TypeId::kNumberTypeFloat64: {
|
||||
res = CalcStridedSliceOnnx<double>(inputs, attrs);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return nullptr;
|
||||
}
|
||||
return res == nullptr ? nullptr : std::make_shared<ConstTensorNode>(res);
|
||||
}
|
||||
|
||||
void MatMulOp::RectifyAbstract(const PrimitivePtr &primitive, AbstractBasePtrList *inputs_abstract) {
|
||||
if (primitive->HasAttr("dst_type")) {
|
||||
auto out_type = primitive->GetAttr("dst_type");
|
||||
|
|
|
@ -326,6 +326,22 @@ class StridedSliceOp : public OpaqueOp {
|
|||
void RectifyAbstract(const PrimitivePtr &primitive, AbstractBasePtrList *inputs_abstract) override;
|
||||
};
|
||||
|
||||
class StridedSliceOnnxOp : public OpaqueOp {
|
||||
public:
|
||||
explicit StridedSliceOnnxOp(const std::string &op) : OpaqueOp(op) {}
|
||||
~StridedSliceOnnxOp() = default;
|
||||
NodePtr InferValue(const NodePtrList &inputs, const DAttrs &attrs) override;
|
||||
|
||||
protected:
|
||||
template <typename TM>
|
||||
tensor::TensorPtr CalcStridedSliceOnnx(const NodePtrList &inputs, const DAttrs &attrs);
|
||||
std::vector<DShape> InferShape(const NodePtrList &, const DAttrs &attrs) override {
|
||||
return GetValue<std::vector<DShape>>(attrs.find("output_shape")->second);
|
||||
}
|
||||
std::vector<TypeId> InferType(const NodePtrList &inputs, const DAttrs &) override { return {inputs[0]->type}; }
|
||||
DFormat InferFormat(const NodePtrList &, const DAttrs &) override { return kOpFormat_DEFAULT; }
|
||||
};
|
||||
|
||||
class MatMulOp : public OpaqueOp {
|
||||
public:
|
||||
explicit MatMulOp(const std::string &op) : OpaqueOp(op) {}
|
||||
|
|
|
@ -113,6 +113,7 @@ OP_REGISTER("BatchMatMul", OpaqueOp);
|
|||
OP_REGISTER("CumSum", OpaqueOp);
|
||||
OP_REGISTER("OneHot", OpaqueOp);
|
||||
OP_REGISTER("StridedSlice", StridedSliceOp);
|
||||
OP_REGISTER("StridedSliceOnnx", StridedSliceOnnxOp);
|
||||
OP_REGISTER("Concat", ConcatOp);
|
||||
OP_REGISTER("Gather", GatherOp);
|
||||
OP_REGISTER("Shape", ShapeOp);
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "common/graph_kernel/expanders/op_desc_registry.h"
|
||||
|
||||
namespace mindspore::graphkernel::expanders {
|
||||
class StridedSlice : public OpDesc {
|
||||
public:
|
||||
StridedSlice() {}
|
||||
~StridedSlice() = default;
|
||||
|
||||
protected:
|
||||
NodePtrList Expand(const NodePtrList &inputs) override {
|
||||
const size_t onnx_slice_input_num = 5;
|
||||
if (inputs.size() != onnx_slice_input_num) return {};
|
||||
std::vector<inner::DShape> shp;
|
||||
(void)shp.emplace_back(outputs_info_[0].shape);
|
||||
auto result = gb.Emit("StridedSliceOnnx", inputs, {{"output_shape", MakeValue(shp)}});
|
||||
return {result};
|
||||
}
|
||||
};
|
||||
EXPANDER_OP_DESC_REGISTER("StridedSlice", StridedSlice);
|
||||
} // namespace mindspore::graphkernel::expanders
|
|
@ -35,44 +35,19 @@
|
|||
#include "tools/graph_kernel/converter/preprocess_weight.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
AnfNodePtr ParaToValueDeco::Run(const AnfNodePtr &node) {
|
||||
AnfNodePtr TensorToValueDeco::Run(const AnfNodePtr &node) {
|
||||
auto cnode = QuickCloneCNode(node);
|
||||
for (const auto &idx : input_idx_) {
|
||||
if (cnode->input(idx + 1)->isa<Parameter>()) {
|
||||
auto param_value = cnode->input(idx + 1)->cast<ParameterPtr>()->default_param()->cast<tensor::TensorPtr>();
|
||||
auto int_value = static_cast<int *>(param_value->data_ptr()->data());
|
||||
ShapeVector out_list;
|
||||
std::transform(int_value, int_value + param_value->data_ptr()->size(), std::back_inserter(out_list), IntToLong);
|
||||
auto value = std::make_shared<ValueNode>(MakeValue(out_list));
|
||||
cnode->set_input(idx + 1, value);
|
||||
}
|
||||
}
|
||||
return decorated_->Run(cnode);
|
||||
}
|
||||
|
||||
AnfNodePtr ParaToTensorDeco::Run(const AnfNodePtr &node) {
|
||||
auto cnode = QuickCloneCNode(node);
|
||||
HashSet<size_t> ids;
|
||||
if (convert_all_) {
|
||||
for (size_t i = 1; i < cnode->inputs().size(); i++) {
|
||||
(void)ids.insert(i - 1);
|
||||
}
|
||||
} else {
|
||||
ids = input_idx_;
|
||||
}
|
||||
for (const auto &idx : ids) {
|
||||
if (cnode->input(idx + 1)->isa<Parameter>()) {
|
||||
auto default_param = cnode->input(idx + 1)->cast<ParameterPtr>()->default_param();
|
||||
if (default_param == nullptr) {
|
||||
continue;
|
||||
if (cnode->input(idx + 1)->isa<ValueNode>()) {
|
||||
auto value = cnode->input(idx + 1)->cast<ValueNodePtr>()->value();
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
auto param_value = value->cast<tensor::TensorPtr>();
|
||||
auto int_value = static_cast<int *>(param_value->data_ptr()->data());
|
||||
ShapeVector out_list;
|
||||
std::transform(int_value, int_value + param_value->data_ptr()->size(), std::back_inserter(out_list), IntToLong);
|
||||
auto new_value = std::make_shared<ValueNode>(MakeValue(out_list));
|
||||
cnode->set_input(idx + 1, new_value);
|
||||
}
|
||||
auto param_value = default_param->cast<tensor::TensorPtr>();
|
||||
if (param_value == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto value = NewValueNode(param_value);
|
||||
value->set_abstract(param_value->ToAbstract());
|
||||
cnode->set_input(idx + 1, value);
|
||||
}
|
||||
}
|
||||
return decorated_->Run(cnode);
|
||||
|
@ -180,7 +155,8 @@ std::vector<PrimitivePtr> GraphKernelExpanderLite::InitOpList() {
|
|||
{kCPUDevice, OpLevel_1, prim::kPrimUnsqueeze}, {kCPUDevice, OpLevel_1, prim::kPrimGather},
|
||||
{kCPUDevice, OpLevel_1, prim::kPrimShape}, {kCPUDevice, OpLevel_1, prim::kPrimConcat},
|
||||
{kCPUDevice, OpLevel_1, prim::kPrimConstantOfShape}, {kCPUDevice, OpLevel_1, prim::kPrimConv2DFusion},
|
||||
{kCPUDevice, OpLevel_1, prim::kPrimAvgPoolFusion}, {kCPUDevice, OpLevel_1, prim::kPrimMaxPoolFusion}};
|
||||
{kCPUDevice, OpLevel_1, prim::kPrimAvgPoolFusion}, {kCPUDevice, OpLevel_1, prim::kPrimMaxPoolFusion},
|
||||
{kCPUDevice, OpLevel_1, prim::kPrimStridedSlice}};
|
||||
const auto &flags = GraphKernelFlags::GetInstance();
|
||||
return GkUtils::GetValidOps(expand_ops_with_level, flags.fusion_ops_level, flags.enable_expand_ops_only,
|
||||
flags.enable_expand_ops, flags.disable_expand_ops);
|
||||
|
@ -213,13 +189,13 @@ ExpanderPtr GraphKernelExpanderLite::InitExpander(const AnfNodePtr &node) {
|
|||
{prim::kPrimShape->name(), {FixFormatDeco::Creator}},
|
||||
{prim::kPrimReshape->name(), {InputToAttrDeco::GetCreator({1}), FixFormatDeco::Creator}},
|
||||
{prim::kPrimConstantOfShape->name(), {InputToAttrDeco::GetCreator({0}), FixFormatDeco::Creator}},
|
||||
{prim::kPrimTranspose->name(), {ParaToValueDeco::GetCreator({1}), InputToAttrDeco::GetCreator({1})}},
|
||||
{prim::kPrimTranspose->name(), {TensorToValueDeco::GetCreator({1}), InputToAttrDeco::GetCreator({1})}},
|
||||
{prim::kPrimGather->name(),
|
||||
{ParaToTensorDeco::GetCreator({1}), ParaToValueDeco::GetCreator({2}), InputToAttrDeco::GetCreator({2}),
|
||||
FixFormatDeco::Creator}},
|
||||
{prim::kPrimConcat->name(), {ParaToTensorDeco::GetCreator({}, true), FixFormatDeco::Creator}},
|
||||
{prim::kPrimConv2DFusion->name(), {ParaToTensorDeco::GetCreator({1}), SubstituteConv2D::Creator}},
|
||||
{prim::kPrimMatMulFusion->name(), {ParaToTensorDeco::GetCreator({1}), MatmulPackB::Creator}},
|
||||
{TensorToValueDeco::GetCreator({2}), InputToAttrDeco::GetCreator({2}), FixFormatDeco::Creator}},
|
||||
{prim::kPrimConcat->name(), {FixFormatDeco::Creator}},
|
||||
{prim::kPrimStridedSlice->name(), {FixFormatDeco::Creator}},
|
||||
{prim::kPrimConv2DFusion->name(), {SubstituteConv2D::Creator}},
|
||||
{prim::kPrimMatMulFusion->name(), {MatmulPackB::Creator}},
|
||||
{prim::kPrimAvgPoolFusion->name(), {PoolLayoutDeco::Creator}},
|
||||
{prim::kPrimMaxPoolFusion->name(), {PoolLayoutDeco::Creator}},
|
||||
};
|
||||
|
|
|
@ -23,15 +23,15 @@
|
|||
#include "utils/hash_set.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
class ParaToValueDeco : public ExpanderDecorator {
|
||||
class TensorToValueDeco : public ExpanderDecorator {
|
||||
public:
|
||||
ParaToValueDeco(const ExpanderPtr &decorated, const HashSet<size_t> &input_idx)
|
||||
TensorToValueDeco(const ExpanderPtr &decorated, const HashSet<size_t> &input_idx)
|
||||
: ExpanderDecorator(decorated), input_idx_(input_idx) {}
|
||||
~ParaToValueDeco() = default;
|
||||
~TensorToValueDeco() = default;
|
||||
|
||||
static ExpanderCreatorFunc GetCreator(const HashSet<size_t> &input_idx) {
|
||||
return [input_idx](const ExpanderPtr &decorated) {
|
||||
return std::static_pointer_cast<Expander>(std::make_shared<ParaToValueDeco>(decorated, input_idx));
|
||||
return std::static_pointer_cast<Expander>(std::make_shared<TensorToValueDeco>(decorated, input_idx));
|
||||
};
|
||||
}
|
||||
AnfNodePtr Run(const AnfNodePtr &node) override;
|
||||
|
@ -40,24 +40,6 @@ class ParaToValueDeco : public ExpanderDecorator {
|
|||
HashSet<size_t> input_idx_;
|
||||
};
|
||||
|
||||
class ParaToTensorDeco : public ExpanderDecorator {
|
||||
public:
|
||||
ParaToTensorDeco(const ExpanderPtr &decorated, const HashSet<size_t> &input_idx, bool convert_all = false)
|
||||
: ExpanderDecorator(decorated), input_idx_(input_idx), convert_all_(convert_all) {}
|
||||
~ParaToTensorDeco() = default;
|
||||
|
||||
static ExpanderCreatorFunc GetCreator(const HashSet<size_t> &input_idx, bool convert_all = false) {
|
||||
return [input_idx, convert_all](const ExpanderPtr &decorated) {
|
||||
return std::static_pointer_cast<Expander>(std::make_shared<ParaToTensorDeco>(decorated, input_idx, convert_all));
|
||||
};
|
||||
}
|
||||
AnfNodePtr Run(const AnfNodePtr &node) override;
|
||||
|
||||
protected:
|
||||
HashSet<size_t> input_idx_;
|
||||
bool convert_all_;
|
||||
};
|
||||
|
||||
class FixFormatDeco : public ExpanderDecorator {
|
||||
public:
|
||||
explicit FixFormatDeco(const ExpanderPtr &decorated) : ExpanderDecorator(decorated) {}
|
||||
|
|
|
@ -36,6 +36,7 @@
|
|||
#include "tools/graph_kernel/converter/graph_kernel_expander_lite.h"
|
||||
#include "tools/graph_kernel/converter/insert_abstract.h"
|
||||
#include "tools/graph_kernel/converter/graph_kernel_splitter_lite.h"
|
||||
#include "tools/graph_kernel/converter/parameter_to_tensor.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
using opt::GetitemTuple;
|
||||
|
@ -51,6 +52,7 @@ GkPassManagerPtr GraphKernelOptimizer::PreProcess() const {
|
|||
// Some ops may lose abstract in converter
|
||||
pm->Add(std::make_shared<InsertAbstract>(), OptLevel_1);
|
||||
pm->Add(std::make_shared<FormatRecognition>(), OptLevel_1);
|
||||
pm->Add(std::make_shared<ParameterToTensor>(), OptLevel_1);
|
||||
return pm;
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tools/graph_kernel/converter/parameter_to_tensor.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
bool ParameterToTensor::Run(const FuncGraphPtr &func_graph) {
|
||||
auto todos = TopoSort(func_graph->output());
|
||||
auto mng = func_graph->manager();
|
||||
if (mng == nullptr) {
|
||||
mng = Manage(func_graph, true);
|
||||
func_graph->set_manager(mng);
|
||||
}
|
||||
for (auto &node : todos) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode == nullptr) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (size_t idx = 1; i < cnode->inputs().size(); idx++) {
|
||||
if (cnode->input(idx)->isa<Parameter>()) {
|
||||
auto default_param = cnode->input(idx)->cast<ParameterPtr>()->default_param();
|
||||
if (default_param == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto param_value = default_param->cast<tensor::TensorPtr>();
|
||||
if (param_value == nullptr) {
|
||||
continue;
|
||||
}
|
||||
auto value = NewValueNode(param_value);
|
||||
value->set_abstract(param_value->ToAbstract());
|
||||
(void)mng->Replace(cnode->input(idx), value);
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
} // namespace mindspore::graphkernel
|
|
@ -0,0 +1,29 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_PARAMETER_TO_TENSOR_H_
|
||||
#define MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_PARAMETER_TO_TENSOR_H_
|
||||
#include "ir/func_graph.h"
|
||||
#include "backend/common/optimizer/pass.h"
|
||||
|
||||
namespace mindspore::graphkernel {
|
||||
class ParameterToTensor : public opt::Pass {
|
||||
public:
|
||||
ParameterToTensor() : Pass("parameter_to_tensor") {}
|
||||
~ParameterToTensor() override = default;
|
||||
bool Run(const FuncGraphPtr &func_graph) override;
|
||||
};
|
||||
} // namespace mindspore::graphkernel
|
||||
#endif // MINDSPORE_LITE_TOOLS_GRAPH_KERNEL_CONVERTER_PARAMETER_TO_TENSOR_H_
|
Loading…
Reference in New Issue