forked from mindspore-Ecosystem/mindspore
Return zero like gradients for PyExecute bprop, and add an unsupported test case.
This commit is contained in:
parent
60b419d560
commit
4f493ecc28
|
@ -368,31 +368,31 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &arg
|
|||
params.push_back(fg->add_parameter());
|
||||
}
|
||||
|
||||
// make fprob first result, maketuple's forward result.
|
||||
// Make fprop first result, make_tuple's forward result.
|
||||
AnfNodePtr out = fg->NewCNodeInOrder(params);
|
||||
|
||||
// make fprob second result, maketuple's backward function.
|
||||
FuncGraphPtr b = std::make_shared<FuncGraph>();
|
||||
// Make fprop second result, make_tuple's backward function.
|
||||
FuncGraphPtr bprop = std::make_shared<FuncGraph>();
|
||||
|
||||
ss.str(std::string());
|
||||
ss.clear();
|
||||
// ◀make_tuple_
|
||||
ss << "\u25C2make_tuple_" << tuple_size;
|
||||
b->debug_info()->set_name(ss.str());
|
||||
AnfNodePtr dout = b->add_parameter();
|
||||
bprop->debug_info()->set_name(ss.str());
|
||||
AnfNodePtr dout = bprop->add_parameter();
|
||||
|
||||
std::vector<AnfNodePtr> grads;
|
||||
grads.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
grads.push_back(NewEnviron(b));
|
||||
grads.push_back(NewEnviron(bprop));
|
||||
for (int64_t i = 0; i < tuple_size; ++i) {
|
||||
grads.push_back(b->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)}));
|
||||
grads.push_back(bprop->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)}));
|
||||
}
|
||||
|
||||
b->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
b->set_output(b->NewCNodeInOrder(grads));
|
||||
bprop->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
bprop->set_output(bprop->NewCNodeInOrder(grads));
|
||||
|
||||
fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)}));
|
||||
fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(bprop)}));
|
||||
(void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeTuple));
|
||||
return fg;
|
||||
}
|
||||
|
@ -412,35 +412,118 @@ FuncGraphPtr MakeListGradient::GenerateFuncGraph(const AbstractBasePtrList &args
|
|||
params.push_back(fg->add_parameter());
|
||||
}
|
||||
|
||||
// make fprob first result, maketuple's forward result.
|
||||
// Make fprop first result, make_list's forward result.
|
||||
AnfNodePtr out = fg->NewCNodeInOrder(params);
|
||||
|
||||
// make fprob second result, maketuple's backward function.
|
||||
FuncGraphPtr b = std::make_shared<FuncGraph>();
|
||||
// Make fprop second result, make_list's backward function.
|
||||
FuncGraphPtr bprop = std::make_shared<FuncGraph>();
|
||||
|
||||
ss.str(std::string());
|
||||
ss.clear();
|
||||
// ◀make_list_
|
||||
ss << "\u25C2make_list_" << list_size;
|
||||
b->debug_info()->set_name(ss.str());
|
||||
AnfNodePtr dout = b->add_parameter();
|
||||
bprop->debug_info()->set_name(ss.str());
|
||||
AnfNodePtr dout = bprop->add_parameter();
|
||||
|
||||
std::vector<AnfNodePtr> grads;
|
||||
grads.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
grads.push_back(NewEnviron(b));
|
||||
grads.push_back(NewEnviron(bprop));
|
||||
for (int64_t i = 0; i < list_size; ++i) {
|
||||
grads.push_back(b->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), dout, NewValueNode(i)}));
|
||||
grads.push_back(bprop->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), dout, NewValueNode(i)}));
|
||||
}
|
||||
|
||||
b->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
b->set_output(b->NewCNodeInOrder(grads));
|
||||
bprop->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
bprop->set_output(bprop->NewCNodeInOrder(grads));
|
||||
|
||||
fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)}));
|
||||
fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(bprop)}));
|
||||
(void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeList));
|
||||
return fg;
|
||||
}
|
||||
|
||||
FuncGraphPtr PyExecuteGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||
int64_t args_size = SizeToLong(args_spec_list.size());
|
||||
constexpr auto py_execute_grad_input_count = 3;
|
||||
constexpr auto op_name = "PyExecute";
|
||||
CheckArgsSize(op_name, args_spec_list, py_execute_grad_input_count);
|
||||
|
||||
std::ostringstream ss;
|
||||
// ▶PyExecute
|
||||
ss << "\u25B8PyExecute_" << args_size;
|
||||
FuncGraphPtr fg = std::make_shared<FuncGraph>();
|
||||
fg->debug_info()->set_name(ss.str());
|
||||
|
||||
std::vector<AnfNodePtr> params;
|
||||
(void)params.emplace_back(NewValueNode(prim::kPrimPyExecute));
|
||||
for (int64_t i = 0; i < args_size; ++i) {
|
||||
(void)params.emplace_back(fg->add_parameter());
|
||||
}
|
||||
|
||||
// Make fprop first result, PyExecute's forward result.
|
||||
AnfNodePtr out = fg->NewCNodeInOrder(params);
|
||||
|
||||
// make fprop second result, PyExecute's backward function.
|
||||
FuncGraphPtr bprop = std::make_shared<FuncGraph>();
|
||||
|
||||
ss.str(std::string());
|
||||
ss.clear();
|
||||
// ◀PyExecute
|
||||
ss << "\u25C2PyExecute_" << args_size;
|
||||
bprop->debug_info()->set_name(ss.str());
|
||||
(void)bprop->add_parameter();
|
||||
|
||||
std::vector<AnfNodePtr> grads;
|
||||
(void)grads.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
(void)grads.emplace_back(NewEnviron(bprop));
|
||||
// Propagate for script string.
|
||||
(void)grads.emplace_back(params[1]);
|
||||
// Propagate for local dict keys.
|
||||
const auto &local_key_args = dyn_cast<abstract::AbstractTuple>(args_spec_list[1]);
|
||||
MS_EXCEPTION_IF_NULL(local_key_args);
|
||||
std::vector<AnfNodePtr> keys;
|
||||
(void)keys.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
for (size_t i = 0; i < local_key_args->size(); ++i) {
|
||||
constexpr auto keys_num = 2;
|
||||
const auto &key_item =
|
||||
bprop->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), params[keys_num], NewValueNode(SizeToLong(i))});
|
||||
const auto &element = local_key_args->elements()[i];
|
||||
const auto &str_element = dyn_cast<abstract::AbstractScalar>(element);
|
||||
if (str_element != nullptr && str_element->BuildType()->isa<String>()) {
|
||||
(void)keys.emplace_back(key_item);
|
||||
} else {
|
||||
(void)keys.emplace_back(bprop->NewCNodeInOrder({NewValueNode(prim::GetPythonOps("zeros_like")), key_item}));
|
||||
}
|
||||
}
|
||||
(void)grads.emplace_back(bprop->NewCNodeInOrder(keys));
|
||||
// Propagate for local dict values.
|
||||
constexpr auto values_arg_num = 2;
|
||||
const auto &local_value_args = dyn_cast<abstract::AbstractTuple>(args_spec_list[values_arg_num]);
|
||||
MS_EXCEPTION_IF_NULL(local_value_args);
|
||||
std::vector<AnfNodePtr> values;
|
||||
(void)values.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
for (size_t i = 0; i < local_value_args->size(); ++i) {
|
||||
constexpr auto values_num = 3;
|
||||
const auto &value_item =
|
||||
bprop->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), params[values_num], NewValueNode(SizeToLong(i))});
|
||||
const auto &element = local_value_args->elements()[i];
|
||||
const auto &str_element = dyn_cast<abstract::AbstractScalar>(element);
|
||||
if (str_element != nullptr && str_element->BuildType()->isa<String>()) {
|
||||
(void)values.emplace_back(value_item);
|
||||
} else {
|
||||
(void)values.emplace_back(bprop->NewCNodeInOrder({NewValueNode(prim::GetPythonOps("zeros_like")), value_item}));
|
||||
}
|
||||
}
|
||||
(void)grads.emplace_back(bprop->NewCNodeInOrder(values));
|
||||
|
||||
bprop->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
bprop->set_output(bprop->NewCNodeInOrder(grads));
|
||||
|
||||
fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(bprop)}));
|
||||
(void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimPyExecute));
|
||||
return fg;
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool IsTupleAllTensor(const AbstractTuplePtr &tuple_arg) {
|
||||
MS_EXCEPTION_IF_NULL(tuple_arg);
|
||||
|
@ -1146,9 +1229,9 @@ FuncGraphPtr VmapOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
|
|||
MS_LOG(EXCEPTION) << "'VmapOperation' requires a network or function as an input, while the input is empty.";
|
||||
}
|
||||
|
||||
constexpr auto kVmapOperationInputNum = 3;
|
||||
constexpr auto vmap_operation_input_num = 3;
|
||||
const std::string op_name = "vmap";
|
||||
CheckArgsSize(op_name, args_spec_list, kVmapOperationInputNum);
|
||||
CheckArgsSize(op_name, args_spec_list, vmap_operation_input_num);
|
||||
|
||||
auto fn_arg = args_spec_list[0];
|
||||
auto in_axes_arg = args_spec_list[1];
|
||||
|
|
|
@ -142,6 +142,16 @@ class MakeListGradient : public MetaFuncGraph {
|
|||
};
|
||||
using MakeListGradientPtr = std::shared_ptr<MakeListGradient>;
|
||||
|
||||
class PyExecuteGradient : public MetaFuncGraph {
|
||||
public:
|
||||
explicit PyExecuteGradient(const std::string &name) : MetaFuncGraph(name) {}
|
||||
~PyExecuteGradient() override = default;
|
||||
MS_DECLARE_PARENT(PyExecuteGradient, MetaFuncGraph)
|
||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||
friend bool operator==(const PyExecuteGradient &lhs, const PyExecuteGradient &rhs) { return lhs.name_ == rhs.name_; }
|
||||
};
|
||||
using PyExecuteGradientPtr = std::shared_ptr<PyExecuteGradient>;
|
||||
|
||||
class GradOperation : public MetaFuncGraph {
|
||||
public:
|
||||
explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false,
|
||||
|
|
|
@ -95,6 +95,12 @@ MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) {
|
|||
return meta;
|
||||
}
|
||||
|
||||
if (IsPrimitiveEquals(prim, prim::kPrimPyExecute)) {
|
||||
MetaFuncGraphPtr meta = std::make_shared<prim::PyExecuteGradient>("PyExecuteGradient");
|
||||
bprop_registry_meta_[prim::kPrimPyExecute] = meta;
|
||||
return meta;
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << ".";
|
||||
}
|
||||
|
||||
|
@ -197,7 +203,9 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_
|
|||
auto fprop = GetFprop(prim);
|
||||
fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer));
|
||||
return fprop;
|
||||
} else if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList)) {
|
||||
} else if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList) ||
|
||||
IsPrimitiveEquals(prim, prim::kPrimPyExecute)) {
|
||||
// Return null to use Meta bprop.
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -223,7 +223,8 @@ inline bool Skip(const MetaFuncGraphPtr &meta_func_graph) {
|
|||
meta_func_graph->isa<prim::UnpackCall>() || meta_func_graph->isa<prim::ZipOperation>() ||
|
||||
meta_func_graph->isa<prim::ListAppend>() || meta_func_graph->isa<prim::ListInsert>() ||
|
||||
meta_func_graph->isa<prim::DoSignatureMetaFuncGraph>() || meta_func_graph->isa<prim::VmapMatchOutAxis>() ||
|
||||
meta_func_graph->isa<prim::VmapGeneralPreprocess>() || meta_func_graph->isa<prim::GradAux>();
|
||||
meta_func_graph->isa<prim::VmapGeneralPreprocess>() || meta_func_graph->isa<prim::GradAux>() ||
|
||||
meta_func_graph->isa<prim::PyExecuteGradient>();
|
||||
}
|
||||
|
||||
/* inherit relation of MetaFuncGraph
|
||||
|
@ -237,6 +238,7 @@ inline bool Skip(const MetaFuncGraphPtr &meta_func_graph) {
|
|||
* ├── Tail
|
||||
* ├── MakeTupleGradient
|
||||
* ├── MakeListGradient
|
||||
* ├── PyExecuteGradient
|
||||
* ├── GradOperation
|
||||
* ├── TupleAdd
|
||||
* └── SequenceSlice
|
||||
|
|
|
@ -924,7 +924,7 @@ class Parser:
|
|||
src = dedent(original_src)
|
||||
self.col_offset = \
|
||||
len(original_src.split('\n')[0]) - len(src.split('\n')[0])
|
||||
logger.info("Get source: %s", src)
|
||||
logger.debug("Get source: %s", src)
|
||||
try:
|
||||
ast_tokens = asttokens.ASTTokens(src, parse=True)
|
||||
except IndentationError as idt_err:
|
||||
|
|
|
@ -225,9 +225,3 @@ def bprop_scalar_not(x, out, dout):
|
|||
def bprop_tensor_move(x, out, dout):
|
||||
"""Backpropagator for primitive `TensorMove`."""
|
||||
return (dout,)
|
||||
|
||||
|
||||
@bprops.register("PyExecute")
|
||||
def get_bprop_py_execute(x, y, z, out, dout):
|
||||
"""Generate bprop for PyExecute"""
|
||||
return x, y, z
|
||||
|
|
|
@ -123,24 +123,48 @@ def test_dict_return_1():
|
|||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dict_get_1():
|
||||
def test_dict_return_2():
|
||||
"""
|
||||
Feature: Return dict.
|
||||
Description: Support dict return.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms.jit
|
||||
def dict_net_1():
|
||||
def dict_net_2():
|
||||
x = {'a': 1, 'b': 2}
|
||||
y = x.get('a')
|
||||
y_tensor = ms.Tensor([y])
|
||||
z = dict(a=y_tensor)
|
||||
return z
|
||||
|
||||
out = dict_net_1()
|
||||
out = dict_net_2()
|
||||
print(f'out: {out}')
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="No support None and Scalar in dict.")
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_dict_return_3():
|
||||
"""
|
||||
Feature: Return dict.
|
||||
Description: Support dict return.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
@ms.jit
|
||||
def dict_net_3():
|
||||
x = {'a': 'a', 'b': 'b'}
|
||||
y = x.get('a')
|
||||
z = dict(y=y, u=9, v=False, w=None)
|
||||
return z
|
||||
|
||||
out = dict_net_3()
|
||||
print(f'out: {out}')
|
||||
assert out == {'y': 'a', 'u': 9, 'v': False, 'w': None}
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
|
|
Loading…
Reference in New Issue