forked from mindspore-Ecosystem/mindspore
!5918 optimizer pynative memory
Merge pull request !5918 from flywind/optimizer_pynative_memory
This commit is contained in:
commit
c57a472748
|
@ -390,18 +390,7 @@ void AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_i
|
||||||
// run op
|
// run op
|
||||||
Execute(graph, false);
|
Execute(graph, false);
|
||||||
// get output
|
// get output
|
||||||
if (op_run_info.value != nullptr) {
|
UpdateOutputs(graph, outputs, input_tensors);
|
||||||
std::vector<tensor::TensorPtr> pre_output_tensors;
|
|
||||||
TensorValueToTensor(op_run_info.value, &pre_output_tensors);
|
|
||||||
for (auto &pre_output : pre_output_tensors) {
|
|
||||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(pre_output->data_type(), pre_output->shape());
|
|
||||||
tensor->set_device_address(pre_output->device_address());
|
|
||||||
tensor->set_sync_status(kNoNeedSync);
|
|
||||||
outputs->emplace_back(tensor);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
UpdateOutputs(graph, outputs, input_tensors);
|
|
||||||
}
|
|
||||||
RunOpMemoryClear(graph.get());
|
RunOpMemoryClear(graph.get());
|
||||||
MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!";
|
MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!";
|
||||||
}
|
}
|
||||||
|
|
|
@ -337,18 +337,7 @@ void GPUSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info
|
||||||
LoadInputData(kernel_graph, input_tensors);
|
LoadInputData(kernel_graph, input_tensors);
|
||||||
Execute(kernel_graph);
|
Execute(kernel_graph);
|
||||||
// Fetch outputs
|
// Fetch outputs
|
||||||
if (op_run_info.value != nullptr) {
|
UpdateOutputs(kernel_graph, outputs, input_tensors);
|
||||||
std::vector<tensor::TensorPtr> pre_output_tensors;
|
|
||||||
TensorValueToTensor(op_run_info.value, &pre_output_tensors);
|
|
||||||
for (auto &pre_output : pre_output_tensors) {
|
|
||||||
tensor::TensorPtr tensor = std::make_shared<tensor::Tensor>(pre_output->data_type(), pre_output->shape());
|
|
||||||
tensor->set_device_address(pre_output->device_address());
|
|
||||||
tensor->set_sync_status(kNoNeedSync);
|
|
||||||
outputs->emplace_back(tensor);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
UpdateOutputs(kernel_graph, outputs, input_tensors);
|
|
||||||
}
|
|
||||||
RunOpClearMemory(kernel_graph.get());
|
RunOpClearMemory(kernel_graph.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -30,6 +30,8 @@
|
||||||
#include "frontend/operator/ops.h"
|
#include "frontend/operator/ops.h"
|
||||||
#include "utils/symbolic.h"
|
#include "utils/symbolic.h"
|
||||||
#include "utils/ms_context.h"
|
#include "utils/ms_context.h"
|
||||||
|
#include "pipeline/jit/action.h"
|
||||||
|
#include "pipeline/jit/parse/resolve.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace ad {
|
namespace ad {
|
||||||
|
@ -183,6 +185,7 @@ void DFunctor::BackPropagate(const CNodePtr &cnode_morph, const CNodePtr &k_app,
|
||||||
|
|
||||||
// Map a morphism.
|
// Map a morphism.
|
||||||
AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
|
AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
|
||||||
|
MS_LOG(DEBUG) << "start MapMorphism:" << morph->DebugString(4);
|
||||||
// MapMorphism All type except CNode should already be mapped by MapObject.
|
// MapMorphism All type except CNode should already be mapped by MapObject.
|
||||||
if (!morph->isa<CNode>()) {
|
if (!morph->isa<CNode>()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -238,9 +241,54 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
|
||||||
|
|
||||||
// Do sens backpropagation
|
// Do sens backpropagation
|
||||||
BackPropagate(cnode_morph, k_app, node_adjoint);
|
BackPropagate(cnode_morph, k_app, node_adjoint);
|
||||||
MS_LOG(DEBUG) << "MapMorphism node " << morph->ToString() << ".";
|
MS_LOG(DEBUG) << "MapMorphism node " << morph->DebugString(4) << ".";
|
||||||
return node_adjoint;
|
return node_adjoint;
|
||||||
}
|
}
|
||||||
|
void TensorSetAddress(const ValuePtr &value, std::map<std::string, tensor::TensorPtr> *tuple_tensors) {
|
||||||
|
MS_LOG(DEBUG) << "Start set tensor address" << value->ToString() << value->isa<tensor::Tensor>();
|
||||||
|
if (value->isa<tensor::Tensor>()) {
|
||||||
|
auto tnode = value->cast<tensor::TensorPtr>();
|
||||||
|
if (tuple_tensors->find(tnode->id()) != tuple_tensors->end()) {
|
||||||
|
MS_LOG(DEBUG) << "Set tensor" << tnode->device_address();
|
||||||
|
(*tuple_tensors)[tnode->id()]->set_device_address(tnode->device_address());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (value->isa<ValueTuple>()) {
|
||||||
|
auto tuple = value->cast<ValueTuplePtr>();
|
||||||
|
for (size_t i = 0; i < tuple->size(); i++) {
|
||||||
|
MS_LOG(DEBUG) << "Set tuple tensor" << (*tuple)[i]->ToString();
|
||||||
|
TensorSetAddress((*tuple)[i], tuple_tensors);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ValuePtr GenNewTensorInner(const ValuePtr &value) {
|
||||||
|
std::vector<ValuePtr> value_list;
|
||||||
|
if (value->isa<tensor::Tensor>()) {
|
||||||
|
auto tensor = value->cast<tensor::TensorPtr>();
|
||||||
|
// return std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape());
|
||||||
|
auto new_tensor = std::make_shared<tensor::Tensor>(*tensor);
|
||||||
|
new_tensor->set_device_address(nullptr);
|
||||||
|
return new_tensor;
|
||||||
|
}
|
||||||
|
if (value->isa<ValueTuple>()) {
|
||||||
|
auto tuple = value->cast<ValueTuplePtr>();
|
||||||
|
for (size_t i = 0; i < tuple->size(); i++) {
|
||||||
|
value_list.push_back(GenNewTensorInner((*tuple)[i]));
|
||||||
|
}
|
||||||
|
return std::make_shared<ValueTuple>(value_list);
|
||||||
|
}
|
||||||
|
return value;
|
||||||
|
}
|
||||||
|
|
||||||
|
ValuePtr GenNewTensor(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, const ValuePtr &value) {
|
||||||
|
ValuePtr out = value;
|
||||||
|
auto ref_size = mng->node_users()[node].size();
|
||||||
|
if (ref_size < 2) {
|
||||||
|
out = GenNewTensorInner(value);
|
||||||
|
}
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_morph) {
|
void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_morph) {
|
||||||
auto forward = cnode_morph->forward().first;
|
auto forward = cnode_morph->forward().first;
|
||||||
|
@ -266,6 +314,7 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor
|
||||||
if (!IsValueNode<FuncGraph>(input_fg)) {
|
if (!IsValueNode<FuncGraph>(input_fg)) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
std::map<std::string, tensor::TensorPtr> tuple_tensors;
|
||||||
auto equivdout = cnode_input->cast<CNodePtr>();
|
auto equivdout = cnode_input->cast<CNodePtr>();
|
||||||
auto func_graph = GetValueNode<FuncGraphPtr>(input_fg);
|
auto func_graph = GetValueNode<FuncGraphPtr>(input_fg);
|
||||||
auto manager = Manage({fg, func_graph}, false);
|
auto manager = Manage({fg, func_graph}, false);
|
||||||
|
@ -273,15 +322,10 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor
|
||||||
auto forward_value = forward;
|
auto forward_value = forward;
|
||||||
if (!forward_id.empty() && ref_size > 1) {
|
if (!forward_id.empty() && ref_size > 1) {
|
||||||
auto inst = pynative::PynativeExecutor::GetInstance();
|
auto inst = pynative::PynativeExecutor::GetInstance();
|
||||||
inst->SaveOpForwardValue(forward_id, forward_value);
|
inst->SaveOpForwardValue(forward_id, forward_value, &tuple_tensors);
|
||||||
}
|
|
||||||
if (ref_size < 2) {
|
|
||||||
auto tensor = forward->cast<tensor::TensorPtr>();
|
|
||||||
if (tensor != nullptr) {
|
|
||||||
auto new_tensor = std::make_shared<tensor::Tensor>(tensor->data_type(), tensor->shape());
|
|
||||||
forward_value = new_tensor;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
forward_value = GenNewTensor(manager, equivdout, forward);
|
||||||
MS_LOG(DEBUG) << "Replace: " << equivdout->ToString() << " with " << forward;
|
MS_LOG(DEBUG) << "Replace: " << equivdout->ToString() << " with " << forward;
|
||||||
auto value_node = NewValueNode(forward_value);
|
auto value_node = NewValueNode(forward_value);
|
||||||
value_node->set_has_new_value(true);
|
value_node->set_has_new_value(true);
|
||||||
|
@ -300,13 +344,43 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor
|
||||||
if (para_ref_size > 0 && input_value.first != nullptr) {
|
if (para_ref_size > 0 && input_value.first != nullptr) {
|
||||||
MS_LOG(DEBUG) << "Replace: " << paras[i]->ToString() << " with " << input_value.first;
|
MS_LOG(DEBUG) << "Replace: " << paras[i]->ToString() << " with " << input_value.first;
|
||||||
auto inst = pynative::PynativeExecutor::GetInstance();
|
auto inst = pynative::PynativeExecutor::GetInstance();
|
||||||
inst->SaveOpForwardValue(input_value.second, input_value.first);
|
if (!input_value.second.empty()) {
|
||||||
|
inst->SaveOpForwardValue(input_value.second, input_value.first, &tuple_tensors);
|
||||||
|
}
|
||||||
auto input_value_node = NewValueNode(input_value.first);
|
auto input_value_node = NewValueNode(input_value.first);
|
||||||
input_value_node->set_has_new_value(true);
|
input_value_node->set_has_new_value(true);
|
||||||
manager->Replace(paras[i], input_value_node);
|
manager->Replace(paras[i], input_value_node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
MS_LOG(DEBUG) << "Start opt node" << fg->output()->DebugString(4);
|
||||||
|
auto res = std::make_shared<pipeline::Resource>();
|
||||||
|
res->set_manager(manager);
|
||||||
|
res->set_func_graph(fg);
|
||||||
|
PynativeElimOpt(res);
|
||||||
|
auto out = fg->output()->cast<CNodePtr>();
|
||||||
|
auto c_input = out->input(1);
|
||||||
|
if (!c_input->isa<ValueNode>()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto out_node = c_input->cast<ValueNodePtr>();
|
||||||
|
out_node->set_value(GenNewTensor(manager, out_node, out_node->value()));
|
||||||
|
|
||||||
cnode_morph->clear_inputs_value();
|
cnode_morph->clear_inputs_value();
|
||||||
|
|
||||||
|
if (tuple_tensors.size() != 0) {
|
||||||
|
MS_LOG(DEBUG) << "Start tuple out" << fg->output()->DebugString(4);
|
||||||
|
for (auto &g : manager->func_graphs()) {
|
||||||
|
for (auto &node : g->value_nodes()) {
|
||||||
|
MS_LOG(DEBUG) << "Set Tensor addr" << node.first->ToString();
|
||||||
|
auto vnode = node.first->cast<ValueNodePtr>()->value();
|
||||||
|
TensorSetAddress(vnode, &tuple_tensors);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fg->ClearAllManagerInfo();
|
||||||
|
func_graph->ClearAllManagerInfo();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -59,6 +59,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() {
|
||||||
MakeSubstitution(std::make_shared<SpecialOpEliminater>(), "special_op_eliminate",
|
MakeSubstitution(std::make_shared<SpecialOpEliminater>(), "special_op_eliminate",
|
||||||
{prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward,
|
{prim::kPrimInsertGradientOf, prim::kPrimStopGradient, prim::kPrimHookBackward,
|
||||||
prim::kPrimPrintShapeType, prim::kPrimGetRefValue, prim::kPrimMirror, prim::kPrimVirtualDiv});
|
prim::kPrimPrintShapeType, prim::kPrimGetRefValue, prim::kPrimMirror, prim::kPrimVirtualDiv});
|
||||||
|
pynative_eliminate_ = MakeSubstitution(std::make_shared<PynativeEliminater>(), "pynative_eliminate", IsCNodeDup);
|
||||||
zero_like_fill_zero_ =
|
zero_like_fill_zero_ =
|
||||||
MakeSubstitution(std::make_shared<ZeroLikeFillZero>(), "zero_like_fill_zero", prim::kPrimZerosLike);
|
MakeSubstitution(std::make_shared<ZeroLikeFillZero>(), "zero_like_fill_zero", prim::kPrimZerosLike);
|
||||||
adjust_all_reduce_mul_add_ =
|
adjust_all_reduce_mul_add_ =
|
||||||
|
|
|
@ -123,6 +123,9 @@ class OptimizeIRPassLib {
|
||||||
|
|
||||||
// SwitchLayer defer inline
|
// SwitchLayer defer inline
|
||||||
SubstitutionPtr switch_layer_defer_inline_;
|
SubstitutionPtr switch_layer_defer_inline_;
|
||||||
|
|
||||||
|
// Pynative Eliminate
|
||||||
|
SubstitutionPtr pynative_eliminate_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// the collection of irpass for resolve action
|
// the collection of irpass for resolve action
|
||||||
|
|
|
@ -21,6 +21,7 @@
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#include "frontend/optimizer/optimizer_caller.h"
|
#include "frontend/optimizer/optimizer_caller.h"
|
||||||
#include "ir/pattern_matcher.h"
|
#include "ir/pattern_matcher.h"
|
||||||
|
@ -31,6 +32,7 @@
|
||||||
#include "frontend/optimizer/optimizer.h"
|
#include "frontend/optimizer/optimizer.h"
|
||||||
#include "utils/comm_manager.h"
|
#include "utils/comm_manager.h"
|
||||||
#include "frontend/parallel/context.h"
|
#include "frontend/parallel/context.h"
|
||||||
|
#include "pipeline/jit/parse/resolve.h"
|
||||||
|
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace opt {
|
namespace opt {
|
||||||
|
@ -206,6 +208,153 @@ class DependValueElim : public OptimizerCaller {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// {{prim:getattr, {prim::resolve, SymbolStr, C}, zeros_like}, Xy} ->Tensor(0, shape(Xy))
|
||||||
|
// {prim:getattr, {prim::resolve, SymbolStr, zeros_like}, Xy} ->Tensor(0, shape(Xy))
|
||||||
|
// {{prim::resolve, CommonOPS, getitem}, (tensor0, tensor1,...), 0} -> tensor0
|
||||||
|
class PynativeEliminater : public OptimizerCaller {
|
||||||
|
bool CheckNameSpaceVNode(const AnfNodePtr &node, const std::string &str_value) {
|
||||||
|
ValueNodePtr value_node = node->cast<ValueNodePtr>();
|
||||||
|
if (value_node == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return GetValueNode<parse::NameSpacePtr>(value_node)->module() == str_value;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool CheckSymbolVNode(const AnfNodePtr &node, const std::string &str_value) {
|
||||||
|
ValueNodePtr value_node = node->cast<ValueNodePtr>();
|
||||||
|
if (value_node == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return GetValueNode<parse::SymbolPtr>(value_node)->symbol() == str_value;
|
||||||
|
}
|
||||||
|
bool CheckStrVNode(const AnfNodePtr &node, const std::string &str_value) {
|
||||||
|
ValueNodePtr value_node = node->cast<ValueNodePtr>();
|
||||||
|
if (value_node == nullptr) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return GetValueNode<StringImmPtr>(value_node)->value() == str_value;
|
||||||
|
}
|
||||||
|
|
||||||
|
ValuePtr FillGetItem(const ValuePtr &value, const ValuePtr &idx) {
|
||||||
|
MS_LOG(DEBUG) << "Start FillGetItem" << value->ToString() << idx->ToString();
|
||||||
|
if (!idx->isa<Int32Imm>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Getitem idx must int:" << idx->ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!value->isa<ValueTuple>()) {
|
||||||
|
MS_LOG(EXCEPTION) << "Getitem value must tuple:" << value->ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto value_tuple = value->cast<ValueTuplePtr>();
|
||||||
|
int idx_t = idx->cast<Int32ImmPtr>()->value();
|
||||||
|
MS_LOG(DEBUG) << "Fill getitem" << idx_t << (*value_tuple)[idx_t]->ToString();
|
||||||
|
return (*value_tuple)[idx_t];
|
||||||
|
}
|
||||||
|
|
||||||
|
ValuePtr FillZero(const ValuePtr &value) {
|
||||||
|
MS_LOG(DEBUG) << "Start FillZero";
|
||||||
|
ValuePtr out = nullptr;
|
||||||
|
if (value->isa<Int32Imm>()) {
|
||||||
|
return MakeValue(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (value->isa<tensor::Tensor>()) {
|
||||||
|
MS_LOG(DEBUG) << "Start FillZero Tensor";
|
||||||
|
auto tensor = value->cast<tensor::TensorPtr>();
|
||||||
|
tensor::TensorPtr out_t = std::make_shared<tensor::Tensor>(tensor->Dtype()->type_id(), tensor->shape());
|
||||||
|
char *data = reinterpret_cast<char *>(out_t->data_c());
|
||||||
|
std::fill(data, data + out_t->data().nbytes(), 0);
|
||||||
|
out = out_t;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<ValuePtr> value_list;
|
||||||
|
if (value->isa<ValueTuple>()) {
|
||||||
|
MS_LOG(DEBUG) << "Start FillZero Tuple" << value->ToString();
|
||||||
|
auto value_tuple = value->cast<ValueTuplePtr>();
|
||||||
|
for (size_t i = 0; i < value_tuple->size(); i++) {
|
||||||
|
value_list.push_back(FillZero((*value_tuple)[i]));
|
||||||
|
}
|
||||||
|
out = std::make_shared<ValueTuple>(value_list);
|
||||||
|
}
|
||||||
|
if (out == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "FillZero failed:" << value->ToString();
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "Result: " << out->ToString();
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||||
|
MS_LOG(DEBUG) << "Start replace node " << node->DebugString(4);
|
||||||
|
PatternNode<AnfNodePtr> symbol_str_vnode, c_vnode, zeros_like_vnode, getitem_vnode, arg, arg1;
|
||||||
|
auto resolve = PPrimitive(prim::kPrimResolve, symbol_str_vnode, c_vnode);
|
||||||
|
auto getattr = PPrimitive(prim::kPrimGetAttr, resolve, zeros_like_vnode);
|
||||||
|
auto pattern = PCNode(getattr, arg);
|
||||||
|
|
||||||
|
if ((pattern).TryCapture(node) &&
|
||||||
|
(CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") &&
|
||||||
|
CheckSymbolVNode(c_vnode.GetNode(node), "C") && CheckStrVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) {
|
||||||
|
auto rep = (arg).GetNode(node);
|
||||||
|
if (rep != nullptr) {
|
||||||
|
if (rep->isa<ValueNode>()) {
|
||||||
|
auto value_node = rep->cast<ValueNodePtr>();
|
||||||
|
value_node->set_value(FillZero(value_node->value()));
|
||||||
|
MS_LOG(DEBUG) << "Zeros_like replace ok " << rep->DebugString(4);
|
||||||
|
return rep;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(DEBUG) << "End replace 1 " << node->DebugString(4);
|
||||||
|
auto resolve1 = PPrimitive(prim::kPrimResolve, symbol_str_vnode, zeros_like_vnode);
|
||||||
|
auto pattern1 = PCNode(resolve1, arg);
|
||||||
|
|
||||||
|
if ((pattern1).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") &&
|
||||||
|
CheckSymbolVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) {
|
||||||
|
auto rep = (arg).GetNode(node);
|
||||||
|
if (rep != nullptr) {
|
||||||
|
if (rep->isa<ValueNode>()) {
|
||||||
|
auto value_node = rep->cast<ValueNodePtr>();
|
||||||
|
value_node->set_value(FillZero(value_node->value()));
|
||||||
|
MS_LOG(DEBUG) << "Zeros_like replace ok 2 " << rep->DebugString(4);
|
||||||
|
return rep;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolve(CommonOPS, getitem)((tensors), 3)
|
||||||
|
auto resolve2 = PPrimitive(prim::kPrimResolve, symbol_str_vnode, getitem_vnode);
|
||||||
|
auto pattern2 = PCNode(resolve2, arg, arg1);
|
||||||
|
if ((pattern2).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "CommonOPS") &&
|
||||||
|
CheckSymbolVNode(getitem_vnode.GetNode(node), "getitem"))) {
|
||||||
|
auto rep = (arg).GetNode(node);
|
||||||
|
if (rep != nullptr) {
|
||||||
|
if (rep->isa<ValueNode>()) {
|
||||||
|
MS_LOG(DEBUG) << "Rep is " << rep->DebugString(4);
|
||||||
|
ValueNodePtr new_node;
|
||||||
|
auto value_node = rep->cast<ValueNodePtr>();
|
||||||
|
auto rep1 = (arg1).GetNode(node);
|
||||||
|
if (rep1 != nullptr) {
|
||||||
|
if (rep1->isa<ValueNode>()) {
|
||||||
|
auto idx = rep1->cast<ValueNodePtr>();
|
||||||
|
if (!value_node->value()->isa<ValueTuple>()) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
new_node = NewValueNode(FillGetItem(value_node->value(), idx->value()));
|
||||||
|
new_node->set_has_new_value(value_node->has_new_value());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "Fill getitem replace ok " << new_node->DebugString(4);
|
||||||
|
return new_node;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
MS_LOG(DEBUG) << "End Replace " << node->DebugString(4);
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
class AllReduceConstElim : public OptimizerCaller {
|
class AllReduceConstElim : public OptimizerCaller {
|
||||||
public:
|
public:
|
||||||
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override {
|
||||||
|
|
|
@ -185,9 +185,11 @@ class Optimizer : public std::enable_shared_from_this<Optimizer> {
|
||||||
MS_LOG(DEBUG) << "The opt " << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end.";
|
MS_LOG(DEBUG) << "The opt " << name_ << " round " << counter << " OptPass " << pass_names_[i] << " end.";
|
||||||
auto fg_name =
|
auto fg_name =
|
||||||
"opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i];
|
"opt_substep_" + name_ + "_r" + std::to_string(counter) + "_" + std::to_string(i) + "_" + pass_names_[i];
|
||||||
func_graph->DumpFuncGraph(fg_name);
|
|
||||||
DumpIR(fg_name + ".ir", func_graph);
|
DumpIR(fg_name + ".ir", func_graph);
|
||||||
ExportIR(fg_name + ".dat", "", func_graph);
|
if (MsContext::GetInstance()->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) {
|
||||||
|
func_graph->DumpFuncGraph(fg_name);
|
||||||
|
ExportIR(fg_name + ".dat", "", func_graph);
|
||||||
|
}
|
||||||
MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph.";
|
MS_LOG(DEBUG) << "Dump " << pass_names_[i] << " func graph.";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -314,6 +314,16 @@ bool VmOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kVmPa
|
||||||
|
|
||||||
bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kPynativePasses); }
|
bool PynativeOptimizeAction(const ResourcePtr &res) { return OptimizeAction(res, kPynativePasses); }
|
||||||
|
|
||||||
|
bool PynativeElimOpt(const ResourcePtr &res) {
|
||||||
|
if (res->manager() == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "PynativeElimOpt error, manager is null.";
|
||||||
|
}
|
||||||
|
if (res->func_graph() == nullptr) {
|
||||||
|
MS_LOG(EXCEPTION) << "PynativeElimOpt error, graph is null.";
|
||||||
|
}
|
||||||
|
return PynativeOptPass(res);
|
||||||
|
}
|
||||||
|
|
||||||
static bool IsCtrlSink() {
|
static bool IsCtrlSink() {
|
||||||
auto ms_ctx = MsContext::GetInstance();
|
auto ms_ctx = MsContext::GetInstance();
|
||||||
if (ms_ctx->get_param<int>(MS_CTX_EXECUTION_MODE) != kGraphMode) {
|
if (ms_ctx->get_param<int>(MS_CTX_EXECUTION_MODE) != kGraphMode) {
|
||||||
|
|
|
@ -36,6 +36,7 @@ bool AbstractSpecializeAction(const ResourcePtr &res);
|
||||||
bool GeOptimizeAction(const ResourcePtr &res);
|
bool GeOptimizeAction(const ResourcePtr &res);
|
||||||
bool VmOptimizeAction(const ResourcePtr &res);
|
bool VmOptimizeAction(const ResourcePtr &res);
|
||||||
bool PynativeOptimizeAction(const ResourcePtr &res);
|
bool PynativeOptimizeAction(const ResourcePtr &res);
|
||||||
|
bool PynativeElimOpt(const ResourcePtr &res);
|
||||||
bool TaskEmitAction(const ResourcePtr &res);
|
bool TaskEmitAction(const ResourcePtr &res);
|
||||||
bool ExecuteAction(const ResourcePtr &res);
|
bool ExecuteAction(const ResourcePtr &res);
|
||||||
bool StartPSWorkerAction(const ResourcePtr &res);
|
bool StartPSWorkerAction(const ResourcePtr &res);
|
||||||
|
|
|
@ -215,6 +215,17 @@ OptPassGroupMap GetOptPassesB(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||||
return map;
|
return map;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
OptPassGroupMap GetOptPassesPynativeElim(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||||
|
opt::OptPassConfig pynative_eliminate = opt::OptPassConfig({
|
||||||
|
irpass.pynative_eliminate_,
|
||||||
|
});
|
||||||
|
|
||||||
|
OptPassGroupMap map({
|
||||||
|
{"pynative_eliminate", pynative_eliminate},
|
||||||
|
});
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
|
||||||
OptPassGroupMap GetOptPassesGraphKernelA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
OptPassGroupMap GetOptPassesGraphKernelA(const opt::irpass::OptimizeIRPassLib &irpass) {
|
||||||
opt::OptPassConfig interface_fusion = opt::OptPassConfig({
|
opt::OptPassConfig interface_fusion = opt::OptPassConfig({
|
||||||
irpass.mark_interface_fusion_,
|
irpass.mark_interface_fusion_,
|
||||||
|
@ -422,6 +433,16 @@ bool InferenceOptPreparePass(const ResourcePtr &res) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool PynativeOptPass(const ResourcePtr &res) {
|
||||||
|
FuncGraphPtr func_graph = res->func_graph();
|
||||||
|
MS_EXCEPTION_IF_NULL(func_graph);
|
||||||
|
opt::irpass::OptimizeIRPassLib irpass;
|
||||||
|
auto pynative_opt = GetOptPassesPynativeElim(irpass);
|
||||||
|
auto pynative_opt_opt = opt::Optimizer::MakeOptimizer("pynative_opt", res, pynative_opt);
|
||||||
|
(void)pynative_opt_opt->step(func_graph, false);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
std::vector<PassItem> kVmPasses = {{"simplify_data_structures", SimplifyDataStructuresPass},
|
||||||
{"opt_a", OptPassAGroup},
|
{"opt_a", OptPassAGroup},
|
||||||
{"clean_after_opta", CleanAfterOptAPass},
|
{"clean_after_opta", CleanAfterOptAPass},
|
||||||
|
|
|
@ -38,6 +38,7 @@ bool ConvertPrepareAdapt(const ResourcePtr &res);
|
||||||
bool AddControlDependPass(const ResourcePtr &res);
|
bool AddControlDependPass(const ResourcePtr &res);
|
||||||
bool InferenceOptPreparePass(const ResourcePtr &res);
|
bool InferenceOptPreparePass(const ResourcePtr &res);
|
||||||
void ReclaimOptimizer();
|
void ReclaimOptimizer();
|
||||||
|
bool PynativeOptPass(const ResourcePtr &res);
|
||||||
} // namespace pipeline
|
} // namespace pipeline
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
||||||
|
|
|
@ -206,7 +206,7 @@ AbstractBasePtr AnalysisEngine::EvalValueNode(const ValueNodePtr &value_node, co
|
||||||
MS_EXCEPTION_IF_NULL(conf);
|
MS_EXCEPTION_IF_NULL(conf);
|
||||||
MS_EXCEPTION_IF_NULL(value_node);
|
MS_EXCEPTION_IF_NULL(value_node);
|
||||||
auto out = ToAbstract(value_node->value(), conf->context(), conf);
|
auto out = ToAbstract(value_node->value(), conf->context(), conf);
|
||||||
if (value_node->has_new_value()) {
|
if (value_node->has_new_value() && out->isa<AbstractTensor>()) {
|
||||||
out = out->Broaden();
|
out = out->Broaden();
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
|
|
|
@ -59,6 +59,8 @@
|
||||||
#include "pipeline/pynative/pynative_execute_ge.h"
|
#include "pipeline/pynative/pynative_execute_ge.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include "debug/anf_ir_dump.h"
|
||||||
|
|
||||||
using mindspore::tensor::TensorPy;
|
using mindspore::tensor::TensorPy;
|
||||||
|
|
||||||
const char SINGLE_OP_GRAPH[] = "single_op_graph";
|
const char SINGLE_OP_GRAPH[] = "single_op_graph";
|
||||||
|
@ -780,19 +782,79 @@ void PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, const py::ob
|
||||||
set_pyobj(curr_g_, obj_id);
|
set_pyobj(curr_g_, obj_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
void PynativeExecutor::SaveOpForwardValue(const std::string &id, const ValuePtr &value) {
|
void GenTupleMap(const ValueTuplePtr &tuple, std::map<std::string, tensor::TensorPtr> *t_map) {
|
||||||
auto iter = op_forward_map_.find(id);
|
if (t_map == nullptr) {
|
||||||
if (iter != op_forward_map_.end()) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto tuple_info_iter = obj_to_forward_id_tuple_info_.find(id);
|
for (size_t i = 0; i < tuple->size(); i++) {
|
||||||
ValuePtr temp_value = value;
|
ValuePtr tuple_i = (*tuple)[i];
|
||||||
if (tuple_info_iter != obj_to_forward_id_tuple_info_.end()) {
|
if (tuple_i->isa<tensor::Tensor>()) {
|
||||||
temp_value = tuple_info_iter->second;
|
auto t = tuple_i->cast<tensor::TensorPtr>();
|
||||||
|
(*t_map)[t->id()] = t;
|
||||||
|
} else if (tuple_i->isa<ValueTuple>()) {
|
||||||
|
GenTupleMap(tuple_i->cast<ValueTuplePtr>(), t_map);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "End GenTupleMap" << tuple->ToString();
|
||||||
|
}
|
||||||
|
|
||||||
|
ValuePtr CleanTupleAddr(const ValueTuplePtr &tuple) {
|
||||||
|
std::vector<ValuePtr> value_list;
|
||||||
|
for (size_t i = 0; i < tuple->size(); i++) {
|
||||||
|
ValuePtr tuple_i = (*tuple)[i];
|
||||||
|
if (tuple_i->isa<tensor::Tensor>()) {
|
||||||
|
auto t = tuple_i->cast<tensor::TensorPtr>();
|
||||||
|
auto new_tensor = std::make_shared<tensor::Tensor>(*t);
|
||||||
|
new_tensor->set_device_address(nullptr);
|
||||||
|
value_list.push_back(new_tensor);
|
||||||
|
} else if (tuple_i->isa<ValueTuple>()) {
|
||||||
|
value_list.push_back(CleanTupleAddr(tuple_i->cast<ValueTuplePtr>()));
|
||||||
|
} else {
|
||||||
|
MS_LOG(DEBUG) << "in value" << tuple_i->ToString();
|
||||||
|
value_list.push_back(tuple_i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "End CleanTupleAddr";
|
||||||
|
return std::make_shared<ValueTuple>(value_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
void PynativeExecutor::SaveOpForwardValue(const std::string &id, const ValuePtr &value,
|
||||||
|
std::map<std::string, tensor::TensorPtr> *t_map) {
|
||||||
|
if (op_forward_map_.find(id) != op_forward_map_.end()) {
|
||||||
|
if (op_forward_map_[id]->isa<ValueTuple>()) {
|
||||||
|
// for one op have multi outputs but save only one tensor
|
||||||
|
if (value->isa<tensor::Tensor>()) {
|
||||||
|
auto tuple = op_forward_map_[id]->cast<ValueTuplePtr>();
|
||||||
|
auto value_t = value->cast<tensor::TensorPtr>();
|
||||||
|
for (size_t i = 0; i < tuple->size(); i++) {
|
||||||
|
if ((*tuple)[i]->isa<tensor::Tensor>()) {
|
||||||
|
auto tuple_t = (*tuple)[i]->cast<tensor::TensorPtr>();
|
||||||
|
if (value_t->id() == tuple_t->id()) {
|
||||||
|
tuple_t->set_device_address(value_t->device_address());
|
||||||
|
MS_LOG(DEBUG) << "After Saveop " << tuple_t->ToString();
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (value->isa<ValueTuple>() && t_map != nullptr) {
|
||||||
|
GenTupleMap(op_forward_map_[id]->cast<ValueTuplePtr>(), t_map);
|
||||||
|
}
|
||||||
|
MS_LOG(DEBUG) << "Save op forward value: "
|
||||||
|
<< "(" << id << "), " << op_forward_map_[id]->ToString();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (value->isa<ValueTuple>() && t_map == nullptr) {
|
||||||
|
// make cnode gen all tuple node and set device_address be null
|
||||||
|
op_forward_map_[id] = CleanTupleAddr(value->cast<ValueTuplePtr>());
|
||||||
|
} else {
|
||||||
|
op_forward_map_[id] = value;
|
||||||
}
|
}
|
||||||
op_forward_map_[id] = temp_value;
|
|
||||||
MS_LOG(DEBUG) << "Save op forward value: "
|
MS_LOG(DEBUG) << "Save op forward value: "
|
||||||
<< "(" << id << "), " << temp_value;
|
<< "(" << id << "), " << value->ToString();
|
||||||
}
|
}
|
||||||
|
|
||||||
void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out) {
|
void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out) {
|
||||||
|
@ -828,7 +890,7 @@ void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CN
|
||||||
auto tuple_item_id = GetId(tuple_item[i]);
|
auto tuple_item_id = GetId(tuple_item[i]);
|
||||||
obj_to_forward_id_[tuple_item_id] = op_id;
|
obj_to_forward_id_[tuple_item_id] = op_id;
|
||||||
}
|
}
|
||||||
obj_to_forward_id_tuple_info_[op_id] = value;
|
SaveOpForwardValue(op_id, value, nullptr);
|
||||||
}
|
}
|
||||||
obj_to_forward_id_[out_id] = op_id;
|
obj_to_forward_id_[out_id] = op_id;
|
||||||
}
|
}
|
||||||
|
@ -840,12 +902,24 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
|
||||||
if (out.second.size() == 1 && out.second[0] == -1) {
|
if (out.second.size() == 1 && out.second[0] == -1) {
|
||||||
return out.first;
|
return out.first;
|
||||||
}
|
}
|
||||||
auto node = out.first;
|
CNodePtr node = out.first->cast<CNodePtr>();
|
||||||
MS_LOG(DEBUG) << "output size " << out.second.size() << node->DebugString();
|
MS_LOG(DEBUG) << "output size " << out.second.size() << node->DebugString();
|
||||||
auto abs = node->abstract();
|
auto abs = node->abstract();
|
||||||
|
ValuePtr out_obj = nullptr;
|
||||||
|
if (node->forward().first != nullptr) {
|
||||||
|
out_obj = node->forward().first;
|
||||||
|
} else {
|
||||||
|
out_obj = PyAttrValue(obj);
|
||||||
|
}
|
||||||
for (auto &idx : out.second) {
|
for (auto &idx : out.second) {
|
||||||
std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(idx)};
|
std::vector<AnfNodePtr> tuple_get_item_inputs{NewValueNode(prim::kPrimTupleGetItem), node, NewValueNode(idx)};
|
||||||
node = curr_g_->NewCNode(tuple_get_item_inputs);
|
node = curr_g_->NewCNode(tuple_get_item_inputs);
|
||||||
|
if (out_obj->isa<ValueTuple>()) {
|
||||||
|
node->add_input_value(out_obj, "");
|
||||||
|
node->add_input_value(MakeValue(idx), "");
|
||||||
|
out_obj = (*out_obj->cast<ValueTuplePtr>())[idx];
|
||||||
|
node->set_forward(out_obj, "");
|
||||||
|
}
|
||||||
if (abs != nullptr && abs->isa<abstract::AbstractTuple>()) {
|
if (abs != nullptr && abs->isa<abstract::AbstractTuple>()) {
|
||||||
auto prim_abs = dyn_cast<abstract::AbstractTuple>(abs)->elements()[idx];
|
auto prim_abs = dyn_cast<abstract::AbstractTuple>(abs)->elements()[idx];
|
||||||
MS_LOG(DEBUG) << "set tuple getitem abs" << prim_abs->ToString();
|
MS_LOG(DEBUG) << "set tuple getitem abs" << prim_abs->ToString();
|
||||||
|
@ -856,7 +930,6 @@ AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj) {
|
||||||
node_abs_map_[id] = node->abstract();
|
node_abs_map_[id] = node->abstract();
|
||||||
}
|
}
|
||||||
MS_LOG(DEBUG) << "GetObjNode output" << node->DebugString(6);
|
MS_LOG(DEBUG) << "GetObjNode output" << node->DebugString(6);
|
||||||
node->cast<CNodePtr>()->set_forward(PyAttrValue(obj), "");
|
|
||||||
return node;
|
return node;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1318,7 +1391,13 @@ void PynativeExecutor::EndGraphByOutId(const std::string &out_id, const py::obje
|
||||||
}
|
}
|
||||||
set_obj_node_map(graph_prev, GetId(out), out_cnode);
|
set_obj_node_map(graph_prev, GetId(out), out_cnode);
|
||||||
} else {
|
} else {
|
||||||
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||||
|
DumpIR("before_resolve.ir", newfg);
|
||||||
|
}
|
||||||
parse::ResolveFuncGraph(newfg, resource_);
|
parse::ResolveFuncGraph(newfg, resource_);
|
||||||
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||||
|
DumpIR("after_resolve.ir", newfg);
|
||||||
|
}
|
||||||
resource_->set_func_graph(newfg);
|
resource_->set_func_graph(newfg);
|
||||||
Popp();
|
Popp();
|
||||||
}
|
}
|
||||||
|
@ -1438,7 +1517,13 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje
|
||||||
MS_LOG(EXCEPTION) << "Could not find top graph by cellid: " << forward_cell_id;
|
MS_LOG(EXCEPTION) << "Could not find top graph by cellid: " << forward_cell_id;
|
||||||
}
|
}
|
||||||
top_g_ = cell_graph_map_[forward_cell_id];
|
top_g_ = cell_graph_map_[forward_cell_id];
|
||||||
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||||
|
DumpIR("before_grad.ir", resource_->func_graph());
|
||||||
|
}
|
||||||
auto g = GradGraph(resource_->func_graph(), grad, w_args, size);
|
auto g = GradGraph(resource_->func_graph(), grad, w_args, size);
|
||||||
|
if (MsContext::GetInstance()->get_param<bool>(MS_CTX_SAVE_GRAPHS_FLAG)) {
|
||||||
|
DumpIR("after_grad.ir", g);
|
||||||
|
}
|
||||||
resource_->set_func_graph(g);
|
resource_->set_func_graph(g);
|
||||||
resource_->manager()->KeepRoots({g});
|
resource_->manager()->KeepRoots({g});
|
||||||
|
|
||||||
|
|
|
@ -25,6 +25,7 @@
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <stack>
|
#include <stack>
|
||||||
#include <set>
|
#include <set>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
#include "pybind11/pybind11.h"
|
#include "pybind11/pybind11.h"
|
||||||
#include "pybind11/numpy.h"
|
#include "pybind11/numpy.h"
|
||||||
|
@ -121,7 +122,8 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
||||||
abstract::AbstractBasePtrList *args_spec_list);
|
abstract::AbstractBasePtrList *args_spec_list);
|
||||||
void MakeCNode(const OpExecInfoPtr &op_exec_info, const py::object &out, const AnfNodePtr &cnode);
|
void MakeCNode(const OpExecInfoPtr &op_exec_info, const py::object &out, const AnfNodePtr &cnode);
|
||||||
ValuePtr GetForwardValue(const OpExecInfoPtr &op_exec_info);
|
ValuePtr GetForwardValue(const OpExecInfoPtr &op_exec_info);
|
||||||
void SaveOpForwardValue(const std::string &id, const ValuePtr &value);
|
void SaveOpForwardValue(const std::string &id, const ValuePtr &value,
|
||||||
|
std::map<std::string, tensor::TensorPtr> *t_map);
|
||||||
void SaveForwardResult(const CNodePtr &cnode, const py::object &out);
|
void SaveForwardResult(const CNodePtr &cnode, const py::object &out);
|
||||||
void SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out);
|
void SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out);
|
||||||
|
|
||||||
|
@ -154,7 +156,6 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> {
|
||||||
std::unordered_map<std::string, ValuePtr> op_forward_map_;
|
std::unordered_map<std::string, ValuePtr> op_forward_map_;
|
||||||
std::unordered_map<std::string, size_t> op_id_map_;
|
std::unordered_map<std::string, size_t> op_id_map_;
|
||||||
std::unordered_map<std::string, std::string> obj_to_forward_id_;
|
std::unordered_map<std::string, std::string> obj_to_forward_id_;
|
||||||
std::unordered_map<std::string, ValuePtr> obj_to_forward_id_tuple_info_;
|
|
||||||
std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
|
std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_;
|
||||||
std::unordered_map<std::string, FuncGraphPtr> df_builder_map_;
|
std::unordered_map<std::string, FuncGraphPtr> df_builder_map_;
|
||||||
// the stack that records the context of graph created, the bottom is the top graph
|
// the stack that records the context of graph created, the bottom is the top graph
|
||||||
|
|
|
@ -85,6 +85,8 @@ void KernelRuntime::RunOpAssignMemory(const ValuePtr &pre_output_value,
|
||||||
const std::vector<tensor::TensorPtr> &input_tensors,
|
const std::vector<tensor::TensorPtr> &input_tensors,
|
||||||
session::KernelGraph *graph) {
|
session::KernelGraph *graph) {
|
||||||
MS_EXCEPTION_IF_NULL(graph);
|
MS_EXCEPTION_IF_NULL(graph);
|
||||||
|
MS_EXCEPTION_IF_NULL(mem_manager_);
|
||||||
|
mem_manager_->ResetDynamicMemory();
|
||||||
RunOpAssignInputMemory(input_tensors, graph);
|
RunOpAssignInputMemory(input_tensors, graph);
|
||||||
AssignStaticMemoryValueNode(graph);
|
AssignStaticMemoryValueNode(graph);
|
||||||
RunOpAssignOutputNodeMemory(pre_output_value, graph);
|
RunOpAssignOutputNodeMemory(pre_output_value, graph);
|
||||||
|
@ -268,7 +270,8 @@ void KernelRuntime::RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value
|
||||||
MS_EXCEPTION_IF_NULL(real_output_cnode);
|
MS_EXCEPTION_IF_NULL(real_output_cnode);
|
||||||
MS_EXCEPTION_IF_NULL(pre_output_tensors[i]);
|
MS_EXCEPTION_IF_NULL(pre_output_tensors[i]);
|
||||||
if (pre_output_tensors[i]->device_address() == nullptr) {
|
if (pre_output_tensors[i]->device_address() == nullptr) {
|
||||||
MS_LOG(EXCEPTION) << "The address of pre output tensor [" << i << "] is a nullptr!";
|
MS_LOG(INFO) << "The address of pre output tensor [" << i << "] is a nullptr!";
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
if (opt::IsNopNode(real_output_cnode)) {
|
if (opt::IsNopNode(real_output_cnode)) {
|
||||||
if (real_output_cnode->inputs().size() < 2) {
|
if (real_output_cnode->inputs().size() < 2) {
|
||||||
|
|
|
@ -155,7 +155,7 @@ def test_softmaxloss_grad():
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
self.weight = Parameter(Tensor(np.ones([64, 10])), name="weight")
|
self.weight = Parameter(Tensor(np.ones([64, 10]).astype(np.float32)), name="weight")
|
||||||
self.bias = Parameter(Tensor(np.ones([10]).astype(np.float32)), name="bias")
|
self.bias = Parameter(Tensor(np.ones([10]).astype(np.float32)), name="bias")
|
||||||
self.fc = P.MatMul()
|
self.fc = P.MatMul()
|
||||||
self.fc2 = nn.Dense(10, 10)
|
self.fc2 = nn.Dense(10, 10)
|
||||||
|
@ -175,7 +175,7 @@ def test_softmaxloss_grad():
|
||||||
|
|
||||||
net = GradWrap(NetWithLossClass(Net()))
|
net = GradWrap(NetWithLossClass(Net()))
|
||||||
|
|
||||||
predict = Tensor(np.ones([1, 64]))
|
predict = Tensor(np.ones([1, 64]).astype(np.float32))
|
||||||
label = Tensor(np.zeros([1, 10]).astype(np.float32))
|
label = Tensor(np.zeros([1, 10]).astype(np.float32))
|
||||||
print("pynative run")
|
print("pynative run")
|
||||||
out = net(predict, label)
|
out = net(predict, label)
|
||||||
|
|
Loading…
Reference in New Issue