Fix GradOperation descripition and support tuple/list for 'get_by_list'.

This commit is contained in:
Zhang Qinghua 2022-06-01 11:44:47 +08:00
parent 057ed3b773
commit ce39353421
19 changed files with 331 additions and 140 deletions

View File

@ -13,7 +13,9 @@ mindspore/mindspore/ccsrc/frontend/optimizer/irpass.cc:mindspore::opt::irpass::O
mindspore/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc:mindspore::parallel::GatherV2PInfo::CheckStrategy
mindspore/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_kernel_runtime.cc:mindspore::device::gpu::GPUKernelRuntime::LaunchKernelDynamic
mindspore/mindspore/ccsrc/pipeline/jit/init.cc:PYBIND11_MODULE
mindspore/mindspore/ccsrc/pipeline/jit/parse/resolve.cc:mindspore::parse::ResolveObjectToNode
mindspore/mindspore/ccsrc/pipeline/jit/pipeline.cc:mindspore::pipeline::GraphExecutorPy::Compile
mindspore/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc:mindspore::abstract::ConvertAbstractToPython
mindspore/mindspore/ccsrc/pybind_api/ir/log_adapter_py.cc:mindspore::PyExceptionInitializer::HandleExceptionPy
mindspore/mindspore/ccsrc/plugin/device/gpu/kernel/math/unary_op_gpu_kernel.h:mindspore::kernel::UnaryOpGpuKernel::Launch
mindspore/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/dynamic_rnn_grad_fission_v2.cc:mindspore::opt::AddLSTMInputGradNode
@ -32,7 +34,6 @@ mindspore/mindspore/lite/src/runtime/ios_reg_ops.cc:mindspore::lite::IosRegister
mindspore/mindspore/lite/src/runtime/ios_reg_kernels.h:mindspore::kernel::IosRegisterKernels
mindspore/mindspore/lite/src/runtime/kernel/cpu/base/quant_dtype_cast.cc:mindspore::kernel::QuantDTypeCastCPUKernel::QuantDTypeCast
mindspore/mindspore/lite/src/runtime/kernel/cpu/base/quant_dtype_cast.cc:mindspore::kernel::QuantDTypeCastCPUKernel::Run
mindspore/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc:mindspore::abstract::ConvertAbstractToPython
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/strided_slice_infer.c:StridedSliceInferShape
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/winograd_transform_fp16.c:WinogradInputTransformFp16
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/pooling_fp16.c:AvgPoolingFp16

View File

@ -45,6 +45,11 @@ mindspore.ops.GradOperation
4. 用 `net` 的输入作为参数调用梯度函数,得到关于所有输入和给定参数的梯度:`gradient_function(x, y)`
注意:对于上面产生的梯度函数,其返回值会因返回梯度的数量不同而出现差异:
1. 如果仅有一个梯度结果,返回单独值;
2. 如果有多个梯度结果返回tuple
3. 如果没有任何梯度结果返回空tuple
我们可以设置 `sens_param` 等于True来配置灵敏度关于输出的梯度向梯度函数传递一个额外的灵敏度输入值。这个输入值必须与 `net` 的输出具有相同的形状和类型(见样例中的 `GradNetWrtXYWithSensParam` )。
1. 构建一个带有 `get_all=True``sens_param=True` 参数的 `GradOperation` 高阶函数:`grad_op = GradOperation(get_all=True, sens_param=True)`

View File

@ -555,9 +555,9 @@ FuncGraphPtr Tail::GenerateGradFuncGraph(const AbstractTuplePtr &tuple_arg, cons
if (tail_type_ == kGradFirst) {
AnfNodePtr tuple_parameter = fg->add_parameter();
PrimitivePtr getitem_op = prim::kPrimTupleGetItem;
if (CanGradArgument(tuple_arg, 1) || EnableGradFirstForTuple(tuple_arg, enable_tuple_grad_first_)) {
fg->set_output(fg->NewCNode({NewValueNode(getitem_op), tuple_parameter, NewValueNode(SizeToLong(1))}));
fg->set_output(
fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), tuple_parameter, NewValueNode(SizeToLong(1))}));
} else {
fg->set_output(NewValueNode(std::make_shared<ValueTuple>(ValuePtrList())));
}
@ -572,14 +572,22 @@ FuncGraphPtr Tail::GenerateGradFuncGraph(const AbstractTuplePtr &tuple_arg, cons
if (tail_type_ == kGradAll) {
AnfNodePtr tuple_parameter = fg->add_parameter();
std::vector<AnfNodePtr> elements = {NewValueNode(prim::kPrimMakeTuple)};
PrimitivePtr op = prim::kPrimTupleGetItem;
for (size_t i = 1; i < tuple_arg->size(); ++i) {
MS_EXCEPTION_IF_NULL((*tuple_arg)[i]);
if (CanGradArgument(tuple_arg, i)) {
elements.push_back(fg->NewCNodeInOrder({NewValueNode(op), tuple_parameter, NewValueNode(SizeToLong(i))}));
elements.push_back(
fg->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), tuple_parameter, NewValueNode(SizeToLong(i))}));
}
}
if (elements.size() > 1) {
// We should deal with 'get_all=True' as other options later:
// "The returned result may vary for grad result element number.
// A single value if only one result, a tuple for multiple results, or a empty tuple for no result.
//
// Notice that even if the user set 'get_all=True' and pass multiple inputs,
// the 'CanGradArgument' may change it to only one gradient output or no gradient."
constexpr size_t args_least_size = 2;
if (elements.size() >= args_least_size) {
fg->set_output(fg->NewCNodeInOrder(elements));
return fg;
}
@ -744,12 +752,8 @@ void CheckPrimBpropReturnSparse(const FuncGraphPtr &primal_graph) {
if (has_sparse_bprop_prim) {
return EXCLUDE;
}
auto prim = GetCNodePrimitive(node);
auto prim = GetCNodePrimitiveWithoutDoSignature(node);
if (prim != nullptr) {
auto do_signature = dyn_cast<mindspore::prim::DoSignaturePrimitive>(prim);
if (do_signature != nullptr) {
prim = dyn_cast<Primitive>(do_signature->function());
}
bool sparse_bprop = GetPrimitiveFlag(prim, GRAPH_FLAG_BPROP_RETURN_SPARSE);
if (sparse_bprop) {
MS_LOG(DEBUG) << "prim: " << prim->ToString() << " has attr 'bprop_return_sparse'";

View File

@ -67,26 +67,16 @@ static AnfNodePtr GenerateUnpackGraphNode(const AnfNodePtr &origin_node, std::ve
return unpack_graph_node;
}
MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr &node) {
ValuePtr value;
if (IsValueNode<prim::DoSignaturePrimitive>(node)) {
value = GetValueNode(node)->cast<prim::DoSignaturePrimitivePtr>()->function();
} else {
value = GetValueNode(node);
}
if (value == nullptr) {
return nullptr;
}
return value->cast<MetaFuncGraphPtr>();
}
// check if node is a specific meta_fg_opration that registered in the meta_fg_ops
// Check if node is a specific meta_fg_opration that registered in the meta_fg_ops
bool CheckMetaFgOps(const AnfNodePtr &node) {
if (node == nullptr) {
return false;
}
auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node);
auto value = GetValueWithoutDoSignature(node);
if (value == nullptr) {
return false;
}
auto meta_func_graph_ptr = value->cast<MetaFuncGraphPtr>();
if (meta_func_graph_ptr == nullptr) {
return false;
}
@ -143,16 +133,29 @@ AnfNodePtr MetaFgVarPrepare::operator()(const OptimizerPtr &, const AnfNodePtr &
// For general meta_fg_opration, sens_param is not involved, and that of GradOperation obtained specifically.
bool sens_param = false;
if (grad_op_->Match(inputs_x[0])) {
auto meta_func = GetMetaFuncGraphOfValueNode(inputs_x[0]);
if (meta_func == nullptr) {
return nullptr;
}
auto value = GetValueWithoutDoSignature(inputs_x[0]);
MS_EXCEPTION_IF_NULL(value);
auto meta_func = value->cast<MetaFuncGraphPtr>();
MS_EXCEPTION_IF_NULL(meta_func);
auto grad_op_ptr = meta_func->cast<prim::GradOperationPtr>();
sens_param = grad_op_ptr->sens_param();
// Remove the tuple/list inputs from order list for grad(UnpackGraph(..), list/tuple)(..)
if (inputs_x.size() > inputs_x_minimum_size) {
constexpr size_t sequence_input_pos = 2;
auto seq_node = inputs_x[sequence_input_pos];
auto prim = GetCNodePrimitiveWithoutDoSignature(seq_node);
if (prim != nullptr &&
(IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList))) {
auto seq_cnode = dyn_cast<CNode>(seq_node);
MS_EXCEPTION_IF_NULL(seq_cnode);
seq_cnode->func_graph()->EraseUnusedNodeInOrder(seq_cnode);
}
}
}
inputs_x[1] = GenerateUnpackGraphNode(node, inputs_y, func_node, is_unpack, sens_param);
// construct new meta_fg_opration
// Construct new meta_fg_opration
auto meta_fg_op_cnode = func_graph->NewCNodeBefore(node, inputs_x);
if (unpack_op_->Match(inputs_y[0])) {
inputs_y[1] = meta_fg_op_cnode;

View File

@ -33,9 +33,6 @@
namespace mindspore {
namespace opt {
namespace irpass {
// get metagraph of value node
MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr &node);
class Matcher {
public:
Matcher() {}
@ -59,7 +56,11 @@ class MetaFgMatcher : public Matcher {
if (node == nullptr) {
return false;
}
auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node);
auto value = GetValueWithoutDoSignature(node);
if (value == nullptr) {
return false;
}
auto meta_func_graph_ptr = value->cast<MetaFuncGraphPtr>();
if (meta_func_graph_ptr == nullptr) {
return false;
}

View File

@ -242,13 +242,15 @@ bool ResolveObjectToNode(const AnfNodePtr &origin_node, const py::object &obj, A
}
MS_LOG(DEBUG) << "Add param graph:" << func_graph->ToString() << ", " << param->DebugString();
output = param;
*node = output;
return true;
} else if (py::hasattr(obj, "__parameter_tuple__")) {
auto tuple = obj.cast<py::tuple>();
std::vector<AnfNodePtr> args;
args.push_back(NewValueNode(prim::kPrimMakeTuple));
for (size_t it = 0; it < tuple.size(); ++it) {
for (size_t i = 0; i < tuple.size(); ++i) {
AnfNodePtr out = nullptr;
bool success = ResolveObjectToNode(origin_node, tuple[it], &out);
bool success = ResolveObjectToNode(origin_node, tuple[i], &out);
if (!success) {
MS_LOG(ERROR) << "Resolve object to node failed";
return false;
@ -258,27 +260,54 @@ bool ResolveObjectToNode(const AnfNodePtr &origin_node, const py::object &obj, A
// The ParameterTuple will not be added in order list,
// since we don't want to deal with its RefTensor elements during auto_monad procedure.
output = NewCNode(std::move(args), func_graph);
} else {
ValuePtr convert_result = nullptr;
// When the cell is set recomputed, it should not use old scope from cache.
auto scope = origin_node->scope();
bool has_recompute_scope = (scope == nullptr) ? false : scope->name().find(kAttrRecompute) == 0;
bool converted =
ConvertData(obj, &convert_result, python_adapter::UseSignatureInResolve(), nullptr, has_recompute_scope);
if (!converted) {
MS_LOG(ERROR) << "Convert data failed";
return false;
*node = output;
return true;
} else if ((py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) && py::len(obj) != 0) {
auto tuple = obj.cast<py::tuple>();
std::vector<AnfNodePtr> args;
args.push_back(NewValueNode(prim::kPrimMakeTuple));
bool all_parameter_sequence = true;
for (size_t i = 0; i < tuple.size(); ++i) {
if (!py::hasattr(tuple[i], "__parameter__") || !py::isinstance<tensor::MetaTensor>(tuple[i])) {
all_parameter_sequence = false;
break;
}
AnfNodePtr out = nullptr;
bool success = ResolveObjectToNode(origin_node, tuple[i], &out);
if (!success) {
MS_LOG(ERROR) << "Resolve object to node failed";
return false;
}
args.push_back(out);
}
MS_EXCEPTION_IF_NULL(convert_result);
if (convert_result->isa<FuncGraph>() && has_recompute_scope) {
UpdateDebugInfo(convert_result->cast<FuncGraphPtr>(), origin_node->scope(), origin_node->debug_info());
}
ConvertLoadedGraph(func_graph, convert_result);
output = NewValueNode(convert_result);
if (convert_result->isa<tensor::Tensor>()) {
output = GetMixedPrecisionCastHelp(func_graph, output);
if (all_parameter_sequence) {
// The Parameter tuple/list will not be added in order list,
// since we don't want to deal with its RefTensor elements during auto_monad procedure.
output = NewCNode(std::move(args), func_graph);
*node = output;
return true;
}
}
// When the cell is set recomputed, it should not use old scope from cache.
auto scope = origin_node->scope();
bool has_recompute_scope = (scope == nullptr) ? false : scope->name().find(kAttrRecompute) == 0;
ValuePtr convert_result = nullptr;
bool converted =
ConvertData(obj, &convert_result, python_adapter::UseSignatureInResolve(), nullptr, has_recompute_scope);
if (!converted) {
MS_LOG(ERROR) << "Convert data failed";
return false;
}
MS_EXCEPTION_IF_NULL(convert_result);
if (convert_result->isa<FuncGraph>() && has_recompute_scope) {
UpdateDebugInfo(convert_result->cast<FuncGraphPtr>(), origin_node->scope(), origin_node->debug_info());
}
ConvertLoadedGraph(func_graph, convert_result);
output = NewValueNode(convert_result);
if (convert_result->isa<tensor::Tensor>()) {
output = GetMixedPrecisionCastHelp(func_graph, output);
}
*node = output;
return true;
}

View File

@ -139,25 +139,6 @@ bool IsKeepRef(const PrimitivePtr &prim) {
IsPrimitiveEquals(prim, prim::kPrimPull);
}
// Gets primitive if the node is a primitive value node.
PrimitivePtr GetPrimitive(const AnfNodePtr &node) {
PrimitivePtr prim = GetValueNode<PrimitivePtr>(node);
auto do_sig = dyn_cast<mindspore::prim::DoSignaturePrimitive>(prim);
if (do_sig) {
auto val = do_sig->function();
return dyn_cast<Primitive>(val);
}
return prim;
}
// Gets primitive from the given cnode, return nullptr if cnode.inputs[0] is not a primitive.
PrimitivePtr GetPrimitive(const CNodePtr &cnode) {
if (cnode == nullptr || cnode->inputs().empty()) {
return nullptr;
}
return GetPrimitive(cnode->input(0));
}
// Gets func_graph from the given cnode, return nullptr if it is not a func graph call.
FuncGraphPtr GetFuncGraph(const CNodePtr &cnode) {
if (cnode != nullptr && !cnode->inputs().empty()) {
@ -620,7 +601,7 @@ class SideEffectFinder {
EffectInfo TraceTupleCNodeEffectInfo(const CNodePtr &cnode, std::stack<int64_t> *tuple_indexes) {
MS_EXCEPTION_IF_NULL(tuple_indexes);
MS_EXCEPTION_IF_NULL(cnode);
auto prim = GetPrimitive(cnode);
auto prim = GetCNodePrimitiveWithoutDoSignature(cnode);
// Trace MakeTuple.
if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) {
if (tuple_indexes->empty()) {
@ -724,7 +705,7 @@ class SideEffectFinder {
// Trace a cnode for effect info.
EffectInfo TraceEffectInfo(const CNodePtr &cnode) {
MS_EXCEPTION_IF_NULL(cnode);
auto prim = GetPrimitive(cnode);
auto prim = GetCNodePrimitiveWithoutDoSignature(cnode);
if (IsPrimitiveEquals(prim, prim::kPrimSwitch)) {
// Special handling for Switch primitive.
return TraceSwitchEffectInfo(cnode);
@ -778,7 +759,7 @@ class SideEffectFinder {
return {EffectInfo::kDetected, false, false, false};
}
// Trace an ANFNode for effect info.
// Trace an AnfNode for effect info.
EffectInfo TraceEffectInfo(const AnfNodePtr &node) {
if (node != nullptr) {
// Trace cnode.
@ -794,7 +775,7 @@ class SideEffectFinder {
}
// Trace primitive.
auto prim = GetPrimitive(node);
auto prim = GetPrimitiveWithoutDoSignature(node);
if (prim != nullptr) {
return GetPrimEffectInfo(prim);
}
@ -886,7 +867,7 @@ class SideEffectFinder {
// Detect effect info by depth first search.
EffectInfo DetectEffectInfo(const CNodePtr &cnode) {
// For primitive, get effect info from its attributes and inputs.
auto prim = GetPrimitive(cnode);
auto prim = GetCNodePrimitiveWithoutDoSignature(cnode);
if (prim != nullptr) {
// Skip 'return' cnode.
if (IsPrimitiveEquals(prim, prim::kPrimReturn)) {

View File

@ -58,17 +58,6 @@ void EvalFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &,
}
}
}
PrimitivePtr GetRealPrimitive(const AnfNodePtr &node) {
const auto &primitive = GetCNodePrimitive(node);
if (primitive != nullptr) {
auto do_signature_prim = dyn_cast<prim::DoSignaturePrimitive>(primitive);
if (do_signature_prim != nullptr) {
return dyn_cast<Primitive>(do_signature_prim->function());
}
}
return primitive;
}
} // namespace
bool CheckIfAlwaysEval(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) {
@ -92,7 +81,7 @@ bool CheckIfAlwaysEval(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg)
void BaseFuncGraphEvaluator::CollectSideEffectNodes(const AnfNodePtr &node,
std::vector<AnfNodePtr> *side_effect_nodes) {
const auto &primitive = GetRealPrimitive(node);
const auto &primitive = GetCNodePrimitiveWithoutDoSignature(node);
if (primitive != nullptr) {
auto effect_info = GetPrimEffectInfo(primitive);
if (effect_info.memory || effect_info.io) {

View File

@ -201,21 +201,10 @@ class OrderEnforcer {
return IsPrimitiveCNode(node, prim::kPrimExpandDims) || IsPrimitiveCNode(node, prim::kPrimBatchNormGrad);
}
// Gets primitive if the node is a primitive value node.
PrimitivePtr GetPrimitive(const AnfNodePtr &node) {
PrimitivePtr prim = GetValueNode<PrimitivePtr>(node);
auto do_sig = dyn_cast<mindspore::prim::DoSignaturePrimitive>(prim);
if (do_sig) {
auto val = do_sig->function();
return dyn_cast<Primitive>(val);
}
return prim;
}
bool IsSpecialParallelPrimitive(const AnfNodePtr &node) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto prim = GetPrimitive(cnode->input(0));
auto prim = GetCNodePrimitiveWithoutDoSignature(cnode);
if (prim == nullptr) {
return false;
}

View File

@ -218,11 +218,9 @@ bool IrExportBuilder::BuildPrimitives() {
prim_proto->set_name(it->second);
prim_proto->set_op_type(prim->name());
if (prim->isa<prim::DoSignaturePrimitive>()) {
auto func = prim->cast<prim::DoSignaturePrimitivePtr>()->function();
if (func != nullptr && func->isa<Primitive>()) {
prim = func->cast<PrimitivePtr>();
}
auto real_prim = GetValueWithoutDoSignature(prim)->cast<PrimitivePtr>();
if (real_prim != nullptr) {
prim = real_prim;
}
// Set primitive attributes

View File

@ -249,6 +249,58 @@ PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) {
return nullptr;
}
// Return the function Primitive if DoSignaturePrimitive,
// otherwise return the Primitive directly.
PrimitivePtr GetPrimitiveWithoutDoSignature(const AnfNodePtr &node) {
const auto &do_signature_prim = GetValueNode<prim::DoSignaturePrimitivePtr>(node);
if (do_signature_prim != nullptr) {
return dyn_cast<Primitive>(do_signature_prim->function());
}
return GetValueNode<PrimitivePtr>(node);
}
// Check the first input of CNode.
// Return the function Primitive if DoSignaturePrimitive,
// otherwise return the Primitive directly.
PrimitivePtr GetCNodePrimitiveWithoutDoSignature(const AnfNodePtr &node) {
auto cnode = dyn_cast<CNode>(node);
if (cnode == nullptr || cnode->size() == 0) {
return nullptr;
}
return GetPrimitiveWithoutDoSignature(cnode->input(0));
}
// Return the function value if DoSignaturePrimitive,
// otherwise return the value directly.
ValuePtr GetValueWithoutDoSignature(const ValuePtr &value) {
auto do_signature_prim = dyn_cast<prim::DoSignaturePrimitive>(value);
if (do_signature_prim != nullptr) {
return do_signature_prim->function();
}
return value;
}
// Return the function value if DoSignaturePrimitive,
// otherwise return the value directly.
ValuePtr GetValueWithoutDoSignature(const AnfNodePtr &node) {
auto value = GetValueNode(node);
if (value == nullptr) {
return nullptr;
}
return GetValueWithoutDoSignature(value);
}
// Check the first input of CNode.
// Return the function value if DoSignaturePrimitive,
// otherwise return the value directly.
ValuePtr GetCNodeValueWithoutDoSignature(const AnfNodePtr &node) {
auto cnode = dyn_cast<CNode>(node);
if (cnode == nullptr || cnode->size() == 0) {
return nullptr;
}
return GetValueWithoutDoSignature(cnode->input(0));
}
std::string GetCNodeFuncName(const CNodePtr cnode) {
if (cnode->inputs().empty()) {
return "";

View File

@ -1071,15 +1071,34 @@ static S GetValue(const ValuePtr &value) {
MS_CORE_API std::string GetCNodeFuncName(CNodePtr cnode);
// used to get FuncGraphPtr from a cnode first input
// Used to get FuncGraphPtr from a cnode first input
MS_CORE_API FuncGraphPtr GetCNodeFuncGraph(const AnfNodePtr &node);
// used to check whether an AnfNode is a cnode with a kind of Primitive as first input
// Used to check whether an AnfNode is a cnode with a kind of Primitive as first input
MS_CORE_API bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value = nullptr);
// used to get PrimitivePtr from a cnode first input
// Used to get PrimitivePtr from a cnode first input
MS_CORE_API PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node);
// Return the function Primitive if DoSignaturePrimitive,
// otherwise return the Primitive directly.
MS_CORE_API PrimitivePtr GetPrimitiveWithoutDoSignature(const AnfNodePtr &node);
// Check the first input of CNode.
// Return the function Primitive if DoSignaturePrimitive,
// otherwise return the Primitive directly.
MS_CORE_API PrimitivePtr GetCNodePrimitiveWithoutDoSignature(const AnfNodePtr &node);
// Return the function value if DoSignaturePrimitive,
// otherwise return the value directly.
MS_CORE_API ValuePtr GetValueWithoutDoSignature(const ValuePtr &value);
// Return the function value if DoSignaturePrimitive,
// otherwise return the value directly.
MS_CORE_API ValuePtr GetValueWithoutDoSignature(const AnfNodePtr &node);
// Check the first input of CNode.
// Return the function value if DoSignaturePrimitive,
// otherwise return the value directly.
MS_CORE_API ValuePtr GetCNodeValueWithoutDoSignature(const AnfNodePtr &node);
/// \brief Used to check whether the given node is a ValueNode with some Primitive value.
///
/// \param[in] node The input node.

View File

@ -373,15 +373,8 @@ bool MSANFModelParser::SetNodeAbstractFromAttrProto(const mind_ir::AttributeProt
}
void MSANFModelParser::SetCNodePrimAttrAndAbstract(const mind_ir::NodeProto &node_proto, const CNodePtr &cnode_ptr) {
auto prim = GetCNodePrimitive(cnode_ptr);
auto prim_to_add_attr = prim;
auto prim_to_add_attr = GetCNodePrimitiveWithoutDoSignature(cnode_ptr);
if (prim_to_add_attr != nullptr) {
if (prim->isa<prim::DoSignaturePrimitive>()) {
auto func = prim->cast<prim::DoSignaturePrimitivePtr>()->function();
if (func != nullptr && func->isa<Primitive>()) {
prim_to_add_attr = func->cast<PrimitivePtr>();
}
}
prim_to_add_attr->set_attr("is_load", MakeValue(true));
}
for (int i = 0; i < node_proto.attribute_size(); ++i) {
@ -393,7 +386,8 @@ void MSANFModelParser::SetCNodePrimAttrAndAbstract(const mind_ir::NodeProto &nod
continue;
}
if (prim_to_add_attr != nullptr && !GetAttrValueForCNode(prim_to_add_attr, attr_proto)) {
MS_LOG(ERROR) << "Parse prim: " << prim->ToString() << " attributes error: " << attr_proto.DebugString();
MS_LOG(ERROR) << "Parse prim: " << prim_to_add_attr->ToString()
<< ", attributes error: " << attr_proto.DebugString();
}
} else {
// ref_attr_name is removed in newer versions.
@ -402,7 +396,8 @@ void MSANFModelParser::SetCNodePrimAttrAndAbstract(const mind_ir::NodeProto &nod
continue;
}
if (prim_to_add_attr != nullptr && !SetPrimitiveAttrWithType(prim_to_add_attr, attr_proto)) {
MS_LOG(ERROR) << "Parse prim: " << prim->ToString() << " attributes error: " << attr_proto.DebugString();
MS_LOG(ERROR) << "Parse prim: " << prim_to_add_attr->ToString()
<< ", attributes error: " << attr_proto.DebugString();
}
}
}
@ -1690,17 +1685,11 @@ bool MSANFModelParser::BuildPrimitiveNode(const mind_ir::PrimitiveProto &primiti
prim->set_instance_name(prim_type);
}
}
prim->set_attr("is_load", MakeValue(true));
// Set primitive attributes
auto prim_to_add_attr = prim;
if (prim->isa<prim::DoSignaturePrimitive>()) {
auto func = prim->cast<prim::DoSignaturePrimitivePtr>()->function();
if (func != nullptr && func->isa<Primitive>()) {
prim_to_add_attr = func->cast<PrimitivePtr>();
}
prim_to_add_attr->set_attr("is_load", MakeValue(true));
}
auto prim_to_add_attr = GetValueWithoutDoSignature(prim)->cast<PrimitivePtr>();
MS_EXCEPTION_IF_NULL(prim_to_add_attr);
prim_to_add_attr->set_attr("is_load", MakeValue(true));
for (int i = 0; i < primitive_proto.attribute_size(); ++i) {
const mind_ir::AttributeProto &attr_proto = primitive_proto.attribute(i);
if (!SetPrimitiveAttrWithType(prim_to_add_attr, attr_proto)) {

View File

@ -264,10 +264,10 @@ class ForwardValueAndGrad(Cell):
if not isinstance(get_by_list, bool):
raise TypeError(f"For 'ForwardValueAndGrad', "
f"the type of 'get_by_list' must be bool, but got '{type(get_by_list)}'")
if get_by_list and not isinstance(weights, ParameterTuple):
if get_by_list and not isinstance(weights, (ParameterTuple, tuple, list)):
raise TypeError(f"For 'ForwardValueAndGrad', "
f"when 'get_by_list' is set to True, the argument 'weights' must be "
f"ParameterTuple type, but got '{type(weights)}'")
f"Parameters array, but got '{type(weights)}'")
self.network = network
if isinstance(network, Cell):
self.network.set_grad()

View File

@ -168,6 +168,8 @@ class GradOperation(GradOperation_):
4. Call the gradient function with input function's inputs
to get the gradients with respect to all inputs and given parameters: `gradient_function(x, y)`.
To be noticed, for above gradient functions, the returned gradient result may vary for grad result element number:
A single value if only one result, a tuple for multiple results, or a empty tuple for no result.
We can configure the sensitivity(gradient with respect to output) by setting `sens_param` as True and
passing an extra sensitivity input to the gradient function, the sensitivity input should has the
@ -188,11 +190,11 @@ class GradOperation(GradOperation_):
Args:
get_all (bool): If True, get all the gradients with respect to inputs. Default: False.
get_by_list (bool): If True, get all the gradients with respect to Parameter variables.
get_by_list (bool): If True, get all the gradients with respect to Parameter free variables.
If get_all and get_by_list are both False, get the gradient with respect to first input.
If get_all and get_by_list are both True, get the gradients with respect to inputs and Parameter variables
at the same time in the form of ((gradients with respect to inputs),
(gradients with respect to parameters)). Default: False.
If get_all and get_by_list are both True, get the gradients with respect to inputs and
Parameter free variables at the same time in the form of ("gradients with respect to inputs",
"gradients with respect to parameter free variables"). Default: False.
sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input.
If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically.
Default: False.

View File

@ -232,3 +232,67 @@ def test_train_lenet_with_new_interface(num_classes=10, epoch=20, batch_size=32)
losses.append(loss)
assert losses[-1].asnumpy() < 0.01
assert losses[-1].asnumpy() > 0.001
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_train_lenet_with_new_interface_tuple(num_classes=10, epoch=20, batch_size=32):
"""
Feature: GradOperation get_by_list pass tuple/list
Description: Grad with Parameters as input type and fv. list or tuple as fv of grad.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
network = LeNet5(num_classes)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_with_criterion = WithLossCell(network, criterion)
net_with_criterion.set_train()
weights = tuple(network.trainable_params())
optimizer = nn.Momentum(weights, 0.1, 0.9)
train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True)
losses = []
for i in range(0, epoch):
data = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([batch_size]).astype(np.int32))
sens = Tensor(np.ones([1]).astype(np.float32))
loss, grads = train_network(data, label, sens)
grads = F.identity(grads)
optimizer(grads)
losses.append(loss)
assert losses[-1].asnumpy() < 0.01
assert losses[-1].asnumpy() > 0.001
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_train_lenet_with_new_interface_list(num_classes=10, epoch=20, batch_size=32):
"""
Feature: GradOperation get_by_list pass tuple/list
Description: Grad with Parameters as input type and fv. list or tuple as fv of grad.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
network = LeNet5(num_classes)
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
net_with_criterion = WithLossCell(network, criterion)
net_with_criterion.set_train()
weights = list(network.trainable_params())
optimizer = nn.Momentum(weights, 0.1, 0.9)
train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True)
losses = []
for i in range(0, epoch):
data = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([batch_size]).astype(np.int32))
sens = Tensor(np.ones([1]).astype(np.float32))
loss, grads = train_network(data, label, sens)
grads = F.identity(grads)
optimizer(grads)
losses.append(loss)
assert losses[-1].asnumpy() < 0.01
assert losses[-1].asnumpy() > 0.001

View File

@ -177,3 +177,69 @@ def test_ascend_lenet2():
loss_output = test_ascend_lenet()
assert loss_output.asnumpy() < 0.004
assert loss_output.asnumpy() > 0.003
class GradWrapTuple(nn.Cell):
"""
GradWrapTuple definition
"""
def __init__(self, network):
super(GradWrapTuple, self).__init__()
self.network = network
self.weights = tuple(filter(lambda x: x.requires_grad, network.get_parameters()))
def construct(self, x, label):
weights = self.weights
return grad_by_list(self.network, weights)(x, label)
def test_ascend_lenet_grad_by_list_tuple():
"""
Feature: GradOperation get_by_list pass tuple/list
Description: Grad with Parameters as input type and fv. list or tuple as fv of grad.
Expectation: No exception.
"""
epoch_size = 20
batch_size = 32
inputs = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32))
labels = Tensor(np.ones([batch_size]).astype(np.int32))
net = LeNet()
criterion = CrossEntropyLoss()
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9)
net_with_criterion = WithLossCell(net, criterion)
train_network = GradWrapTuple(net_with_criterion)
train_network.set_train()
total_time = 0
for epoch in range(0, epoch_size):
start_time = time.time()
fw_output = net(inputs)
loss_output = criterion(fw_output, labels)
grads = train_network(inputs, labels)
optimizer(grads)
end_time = time.time()
cost_time = end_time - start_time
total_time = total_time + cost_time
print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time)
return loss_output
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_ascend_lenet_grad_by_list_tuple1():
"""
Feature: GradOperation get_by_list pass tuple/list
Description: Grad with Parameters as input type and fv. list or tuple as fv of grad.
Expectation: No exception.
"""
os.environ['GRAPH_OP_RUN'] = str(1)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
loss_output = test_ascend_lenet_grad_by_list_tuple()
assert loss_output.asnumpy() < 0.004
assert loss_output.asnumpy() > 0.003

View File

@ -53,16 +53,16 @@ def test_isinstance():
is_tuple_var = isinstance((x, 1, 1.0, y), tuple)
is_list_const = isinstance(self.list_member, list)
is_list_var = isinstance([x, 1, 1.0, y], list)
is_empty_list = isinstance(self.empty_list, list)
is_dict_const = isinstance(self.dict_member, dict)
is_dict_var = isinstance({"x": x, "y": y}, dict)
is_empty_dic = isinstance(self.empty_dict, dict)
is_list_or_tensor = isinstance([x, y], (Tensor, list))
is_int_or_float_or_tensor_or_tuple = isinstance(x, (Tensor, tuple, int, float))
is_list_or_tensor = isinstance([x, y], (Tensor, list))
float_is_int = isinstance(self.float_member, int)
bool_is_string = isinstance(self.bool_member, str)
tensor_is_tuple = isinstance(x, tuple)
tuple_is_list = isinstance(self.tuple_member, list)
is_empty_list = isinstance(self.empty_list, list)
return is_int, is_float, is_bool, bool_is_int, is_string, is_parameter, \
parameter_is_tensor, is_tensor_const, is_tensor_var, \
is_tuple_const, is_tuple_var, is_list_const, is_list_var, is_empty_list, \

View File

@ -229,7 +229,6 @@ def test_grad_parameter_tuple(mode):
GradCellWithParameterTuple(TestCell2(x1, x2))(z)
@pytest.mark.skip(reason='Not support list or tuple of parameters as GradOperation inputs by now')
@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE, context.GRAPH_MODE])
def test_grad_parameter_list_or_tuple(mode):
"""