!46740 Fix pins for PyNative

Merge pull request !46740 from zjun/fix_xxx_alpha
This commit is contained in:
i-robot 2022-12-14 13:53:04 +00:00 committed by Gitee
commit 60a89a9fc8
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
9 changed files with 224 additions and 115 deletions

View File

@ -61,7 +61,6 @@ struct FrontendOpRunInfo {
int mix_type{0};
size_t op_index = 0;
size_t input_size = 0;
size_t custom_bprop_cell_count = 0;
PrimitivePyPtr op_prim{nullptr};
ValuePtr out_value{nullptr};
std::string op_info;
@ -92,7 +91,7 @@ struct InputArgsInfo {
bool has_sens{false};
bool grad_is_running{false};
bool use_dynamic_shape_process{false};
PrimitivePyPtr custom_bprp_prim{nullptr};
PrimitivePyPtr custom_bprop_prim{nullptr};
ValuePtr out_value{nullptr};
std::string cell_id;
std::string already_run_cell_id;

View File

@ -266,7 +266,6 @@ ValuePtr CastOperation::DoAutoCast(const FrontendOpRunInfoPtr &op_run_info, cons
constexpr auto input_size = 2;
const auto &cast_run_info = std::make_shared<FrontendOpRunInfo>();
cast_run_info->grad_flag = op_run_info->grad_flag;
cast_run_info->custom_bprop_cell_count = op_run_info->custom_bprop_cell_count;
MS_EXCEPTION_IF_NULL(cast_prim_);
cast_run_info->op_prim = cast_prim_;
cast_run_info->base_op_run_info.op_name = prim::kPrimCast->name();

View File

@ -180,10 +180,8 @@ void ForwardExecutor::RunOpForward(const FrontendOpRunInfoPtr &op_run_info) {
return;
}
// 4. Do op grad and record op info
// If cell have custom bprop, no need do op grad. Otherwise, need do.
// If ms function is complime, real run op must be not record
bool is_custom_bprop = op_run_info->custom_bprop_cell_count <= 0;
if (!is_ms_function_compiling_ && is_custom_bprop) {
// If ms function is compile, op info will not be find in second training step
if (!is_ms_function_compiling_ && grad()->custom_bprop_cell_count() <= 0) {
grad()->ProcessOpGradInfo(op_run_info);
}
}
@ -195,7 +193,6 @@ FrontendOpRunInfoPtr ForwardExecutor::GenerateOpRunInfo(const py::args &args) co
const auto &op_run_info = std::make_shared<FrontendOpRunInfo>();
// Used for async run
op_run_info->grad_flag = grad()->grad_flag();
op_run_info->custom_bprop_cell_count = grad()->custom_bprop_cell_count();
op_run_info->base_op_run_info.op_name = args[static_cast<size_t>(RunOpArgsEnum::PY_NAME)].cast<std::string>();
op_run_info->base_op_run_info.use_dynamic_shape_process =
grad()->use_dynamic_shape_process() &&

View File

@ -1,7 +1,7 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2021-2022 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -43,6 +43,7 @@ const std::map<SpecialType, std::shared_ptr<Primitive>> kValueType{{SpecialType:
const std::vector<PrimitivePtr> kGradBlackList{
prim::kPrimMakeTuple, prim::kPrimTupleGetItem, prim::kPrimStopGradient, prim::kPrimUpdateState,
prim::kPrimNPUAllocFloatStatus, prim::kPrimNPUGetFloatStatus, prim::kPrimNPUClearFloatStatus};
void ClearDeviceAddress(const ValuePtr &value) {
std::vector<tensor::TensorPtr> tensors;
TensorValueToTensor(value, &tensors);
@ -169,7 +170,7 @@ AnfNodePtr BuildSpecialLikeValue(const FuncGraphPtr &tape, const ValuePtr &value
special_like_value->set_abstract(value->ToAbstract()->Broaden());
return special_like_value;
} else {
MS_EXCEPTION(TypeError) << "For value" << value->type()->ToString() << "`, the type is not tensor or sequence";
MS_EXCEPTION(TypeError) << "For value" << value->ToString() << ", the type is not tensor or sequence";
}
}
@ -280,9 +281,10 @@ AnfNodePtr FunctionNode::HyperAdd(const AnfNodePtr &left_node, const AnfNodePtr
}
}
void FunctionNode::AddEdge(const AnfNodePtr &next_node, const AnfNodePtr &din) {
void FunctionNode::AddNextEdge(const AnfNodePtr &next_node, const AnfNodePtr &din) {
MS_EXCEPTION_IF_NULL(next_node);
MS_EXCEPTION_IF_NULL(din);
// next_node and its corresponding din
(void)next_edges_.emplace_back(std::make_pair(next_node, din));
if (din == fake_dout_) {
(void)need_replace_edges_.emplace_back(next_edges_.size() - 1);
@ -338,6 +340,7 @@ AutoGradCellImpl::AutoGradCellImpl(const AnfNodePtrList &cell_inputs, const std:
bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
MS_EXCEPTION_IF_NULL(grad_param);
MS_LOG(DEBUG) << "Forward cnode: " << grad_param->cnode->DebugString();
auto prim = GetCNodePrimitive(grad_param->cnode);
if (prim == nullptr) {
MS_LOG(EXCEPTION) << "Should be primitive, but: " << grad_param->cnode->DebugString();
@ -346,44 +349,44 @@ bool AutoGradCellImpl::KPynativeOp(const GradParamPtr &grad_param) {
MS_LOG(DEBUG) << "Prim " << prim->name() << " not need do op grad";
return true;
}
bool is_custom_prim =
IsPrimitiveEquals(prim, prim::kPrimHookBackward) || IsPrimitiveEquals(prim, prim::kPrimCellBackwardHook);
// anfnode_to_variable_adjoint_ hold out value, to avoid device not release, clear its device_address
auto cloned_value = ShallowCopyTensorValue(grad_param->out);
ClearDeviceAddress(cloned_value);
AnfNodePtr dout = BuildSpecialLikeValue(tape_, cloned_value, SpecialType::kZerosLikeType);
auto fn = std::make_shared<FunctionNode>(tape_, dout);
auto variable_adjoint = std::make_shared<VariableAdjoint>(fn, cloned_value);
if (!grad_param->grad_by_value) {
BuildKNode(grad_param, variable_adjoint);
// Custom forward cnode no need record in bprop graph, because it is a flag cnode for run python. So just create
// bprop_cut grad op is ok
if (!grad_param->grad_by_value && !is_custom_prim) {
variable_adjoint->set_k_node(BuildKNode(grad_param));
need_do_manager_replace_ = true;
}
CNodePtr input_node = ConstructBpropGraphInput(grad_param, dout, variable_adjoint);
CNodePtr input_node = ConstructBpropGraphInput(grad_param, dout, variable_adjoint, is_custom_prim);
MS_LOG(DEBUG) << "Construct input cnode: " << input_node->DebugString();
// Gradient outputs
std::vector<CNodePtr> outputs;
#ifndef ENABLE_TEST
if (IsPrimitiveEquals(prim, prim::kPrimHookBackward) || IsPrimitiveEquals(prim, prim::kPrimCellBackwardHook)) {
BuildBPropCutCNode(input_node, &outputs);
if (is_custom_prim) {
BuildBPropCutCNode(input_node, prim, &outputs);
} else {
#ifndef ENABLE_TEST
mindspore::BuildBprop(input_node, &outputs, &users_);
if (outputs.empty()) {
MS_LOG(DEBUG) << "Expander has no bprop of this prim: " << grad_param->cnode->DebugString();
BuildCustomBpropCNode(input_node, &outputs);
BuildCustomBpropCNode(input_node, prim, &outputs);
}
}
#else
if (IsPrimitiveEquals(prim, prim::kPrimHookBackward) || IsPrimitiveEquals(prim, prim::kPrimCellBackwardHook)) {
BuildBPropCutCNode(input_node, &outputs);
} else {
BuildCustomBpropCNode(input_node, &outputs);
}
BuildCustomBpropCNode(input_node, prim, &outputs);
#endif
if (!outputs.empty()) {
UpdateNextEdges(fn, grad_param->cnode, outputs, grad_param->op_args);
} else {
}
if (outputs.empty()) {
MS_LOG(DEBUG) << "This op has not custom bprop: " << grad_param->cnode->DebugString();
BuildFakeBpropCNode(input_node, &outputs);
variable_adjoint->set_is_fake_bprop(true);
variable_adjoint->set_fake_prim_name(prim->name());
}
UpdateNextEdges(variable_adjoint, grad_param->cnode, outputs, grad_param->op_args);
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
// record last_node for brackpropagate
last_node_ = grad_param->cnode;
@ -395,6 +398,7 @@ bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) {
AnfNodePtrList args_node_list;
CNodePtr bprop_cnode = nullptr;
AnfNodePtr k_node = nullptr;
AnfNodePtr dout = nullptr;
if (grad_param->grad_by_value) {
for (size_t i = 0; i < grad_param->op_args.size(); ++i) {
@ -412,14 +416,13 @@ bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) {
}
bprop_cnode = GetBPropFromFProp(grad_param->fprop_fg, args_node_list, grad_param->out, &dout);
} else {
k_node = BuildKNode(grad_param);
BuildKNodeListFromPrimalCNode(grad_param->cnode, grad_param->op_args, &args_node_list);
bprop_cnode = GetBPropFromFProp(grad_param->fprop_fg, args_node_list, grad_param->out, &dout);
}
auto fn = std::make_shared<FunctionNode>(tape_, dout);
auto variable_adjoint = std::make_shared<VariableAdjoint>(fn, grad_param->out);
if (!grad_param->grad_by_value) {
BuildKNode(grad_param, variable_adjoint);
}
variable_adjoint->set_k_node(k_node);
std::vector<CNodePtr> outputs;
for (size_t i = 1; i < grad_param->cnode->size(); ++i) {
// bprop_app[0] env
@ -427,7 +430,7 @@ bool AutoGradCellImpl::KPynativeWithFProp(const GradParamPtr &grad_param) {
din->set_abstract(grad_param->op_args[i - 1]->ToAbstract()->Broaden());
(void)outputs.emplace_back(din);
}
UpdateNextEdges(fn, grad_param->cnode, outputs, grad_param->op_args);
UpdateNextEdges(variable_adjoint, grad_param->cnode, outputs, grad_param->op_args);
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(grad_param->cnode, variable_adjoint));
need_do_manager_replace_ = true;
return true;
@ -492,7 +495,6 @@ FuncGraphPtr AutoGradCellImpl::Finish(const AnfNodePtrList &weights, const std::
if (!last_node_->isa<ValueNode>() && !last_node_->isa<Parameter>()) {
(void)BackPropagate();
}
SetOutput(weights, grad_position, grad_attr);
// Replace Parameter of primal funcgraph with parameter of tape_;
ReplacePrimalParameter(weights, grad_attr.has_sens);
@ -501,12 +503,12 @@ FuncGraphPtr AutoGradCellImpl::Finish(const AnfNodePtrList &weights, const std::
}
CNodePtr AutoGradCellImpl::ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout,
const VariableAdjointPtr &variable_adjoint) {
const VariableAdjointPtr &variable_adjoint, bool is_custom_prim) {
MS_EXCEPTION_IF_NULL(grad_param);
std::vector<AnfNodePtr> node_list;
(void)node_list.emplace_back(grad_param->cnode->input(0));
auto out_abs = grad_param->out->ToAbstract()->Broaden();
if (grad_param->grad_by_value) {
if (grad_param->grad_by_value || is_custom_prim) {
for (size_t i = 0; i < grad_param->op_args.size(); ++i) {
const auto &v = grad_param->op_args[i];
auto node = grad_param->cnode->input(i + 1);
@ -541,7 +543,6 @@ void AutoGradCellImpl::BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, cons
std::vector<AnfNodePtr> *const node_list) {
MS_EXCEPTION_IF_NULL(cnode);
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
MS_LOG(DEBUG) << "Get knode for node " << cnode->input(i)->DebugString();
if (cnode->input(i)->isa<CNode>()) {
const auto input_adjoint_iter = anfnode_to_variable_adjoint_.find(cnode->input(i));
if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) {
@ -553,10 +554,11 @@ void AutoGradCellImpl::BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, cons
cnode->input(i)->set_abstract(op_args[i - 1]->ToAbstract()->Broaden());
(void)node_list->emplace_back(cnode->input(i));
}
MS_LOG(DEBUG) << "Get knode for node " << cnode->input(i)->DebugString();
}
}
void AutoGradCellImpl::BuildKNode(const GradParamPtr &grad_param, const VariableAdjointPtr &variable_adjoint) {
AnfNodePtr AutoGradCellImpl::BuildKNode(const GradParamPtr &grad_param) {
MS_EXCEPTION_IF_NULL(grad_param);
AnfNodePtrList node_list;
for (size_t i = 0; i < grad_param->cnode->inputs().size(); ++i) {
@ -564,8 +566,8 @@ void AutoGradCellImpl::BuildKNode(const GradParamPtr &grad_param, const Variable
}
auto k_node = tape_->NewCNode(node_list);
k_node->set_abstract(grad_param->out->ToAbstract()->Broaden());
variable_adjoint->set_k_node(k_node);
MS_LOG(DEBUG) << "Build knode " << k_node->DebugString();
return k_node;
}
AnfNodePtr AutoGradCellImpl::BuildKNodeForCNodeInput(const AnfNodePtr &input_node) {
@ -573,6 +575,11 @@ AnfNodePtr AutoGradCellImpl::BuildKNodeForCNodeInput(const AnfNodePtr &input_nod
if (input_node->isa<CNode>()) {
const auto input_adjoint_iter = anfnode_to_variable_adjoint_.find(input_node);
if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) {
if (IsPrimitiveCNode(input_node, prim::kPrimMakeTuple)) {
return BuildKNodeForMakeTuple(input_node);
} else if (IsPrimitiveCNode(input_node, prim::kPrimTupleGetItem)) {
return BuildKNodeForTupleGetItem(input_node);
}
MS_LOG(EXCEPTION) << "Cannot find input in adjoint map, inp: " << input_node->DebugString();
}
return input_adjoint_iter->second->k_node();
@ -581,22 +588,70 @@ AnfNodePtr AutoGradCellImpl::BuildKNodeForCNodeInput(const AnfNodePtr &input_nod
}
}
AnfNodePtr AutoGradCellImpl::BuildKNodeForMakeTuple(const AnfNodePtr &input_node) {
MS_EXCEPTION_IF_NULL(input_node);
MS_LOG(DEBUG) << "Build knode for MakeTuple " << input_node->DebugString();
const auto &cnode = input_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
ValuePtrList op_args;
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
(void)inputs.emplace_back(BuildKNodeForCNodeInput(cnode->input(i)));
if (cnode->input(i)->isa<CNode>() || cnode->input(i)->isa<Parameter>()) {
const auto input_adjoint_iter = anfnode_to_variable_adjoint_.find(cnode->input(i));
if (input_adjoint_iter == anfnode_to_variable_adjoint_.end()) {
MS_LOG(EXCEPTION) << "Cannot find input in adjoint map, inp: " << cnode->input(i)->DebugString();
}
(void)op_args.emplace_back(input_adjoint_iter->second->out_value());
} else {
auto value_node = cnode->input(i)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
(void)op_args.emplace_back(value_node->value());
}
}
auto out_value = MakeValue(op_args);
AnfNodePtr dout = BuildSpecialLikeValue(tape_, out_value, SpecialType::kZerosLikeType);
auto fn = std::make_shared<FunctionNode>(tape_, dout);
auto variable_adjoint = std::make_shared<VariableAdjoint>(fn, out_value);
auto k_node = tape_->NewCNode(inputs);
k_node->set_abstract(input_node->abstract());
variable_adjoint->set_k_node(k_node);
(void)anfnode_to_variable_adjoint_.insert(std::make_pair(input_node, variable_adjoint));
return k_node;
}
AnfNodePtr AutoGradCellImpl::BuildKNodeForTupleGetItem(const AnfNodePtr &input_node) {
MS_EXCEPTION_IF_NULL(input_node);
MS_LOG(DEBUG) << "Build knode for TupleGetItem " << input_node->DebugString();
const auto &tuple_item_cnode = input_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_item_cnode);
const auto &make_tuple_cnode = tuple_item_cnode->input(1)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple_cnode);
auto index_value = GetValueNode<Int64ImmPtr>(tuple_item_cnode->input(2));
auto real_node = make_tuple_cnode->input(LongToSize(index_value->value()));
return BuildKNodeForCNodeInput(real_node);
}
bool GradPynativeOp(const AutoGradCellImplPtr &k_cell, const GradParamPtr &grad_param) {
return k_cell->KPynativeOp(grad_param);
}
void AutoGradCellImpl::UpdateNextEdges(const FunctionNodePtr &fn, const CNodePtr &cnode,
void AutoGradCellImpl::UpdateNextEdges(const VariableAdjointPtr &variable, const CNodePtr &cnode,
const std::vector<CNodePtr> &dins, const ValuePtrList &op_args) {
MS_EXCEPTION_IF_NULL(cnode);
if (dins.size() != op_args.size()) {
MS_LOG(EXCEPTION) << "The size of dins is not same as op_args";
MS_LOG(EXCEPTION) << "The size of dins is not same as op_args, cnode: " << cnode->DebugString();
}
const auto &fn = variable->fn();
for (size_t i = 0; i < op_args.size(); ++i) {
const auto &node = cnode->input(i + 1);
const auto &din = dins[i];
MS_LOG(DEBUG) << "Node " << node->DebugString() << ", din " << din->DebugString();
UpdateNextEdge(fn, node, din, op_args[i]);
}
if (fn->next_edges().empty()) {
variable->set_is_need_grad(false);
}
}
void AutoGradCellImpl::UpdateNextEdge(const FunctionNodePtr &fn, const AnfNodePtr &node, const AnfNodePtr &din,
@ -604,12 +659,16 @@ void AutoGradCellImpl::UpdateNextEdge(const FunctionNodePtr &fn, const AnfNodePt
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(din);
MS_EXCEPTION_IF_NULL(op_arg);
if (anfnode_to_variable_adjoint_.find(node) != anfnode_to_variable_adjoint_.end()) {
auto variable = anfnode_to_variable_adjoint_.at(node);
auto real_din = GetRealDin(fn, variable->out_value(), op_arg, din);
fn->AddEdge(node, real_din);
const auto it = anfnode_to_variable_adjoint_.find(node);
if (it != anfnode_to_variable_adjoint_.end()) {
if (!it->second->is_need_grad()) {
return;
}
auto real_din = GetRealDin(fn, it->second->out_value(), op_arg, din);
fn->AddNextEdge(node, real_din);
} else if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
const auto &cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (IsPrimitiveCNode(cnode, prim::kPrimStopGradient) || IsPrimitiveCNode(cnode, prim::kPrimUpdateState)) {
return;
}
@ -624,8 +683,10 @@ void AutoGradCellImpl::UpdateNextEdge(const FunctionNodePtr &fn, const AnfNodePt
CNodePtr new_din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), din, NewValueNode(SizeToLong(i))});
new_din->set_abstract(sub_value->ToAbstract()->Broaden());
if (din == fn->fake_dout()) {
// The new_din's index input is fn->fake_dout()
AddUser(fn->fake_dout(), new_din, 1);
}
// Add next edge to fn
UpdateNextEdge(fn, input_node, new_din, sub_value);
}
} else if (IsPrimitiveCNode(cnode, prim::kPrimTupleGetItem)) {
@ -636,7 +697,7 @@ void AutoGradCellImpl::UpdateNextEdge(const FunctionNodePtr &fn, const AnfNodePt
}
UpdateNextEdge(fn, src_node, din, op_arg);
} else {
MS_LOG(EXCEPTION) << "cnode should be tuplegetitem or maketuple " << cnode->DebugString();
MS_LOG(EXCEPTION) << "Cnode should be tuplegetitem or maketuple " << cnode->DebugString();
}
} else if (node->isa<Parameter>()) {
auto param = node->cast<ParameterPtr>();
@ -685,11 +746,14 @@ AnfNodePtr AutoGradCellImpl::GetRealDin(const FunctionNodePtr &fn, const ValuePt
MS_EXCEPTION_IF_NULL(din);
const auto &out_value_id = pynative::PyNativeAlgo::Common::GetIdByValue(out_value);
const auto &sub_value_id = pynative::PyNativeAlgo::Common::GetIdByValue(sub_value);
// The node corresponding output tensor is the same as the currently used tensor
if (out_value_id == sub_value_id) {
return din;
} else if (out_value->isa<tensor::Tensor>()) {
// out_value is be used, may be it is one of multiple output
return BuildZerosLikeNode(tape_, out_value);
} else if (out_value->isa<ValueSequence>()) {
// The corresponding output of node is ValueSequence, but used one of it
std::vector<AnfNodePtr> inputs;
if (out_value->isa<ValueTuple>()) {
(void)inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
@ -699,6 +763,8 @@ AnfNodePtr AutoGradCellImpl::GetRealDin(const FunctionNodePtr &fn, const ValuePt
auto value_seq = out_value->cast<ValueSequencePtr>();
int index = -1;
for (const auto &value : value_seq->value()) {
// Find the value's din, if value equal to sub_value, means value be used, is it will get din; Otherwise value's
// din is zero , which set by second branch condition above
auto real_din = GetRealDin(fn, value, sub_value, din);
(void)inputs.emplace_back(real_din);
@ -717,12 +783,9 @@ AnfNodePtr AutoGradCellImpl::GetRealDin(const FunctionNodePtr &fn, const ValuePt
return nullptr;
}
void AutoGradCellImpl::BuildBPropCutCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs) {
auto prim = GetCNodePrimitive(cnode);
if (prim == nullptr) {
MS_LOG(EXCEPTION) << "Should be primitive, but: " << cnode->DebugString();
}
void AutoGradCellImpl::BuildBPropCutCNode(const CNodePtr &cnode, const PrimitivePtr &prim,
std::vector<CNodePtr> *outputs) {
MS_EXCEPTION_IF_NULL(prim);
auto prim_py = prim->cast<PrimitivePyPtr>();
MS_EXCEPTION_IF_NULL(prim_py);
auto bprop_cut = std::make_shared<PrimitivePy>("bprop_cut");
@ -738,29 +801,35 @@ void AutoGradCellImpl::BuildBPropCutCNode(const CNodePtr &cnode, std::vector<CNo
if (prim->HasAttr("custom_op_bprop")) {
(void)bprop_cut->AddAttr("custom_op_bprop", MakeValue(true));
}
// Create gradient outputs cnode
std::vector<AnfNodePtr> inputs{NewValueNode(bprop_cut)};
auto output = tape_->NewCNode(inputs);
AbstractBasePtrList abs;
size_t args_size = cnode->size() - 2;
// Get input, get output, get dout
for (size_t i = 1; i < cnode->inputs().size(); ++i) {
(void)inputs.emplace_back(cnode->input(i));
}
auto bprop_cut_cnode = tape_->NewCNode(inputs);
size_t input_num = cnode->size() - 2;
AbstractBasePtrList abs_list;
for (size_t i = 1; i < cnode->size(); ++i) {
output->add_input(cnode->input(i));
AddUser(cnode->input(i), output, i);
if (i < args_size) {
auto din = tape_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), output, NewValueNode(SizeToLong(i - 1))});
din->set_abstract(cnode->input(i)->abstract()->Broaden());
// bprop_cut_cnode ith input used cnode->input(i)
AddUser(cnode->input(i), bprop_cut_cnode, i);
if (i < input_num) {
auto din = tape_->NewCNode(
{NewValueNode(prim::kPrimTupleGetItem), bprop_cut_cnode, NewValueNode(static_cast<int64_t>(i - 1))});
MS_EXCEPTION_IF_NULL(cnode->input(i)->abstract());
din->set_abstract(cnode->input(i)->abstract());
abs_list.emplace_back(cnode->input(i)->abstract());
(void)outputs->emplace_back(din);
(void)abs.emplace_back(din->abstract());
}
}
output->set_abstract(std::make_shared<abstract::AbstractTuple>(abs));
bprop_cut_cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
}
void AutoGradCellImpl::BuildCustomBpropCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs) {
auto prim = GetCNodePrimitive(cnode);
if (prim == nullptr) {
MS_LOG(EXCEPTION) << "Should be primitive, but: " << cnode->DebugString();
}
void AutoGradCellImpl::BuildCustomBpropCNode(const CNodePtr &cnode, const PrimitivePtr &prim,
std::vector<CNodePtr> *outputs) {
MS_EXCEPTION_IF_NULL(prim);
MS_LOG(DEBUG) << "Build custom bprop: " << prim->name();
auto prim_py = prim->cast<PrimitivePyPtr>();
{
py::gil_scoped_acquire gil;
@ -780,7 +849,23 @@ void AutoGradCellImpl::BuildCustomBpropCNode(const CNodePtr &cnode, std::vector<
prim_py->AddBackwardHookFn(0, fn);
prim_py->AddAttr("custom_op_bprop", MakeValue(True));
}
BuildBPropCutCNode(cnode, outputs);
BuildBPropCutCNode(cnode, prim, outputs);
}
void AutoGradCellImpl::BuildFakeBpropCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs) {
auto prim = GetCNodePrimitive(cnode);
if (prim == nullptr) {
MS_LOG(EXCEPTION) << "Should be primitive, but: " << cnode->DebugString();
}
size_t dout_index = cnode->size() - 1;
const auto &dout = cnode->input(dout_index);
const auto &dout_cnode = dout->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(dout_cnode);
// Size is same as op_arg size
size_t input_size = cnode->size() - 2;
for (size_t i = 1; i < input_size; ++i) {
(void)outputs->emplace_back(dout_cnode);
}
}
void AutoGradCellImpl::SetSensAndWeights(const AnfNodePtrList &weights, bool has_sens_arg) {
@ -840,13 +925,12 @@ void AutoGradCellImpl::BackPropagate() {
for (auto iter = last_node_reverse_iter; iter != anfnode_to_variable_adjoint_.rend(); ++iter) {
MS_LOG(DEBUG) << "BackPropagate cnode: " << iter->first->DebugString();
const auto &variable = iter->second;
if (!variable->is_need_propagate()) {
if (!variable->is_need_propagate() || !variable->is_need_grad()) {
MS_LOG(DEBUG) << "No need grad";
continue;
}
if (variable->is_fake_bprop()) {
MS_LOG(WARNING) << variable->fake_prim_name() << " op has not corresponding bprop!";
continue;
MS_LOG(EXCEPTION) << variable->fake_prim_name() << " op has not corresponding bprop!";
}
if (!has_primc && iter->first->isa<CNode>() && GetCNodePrimitive(iter->first) != nullptr) {
has_primc = true;

View File

@ -1,5 +1,5 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -64,7 +64,7 @@ class FunctionNode {
public:
FunctionNode(const FuncGraphPtr &tape, const AnfNodePtr &dout)
: tape_(tape), accumulate_dout_(dout), fake_dout_(dout) {}
void AddEdge(const AnfNodePtr &next_node, const AnfNodePtr &din);
void AddNextEdge(const AnfNodePtr &next_node, const AnfNodePtr &din);
void UpdateAccumulativeDout(const AnfNodePtr &new_dout);
const std::vector<std::pair<AnfNodePtr, AnfNodePtr>> &next_edges() const { return next_edges_; }
const FuncGraphPtr tape() { return tape_; }
@ -91,6 +91,7 @@ using FunctionNodePtr = std::shared_ptr<FunctionNode>;
// Variable represent a parameter or output of a middle cnode
class VariableAdjoint {
public:
VariableAdjoint() = default;
VariableAdjoint(const FunctionNodePtr &fn, const ValuePtr &out_value) : fn_(fn), out_value_(out_value) {}
ValuePtr out_value() const { return out_value_; }
@ -101,20 +102,24 @@ class VariableAdjoint {
void set_is_fake_bprop(bool is_fake_bprop) { is_fake_bprop_ = is_fake_bprop; }
bool is_need_propagate() const { return is_need_propagate_; }
void set_is_need_propagate(bool is_need_grad) { is_need_propagate_ = is_need_grad; }
bool is_need_grad() const { return is_need_grad_; }
void set_is_need_grad(bool is_need_grad) { is_need_grad_ = is_need_grad; }
AnfNodePtr k_node() const { return k_node_; }
void set_k_node(const AnfNodePtr &k_node) { k_node_ = k_node; }
AnfNodePtr RealDout();
private:
// Abstract bprop function
FunctionNodePtr fn_;
ValuePtr out_value_;
FunctionNodePtr fn_{nullptr};
ValuePtr out_value_{nullptr};
// If node has not bprop, we record its prim name
std::string fake_prim_name_;
// Record this node is a fake bprop
bool is_fake_bprop_{false};
// Flag to judge need to propagrate
bool is_need_propagate_{false};
// Flag to judge variable whether need grad
bool is_need_grad_{true};
// K mapped cnode for primal CNode; primal CNode is owned by primal funcgraph, this is owned by tape_;
AnfNodePtr k_node_{nullptr};
};
@ -161,9 +166,9 @@ class AutoGradCellImpl {
// construct input as cnode for expander
CNodePtr ConstructBpropGraphInput(const GradParamPtr &grad_param, const AnfNodePtr &dout,
const VariableAdjointPtr &variable_adjoint);
const VariableAdjointPtr &variable_adjoint, bool is_custom_prim);
// Back propagate for one node;
void UpdateNextEdges(const FunctionNodePtr &fn, const CNodePtr &cnode, const std::vector<CNodePtr> &dins,
void UpdateNextEdges(const VariableAdjointPtr &variable, const CNodePtr &cnode, const std::vector<CNodePtr> &dins,
const ValuePtrList &op_args);
void UpdateNextEdge(const FunctionNodePtr &fn, const AnfNodePtr &node, const AnfNodePtr &din, const ValuePtr &op_arg);
@ -172,8 +177,8 @@ class AutoGradCellImpl {
void AddParameterNode(const AnfNodePtr &parameter, const ValuePtr &tensor);
AnfNodePtr GetRealDin(const FunctionNodePtr &fn, const ValuePtr &out_value, const ValuePtr &sub_value,
const AnfNodePtr &din);
void BuildBPropCutCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs);
void BuildCustomBpropCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs);
void BuildBPropCutCNode(const CNodePtr &cnode, const PrimitivePtr &prim, std::vector<CNodePtr> *outputs);
void BuildCustomBpropCNode(const CNodePtr &cnode, const PrimitivePtr &prim, std::vector<CNodePtr> *outputs);
void BuildFakeBpropCNode(const CNodePtr &cnode, std::vector<CNodePtr> *outputs);
// Replace input or weights parameter from primal funcgraph to parameters of tape_;
void ReplacePrimalParameter(const AnfNodePtrList &weights, bool has_sens_arg);
@ -195,10 +200,12 @@ class AutoGradCellImpl {
void ElimateTupleGetItem();
// Fbprop
void BuildKNode(const GradParamPtr &grad_param, const VariableAdjointPtr &variable_adjoint);
AnfNodePtr BuildKNode(const GradParamPtr &grad_param);
void BuildKNodeListFromPrimalCNode(const CNodePtr &cnode, const ValuePtrList &op_args,
std::vector<AnfNodePtr> *const node_list);
AnfNodePtr BuildKNodeForCNodeInput(const AnfNodePtr &input_node);
AnfNodePtr BuildKNodeForMakeTuple(const AnfNodePtr &input_node);
AnfNodePtr BuildKNodeForTupleGetItem(const AnfNodePtr &input_node);
};
using AutoGradCellImplPtr = std::shared_ptr<AutoGradCellImpl>;

View File

@ -142,9 +142,8 @@ ValuePtr ConvertOutputValueToTensor(const ValuePtr &v) {
MS_EXCEPTION_IF_NULL(value_tuple);
return opt::CreateTupleTensor(value_tuple);
}
PyNativeAlgo::Common::FilterSensValues(v);
MS_LOG(DEBUG) << "Output is value sequence, but have tensor and other type mixed. Its value is " << v->ToString();
return v;
return PyNativeAlgo::Common::FilterSensValues(v);
} else if (v->isa<FloatImm>()) {
double input_value = v->cast<FP32ImmPtr>()->value();
return std::make_shared<tensor::Tensor>(input_value, kFloat32);
@ -552,11 +551,10 @@ void GradExecutor::MakeNewTopGraph(const InputArgsInfoPtr &input_args_info) {
input_args_info->already_run_cell_id, resource, fg);
top_cell_->set_forward_already_run(true);
top_cell_->set_input_args_id(input_args_info->input_args_id);
top_cell_->set_is_cell_id_in_dynamic_detect_nodes_map(
cell_id_with_dynamic_detect_nodes_.find(obj_id_with_grad_order) != cell_id_with_dynamic_detect_nodes_.end());
PushHighOrderGraphStack(top_cell_);
(void)top_cell_list_.emplace_back(top_cell_);
is_cell_id_in_dynamic_detect_nodes_map_ =
(cell_id_with_dynamic_detect_nodes_.find(obj_id_with_grad_order) != cell_id_with_dynamic_detect_nodes_.end());
MS_LOG(DEBUG) << "New top graph, fg ptr " << fg.get() << " resource ptr " << resource.get();
}
@ -603,9 +601,7 @@ void GradExecutor::EndGraphInner(const py::object &obj, const py::object &out, c
void GradExecutor::UpdateInputArgsInfo(const InputArgsInfoPtr &input_args_info, const py::object &obj,
const py::object &out, const py::args &args) {
MS_EXCEPTION_IF_NULL(input_args_info);
if (input_args_info->has_custom_bprop) {
GetCustomBpropPrim(obj, args, out, input_args_info);
}
GetCustomBpropPrim(obj, args, out, input_args_info);
// Used at master thread, change its at master thread
if (input_args_info->is_grad_topest_cell) {
grad_flag_ = false;
@ -678,27 +674,30 @@ void GradExecutor::DoGradForCustomBprop(const InputArgsInfoPtr &input_args_info,
return;
}
MS_LOG(DEBUG) << "Do grad for custom bprop";
MS_EXCEPTION_IF_NULL(input_args_info->custom_bprp_prim);
MS_EXCEPTION_IF_NULL(input_args_info->custom_bprop_prim);
auto op_run_info = std::make_shared<FrontendOpRunInfo>();
op_run_info->grad_flag = true;
op_run_info->base_op_run_info.op_name = input_args_info->custom_bprp_prim->name();
op_run_info->op_prim = input_args_info->custom_bprp_prim;
op_run_info->base_op_run_info.op_name = input_args_info->custom_bprop_prim->name();
op_run_info->op_prim = input_args_info->custom_bprop_prim;
op_run_info->input_value = input_args_info->input_arg_value_vec;
op_run_info->input_size = input_args_info->input_arg_value_vec.size();
op_run_info->input_value_id = input_args_info->input_arg_id_vec;
auto cnode = ConstructForwardGraph(op_run_info);
if (!input_args_info->grad_is_running || bprop_grad_stack_.top().second) {
DoOpGrad(op_run_info, cnode, input_args_info->out_value);
}
cnode->set_abstract(input_args_info->out_value->ToAbstract()->Broaden());
DoOpGrad(op_run_info, cnode, input_args_info->out_value);
(void)CheckGraphDynamic(cnode);
SaveOutputNodeMap(out_id, op_run_info, cnode);
}
void GradExecutor::GetCustomBpropPrim(const py::object &obj, const py::args &args, const py::object &out,
const InputArgsInfoPtr &input_args_info) {
MS_EXCEPTION_IF_NULL(input_args_info);
if (!input_args_info->has_custom_bprop) {
return;
}
custom_bprop_cell_count_ -= 1;
input_args_info->custom_bprop_cell_count = custom_bprop_cell_count_;
if (custom_bprop_cell_count_ != 0) {
input_args_info->custom_bprop_cell_count -= 1;
if (input_args_info->custom_bprop_cell_count != 0) {
return;
}
MS_LOG(DEBUG) << "Get custom bprop prim";
@ -744,7 +743,7 @@ void GradExecutor::GetCustomBpropPrim(const py::object &obj, const py::args &arg
(void)input_args_info->input_arg_value_vec.emplace_back(PyNativeAlgo::DataConvert::PyObjToValue(args[i]));
}
}
input_args_info->custom_bprp_prim = fake_prim;
input_args_info->custom_bprop_prim = fake_prim;
}
void GradExecutor::CheckNeedCompileGraph(const InputArgsInfoPtr &input_args_info) {
@ -772,6 +771,7 @@ void GradExecutor::CheckNeedCompileGraph(const InputArgsInfoPtr &input_args_info
EraseTopCellFromTopCellList(pre_top_cell);
}
already_run_top_cell_[already_top_cell_id] = new_top_cell;
new_top_cell->set_force_top_cell_compile(false);
} else {
MS_LOG(DEBUG) << "No need to compile graph again";
pre_top_cell->set_input_args_id(new_top_cell->input_args_id());
@ -1042,8 +1042,7 @@ void GradExecutor::SetGradOrder(const std::string &obj_id) {
// top_cell_->obj_id_with_grad_order() include obj_id and grad_order
// If top_cell_->obj_id_with_grad_order().find(obj_id) == std::string::npos and have cell info stack, means current
// cell is not top cell, grad high order come in
if (top_cell_ == nullptr ||
(!input_args_info_stack_.empty() && top_cell_->obj_id_with_grad_order().find(obj_id) == std::string::npos)) {
if (top_cell_ == nullptr || top_cell_->obj_id_with_grad_order().find(obj_id) == std::string::npos) {
IncreaseGradOrder();
}
if (!grad_is_running_) {
@ -1107,7 +1106,11 @@ py::object GradExecutor::RunGradGraph() {
MS_LOG(DEBUG) << "Eval run " << backend;
grad_is_running_ = true;
top_cell()->set_auto_grad_cell_ptr(nullptr);
// In custom bprop, when running bprop function, top_input_args_info_ will be changed.
// So, here copy and restore after running finished.
auto top_input_args_info = top_input_args_info_;
BaseRef out_value = (*run)(arg_list);
top_input_args_info_ = top_input_args_info;
grad_is_running_ = false;
MS_LOG(DEBUG) << "Eval run end " << out_value.ToString();
const auto &cur_run_bprop_graph = resource->func_graph();
@ -1139,13 +1142,28 @@ void GradExecutor::MakeNestedCnode(bool has_custom_bprop, const std::vector<Valu
auto cnode = curr_g()->NewCNode(inputs);
auto out_value = PyNativeAlgo::DataConvert::BaseRefToValue(out);
// Get output values
if (has_custom_bprop && !out_value->isa<ValueSequence>()) {
std::vector<ValuePtr> out_v{out_value};
out_value = std::make_shared<ValueTuple>(out_v);
}
const auto &out_id = PyNativeAlgo::Common::GetIdByValue(out_value);
top_cell()->SetNodeMapInGraphInfoMap(out_id, cnode);
cnode->set_abstract(out_value->ToAbstract()->Broaden());
MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString();
MS_LOG(DEBUG) << "Nested make cnode is " << cnode->DebugString() << ", out id " << out_id;
// High grad hit cache
bool need_do_grad = true;
if (!cur_vm_compile) {
if (already_run_top_cell_.find(top_cell()->already_run_cell_id()) != already_run_top_cell_.end()) {
const auto &dynamic_nodes = cell_id_with_dynamic_detect_nodes_[top_cell()->obj_id_with_grad_order()];
MS_LOG(DEBUG) << "Cur op index " << (top_cell()->op_index() + 1) << ", outer graph all op size "
<< dynamic_nodes.size();
if (top_cell()->op_index() + 1 == dynamic_nodes.size()) {
need_do_grad = false;
}
}
}
if (!use_dynamic_shape_process_ && !need_do_grad) {
return;
}
@ -1170,11 +1188,7 @@ void GradExecutor::MakeNestedCnode(bool has_custom_bprop, const std::vector<Valu
// Get input values
ValuePtrList input_args(forward_args);
(void)input_args.insert(input_args.end(), weights_args.cbegin(), weights_args.cend());
// Get output values
if (has_custom_bprop && !out_value->isa<ValueSequence>()) {
std::vector<ValuePtr> out_v{out_value};
out_value = std::make_shared<ValueTuple>(out_v);
}
auto grad_param = std::make_shared<ad::GradParam>(cnode, input_args, out_value, second_grad_fg,
!top_cell()->is_high_order_top_cell());
if (!top_cell()->auto_grad_cell_ptr()->KPynativeWithFProp(grad_param)) {
@ -1285,7 +1299,6 @@ void GradExecutor::ClearRes() {
need_renormalize_ = false;
eliminate_forward_ = true;
use_dynamic_shape_process_ = false;
is_cell_id_in_dynamic_detect_nodes_map_ = false;
custom_bprop_cell_count_ = 0;
grad_order_ = 0;
top_cell_ = nullptr;
@ -1388,6 +1401,7 @@ AnfNodePtr GradExecutor::GetValueSequenceInput(const ValuePtr &v, const std::str
return nullptr;
}
ValuePtrList input_args;
abstract::AbstractBasePtrList abs_list;
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimMakeTuple)};
const auto &obj_tuple = v->cast<ValueSequencePtr>();
const auto &v_list = obj_tuple->value();
@ -1400,10 +1414,12 @@ AnfNodePtr GradExecutor::GetValueSequenceInput(const ValuePtr &v, const std::str
(void)input_args.emplace_back(v_arg);
const std::string &id = PyNativeAlgo::Common::GetIdByValue(v_arg);
(void)inputs.emplace_back(GetInput(v_arg, id));
(void)abs_list.emplace_back(v_arg->ToAbstract()->Broaden());
(void)GetValueSequenceInput(v_arg, id);
}
// Create make tuple node and record to graph info map.
auto cnode = curr_g()->NewCNode(inputs);
cnode->set_abstract(std::make_shared<abstract::AbstractTuple>(abs_list));
MS_LOG(DEBUG) << "Create make tuple node: " << cnode->DebugString();
(void)CheckGraphDynamic(cnode);
top_cell()->SetNodeMapInGraphInfoMap(obj_id, cnode, -1, false);
@ -1412,9 +1428,9 @@ AnfNodePtr GradExecutor::GetValueSequenceInput(const ValuePtr &v, const std::str
AnfNodePtr GradExecutor::CreateTupleGetItemNode(const std::string &obj_id,
const std::pair<AnfNodePtr, std::vector<int64_t>> &out) const {
MS_LOG(DEBUG) << "Output size: " << out.second.size();
auto c_node = out.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(c_node);
MS_LOG(DEBUG) << "Cnode " << c_node->DebugString() << ", id " << obj_id << ", out second " << out.second;
auto abs = c_node->abstract();
// Create tuple get item node
for (const auto &idx : out.second) {
@ -1430,12 +1446,11 @@ AnfNodePtr GradExecutor::CreateTupleGetItemNode(const std::string &obj_id,
}
auto prim_abs = elements[static_cast<size_t>(idx)];
MS_EXCEPTION_IF_NULL(prim_abs);
MS_LOG(DEBUG) << "Set tuple getitem abs " << prim_abs->ToString();
c_node->set_abstract(prim_abs);
}
}
(void)CheckGraphDynamic(c_node);
MS_LOG(DEBUG) << "Get input node " << c_node->ToString() << ", id " << obj_id;
MS_LOG(DEBUG) << "Create tuple getitem node " << c_node->ToString() << ", abs " << c_node->abstract()->ToString();
return c_node;
}
@ -1524,7 +1539,10 @@ void GradExecutor::SaveOutputNodeMap(const std::string &obj_id, const FrontendOp
void GradExecutor::DoOpGrad(const FrontendOpRunInfoPtr &op_run_info, const CNodePtr &cnode,
const ValuePtr &op_out) const {
MS_EXCEPTION_IF_NULL(op_run_info);
if (grad_is_running_ && !bprop_grad_stack_.top().second) {
MS_LOG(DEBUG) << "Custom bprop, no need do op grad";
return;
}
// to avoid out exist in tape bprop, avoid out be modified.
ValuePtrList cloned_op_args;
(void)std::transform(op_run_info->input_value.begin(), op_run_info->input_value.end(),
@ -1826,7 +1844,7 @@ void GradExecutor::SaveDynamicDetectNodeInfoInFirstTime(const AnfNodePtr &anf_no
bool GradExecutor::IsGraphDynamic(const AnfNodePtr &anf_node, const size_t node_idx, bool is_ms_function_node,
const std::string &graph_phase) const {
MS_EXCEPTION_IF_NULL(anf_node);
if (!is_cell_id_in_dynamic_detect_nodes_map_) {
if (!top_cell()->is_cell_id_in_dynamic_detect_nodes_map()) {
SaveDynamicDetectNodeInfoInFirstTime(anf_node, node_idx, is_ms_function_node, graph_phase);
// The net is regarded as a static net by default in the first time.
return false;

View File

@ -200,7 +200,7 @@ class GradExecutor {
bool eliminate_forward_{true};
mutable bool use_dynamic_shape_process_{false};
mutable bool is_cell_id_in_dynamic_detect_nodes_map_{false};
int custom_bprop_cell_count_{0};
size_t custom_bprop_cell_count_{0};
size_t obj_order_{0};
// If grad_order=1, indicate first derivative; grad_order=2, indicate second derivative; ...
size_t grad_order_{0};

View File

@ -89,6 +89,10 @@ class TopCellInfo {
inline bool is_high_order_top_cell() const { return is_high_order_top_cell_; }
inline void set_need_do_final_opt(bool need_do_final_opt) { need_do_final_opt_ = need_do_final_opt; }
inline bool need_do_final_opt() const { return need_do_final_opt_; }
inline void set_is_cell_id_in_dynamic_detect_nodes_map(bool is_cell_id_in_dynamic_detect_nodes_map) {
is_cell_id_in_dynamic_detect_nodes_map_ = is_cell_id_in_dynamic_detect_nodes_map;
}
inline bool is_cell_id_in_dynamic_detect_nodes_map() const { return is_cell_id_in_dynamic_detect_nodes_map_; }
inline pipeline::ResourcePtr resource() const { return resource_; }
inline FuncGraphPtr fg() const {
MS_EXCEPTION_IF_NULL(fg_);
@ -162,8 +166,9 @@ class TopCellInfo {
bool force_top_cell_compile_{false};
bool is_high_order_top_cell_{false};
bool need_do_final_opt_{false};
size_t op_index_{0};
bool is_cell_id_in_dynamic_detect_nodes_map_{false};
size_t grad_order_{0};
size_t op_index_{0};
std::string obj_id_with_grad_order_;
std::string cell_id_;
std::string already_run_cell_id_;

View File

@ -190,7 +190,7 @@ def train_process_bert_thor(q, device_id, epoch_size, device_num):
q.put({'loss': loss_list, 'cost': per_step_mseconds})
@pytest.mark.level0
@pytest.mark.level1
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_single