!40313 Fix bug of Concat expander
Merge pull request !40313 from DeshiChen/0812_fix
This commit is contained in:
commit
6f47bc4cfa
|
@ -33,9 +33,9 @@ NodePtr GraphBuilder::Gather(const NodePtr ¶m, const NodePtr &indice, const
|
|||
return Emit("Gather", {param, indice}, {{"axis", axis_value}});
|
||||
}
|
||||
|
||||
NodePtr GraphBuilder::Concat(const NodePtr ¶m, const NodePtr &indice, const int64_t &axis) const {
|
||||
NodePtr GraphBuilder::Concat(const NodePtrList &inputs, const int64_t &axis) const {
|
||||
auto axis_value = MakeValue(axis);
|
||||
return Emit("Concat", {param, indice}, {{"axis", axis_value}});
|
||||
return Emit("Concat", inputs, {{"axis", axis_value}});
|
||||
}
|
||||
|
||||
NodePtr GraphBuilder::Transpose(const NodePtr &input, const ShapeVector &perm) const {
|
||||
|
|
|
@ -72,7 +72,7 @@ class GraphBuilder : public LiteGraph::GraphBuilderBase {
|
|||
NodePtr Reshape(const NodePtr &input, const ShapeVector &shape) const;
|
||||
NodePtr BroadcastTo(const NodePtr &input, const ShapeVector &shape) const;
|
||||
NodePtr Gather(const NodePtr ¶m, const NodePtr &indice, const int64_t &axis) const;
|
||||
NodePtr Concat(const NodePtr ¶m, const NodePtr &indice, const int64_t &axis) const;
|
||||
NodePtr Concat(const NodePtrList &inputs, const int64_t &axis) const;
|
||||
NodePtr Transpose(const NodePtr &input, const ShapeVector &perm) const;
|
||||
|
||||
NodePtr ReduceSum(const NodePtr &input, const std::vector<int64_t> &axis, const bool &keep_dims = false) const;
|
||||
|
|
|
@ -377,7 +377,7 @@ endif()
|
|||
if(MSLITE_ENABLE_GRAPH_KERNEL)
|
||||
file(GLOB_RECURSE GRAPH_KERNEL_SRC
|
||||
${TOOLS_DIR}/graph_kernel/common/*.cc
|
||||
${TOOLS_DIR}/graph_kernel/litert/*.cc
|
||||
${TOOLS_DIR}/graph_kernel/runtime/*.cc
|
||||
)
|
||||
set(LITE_SRC ${LITE_SRC} ${GRAPH_KERNEL_SRC})
|
||||
endif()
|
||||
|
|
|
@ -29,10 +29,8 @@ class Concat : public OpDesc {
|
|||
|
||||
protected:
|
||||
NodePtrList Expand(const NodePtrList &inputs) override {
|
||||
const auto &input_x = inputs[0];
|
||||
const auto &input_y = inputs[1];
|
||||
auto axis = GetValue<int64_t>(attrs_["axis"]);
|
||||
auto result = gb.Concat(input_x, input_y, axis);
|
||||
auto result = gb.Concat(inputs, axis);
|
||||
return {result};
|
||||
}
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue