forked from mindspore-Ecosystem/mindspore
!26430 replace short-circuit eval with deferred evaluation of backward prop function.
Merge pull request !26430 from xychow/replace-shortcurit-eval-with-lazy-eval
This commit is contained in:
commit
20be757f18
|
@ -344,7 +344,6 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
|
|||
// Do forward computation
|
||||
auto forward_app =
|
||||
k_graph_->NewCNode({NewValueNode(prim::kPrimTupleGetItem), k_app, NewValueNode(static_cast<int64_t>(0))});
|
||||
forward_app->set_user_data<abstract::AbstractBase>("primal_abstract", cnode_morph->abstract());
|
||||
// K:: cnode -> forward_app
|
||||
auto node_adjoint = std::make_shared<Adjoint>(morph, forward_app, tape_);
|
||||
UpdateAdjoint(node_adjoint);
|
||||
|
|
|
@ -29,7 +29,7 @@ AbstractBasePtr AsyncAbstract::GetResult() {
|
|||
if (ret != nullptr) {
|
||||
return ret;
|
||||
}
|
||||
auto async_task = AsyncInferTask::MakeShared(shared_from_base<AsyncAbstract>());
|
||||
auto async_task = AsyncInferTask::MakeShared(shared_from_this());
|
||||
MS_LOG(DEBUG) << GetInferThread() << " is waiting for async: " << async_task.get();
|
||||
AnalysisSchedule::GetInstance().Add2Schedule(async_task);
|
||||
ret = async_task->GetResult();
|
||||
|
@ -173,6 +173,51 @@ void AnalysisSchedule::SetThreadID(const std::string &threadID) { localThreadID
|
|||
|
||||
std::string &AnalysisSchedule::GetThreadID() { return localThreadID; }
|
||||
|
||||
AbstractFunctionPtr AsyncAbstractFuncAtom::GetUnique() {
|
||||
if (resolved_ != nullptr) {
|
||||
return resolved_;
|
||||
}
|
||||
// Release GIL for C++;
|
||||
py::gil_scoped_release infer_gil_release;
|
||||
|
||||
MS_LOG(DEBUG) << "Try to GetResult from async_abstract: " << async_abstract_->ToString();
|
||||
const auto &result = async_abstract_->GetResult();
|
||||
if (result->isa<AbstractFuncAtom>()) {
|
||||
resolved_ = result->cast<AbstractFuncAtomPtr>();
|
||||
} else if (result->isa<AbstractSequeue>()) {
|
||||
const auto &abs_seq = result->cast<AbstractSequeuePtr>();
|
||||
MS_EXCEPTION_IF_NULL(abs_seq);
|
||||
const auto &elements = abs_seq->elements();
|
||||
if (elements.size() < index_) {
|
||||
MS_LOG(EXCEPTION) << "Elements of AsyncAbstract result: " << result->ToString()
|
||||
<< " size is less than index: " << index_;
|
||||
}
|
||||
if (!elements[index_]->isa<AbstractFuncAtom>()) {
|
||||
MS_LOG(EXCEPTION) << "AsyncAbstract result cannot resolve to AbstractFuncAtom, but: "
|
||||
<< elements[index_]->ToString();
|
||||
}
|
||||
MS_LOG(DEBUG) << "Return Abstract: " << elements[index_]->ToString();
|
||||
resolved_ = elements[index_]->cast<AbstractFuncAtomPtr>();
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "AsyncAbstract cannot resolve to AbstractFuncAtom or AbstractSequence, but: "
|
||||
<< result->ToString();
|
||||
}
|
||||
return resolved_;
|
||||
}
|
||||
|
||||
std::string AsyncAbstractFuncAtom::ToString() const {
|
||||
if (resolved_ == nullptr) {
|
||||
return "AsyncAbstractFuncAtom(Not Resolved)";
|
||||
}
|
||||
|
||||
std::ostringstream buffer;
|
||||
buffer << "AsyncAbstractFuncAtom(";
|
||||
buffer << resolved_->ToString();
|
||||
buffer << ")";
|
||||
|
||||
return buffer.str();
|
||||
}
|
||||
|
||||
void AnalysisResultCacheMgr::Clear() {
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
cache_.clear();
|
||||
|
|
|
@ -206,11 +206,11 @@ class NormalCache {
|
|||
CacheType cache_;
|
||||
};
|
||||
|
||||
class AsyncAbstract : public Base {
|
||||
class AsyncAbstract : public std::enable_shared_from_this<AsyncAbstract> {
|
||||
public:
|
||||
AsyncAbstract() = default;
|
||||
~AsyncAbstract() = default;
|
||||
AbstractBasePtr GetResult();
|
||||
|
||||
AbstractBasePtr TryGetResult() {
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
return result_;
|
||||
|
@ -225,6 +225,8 @@ class AsyncAbstract : public Base {
|
|||
result_ = result;
|
||||
}
|
||||
|
||||
AbstractBasePtr GetResult();
|
||||
|
||||
std::string ToString() {
|
||||
std::ostringstream buffer;
|
||||
std::lock_guard<std::mutex> lock(lock_);
|
||||
|
@ -237,6 +239,49 @@ class AsyncAbstract : public Base {
|
|||
AbstractBasePtr result_{nullptr};
|
||||
};
|
||||
|
||||
// Wrap AsyncAbstract, so it can work with Join method of AbstractFunction.
|
||||
class AsyncAbstractFuncAtom : public AbstractFuncAtom {
|
||||
public:
|
||||
AsyncAbstractFuncAtom(const AsyncAbstractPtr &async_abstract, std::size_t index)
|
||||
: async_abstract_(async_abstract), index_(index) {}
|
||||
~AsyncAbstractFuncAtom() = default;
|
||||
MS_DECLARE_PARENT(AsyncAbstractFuncAtom, AbstractFuncAtom);
|
||||
|
||||
static std::shared_ptr<AsyncAbstractFuncAtom> MakeShared(const AsyncAbstractPtr &async_abstract, std::size_t index) {
|
||||
MS_EXCEPTION_IF_NULL(async_abstract);
|
||||
auto ret = std::make_shared<AsyncAbstractFuncAtom>(async_abstract, index);
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
return ret;
|
||||
}
|
||||
|
||||
AbstractFunctionPtr Copy() const override { return MakeShared(async_abstract_, index_); }
|
||||
|
||||
bool operator==(const AbstractFunction &other) const override {
|
||||
if (!other.isa<AsyncAbstractFuncAtom>()) {
|
||||
return false;
|
||||
}
|
||||
auto other_async = static_cast<const AsyncAbstractFuncAtom *>(&other);
|
||||
MS_EXCEPTION_IF_NULL(other_async);
|
||||
return (async_abstract_ == other_async->async_abstract_ && index_ == other_async->index_);
|
||||
}
|
||||
|
||||
std::size_t hash() const override {
|
||||
return hash_combine(std::hash<AsyncAbstract *>{}(async_abstract_.get()), std::hash<std::size_t>{}(index_));
|
||||
}
|
||||
|
||||
AbstractFunctionPtr GetUnique() override;
|
||||
|
||||
std::string ToString() const;
|
||||
|
||||
private:
|
||||
// Resolved AbstractFunction after fully analyzed.
|
||||
AbstractFunctionPtr resolved_{nullptr};
|
||||
// Before resolved, use the following two items to track.
|
||||
const AsyncAbstractPtr async_abstract_;
|
||||
const std::size_t index_;
|
||||
};
|
||||
using AsyncAbstractFuncAtomPtr = std::shared_ptr<AsyncAbstractFuncAtom>;
|
||||
|
||||
class AsyncInferTask {
|
||||
public:
|
||||
explicit AsyncInferTask(const std::string &threadId, const AsyncAbstractPtr &abstract)
|
||||
|
|
|
@ -273,51 +273,6 @@ EvalResultPtr BaseFuncGraphEvaluator::Eval(AnalysisEnginePtr engine, const Abstr
|
|||
return res;
|
||||
}
|
||||
|
||||
EvalResultPtr BaseFuncGraphEvaluator::RunShortCircuit(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
AbstractBasePtrList args_spec_list;
|
||||
(void)std::transform(args_conf_list.begin(), args_conf_list.end(), std::back_inserter(args_spec_list),
|
||||
[](const ConfigPtr &conf) -> AbstractBasePtr {
|
||||
MS_EXCEPTION_IF_NULL(conf);
|
||||
return conf->ObtainEvalResult()->abstract();
|
||||
});
|
||||
args_spec_list = NormalizeArgs(args_spec_list);
|
||||
args_spec_list = BroadenUndeterminedArgs(args_spec_list);
|
||||
|
||||
auto func_graph_evaluator = dyn_cast<FuncGraphEvaluator>(shared_from_base<BaseFuncGraphEvaluator>());
|
||||
if (func_graph_evaluator == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Only support for FuncGraphEvaluator, but it's " << ToString();
|
||||
}
|
||||
const auto &fg = func_graph_evaluator->GetFuncGraph(engine, args_spec_list);
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
const auto &output = fg->output();
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
if (!IsPrimitiveCNode(output, prim::kPrimMakeTuple)) {
|
||||
MS_LOG(DEBUG) << "FuncGraph output is not MakeTuple but: " << output->DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
const auto &output_cnode = output->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(output_cnode);
|
||||
const auto &inputs = output_cnode->inputs();
|
||||
if (inputs.size() != 3) {
|
||||
MS_LOG(DEBUG) << "Size of func graph output is 3, but: " << output->DebugString();
|
||||
return nullptr;
|
||||
}
|
||||
const auto &primal_abstract = inputs[1]->user_data<abstract::AbstractBase>("primal_abstract");
|
||||
const auto &item_fg = GetValueNode<FuncGraphPtr>(inputs[2]);
|
||||
if (primal_abstract != nullptr && item_fg != nullptr) {
|
||||
MS_LOG(DEBUG) << "Try to build result from primal abstract: " << primal_abstract->ToString()
|
||||
<< " and fg: " << item_fg->ToString();
|
||||
auto context = parent_context_->NewContext(fg, args_spec_list);
|
||||
const auto &item_fg_abstract = std::make_shared<abstract::FuncGraphAbstractClosure>(item_fg, context, inputs[2]);
|
||||
AbstractBasePtrList abs_list{primal_abstract, item_fg_abstract};
|
||||
const auto &tuple_abstract = std::make_shared<abstract::AbstractTuple>(abs_list);
|
||||
auto res = std::make_shared<EvalResult>(tuple_abstract, nullptr);
|
||||
return res;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void BroadenArgs(const AbstractBasePtrList &args_spec_list, AbstractBasePtrList *broaded_args) {
|
||||
MS_EXCEPTION_IF_NULL(broaded_args);
|
||||
(void)std::transform(args_spec_list.begin(), args_spec_list.end(), std::back_inserter(*broaded_args),
|
||||
|
|
|
@ -51,10 +51,6 @@ class Evaluator : public Base {
|
|||
// Run() will modify cache_ member, so it cannot marked as const;
|
||||
virtual EvalResultPtr Run(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf);
|
||||
virtual EvalResultPtr RunShortCircuit(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) {
|
||||
MS_LOG(EXCEPTION) << "Not support for this evaluator: " << ToString();
|
||||
}
|
||||
|
||||
virtual EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list,
|
||||
const AnfNodeConfigPtr &out_conf) = 0;
|
||||
|
@ -212,8 +208,6 @@ class BaseFuncGraphEvaluator : public Evaluator {
|
|||
~BaseFuncGraphEvaluator() override = default;
|
||||
MS_DECLARE_PARENT(BaseFuncGraphEvaluator, Evaluator);
|
||||
|
||||
EvalResultPtr RunShortCircuit(AnalysisEnginePtr engine, const ConfigPtrList &args_conf_list,
|
||||
const AnfNodeConfigPtr &out_conf) override;
|
||||
EvalResultPtr Eval(AnalysisEnginePtr engine, const AbstractBasePtrList &args_spec_list,
|
||||
const AnfNodeConfigPtr &out_conf) override;
|
||||
|
||||
|
|
|
@ -300,7 +300,14 @@ EvalResultPtr AnalysisEngine::EvalCNode(const CNodePtr &cnode, const AnfNodeConf
|
|||
|
||||
std::vector<EvaluatorPtr> evaluators;
|
||||
auto build_evaluator = [this, &evaluators, &cnode](const AbstractFuncAtomPtr &poss) {
|
||||
auto evaluator = this->GetEvaluatorFor(poss);
|
||||
auto resolved_atom = poss;
|
||||
if (poss->isa<AsyncAbstractFuncAtom>()) {
|
||||
const auto &async_abs_func = poss->cast<AsyncAbstractFuncAtomPtr>();
|
||||
const auto &resolved_func = async_abs_func->GetUnique();
|
||||
resolved_atom = resolved_func->cast<AbstractFuncAtomPtr>();
|
||||
MS_EXCEPTION_IF_NULL(resolved_atom);
|
||||
}
|
||||
auto evaluator = this->GetEvaluatorFor(resolved_atom);
|
||||
evaluator->set_bound_node(cnode);
|
||||
evaluators.push_back(evaluator);
|
||||
};
|
||||
|
@ -809,6 +816,62 @@ void ExecEvaluator(EvaluatorPtr eval, AnalysisEnginePtr engine, ConfigPtrList ar
|
|||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
void BuildPossibleSpecs(const AbstractBasePtr &first_result,
|
||||
const std::vector<AsyncAbstractPtr> &branch_async_abstract_list,
|
||||
AbstractBasePtrList *out_specs) {
|
||||
std::vector<AsyncAbstractPtr> pending_async_abstract_list;
|
||||
std::size_t len = branch_async_abstract_list.size();
|
||||
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
auto result = branch_async_abstract_list[i]->TryGetResult();
|
||||
if (result) {
|
||||
out_specs->push_back(result);
|
||||
} else {
|
||||
pending_async_abstract_list.push_back(branch_async_abstract_list[i]);
|
||||
}
|
||||
}
|
||||
if (first_result->isa<AbstractFunction>()) {
|
||||
for (std::size_t j = 0; j < pending_async_abstract_list.size(); ++j) {
|
||||
auto async_func = AsyncAbstractFuncAtom::MakeShared(pending_async_abstract_list[j], 0);
|
||||
out_specs->push_back(async_func);
|
||||
}
|
||||
} else if (first_result->isa<AbstractSequeue>()) {
|
||||
const auto &orig_abstract_seq = first_result->cast<AbstractSequeuePtr>();
|
||||
MS_EXCEPTION_IF_NULL(orig_abstract_seq);
|
||||
const auto &orig_elements = orig_abstract_seq->elements();
|
||||
AbstractBasePtrList new_elements;
|
||||
for (size_t i = 0; i < orig_elements.size(); ++i) {
|
||||
if (orig_elements[i]->isa<AbstractFuncAtom>()) {
|
||||
AbstractFuncAtomPtrList abs_func_list{orig_elements[i]->cast<AbstractFuncAtomPtr>()};
|
||||
for (size_t j = 0; j < pending_async_abstract_list.size(); ++j) {
|
||||
auto async_func = AsyncAbstractFuncAtom::MakeShared(pending_async_abstract_list[j], i);
|
||||
abs_func_list.push_back(async_func);
|
||||
}
|
||||
new_elements.push_back(AbstractFunction::MakeAbstractFunction(abs_func_list));
|
||||
} else {
|
||||
new_elements.push_back(orig_elements[i]);
|
||||
}
|
||||
}
|
||||
AbstractBasePtr new_first_result;
|
||||
if (first_result->isa<AbstractTuple>()) {
|
||||
new_first_result = std::make_shared<AbstractTuple>(new_elements);
|
||||
} else if (first_result->isa<AbstractList>()) {
|
||||
new_first_result = std::make_shared<AbstractList>(new_elements);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "FirstResult is not AbstractTuple or AbstractList, but: " << first_result->ToString();
|
||||
}
|
||||
MS_LOG(DEBUG) << GetInferThread() << " Try to replace old first with new one, old: " << first_result->ToString()
|
||||
<< ", new: " << new_first_result->ToString();
|
||||
std::replace_if(
|
||||
out_specs->begin(), out_specs->end(), [first_result](const auto &elem) { return elem == first_result; },
|
||||
new_first_result);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << GetInferThread() << " wait for normal async result";
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::vector<EvaluatorPtr> &evaluators,
|
||||
const AnfNodeConfigPtr &out_conf,
|
||||
const ConfigPtrList &args_conf_list) {
|
||||
|
@ -866,24 +929,7 @@ EvalResultPtr AnalysisEngine::ExecuteMultipleEvaluatorsMultiThread(const std::ve
|
|||
AbstractBasePtrList out_specs;
|
||||
size_t len = evaluators.size();
|
||||
if (NeedWaitForBranches(firstResult)) {
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
// shortcircuit begin;
|
||||
if (firstResult->isa<AbstractTuple>() && branchAsyncResults[i]->TryGetResult() == nullptr) {
|
||||
MS_LOG(DEBUG) << "Try to run shortcircuit abstract for evalator: " << evaluators[i]->ToString();
|
||||
const auto &result = evaluators[i]->RunShortCircuit(shared_from_this(), args_conf_list, out_conf);
|
||||
if (result != nullptr) {
|
||||
out_specs.push_back(result->abstract());
|
||||
MS_LOG(DEBUG) << "i: " << i << ", result: " << result->abstract()->ToString();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// shortcircuit end;
|
||||
|
||||
MS_LOG(DEBUG) << GetInferThread() << "async waiting for " << evaluators[i]->ToString();
|
||||
auto result = branchAsyncResults[i]->GetResult();
|
||||
MS_EXCEPTION_IF_NULL(result);
|
||||
out_specs.push_back(result);
|
||||
}
|
||||
BuildPossibleSpecs(firstResult, branchAsyncResults, &out_specs);
|
||||
} else {
|
||||
for (size_t i = 0; i < len; ++i) {
|
||||
// Not wait to get the result of branch.
|
||||
|
|
Loading…
Reference in New Issue