Merge pull request !36426 from DeshiChen/0623_fixbug
This commit is contained in:
i-robot 2022-06-24 09:10:38 +00:00 committed by Gitee
commit 260112518e
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 3 additions and 2 deletions

View File

@ -198,6 +198,7 @@ tensor::TensorPtr CalcByOperator(const NodePtrList &inputs, const std::string &o
{"Abs", [](const std::vector<TM> &n) { return n[0] <= TM(0) ? (TM(0) - n[0]) : n[0]; }},
{"Sqrt", [](const std::vector<TM> &n) { return sqrt(n[0]); }},
{"Rsqrt", [](const std::vector<TM> &n) { return TM(1) / sqrt(n[0]); }},
{"Reshape", [](const std::vector<TM> &n) { return n[0]; }},
};
if (func_map.find(op) == func_map.end()) {
return nullptr;

View File

@ -50,7 +50,7 @@ void ScatterNdCheckShape(const PrimitivePtr &prim, const std::vector<AbstractBas
TypePtr ScatterNdInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
auto dtype = input_args[kInputIndex1]->BuildType();
(void)CheckAndConvertUtils::CheckTensorTypeValid("updates", dtype, common_valid_types, prim->name());
(void)CheckAndConvertUtils::CheckSubClass("updates", dtype, {kTensorType}, prim->name());
return dtype;
}

View File

@ -384,7 +384,7 @@ class GraphSplitByPattern:
self.unique_id = unique_id
self.reach_tab = reach_tab
self.checkers = []
if self.pattern == PrimLib.RESHAPE:
if self.pattern == PrimLib.RESHAPE and init_op.inputs: # reshape's input may be empty (const value)
self.checkers.append(ReshapeElimChecker(init_op))
elif self.pattern == PrimLib.REDUCE:
self.checkers.append(ReduceOutFuseChecker(init_op))