forked from mindspore-Ecosystem/mindspore
!2148 fix hook and bprop debug issue in pynative
Merge pull request !2148 from wangqiuliang/fix-hook-bprop-issue
This commit is contained in:
commit
9ba6f61d01
|
@ -113,6 +113,24 @@ def bool_or(x, y):
|
|||
"""Implement `bool_or`."""
|
||||
return x or y
|
||||
|
||||
def vm_compare(*args):
|
||||
"""Implement `vm_compare` for tensor."""
|
||||
obj_str = args[-1]
|
||||
if obj_str == "shape":
|
||||
fn = getattr(args[0].asnumpy(), obj_str)
|
||||
return fn
|
||||
if len(args) == 2:
|
||||
fn = getattr(args[0].asnumpy(), obj_str)
|
||||
return Tensor(fn())
|
||||
if isinstance(args[0], Tensor):
|
||||
fn = getattr(args[0].asnumpy(), obj_str)
|
||||
y = args[1].asnumpy() if isinstance(args[1], Tensor) else args[1]
|
||||
else:
|
||||
obj_str = "__r" + obj_str[2:]
|
||||
fn = getattr(args[1].asnumpy(), obj_str)
|
||||
y = args[0]
|
||||
return Tensor(np.array(fn(y)))
|
||||
|
||||
|
||||
def make_list(*xs):
|
||||
"""Implement `make_list`."""
|
||||
|
|
|
@ -41,6 +41,35 @@ using TensorPtr = mindspore::tensor::TensorPtr;
|
|||
using MetaTensor = mindspore::tensor::MetaTensor;
|
||||
using MetaTensorPtr = mindspore::tensor::MetaTensorPtr;
|
||||
|
||||
FuncGraphPtr ConvertToBpropCut(const py::object &obj) {
|
||||
std::vector<std::string> results = data_converter::GetObjKey(obj);
|
||||
std::string obj_key = results[0];
|
||||
py::function bprop_func = py::getattr(obj, CUSTOM_BPROP_NAME);
|
||||
|
||||
auto bprop_graph = std::make_shared<FuncGraph>();
|
||||
std::vector<AnfNodePtr> outputs;
|
||||
|
||||
auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut", py::object());
|
||||
fake_bprop->set_hook(bprop_func);
|
||||
(void)fake_bprop->AddAttr(CUSTOM_BPROP_NAME, MakeValue(true));
|
||||
outputs.push_back(NewValueNode(fake_bprop));
|
||||
|
||||
py::object code_obj = py::getattr(bprop_func, "__code__");
|
||||
size_t inputs_num = py::cast<int>(py::getattr(code_obj, "co_argcount")) - 3;
|
||||
for (size_t i = 0; i < inputs_num; ++i) {
|
||||
auto param = bprop_graph->add_parameter();
|
||||
outputs.push_back(param);
|
||||
}
|
||||
auto p1 = bprop_graph->add_parameter();
|
||||
auto p2 = bprop_graph->add_parameter();
|
||||
outputs.push_back(p1);
|
||||
outputs.push_back(p2);
|
||||
|
||||
bprop_graph->set_output(bprop_graph->NewCNode(outputs));
|
||||
data_converter::SetObjGraphValue(obj_key, bprop_graph);
|
||||
return bprop_graph;
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool ConvertTuple(const py::object &obj, ValuePtr *const data, bool use_signature) {
|
||||
MS_LOG(DEBUG) << "Converting python tuple";
|
||||
|
@ -231,35 +260,6 @@ bool ConvertSlice(const py::object &obj, ValuePtr *const data) {
|
|||
return true;
|
||||
}
|
||||
|
||||
FuncGraphPtr ConvertToBpropCut(py::object obj) {
|
||||
std::vector<std::string> results = data_converter::GetObjKey(obj);
|
||||
std::string obj_key = results[0];
|
||||
py::function bprop_func = py::getattr(obj, "bprop");
|
||||
|
||||
FuncGraphPtr bprop_graph = std::make_shared<FuncGraph>();
|
||||
std::vector<AnfNodePtr> outputs;
|
||||
|
||||
auto fake_bprop = std::make_shared<PrimitivePy>("bprop_cut", py::object());
|
||||
fake_bprop->set_hook(bprop_func);
|
||||
(void)fake_bprop->AddAttr("bprop", MakeValue(true));
|
||||
outputs.push_back(NewValueNode(fake_bprop));
|
||||
|
||||
py::object code_obj = py::getattr(bprop_func, "__code__");
|
||||
size_t inputs_num = py::cast<int>(py::getattr(code_obj, "co_argcount")) - 3;
|
||||
for (size_t i = 0; i < inputs_num; ++i) {
|
||||
auto param = bprop_graph->add_parameter();
|
||||
outputs.push_back(param);
|
||||
}
|
||||
auto p1 = bprop_graph->add_parameter();
|
||||
auto p2 = bprop_graph->add_parameter();
|
||||
outputs.push_back(p1);
|
||||
outputs.push_back(p2);
|
||||
|
||||
bprop_graph->set_output(bprop_graph->NewCNode(outputs));
|
||||
data_converter::SetObjGraphValue(obj_key, bprop_graph);
|
||||
return bprop_graph;
|
||||
}
|
||||
|
||||
bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) {
|
||||
FuncGraphPtr func_graph = ConvertToFuncGraph(obj);
|
||||
if (func_graph == nullptr) {
|
||||
|
@ -267,7 +267,7 @@ bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) {
|
|||
return false;
|
||||
}
|
||||
// if the cell object has specified bprop, it has user-defined bprop function parse and record it
|
||||
if (py::hasattr(obj, "bprop")) {
|
||||
if (py::hasattr(obj, CUSTOM_BPROP_NAME)) {
|
||||
FuncGraphPtr bprop_graph = nullptr;
|
||||
bool enable_bprop_debug = py::cast<bool>(py::getattr(obj, "bprop_debug"));
|
||||
if (enable_bprop_debug) {
|
||||
|
@ -276,7 +276,7 @@ bool ConvertCellObjToFuncGraph(py::object obj, ValuePtr *const data) {
|
|||
bprop_graph = ConvertToFuncGraph(obj, PYTHON_MOD_GET_BPROP_METHOD);
|
||||
}
|
||||
if (bprop_graph != nullptr) {
|
||||
(void)func_graph->transforms().insert(std::make_pair("bprop", FuncGraphTransform(bprop_graph)));
|
||||
(void)func_graph->transforms().insert(std::make_pair(CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph)));
|
||||
(void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(func_graph)));
|
||||
func_graph->set_flags(FUNC_GRAPH_FLAG_DEFER_INLINE, true);
|
||||
}
|
||||
|
|
|
@ -51,6 +51,7 @@ void ClearObjectCache();
|
|||
} // namespace data_converter
|
||||
|
||||
ClassPtr ParseDataClass(const py::object &cls_obj);
|
||||
FuncGraphPtr ConvertToBpropCut(const py::object &obj);
|
||||
|
||||
void CleanDataClassToClassMap();
|
||||
|
||||
|
|
|
@ -109,6 +109,7 @@ const char PYTHON_EXTERN_MINDSPORE_FLAG[] = "_mindspore_flags";
|
|||
|
||||
// define the parse constant
|
||||
const int MAX_COMPARISON_OPS_SUPPORTED = 1;
|
||||
const char CUSTOM_BPROP_NAME[] = "bprop";
|
||||
|
||||
// define the Namespace name
|
||||
const char RESOLVE_NAMESPACE_NAME_AST[] = "Ast"; // for ast type namespace
|
||||
|
|
|
@ -45,7 +45,7 @@ enum PynativeStatusCode {
|
|||
PYNATIVE_UNKNOWN_STATE = 0XFF
|
||||
};
|
||||
|
||||
enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_INPUT_MASK, PY_ARGS_NUM };
|
||||
enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM };
|
||||
|
||||
struct OpExecInfo {
|
||||
PrimitivePyPtr py_primitive;
|
||||
|
|
|
@ -110,9 +110,15 @@ py::object GetTupleObj(const py::object &obj) {
|
|||
return obj_tuple;
|
||||
}
|
||||
|
||||
void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *out_args) {
|
||||
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *out_args) {
|
||||
auto &py_args = *out_args;
|
||||
py::tuple input_mask(args.size());
|
||||
for (size_t i = 0; i < args.size(); ++i) {
|
||||
if (py::hasattr(args[i], "__parameter__")) {
|
||||
input_mask[i] = true;
|
||||
} else {
|
||||
input_mask[i] = false;
|
||||
}
|
||||
py_args[i] = GetTupleObj(args[i]);
|
||||
}
|
||||
auto signature = prim->signatures();
|
||||
|
@ -121,7 +127,7 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *
|
|||
[](const Signature &sig) { return sig.dtype; });
|
||||
int empty_dtype_count = std::count(dtypes.begin(), dtypes.end(), SignatureEnumDType::kDTypeEmptyDefaultValue);
|
||||
if (dtypes.size() == 0 || static_cast<int>(dtypes.size()) == empty_dtype_count) {
|
||||
return;
|
||||
return input_mask;
|
||||
}
|
||||
std::map<SignatureEnumDType, std::vector<size_t>> type_indexs;
|
||||
for (size_t i = 0; i < dtypes.size(); ++i) {
|
||||
|
@ -160,6 +166,7 @@ void ConvertInputs(const PrimitivePyPtr &prim, const py::list &args, py::tuple *
|
|||
continue;
|
||||
}
|
||||
}
|
||||
return input_mask;
|
||||
}
|
||||
|
||||
void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecInfo *const op_exec_info) {
|
||||
|
@ -167,7 +174,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn
|
|||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
ValuePtr input_value = PyAttrValue(py_args[i]);
|
||||
if (input_value->isa<tensor::Tensor>()) {
|
||||
if (!py::hasattr(prim->GetPyObj(), "const_value") && input_value->isa<tensor::Tensor>()) {
|
||||
args_spec_list.emplace_back(abstract::FromValueInside(input_value, true));
|
||||
} else {
|
||||
args_spec_list.emplace_back(abstract::FromValueInside(input_value, false));
|
||||
|
@ -179,7 +186,7 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn
|
|||
|
||||
OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
|
||||
if (args.size() != PY_ARGS_NUM) {
|
||||
MS_LOG(ERROR) << "Four args are needed by RunOp";
|
||||
MS_LOG(ERROR) << "Three args are needed by RunOp";
|
||||
return nullptr;
|
||||
}
|
||||
auto op_exec_info = std::make_shared<OpExecInfo>();
|
||||
|
@ -195,14 +202,13 @@ OpExecInfoPtr GenerateOpExecInfo(const py::args &args) {
|
|||
size_t input_num = a.size();
|
||||
op_exec_info->op_inputs = py::tuple(input_num);
|
||||
|
||||
ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs);
|
||||
op_exec_info->inputs_mask = ConvertInputs(prim, args[PY_INPUTS], &op_exec_info->op_inputs);
|
||||
// use python infer method
|
||||
if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) {
|
||||
PynativeInfer(prim, op_exec_info->op_inputs, op_exec_info.get());
|
||||
}
|
||||
op_exec_info->py_primitive = prim;
|
||||
op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs");
|
||||
op_exec_info->inputs_mask = args[PY_INPUT_MASK];
|
||||
if (op_exec_info->op_inputs.size() != op_exec_info->inputs_mask.size()) {
|
||||
MS_LOG(ERROR) << "Op:" << op_exec_info->op_name << " inputs size not equal op_mask";
|
||||
return nullptr;
|
||||
|
@ -488,14 +494,14 @@ py::object RunOpWithBackendPolicy(MsBackendPolicy backend_policy, const OpExecIn
|
|||
return result;
|
||||
}
|
||||
|
||||
AnfNodePtr PynativeExecutor::MakeCNode(const py::args &args, const py::tuple &out) {
|
||||
AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out) {
|
||||
if (!grad_flag_ || graph_info_map_.size() == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]);
|
||||
auto prim = op_exec_info->py_primitive;
|
||||
inputs.push_back(NewValueNode(prim));
|
||||
py::tuple op_masks = args[PY_INPUT_MASK];
|
||||
py::tuple op_masks = op_exec_info->inputs_mask;
|
||||
py::list op_args = args[PY_INPUTS];
|
||||
AbstractBasePtrList args_spec_list;
|
||||
for (size_t i = 0; i < op_args.size(); i++) {
|
||||
|
@ -584,7 +590,7 @@ py::tuple RunOp(const py::args &args) {
|
|||
return err_ret;
|
||||
}
|
||||
|
||||
auto node = PynativeExecutor::GetInstance()->MakeCNode(args, result);
|
||||
auto node = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, args, result);
|
||||
if (node != nullptr) {
|
||||
node->set_abstract(op_exec_info->abstract);
|
||||
MS_LOG(DEBUG) << "RunOp MakeCnode,new node is: " << node->DebugString();
|
||||
|
@ -705,7 +711,7 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
|
|||
}
|
||||
cell_graph_map_[cell_id] = curr_g_;
|
||||
auto out_id = GetId(out);
|
||||
if (!graph_info_map_[curr_g_].obj_node_map.count(out_id)) {
|
||||
if (!graph_info_map_[curr_g_].obj_node_map.count(out_id) && !graph_info_map_[curr_g_].param_map.count(out_id)) {
|
||||
// cell construct return x, y
|
||||
if (py::isinstance<py::tuple>(out)) {
|
||||
std::vector<AnfNodePtr> args;
|
||||
|
@ -727,12 +733,26 @@ void PynativeExecutor::EndGraph(const py::object &cell, const py::object &out, c
|
|||
}
|
||||
}
|
||||
|
||||
auto output_node = GetObjNode(out);
|
||||
AnfNodePtr output_node;
|
||||
if (graph_info_map_[curr_g_].param_map.count(out_id)) {
|
||||
output_node = graph_info_map_[curr_g_].param_map[out_id];
|
||||
} else {
|
||||
output_node = GetObjNode(out);
|
||||
}
|
||||
curr_g_->set_output(output_node);
|
||||
std::vector<AnfNodePtr> inputs;
|
||||
inputs.push_back(NewValueNode(curr_g_));
|
||||
MS_LOG(DEBUG) << "Current graph" << curr_g_->output()->DebugString();
|
||||
resource_->manager()->AddFuncGraph(curr_g_);
|
||||
// custom bprop debug
|
||||
if (py::hasattr(cell, parse::CUSTOM_BPROP_NAME)) {
|
||||
MS_LOG(DEBUG) << "Use cell custom bprop function.";
|
||||
FuncGraphPtr bprop_graph = parse::ConvertToBpropCut(cell);
|
||||
if (bprop_graph != nullptr) {
|
||||
(void)curr_g_->transforms().insert(std::make_pair(parse::CUSTOM_BPROP_NAME, FuncGraphTransform(bprop_graph)));
|
||||
(void)bprop_graph->transforms().insert(std::make_pair("primal", FuncGraphTransform(curr_g_)));
|
||||
}
|
||||
}
|
||||
auto newfg = ad::Grad(curr_g_, resource_, curr_g_ == top_g_);
|
||||
if (curr_g_ != top_g_) {
|
||||
Popp();
|
||||
|
|
|
@ -44,7 +44,7 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
|
|||
|
||||
py::tuple RunOp(const py::args &args);
|
||||
|
||||
void ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *out_args);
|
||||
py::tuple ConvertInputs(const PrimitivePyPtr &prim, const py::list &py_args, py::tuple *out_args);
|
||||
|
||||
void ClearPyNativeSession();
|
||||
|
||||
|
@ -83,7 +83,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
void set_obj_node_map(FuncGraphPtr g, const std::string obj, AnfNodePtr node, int index) {
|
||||
graph_info_map_[g].obj_node_map[obj] = std::make_pair(node, index);
|
||||
}
|
||||
AnfNodePtr MakeCNode(const py::args &args, const py::tuple &out);
|
||||
AnfNodePtr MakeCNode(const OpExecInfoPtr &op_exec_info, const py::args &args, const py::tuple &out);
|
||||
py::object Run(const py::tuple &args, const py::object &phase);
|
||||
|
||||
void Pushp();
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
"""Registry the relation."""
|
||||
|
||||
from collections import UserDict
|
||||
from .. import context
|
||||
|
||||
|
||||
class Registry(UserDict):
|
||||
|
@ -27,9 +28,16 @@ class Registry(UserDict):
|
|||
|
||||
def get(self, obj_str):
|
||||
"""Get the value by str."""
|
||||
if isinstance(obj_str, str):
|
||||
if not isinstance(obj_str, str):
|
||||
raise TypeError("key for tensor registry must be string.")
|
||||
if context.get_context("enable_ge"):
|
||||
def wrap(*args):
|
||||
new_args = list(args)
|
||||
new_args.append(obj_str)
|
||||
return self["vm_compare"](*new_args)
|
||||
obj = wrap
|
||||
else:
|
||||
obj = self[obj_str]
|
||||
return obj
|
||||
|
||||
|
||||
tensor_operator_registry = Registry()
|
||||
|
|
|
@ -19,7 +19,6 @@ from .._c_expression import Tensor as Tensor_
|
|||
from .._c_expression import MetaTensor
|
||||
from .._checkparam import check_type, check_typename
|
||||
from . import dtype as mstype
|
||||
from .. import context
|
||||
from ._register_for_tensor import tensor_operator_registry
|
||||
|
||||
__all__ = ['Tensor', 'MetaTensor']
|
||||
|
@ -76,17 +75,19 @@ class Tensor(Tensor_):
|
|||
return out
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, Tensor):
|
||||
if not isinstance(other, (int, float, Tensor)):
|
||||
return False
|
||||
# The GE backend don't support single `Equal` operator execution.
|
||||
# bool type is not supported for `Equal` operator in backend.
|
||||
if context.get_context("enable_ge") or self.dtype == mstype.bool_ or other.dtype == mstype.bool_:
|
||||
if self.dtype == mstype.bool_ or (isinstance(other, Tensor) and other.dtype == mstype.bool_):
|
||||
return Tensor(np.array(self.asnumpy() == other.asnumpy()))
|
||||
return tensor_operator_registry.get('__eq__')(self, other)
|
||||
|
||||
def __ne__(self, other):
|
||||
if not isinstance(other, Tensor):
|
||||
if not isinstance(other, (int, float, Tensor)):
|
||||
return True
|
||||
# bool type is not supported for `NotEqual` operator in backend.
|
||||
if self.dtype == mstype.bool_ or (isinstance(other, Tensor) and other.dtype == mstype.bool_):
|
||||
return Tensor(np.array(self.asnumpy() != other.asnumpy()))
|
||||
return tensor_operator_registry.get('__ne__')(self, other)
|
||||
|
||||
def __hash__(self):
|
||||
|
@ -105,7 +106,7 @@ class Tensor(Tensor_):
|
|||
return out
|
||||
|
||||
def __radd__(self, other):
|
||||
out = tensor_operator_registry.get('__add__')(other, self)
|
||||
out = tensor_operator_registry.get('__add__')(self, other)
|
||||
return out
|
||||
|
||||
def __imul__(self, other):
|
||||
|
@ -113,15 +114,15 @@ class Tensor(Tensor_):
|
|||
return out
|
||||
|
||||
def __rmul__(self, other):
|
||||
out = tensor_operator_registry.get('__mul__')(other, self)
|
||||
out = tensor_operator_registry.get('__mul__')(self, other)
|
||||
return out
|
||||
|
||||
def __truediv__(self, other):
|
||||
out = tensor_operator_registry.get('__div__')(self, other)
|
||||
out = tensor_operator_registry.get('__truediv__')(self, other)
|
||||
return out
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
out = tensor_operator_registry.get('__div__')(other, self)
|
||||
out = tensor_operator_registry.get('__truediv__')(other, self)
|
||||
return out
|
||||
|
||||
def __sub__(self, other):
|
||||
|
@ -160,7 +161,7 @@ class Tensor(Tensor_):
|
|||
return out
|
||||
|
||||
def __len__(self):
|
||||
out = tensor_operator_registry.get('__shape__')(self)
|
||||
out = tensor_operator_registry.get('shape')(self)
|
||||
if not out:
|
||||
return 1
|
||||
return out[0]
|
||||
|
|
|
@ -819,4 +819,4 @@ class Cell:
|
|||
|
||||
"""
|
||||
self._backward_hook = HookBackward(fn, self.cls_name + "(" + str(id(self)) + ")")
|
||||
self._enable_hook = True
|
||||
self.enable_hook = True
|
||||
|
|
|
@ -140,6 +140,11 @@ class SequentialCell(Cell):
|
|||
def __len__(self):
|
||||
return len(self._cells)
|
||||
|
||||
def set_grad(self, flag=True):
|
||||
self.requires_grad = flag
|
||||
for cell in self._cells.values():
|
||||
cell.set_grad(flag)
|
||||
|
||||
def construct(self, input_data):
|
||||
for cell in self.cell_list:
|
||||
input_data = cell(input_data)
|
||||
|
@ -246,5 +251,10 @@ class CellList(_CellListBase, Cell):
|
|||
self._cells[str(len(self))] = cell
|
||||
return self
|
||||
|
||||
def set_grad(self, flag=True):
|
||||
self.requires_grad = flag
|
||||
for cell in self._cells.values():
|
||||
cell.set_grad(flag)
|
||||
|
||||
def construct(self, *inputs):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -112,7 +112,7 @@ class GradOperation(GradOperation_):
|
|||
grad_ = GradOperation('grad', self.get_all, self.get_by_list, self.sens_param)
|
||||
if self.grad_fn is None or self.fn != fn:
|
||||
if self.get_by_list:
|
||||
if context.get_context("mode") == context.GRAPH_MODE or fn.bprop_debug:
|
||||
if context.get_context("mode") == context.GRAPH_MODE:
|
||||
@ms_function(obj=fn)
|
||||
def after_grad(*args):
|
||||
return grad_(fn, weights)(*args)
|
||||
|
|
|
@ -21,6 +21,7 @@ from mindspore.common._register_for_tensor import tensor_operator_registry
|
|||
from .primitive import Primitive
|
||||
from . import operations as P
|
||||
from .operations import _grad_ops
|
||||
from .._extends import builtin_operations as BP
|
||||
|
||||
typeof = Primitive('typeof')
|
||||
hastype = Primitive('hastype')
|
||||
|
@ -155,7 +156,7 @@ stop_gradient = Primitive("stop_gradient")
|
|||
tensor_operator_registry.register('__add__', tensor_add)
|
||||
tensor_operator_registry.register('__sub__', tensor_sub)
|
||||
tensor_operator_registry.register('__mul__', tensor_mul)
|
||||
tensor_operator_registry.register('__div__', tensor_div)
|
||||
tensor_operator_registry.register('__truediv__', tensor_div)
|
||||
#ms cannot support Tensor(True) compare
|
||||
tensor_operator_registry.register('__eq__', equal)
|
||||
tensor_operator_registry.register('__ne__', not_equal)
|
||||
|
@ -164,4 +165,6 @@ tensor_operator_registry.register('__lt__', tensor_lt)
|
|||
tensor_operator_registry.register('__le__', tensor_le)
|
||||
tensor_operator_registry.register('__gt__', tensor_gt)
|
||||
tensor_operator_registry.register('__ge__', tensor_ge)
|
||||
tensor_operator_registry.register('__shape__', shape)
|
||||
tensor_operator_registry.register('shape', shape)
|
||||
#support GE backend for no compare operators
|
||||
tensor_operator_registry.register('vm_compare', BP.vm_compare)
|
||||
|
|
|
@ -863,6 +863,8 @@ class TupleToArray(PrimitiveWithInfer):
|
|||
args = list()
|
||||
if isinstance(x, range):
|
||||
args.append(tuple(x))
|
||||
else:
|
||||
args.append(x)
|
||||
return _run_op(self, self.name, args)
|
||||
|
||||
|
||||
|
|
|
@ -341,13 +341,7 @@ def constexpr(fn=None, get_instance=True, name=None):
|
|||
@_wrap_func
|
||||
def _run_op(obj, op_name, args):
|
||||
"""Single op execution function supported by ge in PyNative mode."""
|
||||
op_mask = [0] * len(args)
|
||||
op_inputs = []
|
||||
for i, arg in enumerate(args):
|
||||
if hasattr(arg, '__parameter__'):
|
||||
op_mask[i] = 1
|
||||
op_inputs.append(arg)
|
||||
output = real_run_op(obj, op_name, args, tuple(op_mask))
|
||||
output = real_run_op(obj, op_name, args)
|
||||
if not output:
|
||||
raise RuntimeError("Pynative run op %s failed!" % op_name)
|
||||
if len(output) == 1:
|
||||
|
|
|
@ -63,8 +63,7 @@ OpExecInfoPtr ConstructOpExecInfo() {
|
|||
|
||||
auto conv_obj = prim::GetPythonOps("conv2d_prim", "gtest_input.pynative");
|
||||
py::none py_none;
|
||||
py::tuple op_mask = py::make_tuple(0, 1);
|
||||
return GenerateOpExecInfo(py::make_tuple(conv_obj, op_name, op_inputs, op_mask));
|
||||
return GenerateOpExecInfo(py::make_tuple(conv_obj, op_name, op_inputs));
|
||||
}
|
||||
|
||||
TEST_F(TestPynativeExecute, TestRunOpInVM) {
|
||||
|
@ -79,7 +78,7 @@ TEST_F(TestPynativeExecute, TestRunOp) {
|
|||
py::none py_none;
|
||||
auto op_exec_info_ptr = ConstructOpExecInfo();
|
||||
py::tuple outputs = pynative::RunOp(py::make_tuple(op_exec_info_ptr->py_primitive, op_exec_info_ptr->op_name,
|
||||
op_exec_info_ptr->op_inputs, op_exec_info_ptr->inputs_mask));
|
||||
op_exec_info_ptr->op_inputs));
|
||||
if (outputs.size() == 0) {
|
||||
FAIL();
|
||||
} else {
|
||||
|
|
|
@ -452,5 +452,5 @@ def test_tensor_operation():
|
|||
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
|
||||
res = 8 / x
|
||||
assert np.all(res.asnumpy() == np.ones((3, 3)) * 2)
|
||||
with pytest.raises(TypeError):
|
||||
with pytest.raises(ValueError):
|
||||
res = x * (2, 3)
|
||||
|
|
|
@ -8,6 +8,9 @@ from mindspore.nn import WithLossCell, Momentum
|
|||
from mindspore.ops import composite as C
|
||||
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
cell_hook_done = False
|
||||
var_hook_done = False
|
||||
cell_bprop_done = False
|
||||
|
||||
|
||||
def conv(in_channels, out_channels, kernel_size, stride=1, padding=0):
|
||||
|
@ -32,15 +35,35 @@ def weight_variable():
|
|||
|
||||
def cell_hook_function(cell_id, grad_input, grad_output):
|
||||
print(cell_id)
|
||||
global cell_hook_done
|
||||
cell_hook_done = True
|
||||
assert (grad_output[0].asnumpy().shape == (32, 6, 14, 14))
|
||||
assert (grad_input[0].asnumpy().shape == (32, 16, 10, 10))
|
||||
|
||||
|
||||
def var_hook_function(grad_out):
|
||||
print("grad:", grad_out)
|
||||
global var_hook_done
|
||||
var_hook_done = True
|
||||
assert (grad_out[0].asnumpy().shape == (32, 120))
|
||||
|
||||
|
||||
class Block(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Block, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
def bprop(self, x, out, dout):
|
||||
global cell_bprop_done
|
||||
cell_bprop_done = True
|
||||
grad = out.asnumpy() * dout.asnumpy()
|
||||
grad = Tensor(grad)
|
||||
return (grad,)
|
||||
|
||||
class LeNet5(nn.Cell):
|
||||
"""
|
||||
Lenet network
|
||||
|
@ -59,6 +82,7 @@ class LeNet5(nn.Cell):
|
|||
self.conv1 = conv(1, 6, 5)
|
||||
self.conv2 = conv(6, 16, 5)
|
||||
self.conv2.register_backward_hook(cell_hook_function)
|
||||
self.block = Block()
|
||||
self.fc1 = fc_with_initialize(16 * 5 * 5, 120)
|
||||
self.fc2 = fc_with_initialize(120, 84)
|
||||
self.fc3 = fc_with_initialize(84, self.num_class)
|
||||
|
@ -72,7 +96,7 @@ class LeNet5(nn.Cell):
|
|||
x = self.relu(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.conv2(x)
|
||||
x = self.relu(x)
|
||||
x = self.block(x)
|
||||
x = self.max_pool2d(x)
|
||||
x = self.reshape(x, (self.batch_size, -1))
|
||||
x = self.fc1(x)
|
||||
|
@ -110,6 +134,9 @@ def test_hook():
|
|||
loss_output = criterion(output, label)
|
||||
grads = train_network(input_data, label)
|
||||
success = optimizer(grads)
|
||||
assert cell_hook_done
|
||||
assert var_hook_done
|
||||
assert cell_bprop_done
|
||||
print(loss_output.asnumpy().shape)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue