!41551 tensor-scatter-max update

Merge pull request !41551 from wuweikang/tensor-scatter-max
This commit is contained in:
i-robot 2022-09-15 06:28:51 +00:00 committed by Gitee
commit 2ad89058fc
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 0 additions and 18 deletions

View File

@ -107,21 +107,5 @@ const BaseRef TensorScatterMinFission::DefinePattern() const {
VarPtr updates = std::make_shared<Var>();
return VectorRef({prim::kPrimTensorScatterMin, input, indices, updates});
}
const AnfNodePtr TensorScatterMinFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
auto scatter_nd_node = TensorScatterFission::Process(graph, node, nullptr);
MS_EXCEPTION_IF_NULL(scatter_nd_node);
common::AnfAlgo::SetNodeAttr("cust_aicpu", MakeValue<std::string>("ScatterNdMin"), scatter_nd_node);
return scatter_nd_node;
}
const AnfNodePtr TensorScatterMaxFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
const EquivPtr &) const {
auto scatter_nd_node = TensorScatterFission::Process(graph, node, nullptr);
MS_EXCEPTION_IF_NULL(scatter_nd_node);
common::AnfAlgo::SetNodeAttr("cust_aicpu", MakeValue<std::string>("ScatterNdMax"), scatter_nd_node);
return scatter_nd_node;
}
} // namespace opt
} // namespace mindspore

View File

@ -74,7 +74,6 @@ class TensorScatterMaxFission : public TensorScatterFission {
: TensorScatterFission(multigraph, name) {}
~TensorScatterMaxFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override;
protected:
ValueNodePtr GetScatterNdPrimNode() const override;
@ -86,7 +85,6 @@ class TensorScatterMinFission : public TensorScatterFission {
: TensorScatterFission(multigraph, name) {}
~TensorScatterMinFission() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override;
protected:
ValueNodePtr GetScatterNdPrimNode() const override;