commit
260112518e
|
@ -198,6 +198,7 @@ tensor::TensorPtr CalcByOperator(const NodePtrList &inputs, const std::string &o
|
||||||
{"Abs", [](const std::vector<TM> &n) { return n[0] <= TM(0) ? (TM(0) - n[0]) : n[0]; }},
|
{"Abs", [](const std::vector<TM> &n) { return n[0] <= TM(0) ? (TM(0) - n[0]) : n[0]; }},
|
||||||
{"Sqrt", [](const std::vector<TM> &n) { return sqrt(n[0]); }},
|
{"Sqrt", [](const std::vector<TM> &n) { return sqrt(n[0]); }},
|
||||||
{"Rsqrt", [](const std::vector<TM> &n) { return TM(1) / sqrt(n[0]); }},
|
{"Rsqrt", [](const std::vector<TM> &n) { return TM(1) / sqrt(n[0]); }},
|
||||||
|
{"Reshape", [](const std::vector<TM> &n) { return n[0]; }},
|
||||||
};
|
};
|
||||||
if (func_map.find(op) == func_map.end()) {
|
if (func_map.find(op) == func_map.end()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
@ -50,7 +50,7 @@ void ScatterNdCheckShape(const PrimitivePtr &prim, const std::vector<AbstractBas
|
||||||
|
|
||||||
TypePtr ScatterNdInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
TypePtr ScatterNdInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
auto dtype = input_args[kInputIndex1]->BuildType();
|
auto dtype = input_args[kInputIndex1]->BuildType();
|
||||||
(void)CheckAndConvertUtils::CheckTensorTypeValid("updates", dtype, common_valid_types, prim->name());
|
(void)CheckAndConvertUtils::CheckSubClass("updates", dtype, {kTensorType}, prim->name());
|
||||||
return dtype;
|
return dtype;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -384,7 +384,7 @@ class GraphSplitByPattern:
|
||||||
self.unique_id = unique_id
|
self.unique_id = unique_id
|
||||||
self.reach_tab = reach_tab
|
self.reach_tab = reach_tab
|
||||||
self.checkers = []
|
self.checkers = []
|
||||||
if self.pattern == PrimLib.RESHAPE:
|
if self.pattern == PrimLib.RESHAPE and init_op.inputs: # reshape's input may be empty (const value)
|
||||||
self.checkers.append(ReshapeElimChecker(init_op))
|
self.checkers.append(ReshapeElimChecker(init_op))
|
||||||
elif self.pattern == PrimLib.REDUCE:
|
elif self.pattern == PrimLib.REDUCE:
|
||||||
self.checkers.append(ReduceOutFuseChecker(init_op))
|
self.checkers.append(ReduceOutFuseChecker(init_op))
|
||||||
|
|
Loading…
Reference in New Issue