Optimize grad and tail generating procedure.
This commit is contained in:
parent
c29d6bb764
commit
141335dc3b
mindspore
ccsrc
frontend/operator/composite
pipeline/jit/static_analysis
python/mindspore/nn/grad
|
@ -38,11 +38,10 @@ namespace mindspore {
|
|||
// namespace to support composite operators definition
|
||||
namespace prim {
|
||||
constexpr auto kStepDefault = 1;
|
||||
using AbstractTensor = mindspore::abstract::AbstractTensor;
|
||||
using FuncGraphAbstractClosure = mindspore::abstract::FuncGraphAbstractClosure;
|
||||
|
||||
using mindspore::abstract::AbstractAttribute;
|
||||
using mindspore::abstract::AbstractBase;
|
||||
using mindspore::abstract::AbstractBasePtr;
|
||||
using mindspore::abstract::AbstractDictionary;
|
||||
using mindspore::abstract::AbstractDictionaryPtr;
|
||||
using mindspore::abstract::AbstractEllipsis;
|
||||
|
@ -50,10 +49,17 @@ using mindspore::abstract::AbstractEllipsisPtr;
|
|||
using mindspore::abstract::AbstractFunction;
|
||||
using mindspore::abstract::AbstractFunctionPtr;
|
||||
using mindspore::abstract::AbstractList;
|
||||
using mindspore::abstract::AbstractListPtr;
|
||||
using mindspore::abstract::AbstractNone;
|
||||
using mindspore::abstract::AbstractScalar;
|
||||
using mindspore::abstract::AbstractSequence;
|
||||
using mindspore::abstract::AbstractSequencePtr;
|
||||
using mindspore::abstract::AbstractSlice;
|
||||
using mindspore::abstract::AbstractTensor;
|
||||
using mindspore::abstract::AbstractTuple;
|
||||
using mindspore::abstract::AbstractTuplePtr;
|
||||
using mindspore::abstract::AbstractUndetermined;
|
||||
using mindspore::abstract::FuncGraphAbstractClosure;
|
||||
|
||||
void HyperMap::Init() {
|
||||
if (fn_leaf_) {
|
||||
|
@ -344,170 +350,6 @@ abstract::AbstractBasePtrList HyperMap::NormalizeArgs(const AbstractBasePtrList
|
|||
return broadened;
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool CheckSequenceAllTensor(const abstract::AbstractTuplePtr &tuple) {
|
||||
MS_EXCEPTION_IF_NULL(tuple);
|
||||
for (size_t i = 0; i < tuple->size(); ++i) {
|
||||
if (!(*tuple)[i]->isa<abstract::AbstractUndetermined>() &&
|
||||
!((*tuple)[i]->isa<abstract::AbstractTuple>() &&
|
||||
CheckSequenceAllTensor((*tuple)[i]->cast<abstract::AbstractTuplePtr>()))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool EnableGradForScalar(const abstract::AbstractBasePtr &abs) {
|
||||
return MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && abs->BuildType() != nullptr &&
|
||||
abs->BuildType()->isa<Number>();
|
||||
}
|
||||
|
||||
bool EnableGradForTuple(const abstract::AbstractBasePtr &abs, bool enable_tuple_grad) {
|
||||
return abs->isa<abstract::AbstractTuple>() && enable_tuple_grad &&
|
||||
CheckSequenceAllTensor(abs->cast<abstract::AbstractTuplePtr>());
|
||||
}
|
||||
|
||||
bool CheckTailGradFristSequence(const abstract::AbstractSequencePtr &sequeue, bool enable_tuple_grad) {
|
||||
MS_EXCEPTION_IF_NULL(sequeue);
|
||||
return sequeue->size() > 1 && (*sequeue)[1] != nullptr &&
|
||||
((*sequeue)[1]->isa<abstract::AbstractUndetermined>() || (*sequeue)[1]->BuildValue() == kAnyValue ||
|
||||
EnableGradForScalar((*sequeue)[1]) || EnableGradForTuple((*sequeue)[1], enable_tuple_grad));
|
||||
}
|
||||
|
||||
void GenerateSequenceFuncGraphByPosition(const FuncGraphPtr &res, const abstract::AbstractSequencePtr &sequeue,
|
||||
const abstract::AbstractSequencePtr &pos, bool enable_tuple_grad) {
|
||||
if (pos == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Return grad by position, but the grad_position is empty!";
|
||||
}
|
||||
AnfNodePtr tuple_parameter = res->add_parameter();
|
||||
std::vector<AnfNodePtr> pos_elements;
|
||||
PrimitivePtr pos_op = nullptr;
|
||||
if (pos->isa<AbstractTuple>()) {
|
||||
pos_elements.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
pos_op = prim::kPrimTupleGetItem;
|
||||
} else {
|
||||
pos_elements.push_back(NewValueNode(prim::kPrimMakeList));
|
||||
pos_op = prim::kPrimListGetItem;
|
||||
}
|
||||
AnfNodePtr pos_value = nullptr;
|
||||
AnfNodePtr pos_value_adjust = nullptr;
|
||||
auto pos_parameter = res->add_parameter();
|
||||
if (pos->size() == 1) {
|
||||
pos_value = res->NewCNode({NewValueNode(pos_op), pos_parameter, NewValueNode(SizeToLong(0))});
|
||||
pos_value_adjust = res->NewCNode({NewValueNode(prim::kPrimScalarAdd), pos_value, NewValueNode(SizeToLong(1))});
|
||||
if (CheckTailGradFristSequence(sequeue, enable_tuple_grad)) {
|
||||
res->set_output(res->NewCNode({NewValueNode(pos_op), tuple_parameter, pos_value_adjust}));
|
||||
} else {
|
||||
res->set_output(NewValueNode(std::make_shared<ValueTuple>(std::vector<ValuePtr>{})));
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < pos->size(); ++i) {
|
||||
pos_value = res->NewCNode({NewValueNode(pos_op), pos_parameter, NewValueNode(SizeToLong(i))});
|
||||
pos_value_adjust = res->NewCNode({NewValueNode(prim::kPrimScalarAdd), pos_value, NewValueNode(SizeToLong(1))});
|
||||
pos_elements.push_back(res->NewCNodeInOrder({NewValueNode(pos_op), tuple_parameter, pos_value_adjust}));
|
||||
}
|
||||
if (pos_elements.size() > 1) {
|
||||
res->set_output(res->NewCNodeInOrder(pos_elements));
|
||||
} else if (pos->isa<AbstractTuple>()) { // Empty tuple.
|
||||
auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList());
|
||||
auto empty_tuple = NewValueNode(empty_tuple_value);
|
||||
res->set_output(empty_tuple);
|
||||
} else { // Empty list.
|
||||
auto empty_list_value = std::make_shared<ValueList>(ValuePtrList());
|
||||
auto empty_list = NewValueNode(empty_list_value);
|
||||
res->set_output(empty_list);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
FuncGraphPtr Tail::GenerateSequenceFuncGraph(const abstract::AbstractSequencePtr &sequeue,
|
||||
const abstract::AbstractSequencePtr &pos) const {
|
||||
MS_EXCEPTION_IF_NULL(sequeue);
|
||||
|
||||
FuncGraphPtr res = std::make_shared<FuncGraph>();
|
||||
res->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
res->debug_info()->set_name("tail");
|
||||
|
||||
if (tail_type_ == kGradFirst) {
|
||||
AnfNodePtr tuple_parameter = res->add_parameter();
|
||||
PrimitivePtr getitem_op = nullptr;
|
||||
if (sequeue->isa<AbstractTuple>()) {
|
||||
getitem_op = prim::kPrimTupleGetItem;
|
||||
} else {
|
||||
getitem_op = prim::kPrimListGetItem;
|
||||
}
|
||||
if (CheckTailGradFristSequence(sequeue, enable_tuple_grad_)) {
|
||||
res->set_output(res->NewCNode({NewValueNode(getitem_op), tuple_parameter, NewValueNode(SizeToLong(1))}));
|
||||
} else {
|
||||
res->set_output(NewValueNode(std::make_shared<ValueTuple>(ValuePtrList())));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
if (tail_type_ == kGradByPosition) {
|
||||
GenerateSequenceFuncGraphByPosition(res, sequeue, pos, enable_tuple_grad_);
|
||||
return res;
|
||||
}
|
||||
|
||||
AnfNodePtr tuple_parameter = res->add_parameter();
|
||||
std::vector<AnfNodePtr> elements;
|
||||
PrimitivePtr op = nullptr;
|
||||
if (sequeue->isa<AbstractTuple>()) {
|
||||
elements.push_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
op = prim::kPrimTupleGetItem;
|
||||
} else {
|
||||
elements.push_back(NewValueNode(prim::kPrimMakeList));
|
||||
op = prim::kPrimListGetItem;
|
||||
}
|
||||
for (size_t i = 1; i < sequeue->size(); ++i) {
|
||||
if (tail_type_ == kGradAll) {
|
||||
MS_EXCEPTION_IF_NULL((*sequeue)[i]);
|
||||
if ((*sequeue)[i]->isa<abstract::AbstractUndetermined>() || (*sequeue)[i]->BuildValue() == kAnyValue ||
|
||||
EnableGradForScalar((*sequeue)[i])) {
|
||||
elements.push_back(res->NewCNodeInOrder({NewValueNode(op), tuple_parameter, NewValueNode(SizeToLong(i))}));
|
||||
}
|
||||
} else {
|
||||
elements.push_back(res->NewCNodeInOrder({NewValueNode(op), tuple_parameter, NewValueNode(SizeToLong(i))}));
|
||||
}
|
||||
}
|
||||
if (elements.size() > 1) {
|
||||
res->set_output(res->NewCNodeInOrder(elements));
|
||||
return res;
|
||||
} else if (sequeue->isa<AbstractTuple>()) { // Empty tuple.
|
||||
auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList());
|
||||
auto empty_tuple = NewValueNode(empty_tuple_value);
|
||||
res->set_output(empty_tuple);
|
||||
return res;
|
||||
} else { // Empty list.
|
||||
auto empty_list_value = std::make_shared<ValueList>(ValuePtrList());
|
||||
auto empty_list = NewValueNode(empty_list_value);
|
||||
res->set_output(empty_list);
|
||||
return res;
|
||||
}
|
||||
}
|
||||
|
||||
FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||
if (args_spec_list.size() < 1) {
|
||||
MS_LOG(EXCEPTION) << "Tail requires a non-empty tuple.";
|
||||
}
|
||||
|
||||
AbstractBasePtr a = args_spec_list[0];
|
||||
if (a->isa<AbstractTuple>() || a->isa<AbstractList>()) {
|
||||
if (args_spec_list.size() > 1) {
|
||||
AbstractBasePtr pos = args_spec_list[1];
|
||||
if (pos->isa<AbstractTuple>() || pos->isa<AbstractList>()) {
|
||||
return GenerateSequenceFuncGraph(a->cast<abstract::AbstractSequencePtr>(),
|
||||
pos->cast<abstract::AbstractSequencePtr>());
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "'Tail' arg1 must be AbstractTuple or AbstractList, but got " << pos->ToString();
|
||||
}
|
||||
return GenerateSequenceFuncGraph(a->cast<abstract::AbstractSequencePtr>());
|
||||
}
|
||||
|
||||
MS_LOG(EXCEPTION) << "'Tail' arg0 must be AbstractTuple or AbstractList, but got " << a->ToString();
|
||||
}
|
||||
|
||||
FuncGraphPtr MakeTupleGradient::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||
int64_t tuple_size = SizeToLong(args_spec_list.size());
|
||||
|
||||
|
@ -592,6 +434,191 @@ FuncGraphPtr MakeListGradient::GenerateFuncGraph(const AbstractBasePtrList &args
|
|||
return fg;
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool IsTupleAllTensor(const AbstractTuplePtr &tuple_arg) {
|
||||
MS_EXCEPTION_IF_NULL(tuple_arg);
|
||||
for (size_t i = 0; i < tuple_arg->size(); ++i) {
|
||||
if (!(*tuple_arg)[i]->isa<AbstractUndetermined>() &&
|
||||
!((*tuple_arg)[i]->isa<AbstractTuple>() && IsTupleAllTensor((*tuple_arg)[i]->cast<AbstractTuplePtr>()))) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool EnableGradFirstForTuple(const AbstractBasePtr &abs, bool enable_tuple_grad) {
|
||||
return abs->isa<AbstractTuple>() && enable_tuple_grad && IsTupleAllTensor(abs->cast<AbstractTuplePtr>());
|
||||
}
|
||||
|
||||
bool EnableGradForScalar(const AbstractBasePtr &abs) {
|
||||
return MsContext::GetInstance()->get_param<bool>(MS_CTX_GRAD_FOR_SCALAR) && abs->BuildType() != nullptr &&
|
||||
abs->BuildType()->isa<Number>();
|
||||
}
|
||||
|
||||
bool CanGradArgument(const AbstractTuplePtr &tuple_arg, size_t pos) {
|
||||
MS_EXCEPTION_IF_NULL(tuple_arg);
|
||||
return tuple_arg->size() > pos && (*tuple_arg)[pos] != nullptr &&
|
||||
((*tuple_arg)[pos]->isa<AbstractUndetermined>() || (*tuple_arg)[pos]->BuildValue() == kAnyValue ||
|
||||
EnableGradForScalar((*tuple_arg)[pos]));
|
||||
}
|
||||
|
||||
void GenerateFuncGraphByPosition(const FuncGraphPtr &fg, const AbstractTuplePtr &tuple_arg,
|
||||
const AbstractTuplePtr &pos) {
|
||||
if (pos == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Return grad by position, but the grad_position is empty!";
|
||||
}
|
||||
AnfNodePtr tuple_parameter = fg->add_parameter();
|
||||
(void)fg->add_parameter(); // The 'grad_position' parameter.
|
||||
// Collect all parameters by 'grad_position'.
|
||||
std::vector<AnfNodePtr> pos_elements = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
CNodePtr current_element = nullptr;
|
||||
for (size_t i = 0; i < pos->size(); ++i) {
|
||||
auto val = pos->elements()[i]->BuildValue();
|
||||
MS_EXCEPTION_IF_NULL(val);
|
||||
auto int_val = LongToSize(dyn_cast<Int64Imm>(val)->value());
|
||||
++int_val; // Ignore the env position.
|
||||
if (!CanGradArgument(tuple_arg, int_val)) {
|
||||
continue;
|
||||
}
|
||||
current_element =
|
||||
fg->NewCNodeInOrder({NewValueNode(prim::kPrimTupleGetItem), tuple_parameter, NewValueNode(SizeToLong(int_val))});
|
||||
pos_elements.push_back(current_element);
|
||||
}
|
||||
|
||||
// The returned result may vary for grad result element number.
|
||||
// A single value if only one result, a tuple for multiple results, or a empty tuple for no result.
|
||||
//
|
||||
// Notice that even if the user set 'grad_position' as multiple choices,
|
||||
// the 'CanGradArgument' may change it to only one choice or none choice.
|
||||
constexpr size_t args_least_size = 2;
|
||||
if (pos_elements.size() == args_least_size) {
|
||||
fg->set_output(current_element);
|
||||
} else if (pos_elements.size() > args_least_size) {
|
||||
fg->set_output(fg->NewCNodeInOrder(pos_elements));
|
||||
} else { // The 'pos' is empty AbstractTuple or empty AbstractList.
|
||||
auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList());
|
||||
auto empty_tuple = NewValueNode(empty_tuple_value);
|
||||
fg->set_output(empty_tuple);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
FuncGraphPtr Tail::GenerateTailFuncGraph(const AbstractSequencePtr &sequence_arg) const {
|
||||
MS_EXCEPTION_IF_NULL(sequence_arg);
|
||||
FuncGraphPtr fg = std::make_shared<FuncGraph>();
|
||||
fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
fg->debug_info()->set_name("tail");
|
||||
|
||||
AnfNodePtr tuple_parameter = fg->add_parameter();
|
||||
std::vector<AnfNodePtr> elements;
|
||||
PrimitivePtr op = nullptr;
|
||||
if (sequence_arg->isa<AbstractTuple>()) {
|
||||
elements.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
op = prim::kPrimTupleGetItem;
|
||||
} else {
|
||||
elements.emplace_back(NewValueNode(prim::kPrimMakeList));
|
||||
op = prim::kPrimListGetItem;
|
||||
}
|
||||
|
||||
// Remove the first element to make a new sequence.
|
||||
for (size_t i = 1; i < sequence_arg->size(); ++i) {
|
||||
elements.push_back(fg->NewCNodeInOrder({NewValueNode(op), tuple_parameter, NewValueNode(SizeToLong(i))}));
|
||||
}
|
||||
if (elements.size() > 1) {
|
||||
fg->set_output(fg->NewCNodeInOrder(elements));
|
||||
return fg;
|
||||
}
|
||||
|
||||
// No element left, return empty tuple.
|
||||
if (sequence_arg->isa<AbstractTuple>()) {
|
||||
auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList());
|
||||
auto empty_tuple = NewValueNode(empty_tuple_value);
|
||||
fg->set_output(empty_tuple);
|
||||
}
|
||||
// No element left, return empty list.
|
||||
auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList());
|
||||
auto empty_tuple = NewValueNode(empty_tuple_value);
|
||||
fg->set_output(empty_tuple);
|
||||
return fg;
|
||||
}
|
||||
|
||||
FuncGraphPtr Tail::GenerateGradFuncGraph(const AbstractTuplePtr &tuple_arg, const AbstractTuplePtr &position) const {
|
||||
MS_EXCEPTION_IF_NULL(tuple_arg);
|
||||
FuncGraphPtr fg = std::make_shared<FuncGraph>();
|
||||
fg->set_flag(FUNC_GRAPH_FLAG_CORE, true);
|
||||
fg->debug_info()->set_name("grad_tail");
|
||||
|
||||
if (tail_type_ == kGradFirst) {
|
||||
AnfNodePtr tuple_parameter = fg->add_parameter();
|
||||
PrimitivePtr getitem_op = prim::kPrimTupleGetItem;
|
||||
if (CanGradArgument(tuple_arg, 1) || EnableGradFirstForTuple((*tuple_arg)[1], enable_tuple_grad_first_)) {
|
||||
fg->set_output(fg->NewCNode({NewValueNode(getitem_op), tuple_parameter, NewValueNode(SizeToLong(1))}));
|
||||
} else {
|
||||
fg->set_output(NewValueNode(std::make_shared<ValueTuple>(ValuePtrList())));
|
||||
}
|
||||
return fg;
|
||||
}
|
||||
|
||||
if (tail_type_ == kGradByPosition) {
|
||||
GenerateFuncGraphByPosition(fg, tuple_arg, position);
|
||||
return fg;
|
||||
}
|
||||
|
||||
if (tail_type_ == kGradAll) {
|
||||
AnfNodePtr tuple_parameter = fg->add_parameter();
|
||||
std::vector<AnfNodePtr> elements = {NewValueNode(prim::kPrimMakeTuple)};
|
||||
PrimitivePtr op = prim::kPrimTupleGetItem;
|
||||
for (size_t i = 1; i < tuple_arg->size(); ++i) {
|
||||
MS_EXCEPTION_IF_NULL((*tuple_arg)[i]);
|
||||
if (CanGradArgument(tuple_arg, i)) {
|
||||
elements.push_back(fg->NewCNodeInOrder({NewValueNode(op), tuple_parameter, NewValueNode(SizeToLong(i))}));
|
||||
}
|
||||
}
|
||||
if (elements.size() > 1) {
|
||||
fg->set_output(fg->NewCNodeInOrder(elements));
|
||||
return fg;
|
||||
}
|
||||
// Empty tuple.
|
||||
auto empty_tuple_value = std::make_shared<ValueTuple>(ValuePtrList());
|
||||
auto empty_tuple = NewValueNode(empty_tuple_value);
|
||||
fg->set_output(empty_tuple);
|
||||
return fg;
|
||||
}
|
||||
MS_LOG(EXCEPTION) << "'tail_type_' is not for GradOperation, but " << tail_type_;
|
||||
}
|
||||
|
||||
FuncGraphPtr Tail::GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) {
|
||||
// To handle normal tail.
|
||||
if (args_spec_list.size() < 1) {
|
||||
MS_LOG(EXCEPTION) << "'Tail' requires at least 1 argument, but got " << args_spec_list.size();
|
||||
}
|
||||
if (tail_type_ >= kNotGrad) {
|
||||
AbstractSequencePtr sequence_arg = dyn_cast<AbstractSequence>(args_spec_list[0]);
|
||||
if (sequence_arg == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "'Tail' arg0 must be tuple or list, but got " << sequence_arg->ToString();
|
||||
}
|
||||
return GenerateTailFuncGraph(sequence_arg);
|
||||
}
|
||||
|
||||
// To handle for GradOperation tail.
|
||||
constexpr size_t args_max_size = 2;
|
||||
if (args_spec_list.size() > args_max_size) {
|
||||
MS_LOG(EXCEPTION) << "'Tail' requires at most 2 arguments for GradOperation, but got " << args_spec_list.size();
|
||||
}
|
||||
AbstractTuplePtr tuple_arg = dyn_cast<AbstractTuple>(args_spec_list[0]);
|
||||
if (tuple_arg == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "'Tail' arg0 must be tuple, but got " << tuple_arg->ToString();
|
||||
}
|
||||
if (args_spec_list.size() == args_max_size) {
|
||||
AbstractTuplePtr pos = dyn_cast<AbstractTuple>(args_spec_list[1]);
|
||||
if (pos == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "'Tail' arg1 'position' must be tuple, but got " << pos->ToString();
|
||||
}
|
||||
return GenerateGradFuncGraph(tuple_arg, pos);
|
||||
}
|
||||
return GenerateGradFuncGraph(tuple_arg);
|
||||
}
|
||||
|
||||
GradOperation::GradOperation(const std::string &name, bool get_all, bool get_by_list, bool sens_param,
|
||||
bool get_by_position)
|
||||
: MetaFuncGraph(name),
|
||||
|
@ -705,7 +732,7 @@ void GradOperation::GradByParameter(const FuncGraphPtr &k_child, const AnfNodePt
|
|||
// b_app returns (EnvInstance(grads wrt params), grads wrt input0, grads wrt input1, ...),
|
||||
// so obtain first input grad by setting tail_type of Tail to kGradFirst.
|
||||
TailPtr tail_grad_first = std::make_shared<Tail>("tail_grad_first", kGradFirst);
|
||||
tail_grad_first->set_enable_tuple_grad(enable_tuple_grad);
|
||||
tail_grad_first->set_enable_tuple_grad_first(enable_tuple_grad);
|
||||
k_child->set_output(k_child->NewCNodeInOrder({NewValueNode(tail_grad_first), b_app}));
|
||||
}
|
||||
|
||||
|
@ -761,7 +788,8 @@ FuncGraphPtr GradOperation::GenerateFuncGraph(const AbstractBasePtrList &args_sp
|
|||
FuncGraphPtr k_child = nullptr;
|
||||
{
|
||||
TraceGuard guard(std::make_shared<TraceGradOperation>(forward_graph->debug_info()));
|
||||
k_child = GetGrad(j, weights, position, forward_graph->parameters(), forward_graph->has_flag("enable_tuple_grad"));
|
||||
k_child =
|
||||
GetGrad(j, weights, position, forward_graph->parameters(), forward_graph->has_flag("enable_tuple_grad_first"));
|
||||
}
|
||||
grad_fg->set_output(NewValueNode(k_child));
|
||||
|
||||
|
@ -819,9 +847,9 @@ ValuePtr CheckAxes(const AbstractBasePtr &axes_abs, const bool &is_in_axes = fal
|
|||
ValuePtr axes_value = nullptr;
|
||||
auto axes_name = is_in_axes ? "in_axes" : "out_axes";
|
||||
|
||||
auto axes_abs_sequence = dyn_cast<abstract::AbstractSequence>(axes_abs);
|
||||
auto axes_abs_sequence = dyn_cast<AbstractSequence>(axes_abs);
|
||||
if (axes_abs_sequence != nullptr) {
|
||||
axes_value = axes_abs->cast<abstract::AbstractSequencePtr>()->ElementsBuildValue<ValueTuple>();
|
||||
axes_value = axes_abs->cast<AbstractSequencePtr>()->ElementsBuildValue<ValueTuple>();
|
||||
MS_EXCEPTION_IF_NULL(axes_value);
|
||||
if (is_in_axes) {
|
||||
ValueSequencePtr in_axes_seq = dyn_cast<ValueSequence>(axes_value);
|
||||
|
@ -991,8 +1019,8 @@ FuncGraphPtr TupleAdd::GenerateFuncGraph(const AbstractBasePtrList &args_spec_li
|
|||
AbstractBasePtr abs_a = args_spec_list[0];
|
||||
AbstractBasePtr abs_b = args_spec_list[1];
|
||||
|
||||
abstract::AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(abs_a);
|
||||
abstract::AbstractTuplePtr b_tuple = dyn_cast<AbstractTuple>(abs_b);
|
||||
AbstractTuplePtr a_tuple = dyn_cast<AbstractTuple>(abs_a);
|
||||
AbstractTuplePtr b_tuple = dyn_cast<AbstractTuple>(abs_b);
|
||||
if (a_tuple == nullptr || b_tuple == nullptr) {
|
||||
TypePtrList types;
|
||||
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(types),
|
||||
|
@ -1055,7 +1083,7 @@ int64_t CheckSliceMember(const AbstractBasePtr &member, int64_t default_value, c
|
|||
return default_value;
|
||||
}
|
||||
|
||||
if (member->isa<abstract::AbstractTensor>()) {
|
||||
if (member->isa<AbstractTensor>()) {
|
||||
MS_EXCEPTION(TypeError)
|
||||
<< "The argument of SliceMember operator must be a Scalar or None or constant Tensor, but got a variable Tensor";
|
||||
}
|
||||
|
@ -1064,7 +1092,7 @@ int64_t CheckSliceMember(const AbstractBasePtr &member, int64_t default_value, c
|
|||
<< member->BuildType()->ToString();
|
||||
}
|
||||
|
||||
std::tuple<int64_t, int64_t, int64_t> GenerateTupleSliceParameter(const abstract::AbstractSequencePtr &sequence,
|
||||
std::tuple<int64_t, int64_t, int64_t> GenerateTupleSliceParameter(const AbstractSequencePtr &sequence,
|
||||
const AbstractSlicePtr &slice) {
|
||||
MS_EXCEPTION_IF_NULL(sequence);
|
||||
MS_EXCEPTION_IF_NULL(slice);
|
||||
|
@ -1120,7 +1148,7 @@ std::tuple<int64_t, int64_t, int64_t> GenerateTupleSliceParameter(const abstract
|
|||
void SequenceSliceGetItem::CheckArgs(const AbstractBasePtrList &args_spec_list) {
|
||||
constexpr size_t arg_size = 2;
|
||||
abstract::CheckArgsSize(this->name(), args_spec_list, arg_size);
|
||||
sequence_ = abstract::CheckArg<abstract::AbstractSequence>(this->name(), args_spec_list, 0);
|
||||
sequence_ = abstract::CheckArg<AbstractSequence>(this->name(), args_spec_list, 0);
|
||||
slice_ = abstract::CheckArg<AbstractSlice>(this->name(), args_spec_list, 1);
|
||||
}
|
||||
|
||||
|
|
|
@ -103,20 +103,22 @@ enum TailType { kGradAll, kGradFirst, kGradByPosition, kNotGrad };
|
|||
class Tail : public MetaFuncGraph {
|
||||
public:
|
||||
explicit Tail(const std::string &name, TailType tail_type = kNotGrad)
|
||||
: MetaFuncGraph(name), tail_type_(tail_type), enable_tuple_grad_(false) {}
|
||||
: MetaFuncGraph(name), tail_type_(tail_type), enable_tuple_grad_first_(false) {}
|
||||
~Tail() override = default;
|
||||
MS_DECLARE_PARENT(Tail, MetaFuncGraph)
|
||||
|
||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||
FuncGraphPtr GenerateSequenceFuncGraph(const abstract::AbstractSequencePtr &sequeue,
|
||||
const abstract::AbstractSequencePtr &pos = nullptr) const;
|
||||
|
||||
friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; }
|
||||
void set_enable_tuple_grad(bool enable_tuple_grad) { enable_tuple_grad_ = enable_tuple_grad; }
|
||||
void set_enable_tuple_grad_first(bool enable_tuple_grad_first) { enable_tuple_grad_first_ = enable_tuple_grad_first; }
|
||||
|
||||
private:
|
||||
FuncGraphPtr GenerateTailFuncGraph(const abstract::AbstractSequencePtr &sequence_arg) const;
|
||||
FuncGraphPtr GenerateGradFuncGraph(const abstract::AbstractTuplePtr &sequeue,
|
||||
const abstract::AbstractTuplePtr &position = nullptr) const;
|
||||
|
||||
TailType tail_type_;
|
||||
bool enable_tuple_grad_;
|
||||
bool enable_tuple_grad_first_;
|
||||
};
|
||||
using TailPtr = std::shared_ptr<Tail>;
|
||||
|
||||
|
@ -148,7 +150,7 @@ class GradOperation : public MetaFuncGraph {
|
|||
MS_DECLARE_PARENT(GradOperation, MetaFuncGraph)
|
||||
|
||||
FuncGraphPtr GetGrad(const AnfNodePtr &j, const AnfNodePtr &weights, const AnfNodePtr &position,
|
||||
const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad,
|
||||
const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad_first,
|
||||
const std::vector<AnfNodePtr> &weight_args = {});
|
||||
|
||||
FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override;
|
||||
|
@ -163,7 +165,7 @@ class GradOperation : public MetaFuncGraph {
|
|||
|
||||
private:
|
||||
void GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop,
|
||||
const AnfNodePtr &weights, const AnfNodePtr &position, bool enable_tuple_grad);
|
||||
const AnfNodePtr &weights, const AnfNodePtr &position, bool enable_tuple_grad_first);
|
||||
};
|
||||
using GradOperationPtr = std::shared_ptr<GradOperation>;
|
||||
|
||||
|
|
|
@ -657,22 +657,12 @@ EvalResultPtr AnalysisEngine::ExecuteEvaluators(const std::vector<EvaluatorPtr>
|
|||
}
|
||||
}
|
||||
|
||||
void AnalysisEngine::SetUndeterminedFlag(const EvaluatorPtr &evaluator, const FuncGraphPtr &possible_parent_fg) {
|
||||
MS_EXCEPTION_IF_NULL(evaluator);
|
||||
void AnalysisEngine::SetUndeterminedFlag(const FuncGraphPtr &possible_parent_fg) {
|
||||
MS_EXCEPTION_IF_NULL(possible_parent_fg);
|
||||
static std::mutex fg_lock;
|
||||
std::lock_guard<std::mutex> infer_lock(fg_lock);
|
||||
if (possible_parent_fg != nullptr) {
|
||||
possible_parent_fg->set_flag(kFuncGraphFlagUndetermined, true);
|
||||
MS_LOG(DEBUG) << "Set graph undetermined: " << possible_parent_fg->ToString();
|
||||
return;
|
||||
}
|
||||
auto fg_eval = evaluator->cast<FuncGraphEvaluatorPtr>();
|
||||
if (fg_eval == nullptr) {
|
||||
return;
|
||||
}
|
||||
auto fg = fg_eval->func_graph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
MS_LOG(EXCEPTION) << "cannot set Undetermined flag for fg: " << fg->ToString();
|
||||
possible_parent_fg->set_flag(kFuncGraphFlagUndetermined, true);
|
||||
MS_LOG(DEBUG) << "Set graph undetermined flag for " << possible_parent_fg->ToString();
|
||||
}
|
||||
|
||||
EvaluatorPtr AnalysisEngine::HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators,
|
||||
|
@ -960,7 +950,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
|
|||
static std::atomic<int> id_count{0};
|
||||
std::string thread_id = AnalysisSchedule::thread_id() + "." + std::to_string(id_count.fetch_add(1));
|
||||
MS_EXCEPTION_IF_NULL(evaluator);
|
||||
SetUndeterminedFlag(evaluator, possible_parent_fg);
|
||||
SetUndeterminedFlag(possible_parent_fg);
|
||||
AsyncAbstractPtr async_result_branch = std::make_shared<AsyncAbstract>();
|
||||
// Control the order to run.
|
||||
AsyncAbstractPtr control_run_order = std::make_shared<AsyncAbstract>();
|
||||
|
@ -1032,7 +1022,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluators(const std::vector<Evalua
|
|||
auto possible_parent_fg = out_conf->node()->func_graph();
|
||||
for (auto eval : evaluators) {
|
||||
MS_EXCEPTION_IF_NULL(eval);
|
||||
SetUndeterminedFlag(eval, possible_parent_fg);
|
||||
SetUndeterminedFlag(possible_parent_fg);
|
||||
const auto current_inf = EvaluatorArgs(eval, args_spec_list);
|
||||
MS_LOG(DEBUG) << "Check Evaluator " << eval->ToString();
|
||||
// If current evaluator is under tracing, then skip current evaluator to avoid recursively evaluating.
|
||||
|
|
|
@ -346,7 +346,7 @@ class AnalysisEngine : public std::enable_shared_from_this<AnalysisEngine> {
|
|||
static EvalResultPtr ProcessEvalResults(const AbstractBasePtrList &out_specs, const AnfNodePtr &node);
|
||||
|
||||
private:
|
||||
void SetUndeterminedFlag(const EvaluatorPtr &evaluator, const FuncGraphPtr &possible_parent_fg);
|
||||
void SetUndeterminedFlag(const FuncGraphPtr &possible_parent_fg);
|
||||
EvaluatorPtr HandleNestedRecursion(const std::vector<EvaluatorPtr> &evaluators, const EvaluatorPtr &eval,
|
||||
const AbstractBasePtrList &args_spec_list, const EvalTraceRevIter &it,
|
||||
bool *continue_flag);
|
||||
|
|
|
@ -102,9 +102,9 @@ class Jvp(Cell):
|
|||
self.fn = fn
|
||||
self.oneslike = P.OnesLike()
|
||||
self.first_grad = _FirstGrad(fn)
|
||||
self.first_grad.add_flags(enable_tuple_grad=True)
|
||||
self.first_grad.add_flags(enable_tuple_grad_first=True)
|
||||
self.first_grad_single_value = _FirstGradSingleValue(fn)
|
||||
self.first_grad_single_value.add_flags(enable_tuple_grad=True)
|
||||
self.first_grad_single_value.add_flags(enable_tuple_grad_first=True)
|
||||
self.second_grad_op = C.GradOperation(sens_param=True)
|
||||
self.issubclass_ = P.IsSubClass()
|
||||
self.typeof = Primitive('typeof')
|
||||
|
@ -142,9 +142,9 @@ class _JvpInner(Cell):
|
|||
super(_JvpInner, self).__init__()
|
||||
self.oneslike = P.OnesLike()
|
||||
self.first_grad = _JvpFirstGrad()
|
||||
self.first_grad.add_flags(enable_tuple_grad=True)
|
||||
self.first_grad.add_flags(enable_tuple_grad_first=True)
|
||||
self.first_grad_single_value = _JvpFirstGradSingleValue()
|
||||
self.first_grad_single_value.add_flags(enable_tuple_grad=True)
|
||||
self.first_grad_single_value.add_flags(enable_tuple_grad_first=True)
|
||||
self.second_grad_op = C.GradOperation(sens_param=True)
|
||||
self.issubclass_ = P.IsSubClass()
|
||||
self.typeof = Primitive('typeof')
|
||||
|
|
Loading…
Reference in New Issue