fixbug in expander Concat; fix cmakelist for graphkenrel in mslite

This commit is contained in:
dayschan 2022-08-12 10:53:21 +08:00
parent 4c5d3b58ef
commit b5e1c7d638
4 changed files with 5 additions and 7 deletions

View File

@ -33,9 +33,9 @@ NodePtr GraphBuilder::Gather(const NodePtr &param, const NodePtr &indice, const
return Emit("Gather", {param, indice}, {{"axis", axis_value}});
}
NodePtr GraphBuilder::Concat(const NodePtr &param, const NodePtr &indice, const int64_t &axis) const {
NodePtr GraphBuilder::Concat(const NodePtrList &inputs, const int64_t &axis) const {
auto axis_value = MakeValue(axis);
return Emit("Concat", {param, indice}, {{"axis", axis_value}});
return Emit("Concat", inputs, {{"axis", axis_value}});
}
NodePtr GraphBuilder::Transpose(const NodePtr &input, const ShapeVector &perm) const {

View File

@ -67,7 +67,7 @@ class GraphBuilder : public LiteGraph::GraphBuilderBase {
NodePtr Reshape(const NodePtr &input, const ShapeVector &shape) const;
NodePtr BroadcastTo(const NodePtr &input, const ShapeVector &shape) const;
NodePtr Gather(const NodePtr &param, const NodePtr &indice, const int64_t &axis) const;
NodePtr Concat(const NodePtr &param, const NodePtr &indice, const int64_t &axis) const;
NodePtr Concat(const NodePtrList &inputs, const int64_t &axis) const;
NodePtr Transpose(const NodePtr &input, const ShapeVector &perm) const;
NodePtr ReduceSum(const NodePtr &input, const std::vector<int64_t> &axis, const bool &keep_dims = false) const;

View File

@ -377,7 +377,7 @@ endif()
if(MSLITE_ENABLE_GRAPH_KERNEL)
file(GLOB_RECURSE GRAPH_KERNEL_SRC
${TOOLS_DIR}/graph_kernel/common/*.cc
${TOOLS_DIR}/graph_kernel/litert/*.cc
${TOOLS_DIR}/graph_kernel/runtime/*.cc
)
set(LITE_SRC ${LITE_SRC} ${GRAPH_KERNEL_SRC})
endif()

View File

@ -29,10 +29,8 @@ class Concat : public OpDesc {
protected:
NodePtrList Expand(const NodePtrList &inputs) override {
const auto &input_x = inputs[0];
const auto &input_y = inputs[1];
auto axis = GetValue<int64_t>(attrs_["axis"]);
auto result = gb.Concat(input_x, input_y, axis);
auto result = gb.Concat(inputs, axis);
return {result};
}
};