forked from mindspore-Ecosystem/mindspore
Fix GradOperation descripition and support tuple/list for 'get_by_list'.
This commit is contained in:
parent
057ed3b773
commit
ce39353421
|
@ -13,7 +13,9 @@ mindspore/mindspore/ccsrc/frontend/optimizer/irpass.cc:mindspore::opt::irpass::O
|
|||
mindspore/mindspore/ccsrc/frontend/parallel/ops_info/gather_v2_p_info.cc:mindspore::parallel::GatherV2PInfo::CheckStrategy
|
||||
mindspore/mindspore/ccsrc/plugin/device/gpu/hal/device/gpu_kernel_runtime.cc:mindspore::device::gpu::GPUKernelRuntime::LaunchKernelDynamic
|
||||
mindspore/mindspore/ccsrc/pipeline/jit/init.cc:PYBIND11_MODULE
|
||||
mindspore/mindspore/ccsrc/pipeline/jit/parse/resolve.cc:mindspore::parse::ResolveObjectToNode
|
||||
mindspore/mindspore/ccsrc/pipeline/jit/pipeline.cc:mindspore::pipeline::GraphExecutorPy::Compile
|
||||
mindspore/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc:mindspore::abstract::ConvertAbstractToPython
|
||||
mindspore/mindspore/ccsrc/pybind_api/ir/log_adapter_py.cc:mindspore::PyExceptionInitializer::HandleExceptionPy
|
||||
mindspore/mindspore/ccsrc/plugin/device/gpu/kernel/math/unary_op_gpu_kernel.h:mindspore::kernel::UnaryOpGpuKernel::Launch
|
||||
mindspore/mindspore/ccsrc/plugin/device/ascend/optimizer/ir_fission/dynamic_rnn_grad_fission_v2.cc:mindspore::opt::AddLSTMInputGradNode
|
||||
|
@ -32,7 +34,6 @@ mindspore/mindspore/lite/src/runtime/ios_reg_ops.cc:mindspore::lite::IosRegister
|
|||
mindspore/mindspore/lite/src/runtime/ios_reg_kernels.h:mindspore::kernel::IosRegisterKernels
|
||||
mindspore/mindspore/lite/src/runtime/kernel/cpu/base/quant_dtype_cast.cc:mindspore::kernel::QuantDTypeCastCPUKernel::QuantDTypeCast
|
||||
mindspore/mindspore/lite/src/runtime/kernel/cpu/base/quant_dtype_cast.cc:mindspore::kernel::QuantDTypeCastCPUKernel::Run
|
||||
mindspore/mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc:mindspore::abstract::ConvertAbstractToPython
|
||||
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/infer/strided_slice_infer.c:StridedSliceInferShape
|
||||
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/winograd_transform_fp16.c:WinogradInputTransformFp16
|
||||
mindspore/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/fp16/pooling_fp16.c:AvgPoolingFp16
|
||||
|
|
|
@ -45,6 +45,11 @@ mindspore.ops.GradOperation
|
|||
|
||||
4. 用 `net` 的输入作为参数调用梯度函数,得到关于所有输入和给定参数的梯度:`gradient_function(x, y)` 。
|
||||
|
||||
注意:对于上面产生的梯度函数,其返回值会因返回梯度的数量不同而出现差异:
|
||||
1. 如果仅有一个梯度结果,返回单独值;
|
||||
2. 如果有多个梯度结果,返回tuple;
|
||||
3. 如果没有任何梯度结果,返回空tuple;
|
||||
|
||||
我们可以设置 `sens_param` 等于True来配置灵敏度(关于输出的梯度),向梯度函数传递一个额外的灵敏度输入值。这个输入值必须与 `net` 的输出具有相同的形状和类型(见样例中的 `GradNetWrtXYWithSensParam` )。
|
||||
|
||||
1. 构建一个带有 `get_all=True` 和 `sens_param=True` 参数的 `GradOperation` 高阶函数:`grad_op = GradOperation(get_all=True, sens_param=True)` 。
|
||||
|
|
|
@ -555,9 +555,9 @@ FuncGraphPtr Tail::GenerateGradFuncGraph(const AbstractTuplePtr &tuple_arg, cons
|
|||
|
||||
if (tail_type_ == kGradFirst) {
|
||||
AnfNodePtr tuple_parameter = fg->add_parameter();
|
||||
PrimitivePtr getitem_op = prim::kPrimTupleGetItem;
|
||||
if (CanGradArgument(tuple_arg, 1) || EnableGradFirstForTuple(tuple_arg, enable_tuple_grad_first_)) {
|
||||
fg->set_output(fg->NewCNode({NewValueNode(getitem_op), tuple_parameter, NewValueNode(SizeToLong(1))}));
|
||||
fg->set_output(
|
||||
fg->NewCNode({NewValueNode(prim::kPrimTupleGetItem), tuple_parameter, NewValueNode(SizeToLong(1))}));
|
||||
} else {
|
||||
fg->set_output(NewValueNode(std::make_shared<ValueTuple>(ValuePtrList())));
|
||||
}
|
||||
|
@ -572,14 +572,22 @@ FuncGraphPtr Tail::GenerateGradFuncGraph(const AbstractTuplePtr &tuple_arg, cons
|
|||
if (tail_type_ == kGradAll) {
|
||||
AnfNodePtr tuple_parameter = fg->add_parameter();
|
||||
std::vector<AnfNodePtr> elements = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
PrimitivePtr op = prim::kPrimTupleGetItem;
|
||||
for (size_t i = 1; i < tuple_arg->size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL((*tuple_arg)[i]);
|
||||
if (CanGradArgument(tuple_arg, i)) {
|
||||
elements.push_back(fg->NewCNodeInOrder({NewValueNode(op), tuple_parameter, NewValueNode(SizeToLong(i))}));
|
||||
elements.push_back(
|
||||
fg->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), tuple_parameter, NewValueNode(SizeToLong(i))}));
|
||||
}
|
||||
}
|
||||
if (elements.size() > 1) {
|
||||
|
||||
// We should deal with 'get_all=True' as other options later:
|
||||
// "The returned result may vary for grad result element number.
|
||||
// A single value if only one result, a tuple for multiple results, or a empty tuple for no result.
|
||||
//
|
||||
// Notice that even if the user set 'get_all=True' and pass multiple inputs,
|
||||
// the 'CanGradArgument' may change it to only one gradient output or no gradient."
|
||||
constexpr size_t args_least_size = 2;
|
||||
if (elements.size() >= args_least_size) {
|
||||
fg->set_output(fg->NewCNodeInOrder(elements));
|
||||
return fg;
|
||||
}
|
||||
|
@ -744,12 +752,8 @@ void CheckPrimBpropReturnSparse(const FuncGraphPtr &primal_graph) {
|
|||
if (has_sparse_bprop_prim) {
|
||||
return EXCLUDE;
|
||||
}
|
||||
auto prim = GetCNodePrimitive(node);
|
||||
auto prim = GetCNodePrimitiveWithoutDoSignature(node);
|
||||
if (prim != nullptr) {
|
||||
auto do_signature = dyn_cast<mindspore::prim::DoSignaturePrimitive>(prim);
|
||||
if (do_signature != nullptr) {
|
||||
prim = dyn_cast<Primitive>(do_signature->function());
|
||||
}
|
||||
bool sparse_bprop = GetPrimitiveFlag(prim, GRAPH_FLAG_BPROP_RETURN_SPARSE);
|
||||
if (sparse_bprop) {
|
||||
MS_LOG(DEBUG) << "prim: " << prim->ToString() << " has attr 'bprop_return_sparse'";
|
||||
|
|
|
@ -67,26 +67,16 @@ static AnfNodePtr GenerateUnpackGraphNode(const AnfNodePtr &origin_node, std::ve
|
|||
return unpack_graph_node;
|
||||
}
|
||||
|
||||
MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr &node) {
|
||||
ValuePtr value;
|
||||
if (IsValueNode<prim::DoSignaturePrimitive>(node)) {
|
||||
value = GetValueNode(node)->cast<prim::DoSignaturePrimitivePtr>()->function();
|
||||
} else {
|
||||
value = GetValueNode(node);
|
||||
}
|
||||
if (value == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return value->cast<MetaFuncGraphPtr>();
|
||||
}
|
||||
|
||||
// check if node is a specific meta_fg_opration that registered in the meta_fg_ops
|
||||
// Check if node is a specific meta_fg_opration that registered in the meta_fg_ops
|
||||
bool CheckMetaFgOps(const AnfNodePtr &node) {
|
||||
if (node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node);
|
||||
auto value = GetValueWithoutDoSignature(node);
|
||||
if (value == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto meta_func_graph_ptr = value->cast<MetaFuncGraphPtr>();
|
||||
if (meta_func_graph_ptr == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
@ -143,16 +133,29 @@ AnfNodePtr MetaFgVarPrepare::operator()(const OptimizerPtr &, const AnfNodePtr &
|
|||
// For general meta_fg_opration, ‘sens_param’ is not involved, and that of GradOperation obtained specifically.
|
||||
bool sens_param = false;
|
||||
if (grad_op_->Match(inputs_x[0])) {
|
||||
auto meta_func = GetMetaFuncGraphOfValueNode(inputs_x[0]);
|
||||
if (meta_func == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
auto value = GetValueWithoutDoSignature(inputs_x[0]);
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
auto meta_func = value->cast<MetaFuncGraphPtr>();
|
||||
MS_EXCEPTION_IF_NULL(meta_func);
|
||||
auto grad_op_ptr = meta_func->cast<prim::GradOperationPtr>();
|
||||
sens_param = grad_op_ptr->sens_param();
|
||||
|
||||
// Remove the tuple/list inputs from order list for grad(UnpackGraph(..), list/tuple)(..)
|
||||
if (inputs_x.size() > inputs_x_minimum_size) {
|
||||
constexpr size_t sequence_input_pos = 2;
|
||||
auto seq_node = inputs_x[sequence_input_pos];
|
||||
auto prim = GetCNodePrimitiveWithoutDoSignature(seq_node);
|
||||
if (prim != nullptr &&
|
||||
(IsPrimitiveEquals(prim, prim::kPrimMakeTuple) || IsPrimitiveEquals(prim, prim::kPrimMakeList))) {
|
||||
auto seq_cnode = dyn_cast<CNode>(seq_node);
|
||||
MS_EXCEPTION_IF_NULL(seq_cnode);
|
||||
seq_cnode->func_graph()->EraseUnusedNodeInOrder(seq_cnode);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inputs_x[1] = GenerateUnpackGraphNode(node, inputs_y, func_node, is_unpack, sens_param);
|
||||
// construct new meta_fg_opration
|
||||
// Construct new meta_fg_opration
|
||||
auto meta_fg_op_cnode = func_graph->NewCNodeBefore(node, inputs_x);
|
||||
if (unpack_op_->Match(inputs_y[0])) {
|
||||
inputs_y[1] = meta_fg_op_cnode;
|
||||
|
|
|
@ -33,9 +33,6 @@
|
|||
namespace mindspore {
|
||||
namespace opt {
|
||||
namespace irpass {
|
||||
// get metagraph of value node
|
||||
MetaFuncGraphPtr GetMetaFuncGraphOfValueNode(const AnfNodePtr &node);
|
||||
|
||||
class Matcher {
|
||||
public:
|
||||
Matcher() {}
|
||||
|
@ -59,7 +56,11 @@ class MetaFgMatcher : public Matcher {
|
|||
if (node == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto meta_func_graph_ptr = GetMetaFuncGraphOfValueNode(node);
|
||||
auto value = GetValueWithoutDoSignature(node);
|
||||
if (value == nullptr) {
|
||||
return false;
|
||||
}
|
||||
auto meta_func_graph_ptr = value->cast<MetaFuncGraphPtr>();
|
||||
if (meta_func_graph_ptr == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -242,13 +242,15 @@ bool ResolveObjectToNode(const AnfNodePtr &origin_node, const py::object &obj, A
|
|||
}
|
||||
MS_LOG(DEBUG) << "Add param graph:" << func_graph->ToString() << ", " << param->DebugString();
|
||||
output = param;
|
||||
*node = output;
|
||||
return true;
|
||||
} else if (py::hasattr(obj, "__parameter_tuple__")) {
|
||||
auto tuple = obj.cast<py::tuple>();
|
||||
std::vector<AnfNodePtr> args;
|
||||
args.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
for (size_t it = 0; it < tuple.size(); ++it) {
|
||||
for (size_t i = 0; i < tuple.size(); ++i) {
|
||||
AnfNodePtr out = nullptr;
|
||||
bool success = ResolveObjectToNode(origin_node, tuple[it], &out);
|
||||
bool success = ResolveObjectToNode(origin_node, tuple[i], &out);
|
||||
if (!success) {
|
||||
MS_LOG(ERROR) << "Resolve object to node failed";
|
||||
return false;
|
||||
|
@ -258,27 +260,54 @@ bool ResolveObjectToNode(const AnfNodePtr &origin_node, const py::object &obj, A
|
|||
// The ParameterTuple will not be added in order list,
|
||||
// since we don't want to deal with its RefTensor elements during auto_monad procedure.
|
||||
output = NewCNode(std::move(args), func_graph);
|
||||
} else {
|
||||
ValuePtr convert_result = nullptr;
|
||||
// When the cell is set recomputed, it should not use old scope from cache.
|
||||
auto scope = origin_node->scope();
|
||||
bool has_recompute_scope = (scope == nullptr) ? false : scope->name().find(kAttrRecompute) == 0;
|
||||
bool converted =
|
||||
ConvertData(obj, &convert_result, python_adapter::UseSignatureInResolve(), nullptr, has_recompute_scope);
|
||||
if (!converted) {
|
||||
MS_LOG(ERROR) << "Convert data failed";
|
||||
return false;
|
||||
*node = output;
|
||||
return true;
|
||||
} else if ((py::isinstance<py::tuple>(obj) || py::isinstance<py::list>(obj)) && py::len(obj) != 0) {
|
||||
auto tuple = obj.cast<py::tuple>();
|
||||
std::vector<AnfNodePtr> args;
|
||||
args.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
bool all_parameter_sequence = true;
|
||||
for (size_t i = 0; i < tuple.size(); ++i) {
|
||||
if (!py::hasattr(tuple[i], "__parameter__") || !py::isinstance<tensor::MetaTensor>(tuple[i])) {
|
||||
all_parameter_sequence = false;
|
||||
break;
|
||||
}
|
||||
AnfNodePtr out = nullptr;
|
||||
bool success = ResolveObjectToNode(origin_node, tuple[i], &out);
|
||||
if (!success) {
|
||||
MS_LOG(ERROR) << "Resolve object to node failed";
|
||||
return false;
|
||||
}
|
||||
args.push_back(out);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(convert_result);
|
||||
if (convert_result->isa<FuncGraph>() && has_recompute_scope) {
|
||||
UpdateDebugInfo(convert_result->cast<FuncGraphPtr>(), origin_node->scope(), origin_node->debug_info());
|
||||
}
|
||||
ConvertLoadedGraph(func_graph, convert_result);
|
||||
output = NewValueNode(convert_result);
|
||||
if (convert_result->isa<tensor::Tensor>()) {
|
||||
output = GetMixedPrecisionCastHelp(func_graph, output);
|
||||
if (all_parameter_sequence) {
|
||||
// The Parameter tuple/list will not be added in order list,
|
||||
// since we don't want to deal with its RefTensor elements during auto_monad procedure.
|
||||
output = NewCNode(std::move(args), func_graph);
|
||||
*node = output;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
// When the cell is set recomputed, it should not use old scope from cache.
|
||||
auto scope = origin_node->scope();
|
||||
bool has_recompute_scope = (scope == nullptr) ? false : scope->name().find(kAttrRecompute) == 0;
|
||||
ValuePtr convert_result = nullptr;
|
||||
bool converted =
|
||||
ConvertData(obj, &convert_result, python_adapter::UseSignatureInResolve(), nullptr, has_recompute_scope);
|
||||
if (!converted) {
|
||||
MS_LOG(ERROR) << "Convert data failed";
|
||||
return false;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(convert_result);
|
||||
if (convert_result->isa<FuncGraph>() && has_recompute_scope) {
|
||||
UpdateDebugInfo(convert_result->cast<FuncGraphPtr>(), origin_node->scope(), origin_node->debug_info());
|
||||
}
|
||||
ConvertLoadedGraph(func_graph, convert_result);
|
||||
output = NewValueNode(convert_result);
|
||||
if (convert_result->isa<tensor::Tensor>()) {
|
||||
output = GetMixedPrecisionCastHelp(func_graph, output);
|
||||
}
|
||||
*node = output;
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -139,25 +139,6 @@ bool IsKeepRef(const PrimitivePtr &prim) {
|
|||
IsPrimitiveEquals(prim, prim::kPrimPull);
|
||||
}
|
||||
|
||||
// Gets primitive if the node is a primitive value node.
|
||||
PrimitivePtr GetPrimitive(const AnfNodePtr &node) {
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(node);
|
||||
auto do_sig = dyn_cast<mindspore::prim::DoSignaturePrimitive>(prim);
|
||||
if (do_sig) {
|
||||
auto val = do_sig->function();
|
||||
return dyn_cast<Primitive>(val);
|
||||
}
|
||||
return prim;
|
||||
}
|
||||
|
||||
// Gets primitive from the given cnode, return nullptr if cnode.inputs[0] is not a primitive.
|
||||
PrimitivePtr GetPrimitive(const CNodePtr &cnode) {
|
||||
if (cnode == nullptr || cnode->inputs().empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
return GetPrimitive(cnode->input(0));
|
||||
}
|
||||
|
||||
// Gets func_graph from the given cnode, return nullptr if it is not a func graph call.
|
||||
FuncGraphPtr GetFuncGraph(const CNodePtr &cnode) {
|
||||
if (cnode != nullptr && !cnode->inputs().empty()) {
|
||||
|
@ -620,7 +601,7 @@ class SideEffectFinder {
|
|||
EffectInfo TraceTupleCNodeEffectInfo(const CNodePtr &cnode, std::stack<int64_t> *tuple_indexes) {
|
||||
MS_EXCEPTION_IF_NULL(tuple_indexes);
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto prim = GetPrimitive(cnode);
|
||||
auto prim = GetCNodePrimitiveWithoutDoSignature(cnode);
|
||||
// Trace MakeTuple.
|
||||
if (IsPrimitiveEquals(prim, prim::kPrimMakeTuple)) {
|
||||
if (tuple_indexes->empty()) {
|
||||
|
@ -724,7 +705,7 @@ class SideEffectFinder {
|
|||
// Trace a cnode for effect info.
|
||||
EffectInfo TraceEffectInfo(const CNodePtr &cnode) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto prim = GetPrimitive(cnode);
|
||||
auto prim = GetCNodePrimitiveWithoutDoSignature(cnode);
|
||||
if (IsPrimitiveEquals(prim, prim::kPrimSwitch)) {
|
||||
// Special handling for Switch primitive.
|
||||
return TraceSwitchEffectInfo(cnode);
|
||||
|
@ -778,7 +759,7 @@ class SideEffectFinder {
|
|||
return {EffectInfo::kDetected, false, false, false};
|
||||
}
|
||||
|
||||
// Trace an ANFNode for effect info.
|
||||
// Trace an AnfNode for effect info.
|
||||
EffectInfo TraceEffectInfo(const AnfNodePtr &node) {
|
||||
if (node != nullptr) {
|
||||
// Trace cnode.
|
||||
|
@ -794,7 +775,7 @@ class SideEffectFinder {
|
|||
}
|
||||
|
||||
// Trace primitive.
|
||||
auto prim = GetPrimitive(node);
|
||||
auto prim = GetPrimitiveWithoutDoSignature(node);
|
||||
if (prim != nullptr) {
|
||||
return GetPrimEffectInfo(prim);
|
||||
}
|
||||
|
@ -886,7 +867,7 @@ class SideEffectFinder {
|
|||
// Detect effect info by depth first search.
|
||||
EffectInfo DetectEffectInfo(const CNodePtr &cnode) {
|
||||
// For primitive, get effect info from its attributes and inputs.
|
||||
auto prim = GetPrimitive(cnode);
|
||||
auto prim = GetCNodePrimitiveWithoutDoSignature(cnode);
|
||||
if (prim != nullptr) {
|
||||
// Skip 'return' cnode.
|
||||
if (IsPrimitiveEquals(prim, prim::kPrimReturn)) {
|
||||
|
|
|
@ -58,17 +58,6 @@ void EvalFailLogging(const EvaluatorPtr &evaluator, const AbstractBasePtrList &,
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
PrimitivePtr GetRealPrimitive(const AnfNodePtr &node) {
|
||||
const auto &primitive = GetCNodePrimitive(node);
|
||||
if (primitive != nullptr) {
|
||||
auto do_signature_prim = dyn_cast<prim::DoSignaturePrimitive>(primitive);
|
||||
if (do_signature_prim != nullptr) {
|
||||
return dyn_cast<Primitive>(do_signature_prim->function());
|
||||
}
|
||||
}
|
||||
return primitive;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool CheckIfAlwaysEval(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg) {
|
||||
|
@ -92,7 +81,7 @@ bool CheckIfAlwaysEval(const AnfNodeConfigPtr &conf, const AbstractBasePtr &arg)
|
|||
|
||||
void BaseFuncGraphEvaluator::CollectSideEffectNodes(const AnfNodePtr &node,
|
||||
std::vector<AnfNodePtr> *side_effect_nodes) {
|
||||
const auto &primitive = GetRealPrimitive(node);
|
||||
const auto &primitive = GetCNodePrimitiveWithoutDoSignature(node);
|
||||
if (primitive != nullptr) {
|
||||
auto effect_info = GetPrimEffectInfo(primitive);
|
||||
if (effect_info.memory || effect_info.io) {
|
||||
|
|
|
@ -201,21 +201,10 @@ class OrderEnforcer {
|
|||
return IsPrimitiveCNode(node, prim::kPrimExpandDims) || IsPrimitiveCNode(node, prim::kPrimBatchNormGrad);
|
||||
}
|
||||
|
||||
// Gets primitive if the node is a primitive value node.
|
||||
PrimitivePtr GetPrimitive(const AnfNodePtr &node) {
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(node);
|
||||
auto do_sig = dyn_cast<mindspore::prim::DoSignaturePrimitive>(prim);
|
||||
if (do_sig) {
|
||||
auto val = do_sig->function();
|
||||
return dyn_cast<Primitive>(val);
|
||||
}
|
||||
return prim;
|
||||
}
|
||||
|
||||
bool IsSpecialParallelPrimitive(const AnfNodePtr &node) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto prim = GetPrimitive(cnode->input(0));
|
||||
auto prim = GetCNodePrimitiveWithoutDoSignature(cnode);
|
||||
if (prim == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
|
|
@ -218,11 +218,9 @@ bool IrExportBuilder::BuildPrimitives() {
|
|||
prim_proto->set_name(it->second);
|
||||
prim_proto->set_op_type(prim->name());
|
||||
|
||||
if (prim->isa<prim::DoSignaturePrimitive>()) {
|
||||
auto func = prim->cast<prim::DoSignaturePrimitivePtr>()->function();
|
||||
if (func != nullptr && func->isa<Primitive>()) {
|
||||
prim = func->cast<PrimitivePtr>();
|
||||
}
|
||||
auto real_prim = GetValueWithoutDoSignature(prim)->cast<PrimitivePtr>();
|
||||
if (real_prim != nullptr) {
|
||||
prim = real_prim;
|
||||
}
|
||||
|
||||
// Set primitive attributes
|
||||
|
|
|
@ -249,6 +249,58 @@ PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node) {
|
|||
return nullptr;
|
||||
}
|
||||
|
||||
// Return the function Primitive if DoSignaturePrimitive,
|
||||
// otherwise return the Primitive directly.
|
||||
PrimitivePtr GetPrimitiveWithoutDoSignature(const AnfNodePtr &node) {
|
||||
const auto &do_signature_prim = GetValueNode<prim::DoSignaturePrimitivePtr>(node);
|
||||
if (do_signature_prim != nullptr) {
|
||||
return dyn_cast<Primitive>(do_signature_prim->function());
|
||||
}
|
||||
return GetValueNode<PrimitivePtr>(node);
|
||||
}
|
||||
|
||||
// Check the first input of CNode.
|
||||
// Return the function Primitive if DoSignaturePrimitive,
|
||||
// otherwise return the Primitive directly.
|
||||
PrimitivePtr GetCNodePrimitiveWithoutDoSignature(const AnfNodePtr &node) {
|
||||
auto cnode = dyn_cast<CNode>(node);
|
||||
if (cnode == nullptr || cnode->size() == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return GetPrimitiveWithoutDoSignature(cnode->input(0));
|
||||
}
|
||||
|
||||
// Return the function value if DoSignaturePrimitive,
|
||||
// otherwise return the value directly.
|
||||
ValuePtr GetValueWithoutDoSignature(const ValuePtr &value) {
|
||||
auto do_signature_prim = dyn_cast<prim::DoSignaturePrimitive>(value);
|
||||
if (do_signature_prim != nullptr) {
|
||||
return do_signature_prim->function();
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
// Return the function value if DoSignaturePrimitive,
|
||||
// otherwise return the value directly.
|
||||
ValuePtr GetValueWithoutDoSignature(const AnfNodePtr &node) {
|
||||
auto value = GetValueNode(node);
|
||||
if (value == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
return GetValueWithoutDoSignature(value);
|
||||
}
|
||||
|
||||
// Check the first input of CNode.
|
||||
// Return the function value if DoSignaturePrimitive,
|
||||
// otherwise return the value directly.
|
||||
ValuePtr GetCNodeValueWithoutDoSignature(const AnfNodePtr &node) {
|
||||
auto cnode = dyn_cast<CNode>(node);
|
||||
if (cnode == nullptr || cnode->size() == 0) {
|
||||
return nullptr;
|
||||
}
|
||||
return GetValueWithoutDoSignature(cnode->input(0));
|
||||
}
|
||||
|
||||
std::string GetCNodeFuncName(const CNodePtr cnode) {
|
||||
if (cnode->inputs().empty()) {
|
||||
return "";
|
||||
|
|
|
@ -1071,15 +1071,34 @@ static S GetValue(const ValuePtr &value) {
|
|||
|
||||
MS_CORE_API std::string GetCNodeFuncName(CNodePtr cnode);
|
||||
|
||||
// used to get FuncGraphPtr from a cnode first input
|
||||
// Used to get FuncGraphPtr from a cnode first input
|
||||
MS_CORE_API FuncGraphPtr GetCNodeFuncGraph(const AnfNodePtr &node);
|
||||
|
||||
// used to check whether an AnfNode is a cnode with a kind of Primitive as first input
|
||||
// Used to check whether an AnfNode is a cnode with a kind of Primitive as first input
|
||||
MS_CORE_API bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value = nullptr);
|
||||
|
||||
// used to get PrimitivePtr from a cnode first input
|
||||
// Used to get PrimitivePtr from a cnode first input
|
||||
MS_CORE_API PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node);
|
||||
|
||||
// Return the function Primitive if DoSignaturePrimitive,
|
||||
// otherwise return the Primitive directly.
|
||||
MS_CORE_API PrimitivePtr GetPrimitiveWithoutDoSignature(const AnfNodePtr &node);
|
||||
// Check the first input of CNode.
|
||||
// Return the function Primitive if DoSignaturePrimitive,
|
||||
// otherwise return the Primitive directly.
|
||||
MS_CORE_API PrimitivePtr GetCNodePrimitiveWithoutDoSignature(const AnfNodePtr &node);
|
||||
|
||||
// Return the function value if DoSignaturePrimitive,
|
||||
// otherwise return the value directly.
|
||||
MS_CORE_API ValuePtr GetValueWithoutDoSignature(const ValuePtr &value);
|
||||
// Return the function value if DoSignaturePrimitive,
|
||||
// otherwise return the value directly.
|
||||
MS_CORE_API ValuePtr GetValueWithoutDoSignature(const AnfNodePtr &node);
|
||||
// Check the first input of CNode.
|
||||
// Return the function value if DoSignaturePrimitive,
|
||||
// otherwise return the value directly.
|
||||
MS_CORE_API ValuePtr GetCNodeValueWithoutDoSignature(const AnfNodePtr &node);
|
||||
|
||||
/// \brief Used to check whether the given node is a ValueNode with some Primitive value.
|
||||
///
|
||||
/// \param[in] node The input node.
|
||||
|
|
|
@ -373,15 +373,8 @@ bool MSANFModelParser::SetNodeAbstractFromAttrProto(const mind_ir::AttributeProt
|
|||
}
|
||||
|
||||
void MSANFModelParser::SetCNodePrimAttrAndAbstract(const mind_ir::NodeProto &node_proto, const CNodePtr &cnode_ptr) {
|
||||
auto prim = GetCNodePrimitive(cnode_ptr);
|
||||
auto prim_to_add_attr = prim;
|
||||
auto prim_to_add_attr = GetCNodePrimitiveWithoutDoSignature(cnode_ptr);
|
||||
if (prim_to_add_attr != nullptr) {
|
||||
if (prim->isa<prim::DoSignaturePrimitive>()) {
|
||||
auto func = prim->cast<prim::DoSignaturePrimitivePtr>()->function();
|
||||
if (func != nullptr && func->isa<Primitive>()) {
|
||||
prim_to_add_attr = func->cast<PrimitivePtr>();
|
||||
}
|
||||
}
|
||||
prim_to_add_attr->set_attr("is_load", MakeValue(true));
|
||||
}
|
||||
for (int i = 0; i < node_proto.attribute_size(); ++i) {
|
||||
|
@ -393,7 +386,8 @@ void MSANFModelParser::SetCNodePrimAttrAndAbstract(const mind_ir::NodeProto &nod
|
|||
continue;
|
||||
}
|
||||
if (prim_to_add_attr != nullptr && !GetAttrValueForCNode(prim_to_add_attr, attr_proto)) {
|
||||
MS_LOG(ERROR) << "Parse prim: " << prim->ToString() << " attributes error: " << attr_proto.DebugString();
|
||||
MS_LOG(ERROR) << "Parse prim: " << prim_to_add_attr->ToString()
|
||||
<< ", attributes error: " << attr_proto.DebugString();
|
||||
}
|
||||
} else {
|
||||
// ref_attr_name is removed in newer versions.
|
||||
|
@ -402,7 +396,8 @@ void MSANFModelParser::SetCNodePrimAttrAndAbstract(const mind_ir::NodeProto &nod
|
|||
continue;
|
||||
}
|
||||
if (prim_to_add_attr != nullptr && !SetPrimitiveAttrWithType(prim_to_add_attr, attr_proto)) {
|
||||
MS_LOG(ERROR) << "Parse prim: " << prim->ToString() << " attributes error: " << attr_proto.DebugString();
|
||||
MS_LOG(ERROR) << "Parse prim: " << prim_to_add_attr->ToString()
|
||||
<< ", attributes error: " << attr_proto.DebugString();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1690,17 +1685,11 @@ bool MSANFModelParser::BuildPrimitiveNode(const mind_ir::PrimitiveProto &primiti
|
|||
prim->set_instance_name(prim_type);
|
||||
}
|
||||
}
|
||||
prim->set_attr("is_load", MakeValue(true));
|
||||
|
||||
// Set primitive attributes
|
||||
auto prim_to_add_attr = prim;
|
||||
if (prim->isa<prim::DoSignaturePrimitive>()) {
|
||||
auto func = prim->cast<prim::DoSignaturePrimitivePtr>()->function();
|
||||
if (func != nullptr && func->isa<Primitive>()) {
|
||||
prim_to_add_attr = func->cast<PrimitivePtr>();
|
||||
}
|
||||
prim_to_add_attr->set_attr("is_load", MakeValue(true));
|
||||
}
|
||||
auto prim_to_add_attr = GetValueWithoutDoSignature(prim)->cast<PrimitivePtr>();
|
||||
MS_EXCEPTION_IF_NULL(prim_to_add_attr);
|
||||
prim_to_add_attr->set_attr("is_load", MakeValue(true));
|
||||
for (int i = 0; i < primitive_proto.attribute_size(); ++i) {
|
||||
const mind_ir::AttributeProto &attr_proto = primitive_proto.attribute(i);
|
||||
if (!SetPrimitiveAttrWithType(prim_to_add_attr, attr_proto)) {
|
||||
|
|
|
@ -264,10 +264,10 @@ class ForwardValueAndGrad(Cell):
|
|||
if not isinstance(get_by_list, bool):
|
||||
raise TypeError(f"For 'ForwardValueAndGrad', "
|
||||
f"the type of 'get_by_list' must be bool, but got '{type(get_by_list)}'")
|
||||
if get_by_list and not isinstance(weights, ParameterTuple):
|
||||
if get_by_list and not isinstance(weights, (ParameterTuple, tuple, list)):
|
||||
raise TypeError(f"For 'ForwardValueAndGrad', "
|
||||
f"when 'get_by_list' is set to True, the argument 'weights' must be "
|
||||
f"ParameterTuple type, but got '{type(weights)}'")
|
||||
f"Parameters array, but got '{type(weights)}'")
|
||||
self.network = network
|
||||
if isinstance(network, Cell):
|
||||
self.network.set_grad()
|
||||
|
|
|
@ -168,6 +168,8 @@ class GradOperation(GradOperation_):
|
|||
4. Call the gradient function with input function's inputs
|
||||
to get the gradients with respect to all inputs and given parameters: `gradient_function(x, y)`.
|
||||
|
||||
To be noticed, for above gradient functions, the returned gradient result may vary for grad result element number:
|
||||
A single value if only one result, a tuple for multiple results, or a empty tuple for no result.
|
||||
|
||||
We can configure the sensitivity(gradient with respect to output) by setting `sens_param` as True and
|
||||
passing an extra sensitivity input to the gradient function, the sensitivity input should has the
|
||||
|
@ -188,11 +190,11 @@ class GradOperation(GradOperation_):
|
|||
|
||||
Args:
|
||||
get_all (bool): If True, get all the gradients with respect to inputs. Default: False.
|
||||
get_by_list (bool): If True, get all the gradients with respect to Parameter variables.
|
||||
get_by_list (bool): If True, get all the gradients with respect to Parameter free variables.
|
||||
If get_all and get_by_list are both False, get the gradient with respect to first input.
|
||||
If get_all and get_by_list are both True, get the gradients with respect to inputs and Parameter variables
|
||||
at the same time in the form of ((gradients with respect to inputs),
|
||||
(gradients with respect to parameters)). Default: False.
|
||||
If get_all and get_by_list are both True, get the gradients with respect to inputs and
|
||||
Parameter free variables at the same time in the form of ("gradients with respect to inputs",
|
||||
"gradients with respect to parameter free variables"). Default: False.
|
||||
sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input.
|
||||
If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically.
|
||||
Default: False.
|
||||
|
|
|
@ -232,3 +232,67 @@ def test_train_lenet_with_new_interface(num_classes=10, epoch=20, batch_size=32)
|
|||
losses.append(loss)
|
||||
assert losses[-1].asnumpy() < 0.01
|
||||
assert losses[-1].asnumpy() > 0.001
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_train_lenet_with_new_interface_tuple(num_classes=10, epoch=20, batch_size=32):
|
||||
"""
|
||||
Feature: GradOperation get_by_list pass tuple/list
|
||||
Description: Grad with Parameters as input type and fv. list or tuple as fv of grad.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
network = LeNet5(num_classes)
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
net_with_criterion = WithLossCell(network, criterion)
|
||||
net_with_criterion.set_train()
|
||||
|
||||
weights = tuple(network.trainable_params())
|
||||
optimizer = nn.Momentum(weights, 0.1, 0.9)
|
||||
|
||||
train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True)
|
||||
losses = []
|
||||
for i in range(0, epoch):
|
||||
data = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.ones([batch_size]).astype(np.int32))
|
||||
sens = Tensor(np.ones([1]).astype(np.float32))
|
||||
loss, grads = train_network(data, label, sens)
|
||||
grads = F.identity(grads)
|
||||
optimizer(grads)
|
||||
losses.append(loss)
|
||||
assert losses[-1].asnumpy() < 0.01
|
||||
assert losses[-1].asnumpy() > 0.001
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_train_lenet_with_new_interface_list(num_classes=10, epoch=20, batch_size=32):
|
||||
"""
|
||||
Feature: GradOperation get_by_list pass tuple/list
|
||||
Description: Grad with Parameters as input type and fv. list or tuple as fv of grad.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
network = LeNet5(num_classes)
|
||||
criterion = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
net_with_criterion = WithLossCell(network, criterion)
|
||||
net_with_criterion.set_train()
|
||||
|
||||
weights = list(network.trainable_params())
|
||||
optimizer = nn.Momentum(weights, 0.1, 0.9)
|
||||
|
||||
train_network = ForwardValueAndGrad(network=net_with_criterion, weights=weights, get_by_list=True, sens_param=True)
|
||||
losses = []
|
||||
for i in range(0, epoch):
|
||||
data = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32) * 0.01)
|
||||
label = Tensor(np.ones([batch_size]).astype(np.int32))
|
||||
sens = Tensor(np.ones([1]).astype(np.float32))
|
||||
loss, grads = train_network(data, label, sens)
|
||||
grads = F.identity(grads)
|
||||
optimizer(grads)
|
||||
losses.append(loss)
|
||||
assert losses[-1].asnumpy() < 0.01
|
||||
assert losses[-1].asnumpy() > 0.001
|
||||
|
|
|
@ -177,3 +177,69 @@ def test_ascend_lenet2():
|
|||
loss_output = test_ascend_lenet()
|
||||
assert loss_output.asnumpy() < 0.004
|
||||
assert loss_output.asnumpy() > 0.003
|
||||
|
||||
|
||||
class GradWrapTuple(nn.Cell):
|
||||
"""
|
||||
GradWrapTuple definition
|
||||
"""
|
||||
|
||||
def __init__(self, network):
|
||||
super(GradWrapTuple, self).__init__()
|
||||
self.network = network
|
||||
self.weights = tuple(filter(lambda x: x.requires_grad, network.get_parameters()))
|
||||
|
||||
def construct(self, x, label):
|
||||
weights = self.weights
|
||||
return grad_by_list(self.network, weights)(x, label)
|
||||
|
||||
|
||||
def test_ascend_lenet_grad_by_list_tuple():
|
||||
"""
|
||||
Feature: GradOperation get_by_list pass tuple/list
|
||||
Description: Grad with Parameters as input type and fv. list or tuple as fv of grad.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
epoch_size = 20
|
||||
batch_size = 32
|
||||
inputs = Tensor(np.ones([batch_size, 1, 32, 32]).astype(np.float32))
|
||||
labels = Tensor(np.ones([batch_size]).astype(np.int32))
|
||||
|
||||
net = LeNet()
|
||||
criterion = CrossEntropyLoss()
|
||||
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.1, 0.9)
|
||||
|
||||
net_with_criterion = WithLossCell(net, criterion)
|
||||
train_network = GradWrapTuple(net_with_criterion)
|
||||
train_network.set_train()
|
||||
total_time = 0
|
||||
|
||||
for epoch in range(0, epoch_size):
|
||||
start_time = time.time()
|
||||
fw_output = net(inputs)
|
||||
loss_output = criterion(fw_output, labels)
|
||||
grads = train_network(inputs, labels)
|
||||
optimizer(grads)
|
||||
end_time = time.time()
|
||||
cost_time = end_time - start_time
|
||||
total_time = total_time + cost_time
|
||||
|
||||
print("======epoch: ", epoch, " loss: ", loss_output.asnumpy(), " cost time: ", cost_time)
|
||||
return loss_output
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_ascend_lenet_grad_by_list_tuple1():
|
||||
"""
|
||||
Feature: GradOperation get_by_list pass tuple/list
|
||||
Description: Grad with Parameters as input type and fv. list or tuple as fv of grad.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
os.environ['GRAPH_OP_RUN'] = str(1)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
loss_output = test_ascend_lenet_grad_by_list_tuple()
|
||||
assert loss_output.asnumpy() < 0.004
|
||||
assert loss_output.asnumpy() > 0.003
|
||||
|
|
|
@ -53,16 +53,16 @@ def test_isinstance():
|
|||
is_tuple_var = isinstance((x, 1, 1.0, y), tuple)
|
||||
is_list_const = isinstance(self.list_member, list)
|
||||
is_list_var = isinstance([x, 1, 1.0, y], list)
|
||||
is_empty_list = isinstance(self.empty_list, list)
|
||||
is_dict_const = isinstance(self.dict_member, dict)
|
||||
is_dict_var = isinstance({"x": x, "y": y}, dict)
|
||||
is_empty_dic = isinstance(self.empty_dict, dict)
|
||||
is_list_or_tensor = isinstance([x, y], (Tensor, list))
|
||||
is_int_or_float_or_tensor_or_tuple = isinstance(x, (Tensor, tuple, int, float))
|
||||
is_list_or_tensor = isinstance([x, y], (Tensor, list))
|
||||
float_is_int = isinstance(self.float_member, int)
|
||||
bool_is_string = isinstance(self.bool_member, str)
|
||||
tensor_is_tuple = isinstance(x, tuple)
|
||||
tuple_is_list = isinstance(self.tuple_member, list)
|
||||
is_empty_list = isinstance(self.empty_list, list)
|
||||
return is_int, is_float, is_bool, bool_is_int, is_string, is_parameter, \
|
||||
parameter_is_tensor, is_tensor_const, is_tensor_var, \
|
||||
is_tuple_const, is_tuple_var, is_list_const, is_list_var, is_empty_list, \
|
||||
|
|
|
@ -229,7 +229,6 @@ def test_grad_parameter_tuple(mode):
|
|||
GradCellWithParameterTuple(TestCell2(x1, x2))(z)
|
||||
|
||||
|
||||
@pytest.mark.skip(reason='Not support list or tuple of parameters as GradOperation inputs by now')
|
||||
@pytest.mark.parametrize('mode', [context.PYNATIVE_MODE, context.GRAPH_MODE])
|
||||
def test_grad_parameter_list_or_tuple(mode):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue