!18172 unify runtime for PyNative supports bp hook and custom bp operator
Merge pull request !18172 from zyli2020/fix_issue_defect
This commit is contained in:
commit
19b87fe35e
|
@ -2129,5 +2129,15 @@ bool AnfRuntimeAlgorithm::IsOneOfPrimitiveCNode(const AnfNodePtr &node, const Pr
|
|||
}
|
||||
return IsOneOfPrimitive(cnode->inputs().at(kAnfPrimitiveIndex), prim_set);
|
||||
}
|
||||
|
||||
bool AnfRuntimeAlgorithm::IsControlOpExecInBackend(const AnfNodePtr &node) {
|
||||
if (!node->isa<CNode>()) {
|
||||
return false;
|
||||
}
|
||||
// Operators in set control_ops_exec_in_backend will be compiled into kernel graph, rather than be cut into single op
|
||||
// and executed in VM.
|
||||
static std::set<std::string> control_ops_exec_in_backend = {kBpropCutOpName};
|
||||
return control_ops_exec_in_backend.find(AnfAlgo::GetCNodeName(node)) != control_ops_exec_in_backend.end();
|
||||
}
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -298,6 +298,11 @@ class AnfRuntimeAlgorithm {
|
|||
return result;
|
||||
}
|
||||
static bool IsOneOfPrimitiveCNode(const AnfNodePtr &node, const PrimitiveSet &prim_set);
|
||||
|
||||
// Judge a control operator need be compiled into kernel graph rather than be cut into single op and
|
||||
// executed in vm. For example, the operator "bprop_cut" will be compiled into kernel graph and be launch
|
||||
// in backend in PyNative mode.
|
||||
static bool IsControlOpExecInBackend(const AnfNodePtr &node);
|
||||
};
|
||||
} // namespace session
|
||||
using AnfAlgo = session::AnfRuntimeAlgorithm;
|
||||
|
|
|
@ -424,6 +424,23 @@ BaseRef CreateNodeOutputPlaceholder(const AnfNodePtr &anf, const KernelGraphPtr
|
|||
return CreateNodeOutputPlaceholder(item_with_index, graph, input_tensors, indexes, output_indexes);
|
||||
}
|
||||
|
||||
void CheckInputTensorShape(const TensorPtr &tensor, const CNodePtr &kernel, size_t input_index) {
|
||||
const auto &tensor_shape = tensor->shape();
|
||||
const auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel, input_index);
|
||||
|
||||
if (tensor_shape.size() != input_shape.size()) {
|
||||
MS_LOG(EXCEPTION) << "The input tensor's shape size: " << tensor_shape.size()
|
||||
<< " is not equal to expected size: " << input_shape.size() << " for input[" << input_index
|
||||
<< "] of kernel: " << AnfAlgo::GetCNodeName(kernel);
|
||||
}
|
||||
for (size_t i = 0; i < tensor_shape.size(); i++) {
|
||||
if (tensor_shape[i] < 0 || static_cast<size_t>(tensor_shape[i]) != input_shape[i]) {
|
||||
MS_LOG(EXCEPTION) << "The input tensor's shape: " << tensor_shape
|
||||
<< " is not equal to expected shape: " << input_shape << " for input[" << input_index
|
||||
<< "] of kernel: " << AnfAlgo::GetCNodeName(kernel);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
GraphId SessionBasic::graph_sum_ = 0;
|
||||
|
@ -1380,6 +1397,9 @@ void SessionBasic::GetOpInputTensors(const CNodePtr &cnode,
|
|||
tensor = GetParameterOutputTensor(real_input, parameter_index, graph_inputs);
|
||||
} else if (real_input->isa<CNode>()) {
|
||||
tensor = GetCNodeOutputTensor(kernel_with_index, op_output);
|
||||
if (AnfAlgo::IsControlOpExecInBackend(real_input)) {
|
||||
CheckInputTensorShape(tensor, cnode, i - 1);
|
||||
}
|
||||
input_tensor_info->input_kernel.insert(kernel_with_index);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString();
|
||||
|
|
|
@ -1673,7 +1673,7 @@ py::object ForwardExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynativ
|
|||
auto ms_context = MsContext::GetInstance();
|
||||
ms_context->set_param<bool>(MS_CTX_ENABLE_PYNATIVE_INFER, true);
|
||||
|
||||
if (kSession == nullptr) {
|
||||
if (kSession == nullptr && !IsMindRTUsed()) {
|
||||
std::string device_target = ms_context->get_param<std::string>(MS_CTX_DEVICE_TARGET);
|
||||
kSession = session::SessionFactory::Get().Create(device_target);
|
||||
MS_EXCEPTION_IF_NULL(kSession);
|
||||
|
|
|
@ -54,7 +54,7 @@ void CreateGPUKernel(const std::vector<CNodePtr> &kernels) {
|
|||
}
|
||||
}
|
||||
akg_nodes.push_back(kernel);
|
||||
} else {
|
||||
} else if (!AnfAlgo::IsControlOpExecInBackend(kernel)) {
|
||||
auto gpu_kernel_ptr = kernel::GpuKernelFactory::GetInstance().Create(kernel_name, kernel);
|
||||
if (!gpu_kernel_ptr) {
|
||||
MS_LOG(EXCEPTION) << "Build gpu kernel op[" << kernel->fullname_with_scope() << "] failed";
|
||||
|
|
|
@ -424,14 +424,14 @@ void SetKernelInfo(const CNodePtr &kernel_node, KernelType kernel_type) {
|
|||
if (!result) {
|
||||
result = kernel::GpuKernelFactory::GetInstance().ReducePrecision(AnfAlgo::GetCNodeName(kernel_node), builder);
|
||||
}
|
||||
if (!result) {
|
||||
if (!result && (!AnfAlgo::IsControlOpExecInBackend(kernel_node))) {
|
||||
result = SelectAkgKernel(kernel_node, builder->Build());
|
||||
kernel_type = AKG_KERNEL;
|
||||
}
|
||||
} else if (kernel_type == AKG_KERNEL) {
|
||||
result = SelectAkgKernel(kernel_node, builder->Build());
|
||||
}
|
||||
if (!result) {
|
||||
if (!result && (!AnfAlgo::IsControlOpExecInBackend(kernel_node))) {
|
||||
PrintUnsupportedTypeException(kernel_node, inputs_type, outputs_type);
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -289,8 +289,12 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph, const Devic
|
|||
// 'KernelMod' is real executive object of kernel.
|
||||
device_context->CreateKernel(graph->execution_order());
|
||||
|
||||
// Create device address for all anf nodes of graph.
|
||||
CreateDeviceAddress(graph, device_context);
|
||||
const auto &ms_context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(ms_context);
|
||||
if (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kGraphMode) {
|
||||
// Create device address for all anf nodes of graph.
|
||||
CreateDeviceAddress(graph, device_context);
|
||||
}
|
||||
|
||||
graph->set_is_all_nop_node(opt::IsAllNopNode(graph.get()));
|
||||
|
||||
|
|
|
@ -193,6 +193,7 @@ constexpr auto kCallOpName = "call";
|
|||
constexpr auto kPartialOpName = "partial";
|
||||
constexpr auto kSwitchOpName = "Switch";
|
||||
constexpr auto kReturnOpName = "Return";
|
||||
constexpr auto kBpropCutOpName = "bprop_cut";
|
||||
constexpr auto kLarsV2OpName = "LarsV2";
|
||||
constexpr auto kLarsV2UpdateOpName = "LarsV2Update";
|
||||
constexpr auto kSquareSumAllOpName = "SquareSumAll";
|
||||
|
|
|
@ -21,7 +21,9 @@
|
|||
|
||||
#include "vm/transform.h"
|
||||
#include "backend/session/session_factory.h"
|
||||
#include "backend/optimizer/common/helper.h"
|
||||
#include "pipeline/pynative/pynative_execute.h"
|
||||
#include "pipeline/jit/parse/data_converter.h"
|
||||
#include "ir/anf.h"
|
||||
#include "pybind_api/ir/base_ref_py.h"
|
||||
#include "utils/callbacks.h"
|
||||
|
@ -251,7 +253,10 @@ void MsBackend::SetDebugger() { target_sess_->SetDebugger(); }
|
|||
MindRTBackend::MindRTBackend(const std::string &backend_name, const std::string &device_name, uint32_t device_id)
|
||||
: Backend(backend_name), device_name_(device_name), device_id_(device_id) {
|
||||
root_graph_ = nullptr;
|
||||
auto cut_list = compile::GetMsNonlinearOps();
|
||||
auto ms_context = MsContext::GetInstance();
|
||||
const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode);
|
||||
auto &cut_list = pynative_mode ? compile::control_ops : GetMsNonlinearOps();
|
||||
|
||||
graph_partition_ = std::make_shared<GraphPartition>(cut_list, backend_name);
|
||||
graph_compiler_ = std::make_shared<GraphCompiler>();
|
||||
}
|
||||
|
@ -372,6 +377,93 @@ const ActorInfo &MindRTBackend::CompileGraph(const OpRunInfo &op_run_info, const
|
|||
return ret.first->first;
|
||||
}
|
||||
|
||||
namespace {
|
||||
void PlantTensorTupleToVector(const py::tuple &tuple_inputs, std::vector<tensor::TensorPtr> *tensors) {
|
||||
MS_EXCEPTION_IF_NULL(tensors);
|
||||
for (const auto &input_object : tuple_inputs) {
|
||||
if (!py::isinstance<tensor::Tensor>(input_object)) {
|
||||
MS_LOG(EXCEPTION) << "The input object is not a tensor!";
|
||||
}
|
||||
auto tensor = py::cast<tensor::TensorPtr>(input_object);
|
||||
MS_EXCEPTION_IF_NULL(tensor);
|
||||
tensors->emplace_back(tensor);
|
||||
}
|
||||
}
|
||||
|
||||
void ConvertValueTupleToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *tensors) {
|
||||
MS_EXCEPTION_IF_NULL(tensors);
|
||||
ValuePtr input_value = parse::data_converter::PyDataToValue(input_object);
|
||||
MS_EXCEPTION_IF_NULL(input_value);
|
||||
if (!input_value->isa<ValueTuple>()) {
|
||||
MS_LOG(EXCEPTION) << "The input object is not a value tuple!";
|
||||
}
|
||||
|
||||
auto value_tuple = input_value->cast<ValueTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(value_tuple);
|
||||
tensor::TensorPtr tensor_ptr = opt::CreateTupleTensor(value_tuple);
|
||||
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
||||
tensors->emplace_back(tensor_ptr);
|
||||
}
|
||||
|
||||
void ConvertMultiPyObjectToTensor(const py::object &input_object, std::vector<tensor::TensorPtr> *tensors) {
|
||||
MS_EXCEPTION_IF_NULL(tensors);
|
||||
if (!py::isinstance<py::tuple>(input_object)) {
|
||||
MS_LOG(EXCEPTION) << "The input should be a tuple!";
|
||||
}
|
||||
|
||||
auto tuple_inputs = py::cast<py::tuple>(input_object);
|
||||
if (tuple_inputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The size of input list or tuple is 0!";
|
||||
}
|
||||
|
||||
auto inputs = py::cast<py::tuple>(input_object);
|
||||
if (py::isinstance<tensor::Tensor>(inputs[0])) {
|
||||
PlantTensorTupleToVector(inputs, tensors);
|
||||
} else {
|
||||
ConvertValueTupleToTensor(input_object, tensors);
|
||||
}
|
||||
}
|
||||
|
||||
void RunControlOperator(const KernelGraphPtr &graph, const AnfNodePtr &kernel, std::vector<TensorPtr> *input_tensors,
|
||||
VectorRef *op_outputs) {
|
||||
AnfNodePtr front_node = graph->GetFrontAnfByBackendAnf(kernel);
|
||||
MS_EXCEPTION_IF_NULL(front_node);
|
||||
if (!front_node->isa<CNode>()) {
|
||||
MS_LOG(EXCEPTION) << "The front node of bprop_cut is not CNode";
|
||||
}
|
||||
|
||||
CNodePtr cnode = front_node->cast<CNodePtr>();
|
||||
const std::vector<AnfNodePtr> &node_inputs = cnode->inputs();
|
||||
if (node_inputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "The inputs of node[" << cnode->fullname_with_scope() << "] is empty";
|
||||
}
|
||||
|
||||
const AnfNodePtr &fn = node_inputs.at(0);
|
||||
if (!IsValueNode<Primitive>(fn)) {
|
||||
MS_LOG(EXCEPTION) << "The input[0] of kernel[" << kernel->fullname_with_scope()
|
||||
<< "] is not a ValueNode of Primitive";
|
||||
}
|
||||
|
||||
PrimitivePtr prim = GetValueNode<PrimitivePtr>(fn);
|
||||
if (prim->name() == kBpropCutOpName) {
|
||||
VectorRef args;
|
||||
(void)std::transform(input_tensors->begin(), input_tensors->end(), std::back_inserter(args.elements_),
|
||||
[](tensor::TensorPtr &tensor) { return std::move(tensor); });
|
||||
|
||||
BaseRef out = prim->RunHookFunction(args);
|
||||
|
||||
if (utils::isa<PyObjectRef>(out)) {
|
||||
PyObjectRef py_ref = utils::cast<PyObjectRef>(out);
|
||||
auto out_py_tuple = py_ref.object_;
|
||||
std::vector<tensor::TensorPtr> output_tensors;
|
||||
ConvertMultiPyObjectToTensor(out_py_tuple, &output_tensors);
|
||||
(void)std::transform(output_tensors.begin(), output_tensors.end(), std::back_inserter(op_outputs->elements_),
|
||||
[](tensor::TensorPtr &tensor) { return std::move(tensor); });
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs,
|
||||
const std::vector<std::vector<tensor::TensorPtr>> &inputs, VectorRef *outputs) {
|
||||
MS_EXCEPTION_IF_NULL(graph_compiler_);
|
||||
|
@ -402,11 +494,15 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs
|
|||
graph_compiler_->GetSingleOpRunInfoAndGraphInfo(kernel, input_tensor_info.input_tensors, &op_run_info,
|
||||
&graph_info);
|
||||
|
||||
const ActorInfo &actor_info =
|
||||
CompileGraph(op_run_info, graph_info, &input_tensor_info.input_tensors_mask, &input_tensor_info.input_tensors);
|
||||
VectorRef op_outputs;
|
||||
RunGraph(actor_info, &op_run_info, &input_tensor_info.input_tensors_mask, &input_tensor_info.input_tensors,
|
||||
&op_outputs);
|
||||
if (!AnfAlgo::IsControlOpExecInBackend(kernel)) {
|
||||
const ActorInfo &actor_info = CompileGraph(op_run_info, graph_info, &input_tensor_info.input_tensors_mask,
|
||||
&input_tensor_info.input_tensors);
|
||||
RunGraph(actor_info, &op_run_info, &input_tensor_info.input_tensors_mask, &input_tensor_info.input_tensors,
|
||||
&op_outputs);
|
||||
} else {
|
||||
RunControlOperator(graph, kernel, &input_tensor_info.input_tensors, &op_outputs);
|
||||
}
|
||||
|
||||
graph_compiler_->UpdateRefCount(input_tensor_info.input_kernel, &cnode_ref_count, &op_output_map);
|
||||
|
||||
|
@ -414,7 +510,7 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector<KernelGraphPtr> &graphs
|
|||
graph_compiler_->RecoverGraphOutput(kernel, op_outputs, cnode_ref_count, &op_output_map, &graph_output_info);
|
||||
|
||||
// Save grad node to Bucket
|
||||
if (graph->is_bprop()) {
|
||||
if (graph->is_bprop() && (!AnfAlgo::IsControlOpExecInBackend(kernel))) {
|
||||
graph_compiler_->AddGradAddrToBucket(graph->graph_id(), graph_output_info.graph_output_tensors);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,6 +44,10 @@ using TypedPrimitiveAbstractClosurePtr = std::shared_ptr<abstract::TypedPrimitiv
|
|||
|
||||
std::vector<PrimitivePtr> nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch,
|
||||
prim::kPrimMakeTuple, prim::kPrimBpropCut};
|
||||
|
||||
std::vector<PrimitivePtr> control_ops = {prim::kPrimReturn, prim::kPrimPartial, prim::kPrimSwitch, prim::kPrimMakeTuple,
|
||||
prim::kPrimSwitchLayer};
|
||||
|
||||
const std::vector<PrimitivePtr> &GetMsNonlinearOps() {
|
||||
static const std::vector<PrimitivePtr> ms_nonlinear_ops = {prim::kPrimReturn, prim::kPrimPartial,
|
||||
prim::kPrimSwitch, prim::kPrimMakeTuple,
|
||||
|
|
|
@ -43,6 +43,7 @@ extern const char kGeVm[];
|
|||
// A sub namespace in ME to support compile related definition.
|
||||
namespace compile {
|
||||
extern std::vector<PrimitivePtr> nonlinear_ops;
|
||||
extern std::vector<PrimitivePtr> control_ops;
|
||||
const std::vector<PrimitivePtr> &GetMsNonlinearOps();
|
||||
FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph);
|
||||
using VmEvalFunc = std::function<BaseRef(const VectorRef &)>;
|
||||
|
|
Loading…
Reference in New Issue