!68657 rebase master

Merge pull request !68657 from luochao60/Pynative_rebase_master_20240425
This commit is contained in:
i-robot 2024-04-27 00:22:04 +00:00 committed by Gitee
commit 915305f3f8
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
14 changed files with 99 additions and 17 deletions

View File

@ -591,7 +591,7 @@ NodePtr CtrlFlowBlock::IfThenElse(const NodePtr &cond, const BlockFunc &true_cas
auto fb = BuildSubgraph(false_case);
auto s = emitter_->Emit("Switch", {cond, tb, fb});
auto cnode = func_graph_->NewCNode({s->get()});
auto cnode = func_graph_->FuncGraph::NewCNode({s->get()});
cnode->set_abstract(out_abstract_);
auto node = emitter_->NewIrNode(cnode->cast<AnfNodePtr>());
return node;
@ -676,7 +676,7 @@ NodePtr CtrlFlowBlock::While(const NodePtr &cond, const BlockFunc &while_body_fu
cnode->set_abstract(out_abstract_);
while_fg->set_output(cnode);
auto main_cnode = func_graph_->NewCNode(main_while_fg_inputs);
auto main_cnode = func_graph_->FuncGraph::NewCNode(main_while_fg_inputs);
main_cnode->set_abstract(out_abstract_);
return emitter_->NewIrNode(main_cnode);
}

View File

@ -21,6 +21,7 @@
#include "ops/sequence_ops.h"
#include "ops/array_ops.h"
#include "ops/framework_ops.h"
#include "abstract/ops/primitive_infer_map.h"
#include "include/common/expander/core/infer.h"
#include "include/common/profiler.h"
@ -235,6 +236,11 @@ class PynativeIRBuilderWithCache : public PynativeIRBuilder {
return output_nodes;
}
}
for (auto &node_pair : bprop_nodes_) {
if (IsPrimitiveCNode(node_pair.first->get(), prim::kPrimSwitch)) {
return output_nodes;
}
}
bprop_map[abs_list] = BuildBpropOpGraph(input_nodes, output_nodes);
} else {
need_infer_ = false;

View File

@ -282,11 +282,13 @@ NodePtr IrBuilder::EmitValue(const ValuePtr &value) {
NodePtr IrBuilder::Conditional(const NodePtr &cond, const BlockFunc &true_case, const BlockFunc &false_case) {
CtrlFlowBlock cfb(this, this->func_graph());
this->func_graph()->set_flag(kFlagIsControlFlow, true);
return cfb.IfThenElse(cond, true_case, false_case);
}
NodePtr IrBuilder::While(const NodePtr &cond, const BlockFunc &body, const NodePtrList &init_list) {
CtrlFlowBlock cfb(this, this->func_graph());
this->func_graph()->set_flag(kFlagIsControlFlow, true);
return cfb.While(cond, body, init_list);
}
} // namespace bprop

View File

@ -134,10 +134,14 @@ class COMMON_EXPORT Emitter {
if (abs->isa<abstract::AbstractTensor>()) {
return CmpOpWithCast(kEqualOpName, lhs, rhs, dst_type);
} else if (abs->isa<abstract::AbstractScalar>()) {
return CmpOpWithCast("ScalarEq", lhs, rhs, dst_type);
return ScalarEq(lhs, rhs, dst_type);
}
MS_LOG(EXCEPTION) << "'Equal' only support [Tensor] or [Scalar] input, but got: " << abs->ToString();
}
virtual NodePtr ScalarEq(const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type) {
auto node = UnifyDtypeAndEmit("ScalarEq", lhs, rhs);
return dst_type == nullptr ? node : Cast(node, dst_type);
}
NodePtr NotEqual(const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type = nullptr) {
return CmpOpWithCast("NotEqual", lhs, rhs, dst_type);
}

View File

@ -412,6 +412,7 @@ constexpr auto kFlagPyNativeWithJitCallGraph = "pynative_with_jit_call_graph";
constexpr auto kFlagJitCallGraph = "jit_call_graph";
constexpr auto kFlagJitGraph = "jit_graph";
constexpr auto kFlagSwitchInline = "switch_inline_graph";
constexpr auto kFlagIsControlFlow = "is_control_flow";
// custom operator func type
constexpr auto kCustomTypeAOT = "aot";

View File

@ -28,6 +28,7 @@
#include "include/backend/optimizer/op_adaptation_info_factory.h"
#include "pipeline/pynative/pynative_utils.h"
#include "mindspore/core/ops/op_utils.h"
#include "frontend/operator/cc_implementations.h"
namespace mindspore::pynative::autograd {
namespace {
@ -113,6 +114,24 @@ std::vector<int64_t> BuildShape(const abstract::AbstractBasePtr &abs) {
MS_EXCEPTION_IF_NULL(shape);
return shape->shape();
}
bool ParseCond(const NodePtr &cond) {
auto cond_val = cond->Value();
if (cond_val->isa<BoolImm>()) {
return GetValue<bool>(cond_val);
} else if (cond_val->isa<tensor::Tensor>()) {
auto tensor = cond_val->cast<tensor::TensorPtr>();
tensor->data_sync();
size_t data_size = tensor->DataSize();
auto tensor_type = tensor->Dtype();
if (tensor_type->type_id() == kNumberTypeBool) {
auto data_c = reinterpret_cast<bool *>(tensor->data_c());
MS_EXCEPTION_IF_NULL(data_c);
return std::all_of(data_c, data_c + data_size, [](const bool &data) { return static_cast<bool>(data); });
}
}
MS_LOG(EXCEPTION) << "For control flow, the cond should be Tensor[bool] or bool, but got: " << cond_val->ToString();
}
} // namespace
NodePtr FuncBuilder::EmitOp(const PrimitivePtr &prim, const NodePtrList &inputs) {
@ -241,6 +260,33 @@ NodePtr FuncBuilder::MakeTuple(const NodePtrList &inputs) {
NodePtr FuncBuilder::MakeList(const NodePtrList &inputs) { return MakeTuple(inputs); }
NodePtr FuncBuilder::Conditional(const NodePtr &cond, const expander::Emitter::BlockFunc &true_case,
const expander::Emitter::BlockFunc &false_case) {
NodePtrList result;
if (ParseCond(cond)) {
result = true_case(this);
} else {
result = false_case(this);
}
if (result.size() == kSizeOne) {
return result[kIndex0];
}
return MakeTuple(result);
}
NodePtr FuncBuilder::ScalarEq(const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type) {
auto lhs_val = lhs->Value();
auto rhs_val = rhs->Value();
ValuePtr result;
if (lhs_val->isa<BoolImm>() && rhs_val->isa<BoolImm>()) {
result = MakeValue(GetValue<bool>(lhs_val) == GetValue<bool>(rhs_val));
} else {
result = prim::ScalarEq({lhs->Value(), rhs->Value()});
}
MS_LOG(DEBUG) << "ScalarEq op: lhs " << lhs_val->ToString() << ", rhs " << rhs_val->ToString();
return NewFuncNode(result, nullptr, InputType::kOpOutput);
}
void FuncBuilder::SetInputs(std::string instance_name, const std::vector<NodePtr> *inputs,
mindspore::HashMap<std::string, ValuePtr> *attrs_ptr) {
instance_name_ = std::move(instance_name);

View File

@ -50,8 +50,9 @@ class FuncBuilder : public BpropBuilder {
NodePtr TupleGetItem(const NodePtr &input, const NodePtr &index) override;
NodePtr MakeTuple(const NodePtrList &inputs) override;
NodePtr MakeList(const NodePtrList &inputs) override;
NodePtr Conditional(const NodePtr &cond, const BlockFunc &true_case, const BlockFunc &false_case) override;
NodePtr ScalarEq(const NodePtr &lhs, const NodePtr &rhs, const TypePtr &dst_type) override;
NodePtr OutZeros(const NodePtr &node) override;
ValuePtr Ones(const ValuePtr &value);
ValuePtr Zeros(const ValuePtr &value);
ValuePtr Add(const ValuePtr &input, const ValuePtr &other);

View File

@ -1247,6 +1247,9 @@ FuncGraphPtr GradExecutor::GetBpropGraph(const autograd::GradAttr &grad_attr,
} else {
top_cell()->resource()->set_optimize_graph(bprop_graph);
}
if (bprop_graph->has_flag(kFlagIsControlFlow)) {
top_cell()->set_has_control_flow(true);
}
if (top_cell()->has_control_flow()) {
bprop_graph = LiftingClone(bprop_graph);
}
@ -1263,7 +1266,9 @@ FuncGraphPtr GradExecutor::GetBpropGraph(const autograd::GradAttr &grad_attr,
// Update run graph by single op flag. Has two scenario:
// 1. Dynamic shape(or structure) or Dynamic structure
// 2. Has bprop cut op
bprop_graph->set_flag(kFlagEnableRunGraphBySingleOp, auto_grad_cell->bprop_graph_run_by_single_op());
// If set_inputs, but has constrol flow, we need run by actor.
bprop_graph->set_flag(kFlagEnableRunGraphBySingleOp,
auto_grad_cell->bprop_graph_run_by_single_op() && !top_cell()->has_control_flow());
if (top_cell()->has_call_graph()) {
bprop_graph->set_flag(kFlagPyNativeWithJitCallGraph, true);
}

View File

@ -239,6 +239,7 @@ AnfNodePtr IrBprop::MapParameter(const ValuePtr &value, const abstract::Abstract
param->set_abstract(abs);
return param;
}
bprop_graph_run_by_single_op_ = auto_grad_meta_data->is_register_hook();
if (auto_grad_meta_data->input_type() == InputType::kParameter) {
return AddParameterNode(tensor, abs);
}

View File

@ -98,7 +98,7 @@ class IrBprop {
AdParamPtr ad_param() const { return ad_param_; }
inline bool bprop_graph_run_by_single_op() { return bprop_graph_run_by_single_op_; }
void set_bprop_graph_run_by_single_op(bool bprop_graph_run_by_single_op) {
inline void set_bprop_graph_run_by_single_op(bool bprop_graph_run_by_single_op) {
bprop_graph_run_by_single_op_ = bprop_graph_run_by_single_op_ || bprop_graph_run_by_single_op;
}

View File

@ -436,6 +436,18 @@ CNodePtr IrGrad::ConstructBpropGraphInput(const GradParamPtr &grad_param, const
(void)node_list.emplace_back(PyNativeAlgo::Common::CreateValueNodeByValue(
grad_param->op_grad_info->input_value[i], grad_param->op_grad_info->input_abs[i]->Clone()));
}
// Hook run by single op
if (!ir_bprop_->bprop_graph_run_by_single_op()) {
ir_bprop()->set_bprop_graph_run_by_single_op([&grad_param]() {
auto tensor = grad_param->op_grad_info->out_value->template cast<tensor::BaseTensorPtr>();
if (tensor == nullptr) {
return false;
}
auto auto_grad_meta = tensor->auto_grad_meta_data();
MS_EXCEPTION_IF_NULL(auto_grad_meta);
return auto_grad_meta->is_register_hook();
}());
}
// Set out
(void)node_list.emplace_back(PyNativeAlgo::Common::CreateValueNodeByValue(grad_param->op_grad_info->out_value,
grad_param->op_grad_info->out_abs));
@ -669,7 +681,7 @@ AnfNodePtr IrGrad::GetInputGrad(bool grad_all_inputs, bool get_by_position, cons
if (index >= cell_inputs_.size()) {
MS_LOG(EXCEPTION) << "Position index " << index << " is exceed input size.";
}
// Tuple, List, scalar will be ignore
// Tuple, List, scalar will be ignored
if (!IsValidTensorInput(cell_inputs_[index].first->abstract())) {
MS_LOG(DEBUG) << "Get input node is not tensor "
<< ", abs " << cell_inputs_[index].first->abstract()->ToString();
@ -827,7 +839,7 @@ void IrGrad::DoParameterReplaceByUser(bool has_sens_arg, expander::bprop::UserTy
void IrGrad::ReplacePrimalParameter(bool has_sens_arg) {
PyNativeAlgo::Common::DumpGraphIR("replace_param.ir", ad_param()->tape_);
if (need_do_manager_replace_) {
if (need_do_manager_replace_ || ad_param()->tape_->has_flag(kFlagIsControlFlow)) {
MS_LOG(DEBUG) << "Do parameter replace by manager.";
DoParameterReplaceByManager(has_sens_arg);
need_do_manager_replace_ = false;

View File

@ -492,6 +492,7 @@ void IrPassForward::ConvertMakeTupleInputToDynamicInput(const AnfNodePtr &node,
// Pyboost op no need plant tuple inputs
auto prim = GetCNodePrimitive(cnode);
MS_EXCEPTION_IF_NULL(prim);
MS_LOG(DEBUG) << "Get run by single op " << run_by_single_op;
if (run_by_single_op && runtime::PyBoostOpExecute::GetInstance().IsPyBoostOpRegistered(prim->name())) {
cnode->AddAttr(kAttrIsPyboostTupleInput, MakeValue(true));
return;

View File

@ -1510,15 +1510,17 @@ FrontendOpRunInfoPtr PyBoost::Init(const PrimitivePtr &prim, const py::list &arg
return op_run_info;
}
void PyBoost::MakeOutputValue(const FrontendOpRunInfoPtr &op_run_info, const std::vector<TensorPtr> &outputs) {
size_t size = outputs.size();
void PyBoost::MakeOutputValue(const FrontendOpRunInfoPtr &op_run_info, const kernel::pyboost::OpPtr &op) {
size_t size = op->outputs().size();
if (size == kSizeOne) {
op_run_info->real_out = outputs[0];
return;
if (op->output_abs() != nullptr && !op->output_abs()->isa<abstract::AbstractSequence>()) {
op_run_info->real_out = op->outputs()[0];
return;
}
}
std::vector<ValuePtr> output_values(size);
for (size_t i = 0; i < size; ++i) {
const auto &output_tensor = outputs[i];
const auto &output_tensor = op->outputs()[i];
MS_EXCEPTION_IF_NULL(output_tensor);
output_values[i] = output_tensor;
}
@ -1556,7 +1558,7 @@ void PyBoost::UpdateOpRunInfo(const kernel::pyboost::OpPtr &op, const vector<Val
MS_EXCEPTION_IF_NULL(op);
MS_EXCEPTION_IF_NULL(op_run_info);
// Set result to python
MakeOutputValue(op_run_info, op->outputs());
MakeOutputValue(op_run_info, op);
UpdateStubOutput(op_run_info, op->output_abs());
// Update op run info for auto grad

View File

@ -85,7 +85,8 @@ struct Common {
static tensor::TensorPtr ConvertToContiguousTensor(const tensor::TensorPtr &tensor);
static ValuePtr CreateTensorByConstantValue(const ValuePtr &value);
template <typename T>
static std::string PrintDebugInfo(std::vector<T> items, const std::string &info_header = "") {
static std::string PrintDebugInfo(std::vector<T> items, const std::string &info_header = "",
bool is_print_tensor_data = false) {
static constexpr size_t end_char_size = 2;
std::ostringstream buf;
buf << info_header;
@ -94,7 +95,7 @@ struct Common {
MS_LOG(DEBUG) << "The " << i << "'th item is nullptr!";
continue;
}
if (items[i]->template isa<tensor::Tensor>()) {
if (items[i]->template isa<tensor::Tensor>() && is_print_tensor_data) {
auto tensor = items[i]->template cast<tensor::TensorPtr>();
auto grad = std::make_shared<tensor::Tensor>(*tensor);
grad->data_sync();
@ -161,7 +162,7 @@ struct DataConvert {
struct PyBoost {
static FrontendOpRunInfoPtr Init(const PrimitivePtr &prim, const py::list &args);
static void DoGrad(const FrontendOpRunInfoPtr &op_run_info);
static void MakeOutputValue(const FrontendOpRunInfoPtr &op_run_info, const std::vector<TensorPtr> &outputs);
static void MakeOutputValue(const FrontendOpRunInfoPtr &op_run_info, const kernel::pyboost::OpPtr &op);
static void UpdateOutputTensorGradInfo(const std::vector<TensorPtr> &outputs);
static void UpdateStubOutput(const FrontendOpRunInfoPtr &op_run_info, const AbstractBasePtr &abstract);
static void UpdateOpRunInfo(const kernel::pyboost::OpPtr &op, const vector<ValuePtr> &op_inputs,