forked from mindspore-Ecosystem/mindspore
move hook function to primtivePy class
This commit is contained in:
parent
444d9484d7
commit
38436f929f
|
@ -13,7 +13,6 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""builtin_operations"""
|
||||
import functools
|
||||
import numpy as np
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype
|
||||
|
@ -124,17 +123,8 @@ def list_len(x):
|
|||
"""Implement `list_len`."""
|
||||
return len(x)
|
||||
|
||||
|
||||
# only used in PyNative mode
|
||||
def partial(*args):
|
||||
"""Implement `partial`."""
|
||||
func = args[0].__call__
|
||||
partial_func = functools.partial(func, *args[1:])
|
||||
return partial_func
|
||||
|
||||
|
||||
# only used in PyNative mode
|
||||
def depend(value, expr):
|
||||
def Depend(value, expr):
|
||||
"""Implement `Depend`."""
|
||||
return value
|
||||
|
||||
# only used in PyNative mode
|
||||
|
|
|
@ -49,6 +49,8 @@ class PrimitivePy : public Primitive {
|
|||
void AddPyAttr(const py::str &name, const py::object &obj);
|
||||
|
||||
py::dict GetAttrDict();
|
||||
void set_hook(const py::function &hook) { hook_ = hook; }
|
||||
py::function hook() const { return hook_; }
|
||||
|
||||
const bool parse_info_ = true;
|
||||
const py::object &GetPyObj() const { return python_obj_; }
|
||||
|
@ -56,6 +58,7 @@ class PrimitivePy : public Primitive {
|
|||
|
||||
private:
|
||||
py::object python_obj_;
|
||||
py::function hook_;
|
||||
std::vector<Signature> signatures_;
|
||||
};
|
||||
|
||||
|
|
|
@ -89,9 +89,6 @@ class Primitive : public Named {
|
|||
return iter == attrs_.cend() ? nullptr : iter->second;
|
||||
}
|
||||
|
||||
void set_hook(const py::function &hook) { hook_ = hook; }
|
||||
py::function hook() const { return hook_; }
|
||||
|
||||
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
|
||||
std::unordered_map<std::string, ValuePtr> &evaluate_added_attrs() { return evaluate_added_attrs_; }
|
||||
|
||||
|
@ -124,7 +121,6 @@ class Primitive : public Named {
|
|||
|
||||
private:
|
||||
std::string instance_name_;
|
||||
py::function hook_;
|
||||
bool is_base_;
|
||||
bool has_signature_;
|
||||
PrimType prim_type_;
|
||||
|
|
|
@ -220,7 +220,7 @@ const PrimitivePtr kPrimBpropCut = std::make_shared<Primitive>("bprop_cut");
|
|||
|
||||
// Other miscellaneous
|
||||
const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("identity");
|
||||
const PrimitivePtr kPrimPartial = std::make_shared<Primitive>("partial");
|
||||
const PrimitivePtr kPrimPartial = std::make_shared<Primitive>("Partial");
|
||||
const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J");
|
||||
const PrimitivePtr kPrimEnvSetItem = std::make_shared<Primitive>("env_setitem");
|
||||
const PrimitivePtr kPrimEnvGetItem = std::make_shared<Primitive>("env_getitem");
|
||||
|
@ -237,7 +237,7 @@ const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop");
|
|||
const PrimitivePtr kPrimPrint = std::make_shared<Primitive>("Print");
|
||||
|
||||
const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref");
|
||||
const PrimitivePtr kPrimDepend = std::make_shared<Primitive>("depend");
|
||||
const PrimitivePtr kPrimDepend = std::make_shared<Primitive>("Depend");
|
||||
const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem");
|
||||
|
||||
const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("BroadcastGradientArgs");
|
||||
|
|
|
@ -238,8 +238,12 @@ FuncGraphPtr KPrim::BpropCut(const ValueNodePtr &value_node, const pipeline::Res
|
|||
auto func_graph = std::make_shared<FuncGraph>();
|
||||
std::vector<AnfNodePtr> outputs;
|
||||
|
||||
auto bprop_cut = std::make_shared<Primitive>("bprop_cut");
|
||||
bprop_cut->set_hook(prim->hook());
|
||||
auto bprop_cut = std::make_shared<PrimitivePy>("bprop_cut", py::object());
|
||||
if (!prim->is_base()) {
|
||||
PrimitivePyPtr prim_py = dyn_cast<PrimitivePy>(prim);
|
||||
bprop_cut->set_hook(prim_py->hook());
|
||||
}
|
||||
|
||||
auto cell_id = GetValue<std::string>(prim->GetAttr("cell_id"));
|
||||
if (cell_id != "") {
|
||||
(void)bprop_cut->AddAttr("cell_hook", MakeValue(true));
|
||||
|
|
|
@ -72,7 +72,7 @@ constexpr char OP[] = "op";
|
|||
constexpr char IDENTITY_INFO[] = "identity_info";
|
||||
constexpr char DIVISOR[] = "divisor";
|
||||
constexpr char NONE[] = "None";
|
||||
constexpr char DEPEND[] = "depend";
|
||||
constexpr char DEPEND[] = "Depend";
|
||||
constexpr char BATCH_PARALLEL[] = "BatchParallel";
|
||||
|
||||
constexpr char ACTIVATION_TYPE[] = "activation_type";
|
||||
|
|
|
@ -217,7 +217,7 @@ FuncGraphPtr ConvertToBpropCut(py::object obj) {
|
|||
FuncGraphPtr bprop_graph = std::make_shared<FuncGraph>();
|
||||
std::vector<AnfNodePtr> outputs;
|
||||
|
||||
auto fake_bprop = std::make_shared<Primitive>("bprop_cut");
|
||||
auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut", py::object());
|
||||
fake_bprop->set_hook(bprop_func);
|
||||
(void)fake_bprop->AddAttr("bprop", MakeValue(true));
|
||||
outputs.push_back(NewValueNode(fake_bprop));
|
||||
|
|
|
@ -59,7 +59,7 @@ struct OpExecInfo {
|
|||
using OpExecInfoPtr = std::shared_ptr<OpExecInfo>;
|
||||
OpExecInfoPtr GenerateOpExecInfo(const py::args &args);
|
||||
|
||||
const std::set<std::string> ignore_infer_prim = {"partial", "make_ref"};
|
||||
const std::set<std::string> ignore_infer_prim = {"make_ref"};
|
||||
} // namespace pynative
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@
|
|||
|
||||
const char SINGLE_OP_GRAPH[] = "single_op_graph";
|
||||
// primitive unable to infer value for constant input in PyNative mode
|
||||
const std::set<std::string> vm_operators = {"partial", "depend", "make_ref", "HookBackward"};
|
||||
const std::set<std::string> vm_operators = {"make_ref", "HookBackward"};
|
||||
|
||||
namespace mindspore {
|
||||
namespace pynative {
|
||||
|
|
|
@ -959,8 +959,8 @@ void DfGraphConvertor::TraceOutput(const AnfNodePtr node) {
|
|||
for (unsigned int i = 1; i < c->inputs().size(); i++) {
|
||||
TraceOutput(c->input(i));
|
||||
}
|
||||
} else if (name == "depend") {
|
||||
if (c->inputs().size() < 3) { // "depend" primitive have 3 inputs
|
||||
} else if (name == "Depend") {
|
||||
if (c->inputs().size() < 3) { // "Depend" primitive have 3 inputs
|
||||
MS_LOG(EXCEPTION) << "length of inputs is " << c->inputs().size() << ", which is less than 3";
|
||||
}
|
||||
TraceOutput(c->input(1));
|
||||
|
@ -1183,7 +1183,7 @@ void DfGraphConvertor::SetOpInput(const OpAdapterPtr &adpt, const CNodePtr &node
|
|||
auto &inputs = node->inputs();
|
||||
for (size_t i = 1; i < inputs.size(); i++) {
|
||||
auto pred = inputs[i];
|
||||
while (pred->isa<CNode>() && GetCNodeFuncName(pred->cast<CNodePtr>()) == "depend") {
|
||||
while (pred->isa<CNode>() && GetCNodeFuncName(pred->cast<CNodePtr>()) == "Depend") {
|
||||
pred = pred->cast<CNodePtr>()->input(1);
|
||||
}
|
||||
// skip the None input
|
||||
|
@ -1362,7 +1362,7 @@ AnfNodePtr DfGraphConvertor::TraceTupleGetItem(const CNodePtr &node, unsigned in
|
|||
|
||||
AnfNodePtr DfGraphConvertor::TraceDepend(const CNodePtr &node) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
if (cnode->inputs().size() < 3) { // "depend" primitive have 3 inputs
|
||||
if (cnode->inputs().size() < 3) { // "Depend" primitive have 3 inputs
|
||||
MS_LOG(EXCEPTION) << "length of inputs of depend is less than 3";
|
||||
}
|
||||
return cnode->inputs()[1];
|
||||
|
@ -1483,7 +1483,7 @@ AnfNodePtr DfGraphConvertor::GetRealOpNode(AnfNodePtr node) {
|
|||
// depend apply inputs: depend,output,depended_node
|
||||
if (IsPrimitiveCNode(node, prim::kPrimDepend)) {
|
||||
auto depend_inputs = node->cast<CNodePtr>()->inputs();
|
||||
if (depend_inputs.size() != 3) { // "depend" primitive have 3 inputs
|
||||
if (depend_inputs.size() != 3) { // "Depend" primitive have 3 inputs
|
||||
MS_LOG(ERROR) << "depend input items not correct";
|
||||
error_ = FAILED;
|
||||
return node;
|
||||
|
@ -1700,7 +1700,7 @@ void DfGraphConvertor::ConvertControlDependNode(const CNodePtr node) {
|
|||
|
||||
bool DfGraphConvertor::CheckCNode(const std::string &name, const CNodePtr node) {
|
||||
// ignore apply node of return
|
||||
if (name == "return" || name == "depend") {
|
||||
if (name == "return" || name == "Depend") {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
|
@ -585,8 +585,8 @@ void FinalVM::InstPushPrim(const VectorRef &args) {
|
|||
return;
|
||||
}
|
||||
|
||||
VectorRef tuple;
|
||||
auto prim = utils::cast<PrimitivePtr>(args[0]);
|
||||
VectorRef tuple;
|
||||
for (size_t i = 1; i < args.size(); ++i) {
|
||||
auto index = utils::cast<int>(args[i]);
|
||||
tuple.push_back(Ref(index));
|
||||
|
@ -618,6 +618,7 @@ void FinalVM::SyncData(const py::object &arg) {
|
|||
|
||||
BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) {
|
||||
MS_LOG(DEBUG) << "input for operation:";
|
||||
auto prim_py = dyn_cast<PrimitivePy>(prim);
|
||||
std::size_t args_size = args.size();
|
||||
auto py_args = py::tuple(args_size);
|
||||
size_t i = 0;
|
||||
|
@ -631,7 +632,7 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) {
|
|||
bool is_bprop = prim->HasAttr("bprop");
|
||||
if (is_bprop) {
|
||||
SyncData(py_args);
|
||||
py::function fn_bprop = prim->hook();
|
||||
py::function fn_bprop = prim_py->hook();
|
||||
obj = fn_bprop(*py_args);
|
||||
return obj;
|
||||
}
|
||||
|
@ -647,7 +648,7 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) {
|
|||
hook_args[0] = cell_id;
|
||||
hook_args[1] = py::make_tuple(_hook_grad[cell_id]);
|
||||
hook_args[2] = py::make_tuple(py_args[2]);
|
||||
py::function fn_hook = prim->hook();
|
||||
py::function fn_hook = prim_py->hook();
|
||||
obj = fn_hook(*hook_args);
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
obj = py_args[2];
|
||||
|
@ -659,7 +660,7 @@ BaseRef FinalVM::RunHook(const PrimitivePtr &prim, const VectorRef &args) {
|
|||
}
|
||||
} else {
|
||||
// Hook operator for execute variable hook function
|
||||
py::function fn_hook = prim->hook();
|
||||
py::function fn_hook = prim_py->hook();
|
||||
obj = fn_hook(py::make_tuple(py_args[2]));
|
||||
if (py::isinstance<py::none>(obj)) {
|
||||
obj = py_args[2];
|
||||
|
|
|
@ -78,6 +78,8 @@ class Tensor(Tensor_):
|
|||
def __eq__(self, other):
|
||||
if not isinstance(other, Tensor):
|
||||
return False
|
||||
# The GE backend don't support single `Equal` operator execution.
|
||||
# bool type is not supported for `Equal` operator in backend.
|
||||
if context.get_context("enable_ge") or self.dtype() == mstype.bool_ or other.dtype() == mstype.bool_:
|
||||
return Tensor(np.array(self.asnumpy() == other.asnumpy()))
|
||||
return tensor_operator_registry.get('__eq__')(self, other)
|
||||
|
|
|
@ -195,7 +195,7 @@ def bprop_array_reduce(fn, x, shp, out, dout):
|
|||
return F.distribute(dout, F.shape(x)), C.zeros_like(shp)
|
||||
|
||||
|
||||
@bprops.register("depend")
|
||||
@bprops.register("Depend")
|
||||
def bprop_depend(x, y, out, dout):
|
||||
"""Backpropagator for primitive `depend`."""
|
||||
return dout, C.zeros_like(y)
|
||||
|
@ -236,7 +236,6 @@ def bprop_control_depend(x, y, out, dout):
|
|||
"""Backpropagator for primitive `Control_depend`."""
|
||||
return C.zeros_like(x), C.zeros_like(y)
|
||||
|
||||
|
||||
@bprops.register("switch")
|
||||
def bprop_switch(cond, tb, fb, out, dout):
|
||||
"""Backpropagator for primitive `switch`."""
|
||||
|
|
|
@ -22,7 +22,7 @@ from mindspore import context
|
|||
from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, MultitypeFuncGraph_, Tail_, TensorSlice_, \
|
||||
TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_
|
||||
from ...common import dtype as mstype
|
||||
from ...common.api import ms_function, _pynative_exec
|
||||
from ...common.api import ms_function, _pynative_exec, _wrap_func
|
||||
from .. import functional as F
|
||||
from ...common.parameter import Parameter
|
||||
|
||||
|
@ -117,6 +117,7 @@ class GradOperation(GradOperation_):
|
|||
def after_grad(*args):
|
||||
return grad_(fn, weights)(*args)
|
||||
else:
|
||||
@_wrap_func
|
||||
def after_grad(*args):
|
||||
if fn.is_run and not fn.requires_grad:
|
||||
raise ValueError("obj must set_grad.")
|
||||
|
|
|
@ -77,6 +77,9 @@ gather_nd = P.GatherNd()
|
|||
scatter_update = P.ScatterUpdate()
|
||||
scatter_nd_update = P.ScatterNdUpdate()
|
||||
pack = P.Pack()
|
||||
partial = P.Partial()
|
||||
# depend: mount a node to another node
|
||||
depend = P.Depend()
|
||||
|
||||
|
||||
tuple_setitem = Primitive('tuple_setitem')
|
||||
|
@ -131,12 +134,9 @@ mixed_precision_cast = Primitive("mixed_precision_cast")
|
|||
broadcast_gradient_args = Primitive('BroadcastGradientArgs')
|
||||
dot = Primitive('dot')
|
||||
array_reduce = Primitive('array_reduce')
|
||||
partial = Primitive('partial')
|
||||
zeros_like = P.ZerosLike()
|
||||
identity = Primitive('identity')
|
||||
distribute = Primitive('distribute')
|
||||
# depend: mount a node to another node
|
||||
depend = Primitive('depend')
|
||||
embed = Primitive('embed')
|
||||
ref_to_embed = _grad_ops.RefToEmbed()
|
||||
env_setitem = Primitive('env_setitem')
|
||||
|
|
|
@ -74,7 +74,7 @@ from .nn_ops import (LSTM, SGD, Adam, ApplyMomentum, BatchNorm,
|
|||
ApplyProximalAdagrad, SparseApplyProximalAdagrad,
|
||||
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell)
|
||||
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode,
|
||||
CheckValid, MakeRefKey, CheckBprop, ConfusionMatrix)
|
||||
CheckValid, MakeRefKey, Partial, Depend, CheckBprop, ConfusionMatrix)
|
||||
from . import _quant_ops
|
||||
from ._quant_ops import *
|
||||
from .thor_ops import *
|
||||
|
@ -213,6 +213,8 @@ __all__ = [
|
|||
'NMSWithMask',
|
||||
'IOU',
|
||||
'MakeRefKey',
|
||||
'Partial',
|
||||
'Depend',
|
||||
'AvgPool',
|
||||
# Back Primitive
|
||||
'Equal',
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
|
||||
"""Other operators."""
|
||||
import functools
|
||||
from ..._c_expression import signature_rw as sig_rw
|
||||
from ..._c_expression import signature_kind as sig_kind
|
||||
from ..._c_expression import signature_dtype as sig_dtype
|
||||
|
@ -304,6 +305,46 @@ class MakeRefKey(Primitive):
|
|||
pass
|
||||
|
||||
|
||||
class Partial(Primitive):
|
||||
"""
|
||||
Make a partial function instance, used for pynative mode.
|
||||
|
||||
Inputs:
|
||||
- **args** (Union[FunctionType, Tensor]) - The function and bind arguments.
|
||||
|
||||
Outputs:
|
||||
FunctionType, partial function binded with arguments.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, *args):
|
||||
func = args[0].__call__
|
||||
partial_func = functools.partial(func, *args[1:])
|
||||
return partial_func
|
||||
|
||||
class Depend(Primitive):
|
||||
"""
|
||||
Depend is used for process side-effect operations.
|
||||
|
||||
Inputs:
|
||||
- **value** (Tensor) - the real value to return for depend operator.
|
||||
- **expr** (Expression) - the expression to execute with no outputs.
|
||||
|
||||
Outputs:
|
||||
Tensor, the value passed by last operator.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self, value, expr):
|
||||
return value
|
||||
|
||||
|
||||
class CheckBprop(PrimitiveWithInfer):
|
||||
"""
|
||||
Checks whether data type and shape of corresponding element from tuple x and y are the same.
|
||||
|
|
|
@ -341,7 +341,7 @@ TEST_F(TestOps, ResolveTest) {
|
|||
}
|
||||
|
||||
TEST_F(TestOps, PartialTest) {
|
||||
auto prim = std::make_shared<Primitive>("partial");
|
||||
auto prim = std::make_shared<Primitive>("Partial");
|
||||
ASSERT_EQ(prim->name(), kPrimPartial->name());
|
||||
}
|
||||
|
||||
|
|
|
@ -636,7 +636,7 @@ def test_tuple_get_set_item(tag):
|
|||
def test_partial(tag):
|
||||
""" test_partial """
|
||||
fns = FnDict()
|
||||
partail = Primitive('partial')
|
||||
partail = P.Partial()
|
||||
|
||||
def f(x, y):
|
||||
return scalar_add(x, y)
|
||||
|
@ -655,7 +655,7 @@ def test_partial(tag):
|
|||
def test_replace_applicator(tag):
|
||||
""" test_replace_applicator """
|
||||
fns = FnDict()
|
||||
partail = Primitive('partial')
|
||||
partail = P.Partial()
|
||||
|
||||
def app1(x, y):
|
||||
return scalar_add(x, y)
|
||||
|
|
|
@ -22,7 +22,7 @@ four2five = Primitive('Four2Five')
|
|||
five2four = Primitive('Five2Four')
|
||||
transdata = Primitive("TransData")
|
||||
cast = Primitive('Cast')
|
||||
depend = Primitive('depend')
|
||||
depend = P.Depend()
|
||||
|
||||
|
||||
class FnDict:
|
||||
|
|
|
@ -16,13 +16,13 @@ import mindspore.common.dtype as mstype
|
|||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.ops import Primitive
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
AssignSub = P.AssignSub()
|
||||
Mul = P.Mul()
|
||||
Sub = P.Sub()
|
||||
make_tuple = Primitive('make_tuple')
|
||||
tuple_getitem = Primitive('tuple_getitem')
|
||||
depend = Primitive('depend')
|
||||
BatchNorm = P.BatchNorm()
|
||||
Cast = P.Cast()
|
||||
BNTrainingReduce = Primitive('BNTrainingReduce')
|
||||
|
@ -54,8 +54,8 @@ def test_fused_batch_norm_fusion(tag):
|
|||
mul1 = Mul(sub1, constant1)
|
||||
assign_sub0 = AssignSub(var0, mul0)
|
||||
assign_sub1 = AssignSub(var1, mul1)
|
||||
depend0 = depend(tuple_getitem(batch_norm, 0), assign_sub0)
|
||||
depend1 = depend(depend0, assign_sub1)
|
||||
depend0 = F.depend(tuple_getitem(batch_norm, 0), assign_sub0)
|
||||
depend1 = F.depend(depend0, assign_sub1)
|
||||
outputs = make_tuple(depend1, tuple_getitem(batch_norm, 3), tuple_getitem(batch_norm, 4))
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
@ -69,8 +69,8 @@ def test_fused_batch_norm_fusion(tag):
|
|||
mul1 = Mul(sub1, constant1)
|
||||
assign_sub0 = AssignSub(var0, Cast(mul0, mstype.float32))
|
||||
assign_sub1 = AssignSub(var1, Cast(mul1, mstype.float32))
|
||||
depend0 = depend(tuple_getitem(batch_norm, 0), assign_sub0)
|
||||
depend1 = depend(depend0, assign_sub1)
|
||||
depend0 = F.depend(tuple_getitem(batch_norm, 0), assign_sub0)
|
||||
depend1 = F.depend(depend0, assign_sub1)
|
||||
outputs = make_tuple(depend1, tuple_getitem(batch_norm, 3), tuple_getitem(batch_norm, 4))
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
@ -84,8 +84,8 @@ def test_fused_batch_norm_fusion(tag):
|
|||
mul1 = Mul(Cast(sub1, mstype.float32), constant1)
|
||||
assign_sub0 = AssignSub(var0, mul0)
|
||||
assign_sub1 = AssignSub(var1, mul1)
|
||||
depend0 = depend(tuple_getitem(batch_norm, 0), assign_sub0)
|
||||
depend1 = depend(depend0, assign_sub1)
|
||||
depend0 = F.depend(tuple_getitem(batch_norm, 0), assign_sub0)
|
||||
depend1 = F.depend(depend0, assign_sub1)
|
||||
outputs = make_tuple(depend1, tuple_getitem(batch_norm, 3), tuple_getitem(batch_norm, 4))
|
||||
output = tuple_getitem(outputs, 0)
|
||||
return output
|
||||
|
|
|
@ -16,7 +16,7 @@ from mindspore.ops import Primitive
|
|||
from mindspore.ops import operations as P
|
||||
|
||||
tuple_getitem = Primitive('tuple_getitem')
|
||||
depend = Primitive('depend')
|
||||
depend = P.Depend()
|
||||
addn = P.AddN()
|
||||
add = P.TensorAdd()
|
||||
sub = P.Sub()
|
||||
|
|
|
@ -16,7 +16,7 @@ from mindspore.ops import Primitive
|
|||
from mindspore.ops import operations as P
|
||||
|
||||
tuple_getitem = Primitive('tuple_getitem')
|
||||
depend = Primitive('depend')
|
||||
depend = P.Depend()
|
||||
addn = P.AddN()
|
||||
add = P.TensorAdd()
|
||||
sub = P.Sub()
|
||||
|
|
|
@ -15,7 +15,7 @@
|
|||
from mindspore.ops import Primitive
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
depend = Primitive('depend')
|
||||
depend = P.Depend()
|
||||
TransData = Primitive('TransData')
|
||||
add = P.TensorAdd()
|
||||
make_tuple = Primitive('make_tuple')
|
||||
|
|
|
@ -20,9 +20,9 @@ import mindspore.context as context
|
|||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.ops import Primitive
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops.operations import _grad_ops as G
|
||||
from mindspore.ops import prim_attr_register, PrimitiveWithInfer
|
||||
from ..ut_filter import non_graph_engine
|
||||
|
@ -358,7 +358,7 @@ class StateNet(nn.Cell):
|
|||
self.assign = P.Assign()
|
||||
|
||||
def construct(self, x):
|
||||
x = Primitive('depend')(x, self.assign(self.s1, x + self.s1))
|
||||
x = F.depend(x, self.assign(self.s1, x + self.s1))
|
||||
self.s1 = self.sub(self.s1, x)
|
||||
self.s2 = self.sub(self.s2, x)
|
||||
return x
|
||||
|
|
|
@ -132,7 +132,7 @@ def test_hypermap_add3_easy():
|
|||
|
||||
|
||||
add3 = C.MultitypeFuncGraph('add')
|
||||
partial = Primitive('partial')
|
||||
partial = P.Partial()
|
||||
|
||||
|
||||
@add3.register("Number", "Number", "Number")
|
||||
|
|
|
@ -284,3 +284,21 @@ def vm_impl_zeros_like(self):
|
|||
"""Generate vm_impl function for ZerosLike"""
|
||||
def vm_impl(x):
|
||||
return Tensor(np.zeros_like(x.asnumpy()))
|
||||
|
||||
@vm_impl_getters.register(P.Partial)
|
||||
def vm_impl_partial(self):
|
||||
"""Generate vm_impl function for Partial"""
|
||||
def vm_impl(*args):
|
||||
func = args[0].__call__
|
||||
partial_func = functools.partial(func, *args[1:])
|
||||
return partial_func
|
||||
|
||||
return vm_impl
|
||||
|
||||
@vm_impl_getters.register(P.Depend)
|
||||
def vm_impl_depend(self):
|
||||
"""Generate vm_impl function for Depend"""
|
||||
def vm_impl(value, expr):
|
||||
return value
|
||||
|
||||
return vm_impl
|
||||
|
|
Loading…
Reference in New Issue