Return zero like gradients for PyExecute bprop, and add an unsupported test case.

This commit is contained in:
张清华 2023-01-03 15:54:06 +08:00
parent 60b419d560
commit 4f493ecc28
7 changed files with 155 additions and 34 deletions

View File

@ -368,31 +368,31 @@ FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &arg
params.push_back(fg->add_parameter());
}
// make fprob first result, maketuple's forward result.
// Make fprop first result, make_tuple's forward result.
AnfNodePtr out = fg->NewCNodeInOrder(params);
// make fprob second result, maketuple's backward function.
FuncGraphPtr b = std::make_shared<FuncGraph>();
// Make fprop second result, make_tuple's backward function.
FuncGraphPtr bprop = std::make_shared<FuncGraph>();
ss.str(std::string());
ss.clear();
// ◀make_tuple_
ss << "\u25C2make_tuple_" << tuple_size;
b->debug_info()->set_name(ss.str());
AnfNodePtr dout = b->add_parameter();
bprop->debug_info()->set_name(ss.str());
AnfNodePtr dout = bprop->add_parameter();
std::vector<AnfNodePtr> grads;
grads.push_back(NewValueNode(prim::kPrimMakeTuple));
grads.push_back(NewEnviron(b));
grads.push_back(NewEnviron(bprop));
for (int64_t i = 0; i < tuple_size; ++i) {
grads.push_back(b->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)}));
grads.push_back(bprop->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), dout, NewValueNode(i)}));
}
b->set_flag(FUNC_GRAPH_FLAG_CORE, true);
b->set_output(b->NewCNodeInOrder(grads));
bprop->set_flag(FUNC_GRAPH_FLAG_CORE, true);
bprop->set_output(bprop->NewCNodeInOrder(grads));
fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)}));
fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(bprop)}));
(void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeTuple));
return fg;
}
@ -412,35 +412,118 @@ FuncGraphPtr MakeListGradient::GenerateFuncGraph(const AbstractBasePtrList &args
params.push_back(fg->add_parameter());
}
// make fprob first result, maketuple's forward result.
// Make fprop first result, make_list's forward result.
AnfNodePtr out = fg->NewCNodeInOrder(params);
// make fprob second result, maketuple's backward function.
FuncGraphPtr b = std::make_shared<FuncGraph>();
// Make fprop second result, make_list's backward function.
FuncGraphPtr bprop = std::make_shared<FuncGraph>();
ss.str(std::string());
ss.clear();
// ◀make_list_
ss << "\u25C2make_list_" << list_size;
b->debug_info()->set_name(ss.str());
AnfNodePtr dout = b->add_parameter();
bprop->debug_info()->set_name(ss.str());
AnfNodePtr dout = bprop->add_parameter();
std::vector<AnfNodePtr> grads;
grads.push_back(NewValueNode(prim::kPrimMakeTuple));
grads.push_back(NewEnviron(b));
grads.push_back(NewEnviron(bprop));
for (int64_t i = 0; i < list_size; ++i) {
grads.push_back(b->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), dout, NewValueNode(i)}));
grads.push_back(bprop->NewCNodeInOrder({NewValueNode(prim::kPrimListGetItem), dout, NewValueNode(i)}));
}
b->set_flag(FUNC_GRAPH_FLAG_CORE, true);
b->set_output(b->NewCNodeInOrder(grads));
bprop->set_flag(FUNC_GRAPH_FLAG_CORE, true);
bprop->set_output(bprop->NewCNodeInOrder(grads));
fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(b)}));
fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(bprop)}));
(void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimMakeList));
return fg;
}
FuncGraphPtr PyExecuteGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
int64_t args_size = SizeToLong(args_spec_list.size());
constexpr auto py_execute_grad_input_count = 3;
constexpr auto op_name = "PyExecute";
CheckArgsSize(op_name, args_spec_list, py_execute_grad_input_count);
std::ostringstream ss;
// ▶PyExecute
ss << "\u25B8PyExecute_" << args_size;
FuncGraphPtr fg = std::make_shared<FuncGraph>();
fg->debug_info()->set_name(ss.str());
std::vector<AnfNodePtr> params;
(void)params.emplace_back(NewValueNode(prim::kPrimPyExecute));
for (int64_t i = 0; i < args_size; ++i) {
(void)params.emplace_back(fg->add_parameter());
}
// Make fprop first result, PyExecute's forward result.
AnfNodePtr out = fg->NewCNodeInOrder(params);
// make fprop second result, PyExecute's backward function.
FuncGraphPtr bprop = std::make_shared<FuncGraph>();
ss.str(std::string());
ss.clear();
// ◀PyExecute
ss << "\u25C2PyExecute_" << args_size;
bprop->debug_info()->set_name(ss.str());
(void)bprop->add_parameter();
std::vector<AnfNodePtr> grads;
(void)grads.emplace_back(NewValueNode(prim::kPrimMakeTuple));
(void)grads.emplace_back(NewEnviron(bprop));
// Propagate for script string.
(void)grads.emplace_back(params[1]);
// Propagate for local dict keys.
const auto &local_key_args = dyn_cast<abstract::AbstractTuple>(args_spec_list[1]);
MS_EXCEPTION_IF_NULL(local_key_args);
std::vector<AnfNodePtr> keys;
(void)keys.emplace_back(NewValueNode(prim::kPrimMakeTuple));
for (size_t i = 0; i < local_key_args->size(); ++i) {
constexpr auto keys_num = 2;
const auto &key_item =
bprop->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), params[keys_num], NewValueNode(SizeToLong(i))});
const auto &element = local_key_args->elements()[i];
const auto &str_element = dyn_cast<abstract::AbstractScalar>(element);
if (str_element != nullptr && str_element->BuildType()->isa<String>()) {
(void)keys.emplace_back(key_item);
} else {
(void)keys.emplace_back(bprop->NewCNodeInOrder({NewValueNode(prim::GetPythonOps("zeros_like")), key_item}));
}
}
(void)grads.emplace_back(bprop->NewCNodeInOrder(keys));
// Propagate for local dict values.
constexpr auto values_arg_num = 2;
const auto &local_value_args = dyn_cast<abstract::AbstractTuple>(args_spec_list[values_arg_num]);
MS_EXCEPTION_IF_NULL(local_value_args);
std::vector<AnfNodePtr> values;
(void)values.emplace_back(NewValueNode(prim::kPrimMakeTuple));
for (size_t i = 0; i < local_value_args->size(); ++i) {
constexpr auto values_num = 3;
const auto &value_item =
bprop->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), params[values_num], NewValueNode(SizeToLong(i))});
const auto &element = local_value_args->elements()[i];
const auto &str_element = dyn_cast<abstract::AbstractScalar>(element);
if (str_element != nullptr && str_element->BuildType()->isa<String>()) {
(void)values.emplace_back(value_item);
} else {
(void)values.emplace_back(bprop->NewCNodeInOrder({NewValueNode(prim::GetPythonOps("zeros_like")), value_item}));
}
}
(void)grads.emplace_back(bprop->NewCNodeInOrder(values));
bprop->set_flag(FUNC_GRAPH_FLAG_CORE, true);
bprop->set_output(bprop->NewCNodeInOrder(grads));
fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
fg->set_output(fg->NewCNodeInOrder({NewValueNode(prim::kPrimMakeTuple), out, NewValueNode(bprop)}));
(void)fg->transforms().emplace("primal", FuncGraphTransform(prim::kPrimPyExecute));
return fg;
}
namespace {
bool IsTupleAllTensor(const AbstractTuplePtr &tuple_arg) {
MS_EXCEPTION_IF_NULL(tuple_arg);
@ -1146,9 +1229,9 @@ FuncGraphPtr VmapOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
MS_LOG(EXCEPTION) << "'VmapOperation' requires a network or function as an input, while the input is empty.";
}
constexpr auto kVmapOperationInputNum = 3;
constexpr auto vmap_operation_input_num = 3;
const std::string op_name = "vmap";
CheckArgsSize(op_name, args_spec_list, kVmapOperationInputNum);
CheckArgsSize(op_name, args_spec_list, vmap_operation_input_num);
auto fn_arg = args_spec_list[0];
auto in_axes_arg = args_spec_list[1];

View File

@ -142,6 +142,16 @@ class MakeListGradient : public MetaFuncGraph {
};
using MakeListGradientPtr = std::shared_ptr<MakeListGradient>;
class PyExecuteGradient : public MetaFuncGraph {
public:
explicit PyExecuteGradient(const std::string &name) : MetaFuncGraph(name) {}
~PyExecuteGradient() override = default;
MS_DECLARE_PARENT(PyExecuteGradient, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
friend bool operator==(const PyExecuteGradient &lhs, const PyExecuteGradient &rhs) { return lhs.name_ == rhs.name_; }
};
using PyExecuteGradientPtr = std::shared_ptr<PyExecuteGradient>;
class GradOperation : public MetaFuncGraph {
public:
explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false,

View File

@ -95,6 +95,12 @@ MetaFuncGraphPtr KPrim::KMetaFuncGraph(const PrimitivePtr &prim) {
return meta;
}
if (IsPrimitiveEquals(prim, prim::kPrimPyExecute)) {
MetaFuncGraphPtr meta = std::make_shared<prim::PyExecuteGradient>("PyExecuteGradient");
bprop_registry_meta_[prim::kPrimPyExecute] = meta;
return meta;
}
MS_LOG(EXCEPTION) << "Fail to find bprop function for " << prim->name() << ".";
}
@ -197,7 +203,9 @@ FuncGraphPtr KPrim::KPrimitive(const CNodePtr &cnode, const ValueNodePtr &value_
auto fprop = GetFprop(prim);
fprop->transforms().emplace("primal", FuncGraphTransform(prim::kPrimSwitchLayer));
return fprop;
} else if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList)) {
} else if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList) ||
IsPrimitiveEquals(prim, prim::kPrimPyExecute)) {
// Return null to use Meta bprop.
return nullptr;
}

View File

@ -223,7 +223,8 @@ inline bool Skip(const MetaFuncGraphPtr &meta_func_graph) {
meta_func_graph->isa<prim::UnpackCall>() || meta_func_graph->isa<prim::ZipOperation>() ||
meta_func_graph->isa<prim::ListAppend>() || meta_func_graph->isa<prim::ListInsert>() ||
meta_func_graph->isa<prim::DoSignatureMetaFuncGraph>() || meta_func_graph->isa<prim::VmapMatchOutAxis>() ||
meta_func_graph->isa<prim::VmapGeneralPreprocess>() || meta_func_graph->isa<prim::GradAux>();
meta_func_graph->isa<prim::VmapGeneralPreprocess>() || meta_func_graph->isa<prim::GradAux>() ||
meta_func_graph->isa<prim::PyExecuteGradient>();
}
/* inherit relation of MetaFuncGraph
@ -237,6 +238,7 @@ inline bool Skip(const MetaFuncGraphPtr &meta_func_graph) {
* Tail
* MakeTupleGradient
* MakeListGradient
* PyExecuteGradient
* GradOperation
* TupleAdd
* SequenceSlice

View File

@ -924,7 +924,7 @@ class Parser:
src = dedent(original_src)
self.col_offset = \
len(original_src.split('\n')[0]) - len(src.split('\n')[0])
logger.info("Get source: %s", src)
logger.debug("Get source: %s", src)
try:
ast_tokens = asttokens.ASTTokens(src, parse=True)
except IndentationError as idt_err:

View File

@ -225,9 +225,3 @@ def bprop_scalar_not(x, out, dout):
def bprop_tensor_move(x, out, dout):
"""Backpropagator for primitive `TensorMove`."""
return (dout,)
@bprops.register("PyExecute")
def get_bprop_py_execute(x, y, z, out, dout):
"""Generate bprop for PyExecute"""
return x, y, z

View File

@ -123,24 +123,48 @@ def test_dict_return_1():
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dict_get_1():
def test_dict_return_2():
"""
Feature: Return dict.
Description: Support dict return.
Expectation: No exception.
"""
@ms.jit
def dict_net_1():
def dict_net_2():
x = {'a': 1, 'b': 2}
y = x.get('a')
y_tensor = ms.Tensor([y])
z = dict(a=y_tensor)
return z
out = dict_net_1()
out = dict_net_2()
print(f'out: {out}')
@pytest.mark.skip(reason="No support None and Scalar in dict.")
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_dict_return_3():
"""
Feature: Return dict.
Description: Support dict return.
Expectation: No exception.
"""
@ms.jit
def dict_net_3():
x = {'a': 'a', 'b': 'b'}
y = x.get('a')
z = dict(y=y, u=9, v=False, w=None)
return z
out = dict_net_3()
print(f'out: {out}')
assert out == {'y': 'a', 'u': 9, 'v': False, 'w': None}
@pytest.mark.level1
@pytest.mark.platform_x86_gpu_training
@pytest.mark.platform_arm_ascend_training