!26430 replace short-circuit eval with deferred evaluation of backward prop function.

Merge pull request !26430 from xychow/replace-shortcurit-eval-with-lazy-eval
This commit is contained in:
i-robot 2021-11-23 12:53:23 +00:00 committed by Gitee
commit 20be757f18
6 changed files with 158 additions and 74 deletions

View File

@ -344,7 +344,6 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
// Do forward computation
auto forward_app =
k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(static_cast<int64_t>(0))});
forward_app->set_user_data<abstract::AbstractBase>("primal_abstract", cnode_morph->abstract());
// K:: cnode -> forward_app
auto node_adjoint = std::make_shared<Adjoint>(morph, forward_app, tape_);
UpdateAdjoint(node_adjoint);

View File

@ -29,7 +29,7 @@ AbstractBasePtr AsyncAbstract::GetResult() {
if (ret != nullptr) {
return ret;
}
auto async_task = AsyncInferTask::MakeShared(shared_from_base<AsyncAbstract>());
auto async_task = AsyncInferTask::MakeShared(shared_from_this());
MS_LOG(DEBUG) << GetInferThread() << " is waiting for async: " << async_task.get();
AnalysisSchedule::GetInstance().Add2Schedule(async_task);
ret = async_task->GetResult();
@ -173,6 +173,51 @@ void AnalysisSchedule::SetThreadID(const std::string &threadID) { localThreadID
std::string &AnalysisSchedule::GetThreadID() { return localThreadID; }
AbstractFunctionPtr AsyncAbstractFuncAtom::GetUnique() {
if (resolved_ != nullptr) {
return resolved_;
}
// Release GIL for C++;
py::gil_scoped_release infer_gil_release;
MS_LOG(DEBUG) << "Try to GetResult from async_abstract: " << async_abstract_->ToString();
const auto &result = async_abstract_->GetResult();
if (result->isa<AbstractFuncAtom>()) {
resolved_ = result->cast<AbstractFuncAtomPtr>();
} else if (result->isa<AbstractSequeue>()) {
const auto &abs_seq = result->cast<AbstractSequeuePtr>();
MS_EXCEPTION_IF_NULL(abs_seq);
const auto &elements = abs_seq->elements();
if (elements.size() < index_) {
MS_LOG(EXCEPTION) << "Elements of AsyncAbstract result: " << result->ToString()
<< " size is less than index: " << index_;
}
if (!elements[index_]->isa<AbstractFuncAtom>()) {
MS_LOG(EXCEPTION) << "AsyncAbstract result cannot resolve to AbstractFuncAtom, but: "
<< elements[index_]->ToString();
}
MS_LOG(DEBUG) << "Return Abstract: " << elements[index_]->ToString();
resolved_ = elements[index_]->cast<AbstractFuncAtomPtr>();
} else {
MS_LOG(EXCEPTION) << "AsyncAbstract cannot resolve to AbstractFuncAtom or AbstractSequence, but: "
<< result->ToString();
}
return resolved_;
}
std::string AsyncAbstractFuncAtom::ToString() const {
if (resolved_ == nullptr) {
return "AsyncAbstractFuncAtom(Not Resolved)";
}
std::ostringstream buffer;
buffer << "AsyncAbstractFuncAtom(";
buffer << resolved_->ToString();
buffer << ")";
return buffer.str();
}
void AnalysisResultCacheMgr::Clear() {
std::lock_guard<std::mutex> lock(lock_);
cache_.clear();

View File

@ -206,11 +206,11 @@ class NormalCache {
CacheType cache_;
};
class AsyncAbstract : public Base {
class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
public:
AsyncAbstract() = default;
~AsyncAbstract() = default;
AbstractBasePtr GetResult();
AbstractBasePtr TryGetResult() {
std::lock_guard<std::mutex> lock(lock_);
return result_;
@ -225,6 +225,8 @@ class AsyncAbstract : public Base {
result_ = result;
}
AbstractBasePtr GetResult();
std::string ToString() {
std::ostringstream buffer;
std::lock_guard<std::mutex> lock(lock_);
@ -237,6 +239,49 @@ class AsyncAbstract : public Base {
AbstractBasePtr result_{nullptr};
};
// Wrap AsyncAbstract, so it can work with Join method of AbstractFunction.
class AsyncAbstractFuncAtom : public AbstractFuncAtom {
public:
AsyncAbstractFuncAtom(const AsyncAbstractPtr &async_abstract, std::size_t index)
: async_abstract_(async_abstract), index_(index) {}
~AsyncAbstractFuncAtom() = default;
MS_DECLARE_PARENT(AsyncAbstractFuncAtom, AbstractFuncAtom);
static std::shared_ptr<AsyncAbstractFuncAtom> MakeShared(const AsyncAbstractPtr &async_abstract, std::size_t index) {
MS_EXCEPTION_IF_NULL(async_abstract);
auto ret = std::make_shared<AsyncAbstractFuncAtom>(async_abstract, index);
MS_EXCEPTION_IF_NULL(ret);
return ret;
}
AbstractFunctionPtr Copy() const override { return MakeShared(async_abstract_, index_); }
bool operator==(const AbstractFunction &other) const override {
if (!other.isa<AsyncAbstractFuncAtom>()) {
return false;
}
auto other_async = static_cast<const AsyncAbstractFuncAtom *>(&other);
MS_EXCEPTION_IF_NULL(other_async);
return (async_abstract_ == other_async->async_abstract_ && index_ == other_async->index_);
}
std::size_t hash() const override {
return hash_combine(std::hash<AsyncAbstract *>{}(async_abstract_.get()), std::hash<std::size_t>{}(index_));
}
AbstractFunctionPtr GetUnique() override;
std::string ToString() const;
private:
// Resolved AbstractFunction after fully analyzed.
AbstractFunctionPtr resolved_{nullptr};
// Before resolved, use the following two items to track.
const AsyncAbstractPtr async_abstract_;
const std::size_t index_;
};
using AsyncAbstractFuncAtomPtr = std::shared_ptr<AsyncAbstractFuncAtom>;
class AsyncInferTask {
public:
explicit AsyncInferTask(const std::string &threadId, const AsyncAbstractPtr &abstract)

View File

@ -273,51 +273,6 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
return res;
}
EvalResultPtr BaseFuncGraphEvaluator::RunShortCircuit(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) {
AbstractBasePtrList args_spec_list;
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
[](const ConfigPtr &conf) -> AbstractBasePtr {
MS_EXCEPTION_IF_NULL(conf);
return conf->ObtainEvalResult()->abstract();
});
args_spec_list = NormalizeArgs(args_spec_list);
args_spec_list = BroadenUndeterminedArgs(args_spec_list);
auto func_graph_evaluator = dyn_cast<FuncGraphEvaluator>(shared_from_base<BaseFuncGraphEvaluator>());
if (func_graph_evaluator == nullptr) {
MS_LOG(EXCEPTION) << "Only support for FuncGraphEvaluator, but it's " << ToString();
}
const auto &fg = func_graph_evaluator->GetFuncGraph(engine, args_spec_list);
MS_EXCEPTION_IF_NULL(fg);
const auto &output = fg->output();
MS_EXCEPTION_IF_NULL(output);
if (!IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
MS_LOG(DEBUG) << "FuncGraph output is not MakeTuple but: " << output->DebugString();
return nullptr;
}
const auto &output_cnode = output->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(output_cnode);
const auto &inputs = output_cnode->inputs();
if (inputs.size() != 3) {
MS_LOG(DEBUG) << "Size of func graph output is 3, but: " << output->DebugString();
return nullptr;
}
const auto &primal_abstract = inputs[1]->user_data<abstract::AbstractBase>("primal_abstract");
const auto &item_fg = GetValueNode<FuncGraphPtr>(inputs[2]);
if (primal_abstract != nullptr && item_fg != nullptr) {
MS_LOG(DEBUG) << "Try to build result from primal abstract: " << primal_abstract->ToString()
<< " and fg: " << item_fg->ToString();
auto context = parent_context_->NewContext(fg, args_spec_list);
const auto &item_fg_abstract = std::make_shared<abstract::FuncGraphAbstractClosure>(item_fg, context, inputs[2]);
AbstractBasePtrList abs_list{primal_abstract, item_fg_abstract};
const auto &tuple_abstract = std::make_shared<abstract::AbstractTuple>(abs_list);
auto res = std::make_shared<EvalResult>(tuple_abstract, nullptr);
return res;
}
return nullptr;
}
void BroadenArgs(const AbstractBasePtrList &args_spec_list, AbstractBasePtrList *broaded_args) {
MS_EXCEPTION_IF_NULL(broaded_args);
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(*broaded_args),

View File

@ -51,10 +51,6 @@ class Evaluator : public Base {
// Run() will modify cache_ member, so it cannot marked as const;
virtual EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf);
virtual EvalResultPtr RunShortCircuit(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) {
MS_LOG(EXCEPTION) << "Not support for this evaluator: " << ToString();
}
virtual EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list,
const AnfNodeConfigPtr &out_conf) = 0;
@ -212,8 +208,6 @@ class BaseFuncGraphEvaluator : public Evaluator {
~BaseFuncGraphEvaluator() override = default;
MS_DECLARE_PARENT(BaseFuncGraphEvaluator, Evaluator);
EvalResultPtr RunShortCircuit(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
const AnfNodeConfigPtr &out_conf) override;
EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list,
const AnfNodeConfigPtr &out_conf) override;

View File

@ -300,7 +300,14 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
std::vector<EvaluatorPtr> evaluators;
auto build_evaluator = [this, &evaluators, &cnode](const AbstractFuncAtomPtr &poss) {
auto evaluator = this->GetEvaluatorFor(poss);
auto resolved_atom = poss;
if (poss->isa<AsyncAbstractFuncAtom>()) {
const auto &async_abs_func = poss->cast<AsyncAbstractFuncAtomPtr>();
const auto &resolved_func = async_abs_func->GetUnique();
resolved_atom = resolved_func->cast<AbstractFuncAtomPtr>();
MS_EXCEPTION_IF_NULL(resolved_atom);
}
auto evaluator = this->GetEvaluatorFor(resolved_atom);
evaluator->set_bound_node(cnode);
evaluators.push_back(evaluator);
};
@ -809,6 +816,62 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar
}
}
namespace {
void BuildPossibleSpecs(const AbstractBasePtr &first_result,
const std::vector<AsyncAbstractPtr> &branch_async_abstract_list,
AbstractBasePtrList *out_specs) {
std::vector<AsyncAbstractPtr> pending_async_abstract_list;
std::size_t len = branch_async_abstract_list.size();
for (size_t i = 0; i < len; ++i) {
auto result = branch_async_abstract_list[i]->TryGetResult();
if (result) {
out_specs->push_back(result);
} else {
pending_async_abstract_list.push_back(branch_async_abstract_list[i]);
}
}
if (first_result->isa<AbstractFunction>()) {
for (std::size_t j = 0; j < pending_async_abstract_list.size(); ++j) {
auto async_func = AsyncAbstractFuncAtom::MakeShared(pending_async_abstract_list[j], 0);
out_specs->push_back(async_func);
}
} else if (first_result->isa<AbstractSequeue>()) {
const auto &orig_abstract_seq = first_result->cast<AbstractSequeuePtr>();
MS_EXCEPTION_IF_NULL(orig_abstract_seq);
const auto &orig_elements = orig_abstract_seq->elements();
AbstractBasePtrList new_elements;
for (size_t i = 0; i < orig_elements.size(); ++i) {
if (orig_elements[i]->isa<AbstractFuncAtom>()) {
AbstractFuncAtomPtrList abs_func_list{orig_elements[i]->cast<AbstractFuncAtomPtr>()};
for (size_t j = 0; j < pending_async_abstract_list.size(); ++j) {
auto async_func = AsyncAbstractFuncAtom::MakeShared(pending_async_abstract_list[j], i);
abs_func_list.push_back(async_func);
}
new_elements.push_back(AbstractFunction::MakeAbstractFunction(abs_func_list));
} else {
new_elements.push_back(orig_elements[i]);
}
}
AbstractBasePtr new_first_result;
if (first_result->isa<AbstractTuple>()) {
new_first_result = std::make_shared<AbstractTuple>(new_elements);
} else if (first_result->isa<AbstractList>()) {
new_first_result = std::make_shared<AbstractList>(new_elements);
} else {
MS_LOG(EXCEPTION) << "FirstResult is not AbstractTuple or AbstractList, but: " << first_result->ToString();
}
MS_LOG(DEBUG) << GetInferThread() << " Try to replace old first with new one, old: " << first_result->ToString()
<< ", new: " << new_first_result->ToString();
std::replace_if(
out_specs->begin(), out_specs->end(), [first_result](const auto &elem) { return elem == first_result; },
new_first_result);
} else {
MS_LOG(DEBUG) << GetInferThread() << " wait for normal async result";
}
}
} // namespace
EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::vector<EvaluatorPtr> &evaluators,
const AnfNodeConfigPtr &out_conf,
const ConfigPtrList &args_conf_list) {
@ -866,24 +929,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
AbstractBasePtrList out_specs;
size_t len = evaluators.size();
if (NeedWaitForBranches(firstResult)) {
for (size_t i = 0; i < len; ++i) {
// shortcircuit begin;
if (firstResult->isa<AbstractTuple>() && branchAsyncResults[i]->TryGetResult() == nullptr) {
MS_LOG(DEBUG) << "Try to run shortcircuit abstract for evalator: " << evaluators[i]->ToString();
const auto &result = evaluators[i]->RunShortCircuit(shared_from_this(), args_conf_list, out_conf);
if (result != nullptr) {
out_specs.push_back(result->abstract());
MS_LOG(DEBUG) << "i: " << i << ", result: " << result->abstract()->ToString();
continue;
}
}
// shortcircuit end;
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[i]->ToString();
auto result = branchAsyncResults[i]->GetResult();
MS_EXCEPTION_IF_NULL(result);
out_specs.push_back(result);
}
BuildPossibleSpecs(firstResult, branchAsyncResults, &out_specs);
} else {
for (size_t i = 0; i < len; ++i) {
// Not wait to get the result of branch.