* index error in graph_split when reshape's input is empty (it's a const value in anfgraph),
  in this case, we also add the reshape into CalcByOperator so that graphkernel can optimize it.
* scatter_nd's input data supports the tensor of all dtype
This commit is contained in:
dayschan 2022-06-23 17:55:39 +08:00
parent 09e010a0d5
commit c3027943dc
3 changed files with 3 additions and 2 deletions

View File

@ -198,6 +198,7 @@ tensor::TensorPtr CalcByOperator(const NodePtrList &inputs, const std::string &o
{"Abs", [](const std::vector<TM> &n) { return n[0] <= TM(0) ? (TM(0) - n[0]) : n[0]; }},
{"Sqrt", [](const std::vector<TM> &n) { return sqrt(n[0]); }},
{"Rsqrt", [](const std::vector<TM> &n) { return TM(1) / sqrt(n[0]); }},
{"Reshape", [](const std::vector<TM> &n) { return n[0]; }},
};
if (func_map.find(op) == func_map.end()) {
return nullptr;

View File

@ -50,7 +50,7 @@ void ScatterNdCheckShape(const PrimitivePtr &prim, const std::vector<AbstractBas
TypePtr ScatterNdInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto dtype = input_args[kInputIndex1]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("updates", dtype, common_valid_types, prim->name());
(void)CheckAndConvertUtils::CheckSubClass("updates", dtype, {kTensorType}, prim->name());
return dtype;
}

View File

@ -384,7 +384,7 @@ class GraphSplitByPattern:
self.unique_id = unique_id
self.reach_tab = reach_tab
self.checkers = []
if self.pattern == PrimLib.RESHAPE:
if self.pattern == PrimLib.RESHAPE and init_op.inputs: # reshape's input may be empty (const value)
self.checkers.append(ReshapeElimChecker(init_op))
elif self.pattern == PrimLib.REDUCE:
self.checkers.append(ReduceOutFuseChecker(init_op))