fixbug in expander Concat; fix cmakelist for graphkenrel in mslite
This commit is contained in:
parent
4c5d3b58ef
commit
b5e1c7d638
|
@ -33,9 +33,9 @@ NodePtr GraphBuilder::Gather(const NodePtr ¶m, const NodePtr &indice, const
|
|||
return Emit("Gather", {param, indice}, {{"axis", axis_value}});
|
||||
}
|
||||
|
||||
NodePtr GraphBuilder::Concat(const NodePtr ¶m, const NodePtr &indice, const int64_t &axis) const {
|
||||
NodePtr GraphBuilder::Concat(const NodePtrList &inputs, const int64_t &axis) const {
|
||||
auto axis_value = MakeValue(axis);
|
||||
return Emit("Concat", {param, indice}, {{"axis", axis_value}});
|
||||
return Emit("Concat", inputs, {{"axis", axis_value}});
|
||||
}
|
||||
|
||||
NodePtr GraphBuilder::Transpose(const NodePtr &input, const ShapeVector &perm) const {
|
||||
|
|
|
@ -67,7 +67,7 @@ class GraphBuilder : public LiteGraph::GraphBuilderBase {
|
|||
NodePtr Reshape(const NodePtr &input, const ShapeVector &shape) const;
|
||||
NodePtr BroadcastTo(const NodePtr &input, const ShapeVector &shape) const;
|
||||
NodePtr Gather(const NodePtr ¶m, const NodePtr &indice, const int64_t &axis) const;
|
||||
NodePtr Concat(const NodePtr ¶m, const NodePtr &indice, const int64_t &axis) const;
|
||||
NodePtr Concat(const NodePtrList &inputs, const int64_t &axis) const;
|
||||
NodePtr Transpose(const NodePtr &input, const ShapeVector &perm) const;
|
||||
|
||||
NodePtr ReduceSum(const NodePtr &input, const std::vector<int64_t> &axis, const bool &keep_dims = false) const;
|
||||
|
|
|
@ -377,7 +377,7 @@ endif()
|
|||
if(MSLITE_ENABLE_GRAPH_KERNEL)
|
||||
file(GLOB_RECURSE GRAPH_KERNEL_SRC
|
||||
${TOOLS_DIR}/graph_kernel/common/*.cc
|
||||
${TOOLS_DIR}/graph_kernel/litert/*.cc
|
||||
${TOOLS_DIR}/graph_kernel/runtime/*.cc
|
||||
)
|
||||
set(LITE_SRC ${LITE_SRC} ${GRAPH_KERNEL_SRC})
|
||||
endif()
|
||||
|
|
|
@ -29,10 +29,8 @@ class Concat : public OpDesc {
|
|||
|
||||
protected:
|
||||
NodePtrList Expand(const NodePtrList &inputs) override {
|
||||
const auto &input_x = inputs[0];
|
||||
const auto &input_y = inputs[1];
|
||||
auto axis = GetValue<int64_t>(attrs_["axis"]);
|
||||
auto result = gb.Concat(input_x, input_y, axis);
|
||||
auto result = gb.Concat(inputs, axis);
|
||||
return {result};
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue