forked from mindspore-Ecosystem/mindspore
!952 Simplify the `ZeroLikeFillZero` optimization pass
Merge pull request !952 from thlinh/dev_May6th_improve_zero_fill_like_zero
This commit is contained in:
commit
c176bbe4c8
|
@ -52,8 +52,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
|
||||
{prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType,
|
||||
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
|
||||
zero_like_fill_zero_ =
|
||||
MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor, opt::FORCE_RENORM);
|
||||
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor);
|
||||
|
||||
// ops eliminate
|
||||
item_tuple_eliminate_ =
|
||||
|
|
|
@ -30,6 +30,7 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
|
||||
class SpecialOpEliminater {
|
||||
public:
|
||||
SpecialOpEliminater()
|
||||
|
@ -156,12 +157,27 @@ class ZeroLikeFillZero : public AnfVisitor {
|
|||
if (y_ == nullptr || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
if ((y_->abstract() == nullptr) || !y_->abstract()->isa<abstract::AbstractTensor>()) {
|
||||
auto fg = node->func_graph();
|
||||
auto dtype = fg->NewCNode({NewValueNode(PrimDType_), y_});
|
||||
auto shape = fg->NewCNode({NewValueNode(PrimShape_), y_});
|
||||
return fg->NewCNode({NewValueNode(PrimFill_), dtype, shape, NewValueNode(MakeValue(0))});
|
||||
}
|
||||
|
||||
auto fg = node->func_graph();
|
||||
auto dtype = fg->NewCNode({NewValueNode(PrimDType_), y_});
|
||||
auto shape = fg->NewCNode({NewValueNode(PrimShape_), y_});
|
||||
abstract::AbstractTensorPtr tensor_abstract = y_->abstract()->cast<abstract::AbstractTensorPtr>();
|
||||
|
||||
return fg->NewCNode({NewValueNode(PrimFill_), dtype, shape, NewValueNode(MakeValue(0))});
|
||||
TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType();
|
||||
std::vector<int> tensor_shape = tensor_abstract->shape()->shape();
|
||||
|
||||
tensor::TensorPtr new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
|
||||
size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
|
||||
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c(true));
|
||||
std::memset(data, 0, mem_size);
|
||||
|
||||
auto new_cnode = NewValueNode(new_tensor_ptr);
|
||||
new_cnode->set_abstract(new_tensor_ptr->ToAbstract());
|
||||
|
||||
return new_cnode;
|
||||
}
|
||||
|
||||
void Visit(const AnfNodePtr &node) override { y_ = node; }
|
||||
|
|
Loading…
Reference in New Issue