!952 Simplify the `ZeroLikeFillZero` optimization pass

Merge pull request !952 from thlinh/dev_May6th_improve_zero_fill_like_zero
This commit is contained in:
mindspore-ci-bot 2020-05-08 14:18:58 +08:00 committed by Gitee
commit c176bbe4c8
2 changed files with 21 additions and 6 deletions

View File

@ -52,8 +52,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
special_op_eliminate_ = MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimPrintShapeType,
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
zero_like_fill_zero_ =
MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor, opt::FORCE_RENORM);
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor);
// ops eliminate
item_tuple_eliminate_ =

View File

@ -30,6 +30,7 @@
namespace mindspore {
namespace opt {
namespace irpass {
class SpecialOpEliminater {
public:
SpecialOpEliminater()
@ -156,12 +157,27 @@ class ZeroLikeFillZero : public AnfVisitor {
if (y_ == nullptr || node->func_graph() == nullptr) {
return nullptr;
}
if ((y_->abstract() == nullptr) || !y_->abstract()->isa<abstract::AbstractTensor>()) {
auto fg = node->func_graph();
auto dtype = fg->NewCNode({NewValueNode(PrimDType_), y_});
auto shape = fg->NewCNode({NewValueNode(PrimShape_), y_});
return fg->NewCNode({NewValueNode(PrimFill_), dtype, shape, NewValueNode(MakeValue(0))});
}
auto fg = node->func_graph();
auto dtype = fg->NewCNode({NewValueNode(PrimDType_), y_});
auto shape = fg->NewCNode({NewValueNode(PrimShape_), y_});
abstract::AbstractTensorPtr tensor_abstract = y_->abstract()->cast<abstract::AbstractTensorPtr>();
return fg->NewCNode({NewValueNode(PrimFill_), dtype, shape, NewValueNode(MakeValue(0))});
TypePtr tensor_type_ptr = tensor_abstract->element()->BuildType();
std::vector<int> tensor_shape = tensor_abstract->shape()->shape();
tensor::TensorPtr new_tensor_ptr = std::make_shared<tensor::Tensor>(tensor_type_ptr->type_id(), tensor_shape);
size_t mem_size = GetTypeByte(tensor_type_ptr) * IntToSize(new_tensor_ptr->ElementsNum());
char *data = reinterpret_cast<char *>(new_tensor_ptr->data_c(true));
std::memset(data, 0, mem_size);
auto new_cnode = NewValueNode(new_tensor_ptr);
new_cnode->set_abstract(new_tensor_ptr->ToAbstract());
return new_cnode;
}
void Visit(const AnfNodePtr &node) override { y_ = node; }