forked from mindspore-Ecosystem/mindspore
!41551 tensor-scatter-max update
Merge pull request !41551 from wuweikang/tensor-scatter-max
This commit is contained in:
commit
2ad89058fc
|
@ -107,21 +107,5 @@ const BaseRef TensorScatterMinFission::DefinePattern() const {
|
|||
VarPtr updates = std::make_shared<Var>();
|
||||
return VectorRef({prim::kPrimTensorScatterMin, input, indices, updates});
|
||||
}
|
||||
|
||||
const AnfNodePtr TensorScatterMinFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
auto scatter_nd_node = TensorScatterFission::Process(graph, node, nullptr);
|
||||
MS_EXCEPTION_IF_NULL(scatter_nd_node);
|
||||
common::AnfAlgo::SetNodeAttr("cust_aicpu", MakeValue<std::string>("ScatterNdMin"), scatter_nd_node);
|
||||
return scatter_nd_node;
|
||||
}
|
||||
|
||||
const AnfNodePtr TensorScatterMaxFission::Process(const FuncGraphPtr &graph, const AnfNodePtr &node,
|
||||
const EquivPtr &) const {
|
||||
auto scatter_nd_node = TensorScatterFission::Process(graph, node, nullptr);
|
||||
MS_EXCEPTION_IF_NULL(scatter_nd_node);
|
||||
common::AnfAlgo::SetNodeAttr("cust_aicpu", MakeValue<std::string>("ScatterNdMax"), scatter_nd_node);
|
||||
return scatter_nd_node;
|
||||
}
|
||||
} // namespace opt
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -74,7 +74,6 @@ class TensorScatterMaxFission : public TensorScatterFission {
|
|||
: TensorScatterFission(multigraph, name) {}
|
||||
~TensorScatterMaxFission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
|
||||
protected:
|
||||
ValueNodePtr GetScatterNdPrimNode() const override;
|
||||
|
@ -86,7 +85,6 @@ class TensorScatterMinFission : public TensorScatterFission {
|
|||
: TensorScatterFission(multigraph, name) {}
|
||||
~TensorScatterMinFission() override = default;
|
||||
const BaseRef DefinePattern() const override;
|
||||
const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, const EquivPtr &) const override;
|
||||
|
||||
protected:
|
||||
ValueNodePtr GetScatterNdPrimNode() const override;
|
||||
|
|
Loading…
Reference in New Issue