!18172 unify runtime for PyNative supports bp hook and custom bp operator

Merge pull request !18172 from zyli2020/fix_issue_defect
This commit is contained in:
i-robot 2021-06-16 14:17:31 +08:00 committed by Gitee
commit 19b87fe35e
11 changed files with 153 additions and 12 deletions

View File

@ -2129,5 +2129,15 @@ bool AnfRuntimeAlgorithm::IsOneOfPrimitiveCNode(const AnfNodePtr &node, const Pr
}
return IsOneOfPrimitive(cnode->inputs().at(kAnfPrimitiveIndex), prim_set);
}
bool AnfRuntimeAlgorithm::IsControlOpExecInBackend(const AnfNodePtr &node) {
if (!node->isa<CNode>()) {
return false;
}
// Operators in set control_ops_exec_in_backend will be compiled into kernel graph, rather than be cut into single op
// and executed in VM.
static std::set<std::string> control_ops_exec_in_backend = {kBpropCutOpName};
return control_ops_exec_in_backend.find(AnfAlgo::GetCNodeName(node)) != control_ops_exec_in_backend.end();
}
} // namespace session
} // namespace mindspore

View File

@ -298,6 +298,11 @@ class AnfRuntimeAlgorithm {
return result;
}
static bool IsOneOfPrimitiveCNode(const AnfNodePtr &node, const PrimitiveSet &prim_set);
// Judge a control operator need be compiled into kernel graph rather than be cut into single op and
// executed in vm. For example, the operator "bprop_cut" will be compiled into kernel graph and be launch
// in backend in PyNative mode.
static bool IsControlOpExecInBackend(const AnfNodePtr &node);
};
} // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm;

View File

@ -424,6 +424,23 @@ BaseRef CreateNodeOutputPlaceholder(const AnfNodePtr &anf, const KernelGraphPtr
return CreateNodeOutputPlaceholder(item_with_index, graph, input_tensors, indexes, output_indexes);
}
void CheckInputTensorShape(const TensorPtr &tensor, const CNodePtr &kernel, size_t input_index) {
const auto &tensor_shape = tensor->shape();
const auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel, input_index);
if (tensor_shape.size() != input_shape.size()) {
MS_LOG(EXCEPTION) << "The input tensor's shape size: " << tensor_shape.size()
<< " is not equal to expected size: " << input_shape.size() << " for input[" << input_index
<< "] of kernel: " << AnfAlgo::GetCNodeName(kernel);
}
for (size_t i = 0; i < tensor_shape.size(); i++) {
if (tensor_shape[i] < 0 || static_cast<size_t>(tensor_shape[i]) != input_shape[i]) {
MS_LOG(EXCEPTION) << "The input tensor's shape: " << tensor_shape
<< " is not equal to expected shape: " << input_shape << " for input[" << input_index
<< "] of kernel: " << AnfAlgo::GetCNodeName(kernel);
}
}
}
} // namespace
GraphId SessionBasic::graph_sum_ = 0;
@ -1380,6 +1397,9 @@ void SessionBasic::GetOpInputTensors(const CNodePtr &cnode,
tensor = GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
} else if (real_input->isa<CNode>()) {
tensor = GetCNodeOutputTensor(kernel_with_index, op_output);
if (AnfAlgo::IsControlOpExecInBackend(real_input)) {
CheckInputTensorShape(tensor, cnode, i - 1);
}
input_tensor_info->input_kernel.insert(kernel_with_index);
} else {
MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString();

View File

@ -1673,7 +1673,7 @@ py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynativ
auto ms_context = MsContext::GetInstance();
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, true);
if (kSession == nullptr) {
if (kSession == nullptr && !IsMindRTUsed()) {
std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
kSession = session::SessionFactory::Get().Create(device_target);
MS_EXCEPTION_IF_NULL(kSession);

View File

@ -54,7 +54,7 @@ void CreateGPUKernel(const std::vector<CNodePtr> &kernels) {
}
}
akg_nodes.push_back(kernel);
} else {
} else if (!AnfAlgo::IsControlOpExecInBackend(kernel)) {
auto gpu_kernel_ptr = kernel::GpuKernelFactory::GetInstance().Create(kernel_name, kernel);
if (!gpu_kernel_ptr) {
MS_LOG(EXCEPTION) << "Build gpu kernel op[" << kernel->fullname_with_scope() << "] failed";

View File

@ -424,14 +424,14 @@ void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
if (!result) {
result = kernel::GpuKernelFactory::GetInstance().ReducePrecision(AnfAlgo::GetCNodeName(kernel_node), builder);
}
if (!result) {
if (!result && (!AnfAlgo::IsControlOpExecInBackend(kernel_node))) {
result = SelectAkgKernel(kernel_node, builder->Build());
kernel_type = AKG_KERNEL;
}
} else if (kernel_type == AKG_KERNEL) {
result = SelectAkgKernel(kernel_node, builder->Build());
}
if (!result) {
if (!result && (!AnfAlgo::IsControlOpExecInBackend(kernel_node))) {
PrintUnsupportedTypeException(kernel_node, inputs_type, outputs_type);
return;
}

View File

@ -289,8 +289,12 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
// 'KernelMod' is real executive object of kernel.
device_context->CreateKernel(graph->execution_order());
// Create device address for all anf nodes of graph.
CreateDeviceAddress(graph, device_context);
const auto &ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
// Create device address for all anf nodes of graph.
CreateDeviceAddress(graph, device_context);
}
graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));

View File

@ -193,6 +193,7 @@ constexpr auto kCallOpName = "call";
constexpr auto kPartialOpName = "partial";
constexpr auto kSwitchOpName = "Switch";
constexpr auto kReturnOpName = "Return";
constexpr auto kBpropCutOpName = "bprop_cut";
constexpr auto kLarsV2OpName = "LarsV2";
constexpr auto kLarsV2UpdateOpName = "LarsV2Update";
constexpr auto kSquareSumAllOpName = "SquareSumAll";

View File

@ -21,7 +21,9 @@
#include "vm/transform.h"
#include "backend/session/session_factory.h"
#include "backend/optimizer/common/helper.h"
#include "pipeline/pynative/pynative_execute.h"
#include "pipeline/jit/parse/data_converter.h"
#include "ir/anf.h"
#include "pybind_api/ir/base_ref_py.h"
#include "utils/callbacks.h"
@ -251,7 +253,10 @@ void MsBackend::SetDebugger() { target_sess_->SetDebugger(); }
MindRTBackend::MindRTBackend(const std::string &backend_name, const std::string &device_name, uint32_t device_id)
: Backend(backend_name), device_name_(device_name), device_id_(device_id) {
root_graph_ = nullptr;
auto cut_list = compile::GetMsNonlinearOps();
auto ms_context = MsContext::GetInstance();
const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
auto &cut_list = pynative_mode ? compile::control_ops : GetMsNonlinearOps();
graph_partition_ = std::make_shared<GraphPartition>(cut_list, backend_name);
graph_compiler_ = std::make_shared<GraphCompiler>();
}
@ -372,6 +377,93 @@ const ActorInfo &MindRTBackend::CompileGraph(const OpRunInfo &op_run_info, const
return ret.first->first;
}
namespace {
void PlantTensorTupleToVector(const py::tuple &tuple_inputs, std::vector<tensor::TensorPtr> *tensors) {
MS_EXCEPTION_IF_NULL(tensors);
for (const auto &input_object : tuple_inputs) {
if (!py::isinstance<tensor::Tensor>(input_object)) {
MS_LOG(EXCEPTION) << "The input object is not a tensor!";
}
auto tensor = py::cast<tensor::TensorPtr>(input_object);
MS_EXCEPTION_IF_NULL(tensor);
tensors->emplace_back(tensor);
}
}
void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *tensors) {
MS_EXCEPTION_IF_NULL(tensors);
ValuePtr input_value = parse::data_converter::PyDataToValue(input_object);
MS_EXCEPTION_IF_NULL(input_value);
if (!input_value->isa<ValueTuple>()) {
MS_LOG(EXCEPTION) << "The input object is not a value tuple!";
}
auto value_tuple = input_value->cast<ValueTuplePtr>();
MS_EXCEPTION_IF_NULL(value_tuple);
tensor::TensorPtr tensor_ptr = opt::CreateTupleTensor(value_tuple);
MS_EXCEPTION_IF_NULL(tensor_ptr);
tensors->emplace_back(tensor_ptr);
}
void ConvertMultiPyObjectToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *tensors) {
MS_EXCEPTION_IF_NULL(tensors);
if (!py::isinstance<py::tuple>(input_object)) {
MS_LOG(EXCEPTION) << "The input should be a tuple!";
}
auto tuple_inputs = py::cast<py::tuple>(input_object);
if (tuple_inputs.empty()) {
MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!";
}
auto inputs = py::cast<py::tuple>(input_object);
if (py::isinstance<tensor::Tensor>(inputs[0])) {
PlantTensorTupleToVector(inputs, tensors);
} else {
ConvertValueTupleToTensor(input_object, tensors);
}
}
void RunControlOperator(const KernelGraphPtr &graph, const AnfNodePtr &kernel, std::vector<TensorPtr> *input_tensors,
VectorRef *op_outputs) {
AnfNodePtr front_node = graph->GetFrontAnfByBackendAnf(kernel);
MS_EXCEPTION_IF_NULL(front_node);
if (!front_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "The front node of bprop_cut is not CNode";
}
CNodePtr cnode = front_node->cast<CNodePtr>();
const std::vector<AnfNodePtr> &node_inputs = cnode->inputs();
if (node_inputs.empty()) {
MS_LOG(EXCEPTION) << "The inputs of node[" << cnode->fullname_with_scope() << "] is empty";
}
const AnfNodePtr &fn = node_inputs.at(0);
if (!IsValueNode<Primitive>(fn)) {
MS_LOG(EXCEPTION) << "The input[0] of kernel[" << kernel->fullname_with_scope()
<< "] is not a ValueNode of Primitive";
}
PrimitivePtr prim = GetValueNode<PrimitivePtr>(fn);
if (prim->name() == kBpropCutOpName) {
VectorRef args;
(void)std::transform(input_tensors->begin(), input_tensors->end(), std::back_inserter(args.elements_),
[](tensor::TensorPtr &tensor) { return std::move(tensor); });
BaseRef out = prim->RunHookFunction(args);
if (utils::isa<PyObjectRef>(out)) {
PyObjectRef py_ref = utils::cast<PyObjectRef>(out);
auto out_py_tuple = py_ref.object_;
std::vector<tensor::TensorPtr> output_tensors;
ConvertMultiPyObjectToTensor(out_py_tuple, &output_tensors);
(void)std::transform(output_tensors.begin(), output_tensors.end(), std::back_inserter(op_outputs->elements_),
[](tensor::TensorPtr &tensor) { return std::move(tensor); });
}
}
}
} // namespace
void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs,
const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(graph_compiler_);
@ -402,11 +494,15 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs
graph_compiler_->GetSingleOpRunInfoAndGraphInfo(kernel, input_tensor_info.input_tensors, &op_run_info,
&graph_info);
const ActorInfo &actor_info =
CompileGraph(op_run_info, graph_info, &input_tensor_info.input_tensors_mask, &input_tensor_info.input_tensors);
VectorRef op_outputs;
RunGraph(actor_info, &op_run_info, &input_tensor_info.input_tensors_mask, &input_tensor_info.input_tensors,
&op_outputs);
if (!AnfAlgo::IsControlOpExecInBackend(kernel)) {
const ActorInfo &actor_info = CompileGraph(op_run_info, graph_info, &input_tensor_info.input_tensors_mask,
&input_tensor_info.input_tensors);
RunGraph(actor_info, &op_run_info, &input_tensor_info.input_tensors_mask, &input_tensor_info.input_tensors,
&op_outputs);
} else {
RunControlOperator(graph, kernel, &input_tensor_info.input_tensors, &op_outputs);
}
graph_compiler_->UpdateRefCount(input_tensor_info.input_kernel, &cnode_ref_count, &op_output_map);
@ -414,7 +510,7 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs
graph_compiler_->RecoverGraphOutput(kernel, op_outputs, cnode_ref_count, &op_output_map, &graph_output_info);
// Save grad node to Bucket
if (graph->is_bprop()) {
if (graph->is_bprop() && (!AnfAlgo::IsControlOpExecInBackend(kernel))) {
graph_compiler_->AddGradAddrToBucket(graph->graph_id(), graph_output_info.graph_output_tensors);
}
}

View File

@ -44,6 +44,10 @@ using TypedPrimitiveAbstractClosurePtr = std::shared_ptr<abstract::TypedPrimitiv
std::vector<PrimitivePtr> nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch,
prim::kPrimMakeTuple, prim::kPrimBpropCut};
std::vector<PrimitivePtr> control_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch, prim::kPrimMakeTuple,
prim::kPrimSwitchLayer};
const std::vector<PrimitivePtr> &GetMsNonlinearOps() {
static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial,
prim::kPrimSwitch, prim::kPrimMakeTuple,

View File

@ -43,6 +43,7 @@ extern const char kGeVm[];
// A sub namespace in ME to support compile related definition.
namespace compile {
extern std::vector<PrimitivePtr> nonlinear_ops;
extern std::vector<PrimitivePtr> control_ops;
const std::vector<PrimitivePtr> &GetMsNonlinearOps();
FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph);
using VmEvalFunc = std::function<BaseRef(const VectorRef &)>;