forked from mindspore-Ecosystem/mindspore
!9716 Add pynative second derivative function
From: @zjun3021 Reviewed-by: @chujinjin,@kisnwang Signed-off-by: @chujinjin
This commit is contained in:
commit
2c5123300f
|
@ -276,7 +276,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
|
|||
return node_adjoint;
|
||||
}
|
||||
|
||||
ValuePtr GenNewTensorInner(const ValuePtr &value) {
|
||||
ValuePtr DFunctor::GenNewTensorInner(const ValuePtr &value) {
|
||||
std::vector<ValuePtr> value_list;
|
||||
if (value->isa<tensor::Tensor>()) {
|
||||
auto tensor = value->cast<tensor::TensorPtr>();
|
||||
|
@ -294,11 +294,18 @@ ValuePtr GenNewTensorInner(const ValuePtr &value) {
|
|||
return value;
|
||||
}
|
||||
|
||||
ValuePtr GenNewTensor(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, const ValuePtr &value) {
|
||||
ValuePtr DFunctor::GenNewTensor(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, const ValuePtr &value,
|
||||
bool need_replace_forward) {
|
||||
ValuePtr out = value;
|
||||
auto ref_size = mng->node_users()[node].size();
|
||||
if (ref_size < 2) {
|
||||
out = GenNewTensorInner(value);
|
||||
if (need_replace_forward) {
|
||||
out = GenNewTensorInner(value);
|
||||
} else {
|
||||
auto tensor = value->cast<tensor::TensorPtr>();
|
||||
tensor->set_device_address(nullptr);
|
||||
return tensor;
|
||||
}
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
@ -333,8 +340,13 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor
|
|||
auto func_graph = GetValueNode<FuncGraphPtr>(input_fg);
|
||||
MS_EXCEPTION_IF_NULL(func_graph);
|
||||
auto manager = Manage({fg, func_graph}, false);
|
||||
|
||||
auto forward_value = GenNewTensor(manager, equivdout, forward);
|
||||
auto need_replace_forward = pynative::PynativeExecutor::GetInstance()->need_replace_forward();
|
||||
auto forward_value = GenNewTensor(manager, equivdout, forward, need_replace_forward);
|
||||
if (!need_replace_forward) {
|
||||
cnode_morph->clear_inputs_value();
|
||||
MS_LOG(DEBUG) << "No need replace forward result";
|
||||
return;
|
||||
}
|
||||
MS_LOG(DEBUG) << "Replace: " << equivdout->ToString() << " with " << forward;
|
||||
auto value_node = NewValueNode(forward_value);
|
||||
value_node->set_has_new_value(true);
|
||||
|
@ -373,7 +385,7 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor
|
|||
}
|
||||
auto out_node = c_input->cast<ValueNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(out_node);
|
||||
out_node->set_value(GenNewTensor(manager, out_node, out_node->value()));
|
||||
out_node->set_value(GenNewTensor(manager, out_node, out_node->value(), need_replace_forward));
|
||||
// clear resource
|
||||
cnode_morph->clear_inputs_value();
|
||||
fg->ClearAllManagerInfo();
|
||||
|
|
|
@ -99,9 +99,13 @@ class DFunctor : public std::enable_shared_from_this<DFunctor> {
|
|||
// Update k hole with adjoint_definition, only applied in recursive case.
|
||||
void UpdateAdjoint(const AdjointPtr &adjoint_definition);
|
||||
void CallDoutHoleOnTape();
|
||||
void ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_morph);
|
||||
// Replace the primal graph with k graph
|
||||
void EliminatePrimalGraph();
|
||||
// Pynative specialize
|
||||
void ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_morph);
|
||||
ValuePtr GenNewTensorInner(const ValuePtr &value);
|
||||
ValuePtr GenNewTensor(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, const ValuePtr &value,
|
||||
bool need_replace_forward);
|
||||
|
||||
std::unordered_map<AnfNodePtr, AdjointPtr> anfnode_to_adjoin_;
|
||||
// Cache for indirect fv backpropagation, K o K can only do backprop layer by layer.
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -58,7 +58,7 @@ py::tuple RunOp(const py::args &args);
|
|||
void ClearPyNativeSession();
|
||||
|
||||
struct GraphInfo {
|
||||
std::unordered_set<std::string> params; // hold input parameters and cell weigths
|
||||
std::unordered_map<std::string, AnfNodePtr> params; // hold input parameters and cell weigths
|
||||
std::unordered_map<std::string, std::pair<AnfNodePtr, std::vector<int64_t>>> node_map;
|
||||
AnfNodePtr output;
|
||||
std::vector<std::string> objects;
|
||||
|
@ -77,22 +77,22 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
PynativeExecutor(const PynativeExecutor &) = delete;
|
||||
PynativeExecutor &operator=(const PynativeExecutor &) = delete;
|
||||
|
||||
bool grad_flag() { return grad_flag_; }
|
||||
bool need_replace_forward() const { return need_replace_forward_; }
|
||||
bool grad_flag() const { return grad_flag_; }
|
||||
void set_grad_flag(bool flag) { grad_flag_ = flag; }
|
||||
|
||||
py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info);
|
||||
OpExecInfoPtr GenerateOpExecInfo(const py::args &args);
|
||||
void NewGraph(const py::object &cell, const py::args &args);
|
||||
py::object Run(const py::tuple &args, const py::object &phase);
|
||||
py::object Run(const py::object &cell, const py::tuple &args, const py::object &phase);
|
||||
py::object CheckGraph(const py::object &cell, const py::args &args);
|
||||
void EndGraph(const py::object &cell, const py::object &out, const py::args &args);
|
||||
void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args);
|
||||
|
||||
// Call by python
|
||||
void Clear(const std::string &flag = "");
|
||||
// Abnormal existed
|
||||
void Clean();
|
||||
// Destrcut call
|
||||
// Abnormal existed
|
||||
void ClearRes();
|
||||
// Sync stream
|
||||
void Sync();
|
||||
|
@ -100,7 +100,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
private:
|
||||
PynativeExecutor() = default;
|
||||
|
||||
// check cell struct
|
||||
// Check cell struct
|
||||
bool IsDynamicCell(const py::object &cell);
|
||||
std::string GetCellInfo(const py::object &cell);
|
||||
void ParseInputArgs(const std::shared_ptr<parse::ParseAst> &ast, const py::object &fn_node);
|
||||
|
@ -110,14 +110,13 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
bool ParseForExprNode(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node);
|
||||
std::string ParseNodeName(const std::shared_ptr<parse::ParseAst> &ast, const py::object &node,
|
||||
parse::AstMainType type);
|
||||
|
||||
py::object DoParamMixPrecisionCast(bool *is_cast, const py::object obj, const std::string &op_name, size_t index);
|
||||
py::object DoParamMixPrecisionCastTuple(bool *is_cast, const py::tuple tuple, const std::string &op_name,
|
||||
size_t index);
|
||||
py::object DoAutoCast(const py::object &arg, const TypeId &type_id, const std::string &op_name, size_t index);
|
||||
void DoSignatrueCast(const PrimitivePyPtr &prim, const std::map<SignatureEnumDType, TypeId> &dst_type,
|
||||
const std::vector<SignatureEnumDType> &dtypes, const OpExecInfoPtr &op_exec_info);
|
||||
// run op
|
||||
// Run op
|
||||
AnfNodePtr GetInput(const py::object &obj, bool op_mask);
|
||||
MsBackendPolicy InitEnv(const OpExecInfoPtr &op_exec_info);
|
||||
py::tuple RunOpWithInitBackendPolicy(const OpExecInfoPtr &op_exec_info);
|
||||
|
@ -134,7 +133,7 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
bool *is_find);
|
||||
void SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const AnfNodePtr &cnode);
|
||||
|
||||
// replace for grad graph
|
||||
// Replace for grad graph
|
||||
ValuePtr CleanTupleAddr(const ValueTuplePtr &tuple);
|
||||
void GenTupleMap(const ValueTuplePtr &tuple, std::map<std::string, tensor::TensorPtr> *t_map);
|
||||
void SaveAllResult(const OpExecInfoPtr &op_exec_info, const AnfNodePtr &node, const py::object &out_real);
|
||||
|
@ -143,34 +142,61 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
void SaveTensorsInValueNode(const ResourcePtr &resource);
|
||||
void CleanTensorsInValueNode();
|
||||
|
||||
// construct grad graph
|
||||
// Construct grad graph
|
||||
void PushCurrentGraphToStack();
|
||||
void PopGraphStack();
|
||||
FuncGraphPtr GetDfbuilder(const std::string &cell_id = "");
|
||||
ResourcePtr GetResource(const std::string &cell_id = "");
|
||||
void AddNestedGradCount() { ++grad_count_; }
|
||||
void SubNestedGradCount();
|
||||
bool IsNotNestedGrad() const;
|
||||
bool IsTopGraph(const std::string &cell_id);
|
||||
bool grad_running() const { return grad_is_running_; }
|
||||
void set_grad_runing(bool grad_runing) { grad_is_running_ = grad_runing; }
|
||||
void set_need_replace_forward(bool need_replace_forward) { need_replace_forward_ = need_replace_forward; }
|
||||
bool need_construct_graph() { return !graph_stack_.empty() && grad_flag_; }
|
||||
bool CheckCellGraph(const std::string &cell_id, bool is_grad = false);
|
||||
void UpdateCellGraph(const std::string &cell_id, bool need_cloned = false, bool is_grad = false);
|
||||
void NewGraphInner(const py::object &cell, const py::args &args);
|
||||
void MakeNewTopGraph(const string &cell_id, const py::args &args, const FuncGraphPtr &g);
|
||||
void EndGraphInner(const py::object &cell, const py::object &out, const py::args &args);
|
||||
void EndGraphByOutId(const std::string &out_id, const py::object &cell, const py::object &out, const py::args &args);
|
||||
FuncGraphPtr MakeGradGraph(const py::object &cell, const py::args &args);
|
||||
void EndGraphByOutId(const py::object &cell, const std::string &cell_id, const py::object &out,
|
||||
const std::string &out_id, const py::args &args);
|
||||
FuncGraphPtr MakeGradGraph(const py::object &cell, const py::args &args, const FuncGraphPtr &g, const ResourcePtr &r,
|
||||
bool is_top);
|
||||
std::string GetGradCellId(bool has_sens, const py::object &cell, const py::args &args, py::object *forward_args,
|
||||
py::object *sens = nullptr);
|
||||
void GradNetInner(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
|
||||
const py::args &args);
|
||||
std::string GetCellId(const py::object &obj, const py::args &args);
|
||||
std::string CheckCellChanged(const GradOperationPtr &grad, const py::object &cell, const py::object &weights,
|
||||
const py::args &args, std::pair<bool, bool> *sens_weights_changed);
|
||||
void SetGradGraphParams(size_t size, const std::string &cell_id, const std::pair<bool, bool> &sens_weights_changed);
|
||||
void GradGraph(FuncGraphPtr g, const GradOperationPtr &grad_op, const std::vector<AnfNodePtr> &weights,
|
||||
size_t arg_size);
|
||||
std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights);
|
||||
abstract::AbstractBasePtrList GetArgsSpec(const py::args &args);
|
||||
std::pair<bool, bool> CheckCellChanged(const std::string &cell_id, const py::object &weights, const py::object &sens);
|
||||
void SetGradGraphParams(const FuncGraphPtr &df_builder, const ResourcePtr &resource, size_t size);
|
||||
bool CloneDfbuiler(const std::string &cell_id, const FuncGraphPtr &df_builder, const ResourcePtr &resource);
|
||||
void GradGraph(const FuncGraphPtr &g, const GradOperationPtr &grad_op, const std::vector<AnfNodePtr> &weights,
|
||||
size_t arg_size, const std::string &cell_id);
|
||||
std::vector<AnfNodePtr> GetWeightsArgs(const py::object &weights, const FuncGraphPtr &df_builder);
|
||||
abstract::AbstractBasePtrList GetArgsSpec(const py::args &args, const FuncGraphPtr &df_builder);
|
||||
void UpdateGraphInfoMap(const std::string &cell_id);
|
||||
void SetNestedTopGraph(const py::object &cell, const py::args &args, const std::string &cell_id);
|
||||
void MakeNestedCnode(const std::string &cell_id, const py::args &args, const ResourcePtr &resource,
|
||||
const py::object &out, bool has_sens);
|
||||
|
||||
// hold graph(forward and grad) info
|
||||
void SetPyObjInGraphInfoMap(FuncGraphPtr g, const std::string obj) { graph_info_map_[g].objects.push_back(obj); }
|
||||
// Hold graph(forward and grad) info
|
||||
void SetPyObjInGraphInfoMap(const FuncGraphPtr &g, const std::string &obj) {
|
||||
graph_info_map_[g].second.objects.push_back(obj);
|
||||
}
|
||||
void SetTupleArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &args, const AnfNodePtr &node,
|
||||
bool is_param = false);
|
||||
void SetNodeMapInGraphInfoMap(FuncGraphPtr g, const std::string id, AnfNodePtr node, int64_t index = -1) {
|
||||
graph_info_map_[g].node_map[id] = std::make_pair(node, std::vector<int64_t>{index});
|
||||
void SetParamNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const ParameterPtr ¶m) {
|
||||
graph_info_map_[g].second.params.emplace(std::make_pair(id, param));
|
||||
}
|
||||
void SetNodeMapInGraphInfoMap(FuncGraphPtr g, const std::string id, AnfNodePtr node, std::vector<int64_t> index) {
|
||||
graph_info_map_[g].node_map[id] = std::make_pair(node, index);
|
||||
void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node,
|
||||
int64_t index = -1) {
|
||||
graph_info_map_[g].second.node_map[id] = std::make_pair(node, std::vector<int64_t>{index});
|
||||
}
|
||||
void SetNodeMapInGraphInfoMap(const FuncGraphPtr &g, const std::string &id, const AnfNodePtr &node,
|
||||
const std::vector<int64_t> &index) {
|
||||
graph_info_map_[g].second.node_map[id] = std::make_pair(node, index);
|
||||
}
|
||||
void SetTupleItemArgsToGraphInfoMap(const FuncGraphPtr &g, const py::object &id, const AnfNodePtr &node,
|
||||
const std::vector<int64_t> &index_sequence, bool is_param = false);
|
||||
|
@ -178,39 +204,35 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
|||
static std::shared_ptr<PynativeExecutor> executor_;
|
||||
static std::mutex instance_lock_;
|
||||
static int64_t graph_id_;
|
||||
int64_t grad_count_{0};
|
||||
bool grad_flag_{false};
|
||||
bool dynamic_cell_{false};
|
||||
bool grad_is_running{false};
|
||||
bool grad_is_running_{false};
|
||||
bool need_replace_forward_{true};
|
||||
|
||||
// Used for construct grad graph
|
||||
FuncGraphPtr top_g_{nullptr};
|
||||
FuncGraphPtr curr_g_{nullptr};
|
||||
FuncGraphPtr df_builder_{nullptr};
|
||||
ResourcePtr resource_{nullptr};
|
||||
// Records forwrad graph, the bottom is top graph
|
||||
std::stack<FuncGraphPtr> graph_stack_;
|
||||
std::unordered_set<std::string> top_graph_cells_;
|
||||
|
||||
// record all info of a graph
|
||||
std::unordered_map<FuncGraphPtr, GraphInfo> graph_info_map_;
|
||||
std::unordered_set<std::string> cell_input_args_;
|
||||
std::unordered_map<std::string, bool> cell_dynamic_map_;
|
||||
std::unordered_map<std::string, ResourcePtr> cell_resource_map_;
|
||||
std::unordered_map<std::string, std::pair<FuncGraphPtr, bool>> cell_graph_map_;
|
||||
// key: cell_id, value: (send_id, weigths_id), cache for sens and weight change
|
||||
// Record all info for all cells
|
||||
std::unordered_map<FuncGraphPtr, std::pair<std::string, GraphInfo>> graph_info_map_;
|
||||
// key: cell_id, value: (send_id, weighs_id), cache for sens and weight change
|
||||
std::unordered_map<std::string, std::pair<std::string, std::string>> cell_sw_map_;
|
||||
// key: cell_id, value: (forward graph, grad graph)
|
||||
std::unordered_map<std::string, std::pair<FuncGraphPtr, FuncGraphPtr>> df_builder_map_;
|
||||
// key: cell_id, value: (forward graph, whether grad), use vector for keep order
|
||||
std::vector<std::pair<std::string, std::pair<FuncGraphPtr, bool>>> cell_graph_list_;
|
||||
// key: cell_id, value: (resource, (df_builder, grad graph), use vector for keep order
|
||||
std::vector<std::pair<std::string, std::pair<ResourcePtr, std::pair<FuncGraphPtr, FuncGraphPtr>>>> top_cell_list_;
|
||||
|
||||
// used for runop and replace forward result of grad graph
|
||||
// Used for runop and replace forward result of grad graph
|
||||
std::unordered_map<std::string, size_t> op_index_map_;
|
||||
std::unordered_map<std::string, std::string> obj_to_forward_id_;
|
||||
std::unordered_map<std::string, std::vector<std::string>> op_index_with_tensor_id_;
|
||||
std::unordered_map<std::string, std::vector<tensor::TensorPtr>> tensor_id_with_tensor_;
|
||||
std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
|
||||
std::unordered_map<std::string, AbstractListMap> prim_abs_list_;
|
||||
const inline static std::string kOpsFunctionModelName = "mindspore.ops.functional";
|
||||
const inline static std::string kMSDtypeModelName = "mindspore.common.dtype";
|
||||
};
|
||||
|
||||
using PynativeExecutorPtr = std::shared_ptr<PynativeExecutor>;
|
||||
|
|
|
@ -311,8 +311,8 @@ class _PynativeExecutor:
|
|||
def grad(self, grad, obj, weights, *args, **kwargs):
|
||||
self._executor.grad_net(grad, obj, weights, *args, *(kwargs.values()))
|
||||
|
||||
def clear(self, flag=""):
|
||||
self._executor.clear(flag)
|
||||
def clear(self, cell_id=""):
|
||||
self._executor.clear(cell_id)
|
||||
|
||||
def sync(self):
|
||||
self._executor.sync()
|
||||
|
@ -320,9 +320,9 @@ class _PynativeExecutor:
|
|||
def set_grad_flag(self, flag):
|
||||
self._executor.set_grad_flag(flag)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
def __call__(self, obj, *args, **kwargs):
|
||||
args = args + tuple(kwargs.values())
|
||||
return self._executor(args, "")
|
||||
return self._executor(obj, args, "")
|
||||
|
||||
|
||||
class _Executor:
|
||||
|
|
|
@ -347,7 +347,7 @@ class GradOperation(GradOperation_):
|
|||
fn.already_run = False
|
||||
self._pynative_forward_run(args, kwargs, fn)
|
||||
_pynative_exec.grad(grad_, fn, weights, *args, **kwargs)
|
||||
out = _pynative_exec(*args, **kwargs)
|
||||
out = _pynative_exec(fn, *args, **kwargs)
|
||||
_pynative_exec.clear()
|
||||
return out
|
||||
self.grad_fn = after_grad
|
||||
|
|
Loading…
Reference in New Issue