forked from mindspore-Ecosystem/mindspore
Remove ZerosLikeTensor and sub with ZerosLike
This commit is contained in:
parent
b9ba99bb13
commit
96379faa3a
|
@ -210,7 +210,7 @@ const PrimitivePtr kPrimGelu = std::make_shared<Primitive>("Gelu");
|
|||
const PrimitivePtr kPrimGeluGrad = std::make_shared<Primitive>("GeluGrad");
|
||||
const PrimitivePtr kPrimRelu = std::make_shared<Primitive>("ReLU");
|
||||
const PrimitivePtr kPrimReluV2 = std::make_shared<Primitive>("ReLUV2");
|
||||
const PrimitivePtr kPrimZerosLikeTensor = std::make_shared<Primitive>("zeros_like_tensor");
|
||||
const PrimitivePtr kPrimZerosLike = std::make_shared<Primitive>("ZerosLike");
|
||||
const PrimitivePtr kPrimFakeBprop = std::make_shared<Primitive>("fake_bprop");
|
||||
const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut");
|
||||
|
||||
|
|
|
@ -217,7 +217,7 @@ extern const PrimitivePtr kPrimGeluGrad;
|
|||
extern const PrimitivePtr kPrimRelu;
|
||||
extern const PrimitivePtr kPrimReluV2;
|
||||
extern const PrimitivePtr kPrimActivation;
|
||||
extern const PrimitivePtr kPrimZerosLikeTensor;
|
||||
extern const PrimitivePtr kPrimZerosLike;
|
||||
extern const PrimitivePtr kPrimFakeBprop;
|
||||
extern const PrimitivePtr kPrimBpropCut;
|
||||
|
||||
|
|
|
@ -271,8 +271,8 @@ AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &pri
|
|||
return args_spec_list[0]->Broaden();
|
||||
}
|
||||
|
||||
AbstractBasePtr InferImplZerosLikeTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list) {
|
||||
// Inputs: a tensor.
|
||||
CheckArgsSize(primitive->name(), args_spec_list, 1);
|
||||
return args_spec_list[0]->Broaden();
|
||||
|
|
|
@ -53,7 +53,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
|||
MakeSubstitution(SpecialOpEliminater(), "special_op_eliminate",
|
||||
{prim::kPrimInsertGradientOf, prim::kPrimHookBackward, prim::kPrimPrintShapeType,
|
||||
prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv});
|
||||
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor);
|
||||
zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLike);
|
||||
|
||||
// ops eliminate
|
||||
item_tuple_eliminate_ =
|
||||
|
|
|
@ -120,8 +120,8 @@ class AddByZero : public AnfVisitor {
|
|||
AnfNodePtr x_{nullptr};
|
||||
};
|
||||
|
||||
// {prim::kPrimTensorAdd, {PrimZerosLikeTensor, Y}, X},
|
||||
// {prim::kPrimTensorAdd, X, {PrimZerosLikeTensor, Y}}
|
||||
// {prim::kPrimTensorAdd, {kPrimZerosLike, Y}, X},
|
||||
// {prim::kPrimTensorAdd, X, {kPrimZerosLike, Y}}
|
||||
class TensorAddByZero : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
|
@ -135,7 +135,7 @@ class TensorAddByZero : public AnfVisitor {
|
|||
}
|
||||
|
||||
void Visit(const AnfNodePtr &node) override {
|
||||
if (IsPrimitive(node, prim::kPrimZerosLikeTensor)) {
|
||||
if (IsPrimitive(node, prim::kPrimZerosLike)) {
|
||||
is_zero_ = true;
|
||||
return;
|
||||
}
|
||||
|
@ -153,7 +153,7 @@ class TensorAddByZero : public AnfVisitor {
|
|||
AnfNodePtr x_{nullptr};
|
||||
};
|
||||
|
||||
// {PrimMomentum, {PrimZerosLikeTensor, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y}
|
||||
// {PrimMomentum, {kPrimZerosLike, X}, Y, Z, Xs} -> {prim::kPrimMakeTuple, Z, Y}
|
||||
class OptUpdateZeroTensor : public AnfVisitor {
|
||||
public:
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
|
@ -163,13 +163,13 @@ class OptUpdateZeroTensor : public AnfVisitor {
|
|||
|
||||
// {PrimMomentum, {...}, Y, Z, Xs}
|
||||
auto &inputs = node->cast<CNodePtr>()->inputs();
|
||||
if (inputs.size() < 4 || !IsPrimitiveCNode(inputs[1], prim::kPrimZerosLikeTensor)) {
|
||||
if (inputs.size() < 4 || !IsPrimitiveCNode(inputs[1], prim::kPrimZerosLike)) {
|
||||
return nullptr;
|
||||
}
|
||||
auto y = inputs[2];
|
||||
auto z = inputs[3];
|
||||
|
||||
// {PrimZerosLikeTensor, X}
|
||||
// {kPrimZerosLike, X}
|
||||
if (inputs[1]->cast<CNodePtr>()->size() != 2) {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -177,7 +177,7 @@ class AddNZeroFilter : public AnfVisitor {
|
|||
// {kPrimMakeTuple, X1, X2, ...}
|
||||
filtered_Xs_.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
for (auto &x : Xs_) {
|
||||
if (!IsPrimitiveCNode(x, prim::kPrimZerosLikeTensor)) {
|
||||
if (!IsPrimitiveCNode(x, prim::kPrimZerosLike)) {
|
||||
filtered_Xs_.push_back(x);
|
||||
} else {
|
||||
has_zero_like_ = true;
|
||||
|
|
|
@ -143,7 +143,7 @@ class ResetDeferInline : public AnfVisitor {
|
|||
}
|
||||
};
|
||||
|
||||
// {PrimZerosLikeTensor, Y} ->
|
||||
// {PrimZerosLike, Y} ->
|
||||
// {PrimFill, {PrimDType, Y}, {PrimShape, Y}, 0}
|
||||
class ZeroLikeFillZero : public AnfVisitor {
|
||||
public:
|
||||
|
@ -155,7 +155,7 @@ class ZeroLikeFillZero : public AnfVisitor {
|
|||
|
||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||
y_ = nullptr;
|
||||
AnfVisitor::Match(prim::kPrimZerosLikeTensor, {IsNode})(node);
|
||||
AnfVisitor::Match(prim::kPrimZerosLike, {IsNode})(node);
|
||||
if (y_ == nullptr || node->func_graph() == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -75,7 +75,7 @@ const std::set<std::string> BLACK_LIST = {TUPLE_GETITEM,
|
|||
DROPOUT_GEN_MASK,
|
||||
EMBED,
|
||||
CREATINSTANCE,
|
||||
ZEROSLIKETENSOR,
|
||||
ZEROSLIKE,
|
||||
ASSIGN,
|
||||
REF_TO_EMBED,
|
||||
STOP_GRADIENT};
|
||||
|
|
|
@ -263,7 +263,7 @@ constexpr char COL2IMV1[] = "col2im_v1";
|
|||
constexpr char RESOLVE[] = "resolve";
|
||||
constexpr char EMBED[] = "embed";
|
||||
constexpr char CREATINSTANCE[] = "create_instance";
|
||||
constexpr char ZEROSLIKETENSOR[] = "zeros_like_tensor";
|
||||
constexpr char ZEROSLIKE[] = "ZerosLike";
|
||||
constexpr char REF_TO_EMBED[] = "RefToEmbed";
|
||||
constexpr char STOP_GRADIENT[] = "stop_gradient";
|
||||
|
||||
|
|
|
@ -106,8 +106,8 @@ PrimitiveEvalImplMap &GetPrimitiveToEvalImplMap() {
|
|||
{prim::kPrimConv2DBackpropFilter, {InferImplConv2DBackpropFilter, true}},
|
||||
{prim::kPrimBiasAddGrad, {InferImplBiasAddGrad, true}},
|
||||
{prim::kPrimRelu, {InferImplRelu, true}},
|
||||
{prim::kPrimZerosLikeTensor, {InferImplZerosLikeTensor, true}},
|
||||
{prim::kPrimFakeBprop, {InferImplFakeBprop, false}},
|
||||
{prim::kPrimZerosLike, {InferImplZerosLike, true}},
|
||||
{prim::kPrimBpropCut, {InferImplBpropCut, true}},
|
||||
{prim::kPrimLayerNorm, {InferImplLayerNorm, true}},
|
||||
{prim::kPrimLayerNormGrad, {InferImplLayerNormGrad, true}},
|
||||
|
|
|
@ -206,10 +206,10 @@ AbstractBasePtr InferImplGeluGrad(const AnalysisEnginePtr &, const PrimitivePtr
|
|||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplRelu(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplZerosLikeTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplZerosLike(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplBpropCut(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const AbstractBasePtrList &args_spec_list);
|
||||
AbstractBasePtr InferImplLayerNorm(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
|
|
|
@ -53,7 +53,7 @@
|
|||
|
||||
const char SINGLE_OP_GRAPH[] = "single_op_graph";
|
||||
// primitive unable to infer value for constant input in PyNative mode
|
||||
const std::set<std::string> vm_operators = {"partial", "depend", "make_ref", "zeros_like_tensor", "HookBackward"};
|
||||
const std::set<std::string> vm_operators = {"partial", "depend", "make_ref", "HookBackward"};
|
||||
|
||||
namespace mindspore {
|
||||
namespace pynative {
|
||||
|
|
|
@ -84,7 +84,7 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
|
|||
g_norm = op_norm(gradient_fp32)
|
||||
|
||||
g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay_tensor * param_fp32)
|
||||
zeros = F.zeros_like_tensor(w_norm)
|
||||
zeros = F.zeros_like(w_norm)
|
||||
ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
|
||||
trust_ratio = op_select(
|
||||
op_greater(w_norm, zeros),
|
||||
|
|
|
@ -296,7 +296,7 @@ env_get = MultitypeFuncGraph("env_get")
|
|||
@env_get.register("EnvType", "Tensor")
|
||||
def _tensor_env_get(env, parameter):
|
||||
"""Used to get env."""
|
||||
return F.env_getitem(env, F.ref_to_embed(parameter), F.zeros_like_tensor(parameter))
|
||||
return F.env_getitem(env, F.ref_to_embed(parameter), F.zeros_like(parameter))
|
||||
|
||||
|
||||
_mp_cast_helper = MultitypeFuncGraph('mixed_precision_cast_helper')
|
||||
|
|
|
@ -57,7 +57,7 @@ def _zeros_like_func(x):
|
|||
@zeros_like_leaf.register("Tensor")
|
||||
def _zeros_like_tensor(x):
|
||||
"""Returns a tensor with the same shape and dtype as x and all elements ars 1."""
|
||||
return F.zeros_like_tensor(x)
|
||||
return F.zeros_like(x)
|
||||
|
||||
|
||||
@zeros_like_leaf.register("TypeType")
|
||||
|
|
|
@ -130,7 +130,7 @@ broadcast_gradient_args = Primitive('BroadcastGradientArgs')
|
|||
dot = Primitive('dot')
|
||||
array_reduce = Primitive('array_reduce')
|
||||
partial = Primitive('partial')
|
||||
zeros_like_tensor = Primitive('zeros_like_tensor')
|
||||
zeros_like = P.ZerosLike()
|
||||
identity = Primitive('identity')
|
||||
distribute = Primitive('distribute')
|
||||
# depend: mount a node to another node
|
||||
|
|
|
@ -878,7 +878,7 @@ def test_addn_zero(tag):
|
|||
fns = FnDict()
|
||||
addn = P.AddN()
|
||||
AddN = P.AddN
|
||||
zero_tensor = Primitive('zeros_like_tensor')
|
||||
zero_tensor = Primitive('ZerosLike')
|
||||
|
||||
@fns
|
||||
def before_1(x, y, z, a):
|
||||
|
|
|
@ -278,3 +278,9 @@ def vm_impl_square(self):
|
|||
return Tensor(x * x)
|
||||
|
||||
return vm_impl
|
||||
|
||||
@vm_impl_getters.register(P.ZerosLike)
|
||||
def vm_impl_zeros_like(self):
|
||||
"""Generate vm_impl function for ZerosLike"""
|
||||
def vm_impl(x):
|
||||
return Tensor(np.zeros_like(x.asnumpy()))
|
||||
|
|
Loading…
Reference in New Issue