Remove ZerosLikeTensor and sub with ZerosLike

This commit is contained in:
BowenK 2020-06-01 19:52:38 +08:00
parent b9ba99bb13
commit 96379faa3a
18 changed files with 31 additions and 25 deletions

View File

@ -210,7 +210,7 @@ const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu");
const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad");
const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU");
const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2");
const PrimitivePtr kPrimZerosLikeTensor = std::make_shared<Primitive>("zeros_like_tensor");
const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike");
const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop");
const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut");

View File

@ -217,7 +217,7 @@ extern const PrimitivePtr kPrimGeluGrad;
extern const PrimitivePtr kPrimRelu;
extern const PrimitivePtr kPrimReluV2;
extern const PrimitivePtr kPrimActivation;
extern const PrimitivePtr kPrimZerosLikeTensor;
extern const PrimitivePtr kPrimZerosLike;
extern const PrimitivePtr kPrimFakeBprop;
extern const PrimitivePtr kPrimBpropCut;

View File

@ -271,8 +271,8 @@ AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &pri
return args_spec_list[0]->Broaden();
}
AbstractBasePtr InferImplZerosLikeTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list) {
// Inputs: a tensor.
CheckArgsSize(primitive->name(), args_spec_list, 1);
return args_spec_list[0]->Broaden();

View File

@ -53,7 +53,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
{prim::kPrimInsertGradientOf, prim::kPrimHookBackward, prim::kPrimPrintShapeType,
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor);
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLike);
// ops eliminate
item_tuple_eliminate_ =

View File

@ -120,8 +120,8 @@ class AddByZero : public AnfVisitor {
AnfNodePtr x_{nullptr};
};
// {prim::kPrimTensorAdd, {PrimZerosLikeTensor, Y}, X},
// {prim::kPrimTensorAdd, X, {PrimZerosLikeTensor, Y}}
// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X},
// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}}
class TensorAddByZero : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
@ -135,7 +135,7 @@ class TensorAddByZero : public AnfVisitor {
}
void Visit(const AnfNodePtr &node) override {
if (IsPrimitive(node, prim::kPrimZerosLikeTensor)) {
if (IsPrimitive(node, prim::kPrimZerosLike)) {
is_zero_ = true;
return;
}
@ -153,7 +153,7 @@ class TensorAddByZero : public AnfVisitor {
AnfNodePtr x_{nullptr};
};
// {PrimMomentum, {PrimZerosLikeTensor, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y}
// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y}
class OptUpdateZeroTensor : public AnfVisitor {
public:
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
@ -163,13 +163,13 @@ class OptUpdateZeroTensor : public AnfVisitor {
// {PrimMomentum, {...}, Y, Z, Xs}
auto &inputs = node->cast<CNodePtr>()->inputs();
if (inputs.size() < 4 || !IsPrimitiveCNode(inputs[1], prim::kPrimZerosLikeTensor)) {
if (inputs.size() < 4 || !IsPrimitiveCNode(inputs[1], prim::kPrimZerosLike)) {
return nullptr;
}
auto y = inputs[2];
auto z = inputs[3];
// {PrimZerosLikeTensor, X}
// {kPrimZerosLike, X}
if (inputs[1]->cast<CNodePtr>()->size() != 2) {
return nullptr;
}

View File

@ -177,7 +177,7 @@ class AddNZeroFilter : public AnfVisitor {
// {kPrimMakeTuple, X1, X2, ...}
filtered_Xs_.push_back(NewValueNode(prim::kPrimMakeTuple));
for (auto &x : Xs_) {
if (!IsPrimitiveCNode(x, prim::kPrimZerosLikeTensor)) {
if (!IsPrimitiveCNode(x, prim::kPrimZerosLike)) {
filtered_Xs_.push_back(x);
} else {
has_zero_like_ = true;

View File

@ -143,7 +143,7 @@ class ResetDeferInline : public AnfVisitor {
}
};
// {PrimZerosLikeTensor, Y} ->
// {PrimZerosLike, Y} ->
// {PrimFill, {PrimDType, Y}, {PrimShape, Y}, 0}
class ZeroLikeFillZero : public AnfVisitor {
public:
@ -155,7 +155,7 @@ class ZeroLikeFillZero : public AnfVisitor {
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
y_ = nullptr;
AnfVisitor::Match(prim::kPrimZerosLikeTensor, {IsNode})(node);
AnfVisitor::Match(prim::kPrimZerosLike, {IsNode})(node);
if (y_ == nullptr || node->func_graph() == nullptr) {
return nullptr;
}

View File

@ -75,7 +75,7 @@ const std::set<std::string> BLACK_LIST = {TUPLE_GETITEM,
DROPOUT_GEN_MASK,
EMBED,
CREATINSTANCE,
ZEROSLIKETENSOR,
ZEROSLIKE,
ASSIGN,
REF_TO_EMBED,
STOP_GRADIENT};

View File

@ -263,7 +263,7 @@ constexpr char COL2IMV1[] = "col2im_v1";
constexpr char RESOLVE[] = "resolve";
constexpr char EMBED[] = "embed";
constexpr char CREATINSTANCE[] = "create_instance";
constexpr char ZEROSLIKETENSOR[] = "zeros_like_tensor";
constexpr char ZEROSLIKE[] = "ZerosLike";
constexpr char REF_TO_EMBED[] = "RefToEmbed";
constexpr char STOP_GRADIENT[] = "stop_gradient";

View File

@ -106,8 +106,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
{prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}},
{prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}},
{prim::kPrimRelu, {InferImplRelu, true}},
{prim::kPrimZerosLikeTensor, {InferImplZerosLikeTensor, true}},
{prim::kPrimFakeBprop, {InferImplFakeBprop, false}},
{prim::kPrimZerosLike, {InferImplZerosLike, true}},
{prim::kPrimBpropCut, {InferImplBpropCut, true}},
{prim::kPrimLayerNorm, {InferImplLayerNorm, true}},
{prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}},

View File

@ -206,10 +206,10 @@ AbstractBasePtr InferImplGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplZerosLikeTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
const AbstractBasePtrList &args_spec_list);
AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,

View File

@ -53,7 +53,7 @@
const char SINGLE_OP_GRAPH[] = "single_op_graph";
// primitive unable to infer value for constant input in PyNative mode
const std::set<std::string> vm_operators = {"partial", "depend", "make_ref", "zeros_like_tensor", "HookBackward"};
const std::set<std::string> vm_operators = {"partial", "depend", "make_ref", "HookBackward"};
namespace mindspore {
namespace pynative {

View File

@ -84,7 +84,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
g_norm = op_norm(gradient_fp32)
g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay_tensor * param_fp32)
zeros = F.zeros_like_tensor(w_norm)
zeros = F.zeros_like(w_norm)
ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
trust_ratio = op_select(
op_greater(w_norm, zeros),

View File

@ -296,7 +296,7 @@ env_get = MultitypeFuncGraph("env_get")
@env_get.register("EnvType", "Tensor")
def _tensor_env_get(env, parameter):
"""Used to get env."""
return F.env_getitem(env, F.ref_to_embed(parameter), F.zeros_like_tensor(parameter))
return F.env_getitem(env, F.ref_to_embed(parameter), F.zeros_like(parameter))
_mp_cast_helper = MultitypeFuncGraph('mixed_precision_cast_helper')

View File

@ -57,7 +57,7 @@ def _zeros_like_func(x):
@zeros_like_leaf.register("Tensor")
def _zeros_like_tensor(x):
"""Returns a tensor with the same shape and dtype as x and all elements ars 1."""
return F.zeros_like_tensor(x)
return F.zeros_like(x)
@zeros_like_leaf.register("TypeType")

View File

@ -130,7 +130,7 @@ broadcast_gradient_args = Primitive('BroadcastGradientArgs')
dot = Primitive('dot')
array_reduce = Primitive('array_reduce')
partial = Primitive('partial')
zeros_like_tensor = Primitive('zeros_like_tensor')
zeros_like = P.ZerosLike()
identity = Primitive('identity')
distribute = Primitive('distribute')
# depend: mount a node to another node

View File

@ -878,7 +878,7 @@ def test_addn_zero(tag):
fns = FnDict()
addn = P.AddN()
AddN = P.AddN
zero_tensor = Primitive('zeros_like_tensor')
zero_tensor = Primitive('ZerosLike')
@fns
def before_1(x, y, z, a):

View File

@ -278,3 +278,9 @@ def vm_impl_square(self):
return Tensor(x * x)
return vm_impl
@vm_impl_getters.register(P.ZerosLike)
def vm_impl_zeros_like(self):
"""Generate vm_impl function for ZerosLike"""
def vm_impl(x):
return Tensor(np.zeros_like(x.asnumpy()))