Optimize grad and tail generating procedure.

This commit is contained in:
Zhang Qinghua 2022-05-06 20:28:18 +08:00
parent c29d6bb764
commit 141335dc3b
5 changed files with 223 additions and 203 deletions
mindspore
ccsrc
frontend/operator/composite
pipeline/jit/static_analysis
python/mindspore/nn/grad

View File

@ -38,11 +38,10 @@ namespace mindspore {
// namespace to support composite operators definition
namespace prim {
constexpr auto kStepDefault = 1;
using AbstractTensor = mindspore::abstract::AbstractTensor;
using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure;
using mindspore::abstract::AbstractAttribute;
using mindspore::abstract::AbstractBase;
using mindspore::abstract::AbstractBasePtr;
using mindspore::abstract::AbstractDictionary;
using mindspore::abstract::AbstractDictionaryPtr;
using mindspore::abstract::AbstractEllipsis;
@ -50,10 +49,17 @@ using mindspore::abstract::AbstractEllipsisPtr;
using mindspore::abstract::AbstractFunction;
using mindspore::abstract::AbstractFunctionPtr;
using mindspore::abstract::AbstractList;
using mindspore::abstract::AbstractListPtr;
using mindspore::abstract::AbstractNone;
using mindspore::abstract::AbstractScalar;
using mindspore::abstract::AbstractSequence;
using mindspore::abstract::AbstractSequencePtr;
using mindspore::abstract::AbstractSlice;
using mindspore::abstract::AbstractTensor;
using mindspore::abstract::AbstractTuple;
using mindspore::abstract::AbstractTuplePtr;
using mindspore::abstract::AbstractUndetermined;
using mindspore::abstract::FuncGraphAbstractClosure;
void HyperMap::Init() {
if (fn_leaf_) {
@ -344,170 +350,6 @@ abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList
return broadened;
}
namespace {
bool CheckSequenceAllTensor(const abstract::AbstractTuplePtr &tuple) {
MS_EXCEPTION_IF_NULL(tuple);
for (size_t i = 0; i < tuple->size(); ++i) {
if (!(*tuple)[i]->isa<abstract::AbstractUndetermined>() &&
!((*tuple)[i]->isa<abstract::AbstractTuple>() &&
CheckSequenceAllTensor((*tuple)[i]->cast<abstract::AbstractTuplePtr>()))) {
return false;
}
}
return true;
}
bool EnableGradForScalar(const abstract::AbstractBasePtr &abs) {
return MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && abs->BuildType() != nullptr &&
abs->BuildType()->isa<Number>();
}
bool EnableGradForTuple(const abstract::AbstractBasePtr &abs, bool enable_tuple_grad) {
return abs->isa<abstract::AbstractTuple>() && enable_tuple_grad &&
CheckSequenceAllTensor(abs->cast<abstract::AbstractTuplePtr>());
}
bool CheckTailGradFristSequence(const abstract::AbstractSequencePtr &sequeue, bool enable_tuple_grad) {
MS_EXCEPTION_IF_NULL(sequeue);
return sequeue->size() > 1 && (*sequeue)[1] != nullptr &&
((*sequeue)[1]->isa<abstract::AbstractUndetermined>() || (*sequeue)[1]->BuildValue() == kAnyValue ||
EnableGradForScalar((*sequeue)[1]) || EnableGradForTuple((*sequeue)[1], enable_tuple_grad));
}
void GenerateSequenceFuncGraphByPosition(const FuncGraphPtr &res, const abstract::AbstractSequencePtr &sequeue,
const abstract::AbstractSequencePtr &pos, bool enable_tuple_grad) {
if (pos == nullptr) {
MS_LOG(EXCEPTION) << "Return grad by position, but the grad_position is empty!";
}
AnfNodePtr tuple_parameter = res->add_parameter();
std::vector<AnfNodePtr> pos_elements;
PrimitivePtr pos_op = nullptr;
if (pos->isa<AbstractTuple>()) {
pos_elements.push_back(NewValueNode(prim::kPrimMakeTuple));
pos_op = prim::kPrimTupleGetItem;
} else {
pos_elements.push_back(NewValueNode(prim::kPrimMakeList));
pos_op = prim::kPrimListGetItem;
}
AnfNodePtr pos_value = nullptr;
AnfNodePtr pos_value_adjust = nullptr;
auto pos_parameter = res->add_parameter();
if (pos->size() == 1) {
pos_value = res->NewCNode({NewValueNode(pos_op), pos_parameter, NewValueNode(SizeToLong(0))});
pos_value_adjust = res->NewCNode({NewValueNode(prim::kPrimScalarAdd), pos_value, NewValueNode(SizeToLong(1))});
if (CheckTailGradFristSequence(sequeue, enable_tuple_grad)) {
res->set_output(res->NewCNode({NewValueNode(pos_op), tuple_parameter, pos_value_adjust}));
} else {
res->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{})));
}
} else {
for (size_t i = 0; i < pos->size(); ++i) {
pos_value = res->NewCNode({NewValueNode(pos_op), pos_parameter, NewValueNode(SizeToLong(i))});
pos_value_adjust = res->NewCNode({NewValueNode(prim::kPrimScalarAdd), pos_value, NewValueNode(SizeToLong(1))});
pos_elements.push_back(res->NewCNodeInOrder({NewValueNode(pos_op), tuple_parameter, pos_value_adjust}));
}
if (pos_elements.size() > 1) {
res->set_output(res->NewCNodeInOrder(pos_elements));
} else if (pos->isa<AbstractTuple>()) { // Empty tuple.
auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList());
auto empty_tuple = NewValueNode(empty_tuple_value);
res->set_output(empty_tuple);
} else { // Empty list.
auto empty_list_value = std::make_shared<ValueList>(ValuePtrList());
auto empty_list = NewValueNode(empty_list_value);
res->set_output(empty_list);
}
}
}
} // namespace
FuncGraphPtr Tail::GenerateSequenceFuncGraph(const abstract::AbstractSequencePtr &sequeue,
const abstract::AbstractSequencePtr &pos) const {
MS_EXCEPTION_IF_NULL(sequeue);
FuncGraphPtr res = std::make_shared<FuncGraph>();
res->set_flag(FUNC_GRAPH_FLAG_CORE, true);
res->debug_info()->set_name("tail");
if (tail_type_ == kGradFirst) {
AnfNodePtr tuple_parameter = res->add_parameter();
PrimitivePtr getitem_op = nullptr;
if (sequeue->isa<AbstractTuple>()) {
getitem_op = prim::kPrimTupleGetItem;
} else {
getitem_op = prim::kPrimListGetItem;
}
if (CheckTailGradFristSequence(sequeue, enable_tuple_grad_)) {
res->set_output(res->NewCNode({NewValueNode(getitem_op), tuple_parameter, NewValueNode(SizeToLong(1))}));
} else {
res->set_output(NewValueNode(std::make_shared<ValueTuple>(ValuePtrList())));
}
return res;
}
if (tail_type_ == kGradByPosition) {
GenerateSequenceFuncGraphByPosition(res, sequeue, pos, enable_tuple_grad_);
return res;
}
AnfNodePtr tuple_parameter = res->add_parameter();
std::vector<AnfNodePtr> elements;
PrimitivePtr op = nullptr;
if (sequeue->isa<AbstractTuple>()) {
elements.push_back(NewValueNode(prim::kPrimMakeTuple));
op = prim::kPrimTupleGetItem;
} else {
elements.push_back(NewValueNode(prim::kPrimMakeList));
op = prim::kPrimListGetItem;
}
for (size_t i = 1; i < sequeue->size(); ++i) {
if (tail_type_ == kGradAll) {
MS_EXCEPTION_IF_NULL((*sequeue)[i]);
if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>() || (*sequeue)[i]->BuildValue() == kAnyValue ||
EnableGradForScalar((*sequeue)[i])) {
elements.push_back(res->NewCNodeInOrder({NewValueNode(op), tuple_parameter, NewValueNode(SizeToLong(i))}));
}
} else {
elements.push_back(res->NewCNodeInOrder({NewValueNode(op), tuple_parameter, NewValueNode(SizeToLong(i))}));
}
}
if (elements.size() > 1) {
res->set_output(res->NewCNodeInOrder(elements));
return res;
} else if (sequeue->isa<AbstractTuple>()) { // Empty tuple.
auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList());
auto empty_tuple = NewValueNode(empty_tuple_value);
res->set_output(empty_tuple);
return res;
} else { // Empty list.
auto empty_list_value = std::make_shared<ValueList>(ValuePtrList());
auto empty_list = NewValueNode(empty_list_value);
res->set_output(empty_list);
return res;
}
}
FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
if (args_spec_list.size() < 1) {
MS_LOG(EXCEPTION) << "Tail requires a non-empty tuple.";
}
AbstractBasePtr a = args_spec_list[0];
if (a->isa<AbstractTuple>() || a->isa<AbstractList>()) {
if (args_spec_list.size() > 1) {
AbstractBasePtr pos = args_spec_list[1];
if (pos->isa<AbstractTuple>() || pos->isa<AbstractList>()) {
return GenerateSequenceFuncGraph(a->cast<abstract::AbstractSequencePtr>(),
pos->cast<abstract::AbstractSequencePtr>());
}
MS_LOG(EXCEPTION) << "'Tail' arg1 must be AbstractTuple or AbstractList, but got " << pos->ToString();
}
return GenerateSequenceFuncGraph(a->cast<abstract::AbstractSequencePtr>());
}
MS_LOG(EXCEPTION) << "'Tail' arg0 must be AbstractTuple or AbstractList, but got " << a->ToString();
}
FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
int64_t tuple_size = SizeToLong(args_spec_list.size());
@ -592,6 +434,191 @@ FuncGraphPtr MakeListGradient::GenerateFuncGraph(const AbstractBasePtrList &args
return fg;
}
namespace {
bool IsTupleAllTensor(const AbstractTuplePtr &tuple_arg) {
MS_EXCEPTION_IF_NULL(tuple_arg);
for (size_t i = 0; i < tuple_arg->size(); ++i) {
if (!(*tuple_arg)[i]->isa<AbstractUndetermined>() &&
!((*tuple_arg)[i]->isa<AbstractTuple>() && IsTupleAllTensor((*tuple_arg)[i]->cast<AbstractTuplePtr>()))) {
return false;
}
}
return true;
}
bool EnableGradFirstForTuple(const AbstractBasePtr &abs, bool enable_tuple_grad) {
return abs->isa<AbstractTuple>() && enable_tuple_grad && IsTupleAllTensor(abs->cast<AbstractTuplePtr>());
}
bool EnableGradForScalar(const AbstractBasePtr &abs) {
return MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && abs->BuildType() != nullptr &&
abs->BuildType()->isa<Number>();
}
bool CanGradArgument(const AbstractTuplePtr &tuple_arg, size_t pos) {
MS_EXCEPTION_IF_NULL(tuple_arg);
return tuple_arg->size() > pos && (*tuple_arg)[pos] != nullptr &&
((*tuple_arg)[pos]->isa<AbstractUndetermined>() || (*tuple_arg)[pos]->BuildValue() == kAnyValue ||
EnableGradForScalar((*tuple_arg)[pos]));
}
void GenerateFuncGraphByPosition(const FuncGraphPtr &fg, const AbstractTuplePtr &tuple_arg,
const AbstractTuplePtr &pos) {
if (pos == nullptr) {
MS_LOG(EXCEPTION) << "Return grad by position, but the grad_position is empty!";
}
AnfNodePtr tuple_parameter = fg->add_parameter();
(void)fg->add_parameter(); // The 'grad_position' parameter.
// Collect all parameters by 'grad_position'.
std::vector<AnfNodePtr> pos_elements = {NewValueNode(prim::kPrimMakeTuple)};
CNodePtr current_element = nullptr;
for (size_t i = 0; i < pos->size(); ++i) {
auto val = pos->elements()[i]->BuildValue();
MS_EXCEPTION_IF_NULL(val);
auto int_val = LongToSize(dyn_cast<Int64Imm>(val)->value());
++int_val; // Ignore the env position.
if (!CanGradArgument(tuple_arg, int_val)) {
continue;
}
current_element =
fg->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), tuple_parameter, NewValueNode(SizeToLong(int_val))});
pos_elements.push_back(current_element);
}
// The returned result may vary for grad result element number.
// A single value if only one result, a tuple for multiple results, or a empty tuple for no result.
//
// Notice that even if the user set 'grad_position' as multiple choices,
// the 'CanGradArgument' may change it to only one choice or none choice.
constexpr size_t args_least_size = 2;
if (pos_elements.size() == args_least_size) {
fg->set_output(current_element);
} else if (pos_elements.size() > args_least_size) {
fg->set_output(fg->NewCNodeInOrder(pos_elements));
} else { // The 'pos' is empty AbstractTuple or empty AbstractList.
auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList());
auto empty_tuple = NewValueNode(empty_tuple_value);
fg->set_output(empty_tuple);
}
}
} // namespace
FuncGraphPtr Tail::GenerateTailFuncGraph(const AbstractSequencePtr &sequence_arg) const {
MS_EXCEPTION_IF_NULL(sequence_arg);
FuncGraphPtr fg = std::make_shared<FuncGraph>();
fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
fg->debug_info()->set_name("tail");
AnfNodePtr tuple_parameter = fg->add_parameter();
std::vector<AnfNodePtr> elements;
PrimitivePtr op = nullptr;
if (sequence_arg->isa<AbstractTuple>()) {
elements.emplace_back(NewValueNode(prim::kPrimMakeTuple));
op = prim::kPrimTupleGetItem;
} else {
elements.emplace_back(NewValueNode(prim::kPrimMakeList));
op = prim::kPrimListGetItem;
}
// Remove the first element to make a new sequence.
for (size_t i = 1; i < sequence_arg->size(); ++i) {
elements.push_back(fg->NewCNodeInOrder({NewValueNode(op), tuple_parameter, NewValueNode(SizeToLong(i))}));
}
if (elements.size() > 1) {
fg->set_output(fg->NewCNodeInOrder(elements));
return fg;
}
// No element left, return empty tuple.
if (sequence_arg->isa<AbstractTuple>()) {
auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList());
auto empty_tuple = NewValueNode(empty_tuple_value);
fg->set_output(empty_tuple);
}
// No element left, return empty list.
auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList());
auto empty_tuple = NewValueNode(empty_tuple_value);
fg->set_output(empty_tuple);
return fg;
}
FuncGraphPtr Tail::GenerateGradFuncGraph(const AbstractTuplePtr &tuple_arg, const AbstractTuplePtr &position) const {
MS_EXCEPTION_IF_NULL(tuple_arg);
FuncGraphPtr fg = std::make_shared<FuncGraph>();
fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
fg->debug_info()->set_name("grad_tail");
if (tail_type_ == kGradFirst) {
AnfNodePtr tuple_parameter = fg->add_parameter();
PrimitivePtr getitem_op = prim::kPrimTupleGetItem;
if (CanGradArgument(tuple_arg, 1) || EnableGradFirstForTuple((*tuple_arg)[1], enable_tuple_grad_first_)) {
fg->set_output(fg->NewCNode({NewValueNode(getitem_op), tuple_parameter, NewValueNode(SizeToLong(1))}));
} else {
fg->set_output(NewValueNode(std::make_shared<ValueTuple>(ValuePtrList())));
}
return fg;
}
if (tail_type_ == kGradByPosition) {
GenerateFuncGraphByPosition(fg, tuple_arg, position);
return fg;
}
if (tail_type_ == kGradAll) {
AnfNodePtr tuple_parameter = fg->add_parameter();
std::vector<AnfNodePtr> elements = {NewValueNode(prim::kPrimMakeTuple)};
PrimitivePtr op = prim::kPrimTupleGetItem;
for (size_t i = 1; i < tuple_arg->size(); ++i) {
MS_EXCEPTION_IF_NULL((*tuple_arg)[i]);
if (CanGradArgument(tuple_arg, i)) {
elements.push_back(fg->NewCNodeInOrder({NewValueNode(op), tuple_parameter, NewValueNode(SizeToLong(i))}));
}
}
if (elements.size() > 1) {
fg->set_output(fg->NewCNodeInOrder(elements));
return fg;
}
// Empty tuple.
auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList());
auto empty_tuple = NewValueNode(empty_tuple_value);
fg->set_output(empty_tuple);
return fg;
}
MS_LOG(EXCEPTION) << "'tail_type_' is not for GradOperation, but " << tail_type_;
}
FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
// To handle normal tail.
if (args_spec_list.size() < 1) {
MS_LOG(EXCEPTION) << "'Tail' requires at least 1 argument, but got " << args_spec_list.size();
}
if (tail_type_ >= kNotGrad) {
AbstractSequencePtr sequence_arg = dyn_cast<AbstractSequence>(args_spec_list[0]);
if (sequence_arg == nullptr) {
MS_LOG(EXCEPTION) << "'Tail' arg0 must be tuple or list, but got " << sequence_arg->ToString();
}
return GenerateTailFuncGraph(sequence_arg);
}
// To handle for GradOperation tail.
constexpr size_t args_max_size = 2;
if (args_spec_list.size() > args_max_size) {
MS_LOG(EXCEPTION) << "'Tail' requires at most 2 arguments for GradOperation, but got " << args_spec_list.size();
}
AbstractTuplePtr tuple_arg = dyn_cast<AbstractTuple>(args_spec_list[0]);
if (tuple_arg == nullptr) {
MS_LOG(EXCEPTION) << "'Tail' arg0 must be tuple, but got " << tuple_arg->ToString();
}
if (args_spec_list.size() == args_max_size) {
AbstractTuplePtr pos = dyn_cast<AbstractTuple>(args_spec_list[1]);
if (pos == nullptr) {
MS_LOG(EXCEPTION) << "'Tail' arg1 'position' must be tuple, but got " << pos->ToString();
}
return GenerateGradFuncGraph(tuple_arg, pos);
}
return GenerateGradFuncGraph(tuple_arg);
}
GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param,
bool get_by_position)
: MetaFuncGraph(name),
@ -705,7 +732,7 @@ void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePt
// b_app returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...),
// so obtain first input grad by setting tail_type of Tail to kGradFirst.
TailPtr tail_grad_first = std::make_shared<Tail>("tail_grad_first", kGradFirst);
tail_grad_first->set_enable_tuple_grad(enable_tuple_grad);
tail_grad_first->set_enable_tuple_grad_first(enable_tuple_grad);
k_child->set_output(k_child->NewCNodeInOrder({NewValueNode(tail_grad_first), b_app}));
}
@ -761,7 +788,8 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
FuncGraphPtr k_child = nullptr;
{
TraceGuard guard(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
k_child = GetGrad(j, weights, position, forward_graph->parameters(), forward_graph->has_flag("enable_tuple_grad"));
k_child =
GetGrad(j, weights, position, forward_graph->parameters(), forward_graph->has_flag("enable_tuple_grad_first"));
}
grad_fg->set_output(NewValueNode(k_child));
@ -819,9 +847,9 @@ ValuePtr CheckAxes(const AbstractBasePtr &axes_abs, const bool &is_in_axes = fal
ValuePtr axes_value = nullptr;
auto axes_name = is_in_axes ? "in_axes" : "out_axes";
auto axes_abs_sequence = dyn_cast<abstract::AbstractSequence>(axes_abs);
auto axes_abs_sequence = dyn_cast<AbstractSequence>(axes_abs);
if (axes_abs_sequence != nullptr) {
axes_value = axes_abs->cast<abstract::AbstractSequencePtr>()->ElementsBuildValue<ValueTuple>();
axes_value = axes_abs->cast<AbstractSequencePtr>()->ElementsBuildValue<ValueTuple>();
MS_EXCEPTION_IF_NULL(axes_value);
if (is_in_axes) {
ValueSequencePtr in_axes_seq = dyn_cast<ValueSequence>(axes_value);
@ -991,8 +1019,8 @@ FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_li
AbstractBasePtr abs_a = args_spec_list[0];
AbstractBasePtr abs_b = args_spec_list[1];
abstract::AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(abs_a);
abstract::AbstractTuplePtr b_tuple = dyn_cast<AbstractTuple>(abs_b);
AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(abs_a);
AbstractTuplePtr b_tuple = dyn_cast<AbstractTuple>(abs_b);
if (a_tuple == nullptr || b_tuple == nullptr) {
TypePtrList types;
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types),
@ -1055,7 +1083,7 @@ int64_t CheckSliceMember(const AbstractBasePtr &member, int64_t default_value, c
return default_value;
}
if (member->isa<abstract::AbstractTensor>()) {
if (member->isa<AbstractTensor>()) {
MS_EXCEPTION(TypeError)
<< "The argument of SliceMember operator must be a Scalar or None or constant Tensor, but got a variable Tensor";
}
@ -1064,7 +1092,7 @@ int64_t CheckSliceMember(const AbstractBasePtr &member, int64_t default_value, c
<< member->BuildType()->ToString();
}
std::tuple<int64_t, int64_t, int64_t> GenerateTupleSliceParameter(const abstract::AbstractSequencePtr &sequence,
std::tuple<int64_t, int64_t, int64_t> GenerateTupleSliceParameter(const AbstractSequencePtr &sequence,
const AbstractSlicePtr &slice) {
MS_EXCEPTION_IF_NULL(sequence);
MS_EXCEPTION_IF_NULL(slice);
@ -1120,7 +1148,7 @@ std::tuple<int64_t, int64_t, int64_t> GenerateTupleSliceParameter(const abstract
void SequenceSliceGetItem::CheckArgs(const AbstractBasePtrList &args_spec_list) {
constexpr size_t arg_size = 2;
abstract::CheckArgsSize(this->name(), args_spec_list, arg_size);
sequence_ = abstract::CheckArg<abstract::AbstractSequence>(this->name(), args_spec_list, 0);
sequence_ = abstract::CheckArg<AbstractSequence>(this->name(), args_spec_list, 0);
slice_ = abstract::CheckArg<AbstractSlice>(this->name(), args_spec_list, 1);
}

View File

@ -103,20 +103,22 @@ enum TailType { kGradAll, kGradFirst, kGradByPosition, kNotGrad };
class Tail : public MetaFuncGraph {
public:
explicit Tail(const std::string &name, TailType tail_type = kNotGrad)
: MetaFuncGraph(name), tail_type_(tail_type), enable_tuple_grad_(false) {}
: MetaFuncGraph(name), tail_type_(tail_type), enable_tuple_grad_first_(false) {}
~Tail() override = default;
MS_DECLARE_PARENT(Tail, MetaFuncGraph)
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
FuncGraphPtr GenerateSequenceFuncGraph(const abstract::AbstractSequencePtr &sequeue,
const abstract::AbstractSequencePtr &pos = nullptr) const;
friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; }
void set_enable_tuple_grad(bool enable_tuple_grad) { enable_tuple_grad_ = enable_tuple_grad; }
void set_enable_tuple_grad_first(bool enable_tuple_grad_first) { enable_tuple_grad_first_ = enable_tuple_grad_first; }
private:
FuncGraphPtr GenerateTailFuncGraph(const abstract::AbstractSequencePtr &sequence_arg) const;
FuncGraphPtr GenerateGradFuncGraph(const abstract::AbstractTuplePtr &sequeue,
const abstract::AbstractTuplePtr &position = nullptr) const;
TailType tail_type_;
bool enable_tuple_grad_;
bool enable_tuple_grad_first_;
};
using TailPtr = std::shared_ptr<Tail>;
@ -148,7 +150,7 @@ class GradOperation : public MetaFuncGraph {
MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
FuncGraphPtr GetGrad(const AnfNodePtr &j, const AnfNodePtr &weights, const AnfNodePtr &position,
const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad,
const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad_first,
const std::vector<AnfNodePtr> &weight_args = {});
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
@ -163,7 +165,7 @@ class GradOperation : public MetaFuncGraph {
private:
void GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
const AnfNodePtr &weights, const AnfNodePtr &position, bool enable_tuple_grad);
const AnfNodePtr &weights, const AnfNodePtr &position, bool enable_tuple_grad_first);
};
using GradOperationPtr = std::shared_ptr<GradOperation>;

View File

@ -657,22 +657,12 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr>
}
}
void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator, const FuncGraphPtr &possible_parent_fg) {
MS_EXCEPTION_IF_NULL(evaluator);
void AnalysisEngine::SetUndeterminedFlag(const FuncGraphPtr &possible_parent_fg) {
MS_EXCEPTION_IF_NULL(possible_parent_fg);
static std::mutex fg_lock;
std::lock_guard<std::mutex> infer_lock(fg_lock);
if (possible_parent_fg != nullptr) {
possible_parent_fg->set_flag(kFuncGraphFlagUndetermined, true);
MS_LOG(DEBUG) << "Set graph undetermined: " << possible_parent_fg->ToString();
return;
}
auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>();
if (fg_eval == nullptr) {
return;
}
auto fg = fg_eval->func_graph();
MS_EXCEPTION_IF_NULL(fg);
MS_LOG(EXCEPTION) << "cannot set Undetermined flag for fg: " << fg->ToString();
possible_parent_fg->set_flag(kFuncGraphFlagUndetermined, true);
MS_LOG(DEBUG) << "Set graph undetermined flag for " << possible_parent_fg->ToString();
}
EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators,
@ -960,7 +950,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
static std::atomic<int> id_count{0};
std::string thread_id = AnalysisSchedule::thread_id() + "." + std::to_string(id_count.fetch_add(1));
MS_EXCEPTION_IF_NULL(evaluator);
SetUndeterminedFlag(evaluator, possible_parent_fg);
SetUndeterminedFlag(possible_parent_fg);
AsyncAbstractPtr async_result_branch = std::make_shared<AsyncAbstract>();
// Control the order to run.
AsyncAbstractPtr control_run_order = std::make_shared<AsyncAbstract>();
@ -1032,7 +1022,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
auto possible_parent_fg = out_conf->node()->func_graph();
for (auto eval : evaluators) {
MS_EXCEPTION_IF_NULL(eval);
SetUndeterminedFlag(eval, possible_parent_fg);
SetUndeterminedFlag(possible_parent_fg);
const auto current_inf = EvaluatorArgs(eval, args_spec_list);
MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString();
// If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating.

View File

@ -346,7 +346,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
static EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs, const AnfNodePtr &node);
private:
void SetUndeterminedFlag(const EvaluatorPtr &evaluator, const FuncGraphPtr &possible_parent_fg);
void SetUndeterminedFlag(const FuncGraphPtr &possible_parent_fg);
EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval,
const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it,
bool *continue_flag);

View File

@ -102,9 +102,9 @@ class Jvp(Cell):
self.fn = fn
self.oneslike = P.OnesLike()
self.first_grad = _FirstGrad(fn)
self.first_grad.add_flags(enable_tuple_grad=True)
self.first_grad.add_flags(enable_tuple_grad_first=True)
self.first_grad_single_value = _FirstGradSingleValue(fn)
self.first_grad_single_value.add_flags(enable_tuple_grad=True)
self.first_grad_single_value.add_flags(enable_tuple_grad_first=True)
self.second_grad_op = C.GradOperation(sens_param=True)
self.issubclass_ = P.IsSubClass()
self.typeof = Primitive('typeof')
@ -142,9 +142,9 @@ class _JvpInner(Cell):
super(_JvpInner, self).__init__()
self.oneslike = P.OnesLike()
self.first_grad = _JvpFirstGrad()
self.first_grad.add_flags(enable_tuple_grad=True)
self.first_grad.add_flags(enable_tuple_grad_first=True)
self.first_grad_single_value = _JvpFirstGradSingleValue()
self.first_grad_single_value.add_flags(enable_tuple_grad=True)
self.first_grad_single_value.add_flags(enable_tuple_grad_first=True)
self.second_grad_op = C.GradOperation(sens_param=True)
self.issubclass_ = P.IsSubClass()
self.typeof = Primitive('typeof')