!45635 解耦SetArgs和InferOp

Merge pull request !45635 from nomindcarry/master
This commit is contained in:
i-robot 2022-11-23 09:47:14 +00:00 committed by Gitee
commit d5b972e903
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 116 additions and 1 deletions

View File

@ -244,6 +244,65 @@ void InferShape(const CNodePtr &cnode, std::map<uint32_t, tensor::TensorPtr> *de
cnode->set_abstract(new_abs);
}
void InferShapeDynamic(const CNodePtr &cnode, std::map<uint32_t, tensor::TensorPtr> *depend_tensor_map, void *args) {
MS_EXCEPTION_IF_NULL(cnode);
MS_EXCEPTION_IF_NULL(depend_tensor_map);
MS_LOG(DEBUG) << "InferShape start, node:" << cnode->fullname_with_scope();
std::set<int64_t> depend_list = abstract::GetValueDependArgIndices(cnode);
auto ret = InferShapeForDefiniteOutputNode(cnode);
if (ret) {
return;
}
depend_tensor_map->clear();
auto &inputs = cnode->inputs();
if (inputs.empty()) {
MS_LOG(EXCEPTION) << "Invalid inputs.";
}
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
AbstractBasePtrList args_spec_list;
auto primitive = GetValueNode<PrimitivePtr>(inputs[0]);
auto input_size = common::AnfAlgo::GetInputTensorNum(cnode);
for (size_t i = 0; i < input_size; i++) {
auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, i, false);
auto real_input = input_node_with_index.first;
auto real_input_index = input_node_with_index.second;
AbstractBasePtr cached_abstract;
AbstractBasePtr real_input_abs = real_input->abstract();
MS_EXCEPTION_IF_NULL(real_input);
if (depend_list.find(i) != depend_list.end()) {
auto out_tensor = GetDependValueTensor(cnode, i, input_node_with_index, false, args, false);
// cppcheck-suppress unreadVariable
auto lock = AnfUtils::GetAbstractLock(real_input.get());
AbstractBasePtr real_abs = real_input->abstract();
if (real_abs->isa<abstract::AbstractTensor>()) {
real_abs->set_value(out_tensor);
} else if (real_abs->isa<abstract::AbstractTuple>()) {
auto abstract_tuple = real_abs->cast<abstract::AbstractTuplePtr>();
MS_EXCEPTION_IF_NULL(abstract_tuple);
MS_EXCEPTION_IF_CHECK_FAIL((real_input_index < abstract_tuple->elements().size()), "Index is out of range.");
auto tuple_elements = abstract_tuple->elements()[real_input_index];
tuple_elements->set_value(out_tensor);
}
}
common::AnfAlgo::AddArgList(&args_spec_list, real_input, real_input_index);
}
// Pynative mode is rely on the origin abstract of cnode, so cannot modify the abstract inplace, clone from old
// abstract instead.
auto old_abs = cnode->abstract();
MS_EXCEPTION_IF_NULL(old_abs);
auto new_abs = old_abs->Clone();
opt::CppInferShape(primitive, args_spec_list, new_abs);
MS_LOG(DEBUG) << "The abstract of " << cnode->fullname_with_scope() << " changes from " << old_abs << " to "
<< new_abs;
cnode->set_abstract(new_abs);
}
inline bool IsDeprecatedCpuOrGpuKernelMod(kernel::KernelModType kernel_mod_type) {
return kernel_mod_type == kernel::KernelModType::DeprecatedNativeGpuKernelMod ||
kernel_mod_type == kernel::KernelModType::DeprecatedNativeCpuKernelMod;
@ -319,6 +378,59 @@ void InferOp(const CNodePtr &cnode, void *args) {
}
}
void InferOpDynamic(const CNodePtr &cnode, void *args) {
MS_EXCEPTION_IF_NULL(cnode);
auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
MS_EXCEPTION_IF_NULL(kernel_mod);
kernel::KernelArgs kernel_args;
if (AnfAlgo::IsDynamicShapeSkipExecute(cnode)) {
std::vector<TypeId> dtypes{common::AnfAlgo::GetOutputInferDataType(cnode, 0)};
common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, {AnfAlgo::GetInputDeviceShape(cnode, 0)}, cnode.get());
} else {
InferShapeDynamic(cnode, &kernel_args.depend_tensor_map, args);
}
}
void SetOpArgs(const CNodePtr &cnode, void *args) {
MS_EXCEPTION_IF_NULL(cnode);
auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
MS_EXCEPTION_IF_NULL(kernel_mod);
kernel::KernelArgs kernel_args;
std::set<int64_t> depend_list = abstract::GetValueDependArgIndices(cnode);
auto *depend_tensor_map = &kernel_args.depend_tensor_map;
auto &inputs = cnode->inputs();
if (inputs.empty()) {
MS_LOG(EXCEPTION) << "Invalid inputs.";
}
auto context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context);
auto input_size = common::AnfAlgo::GetInputTensorNum(cnode);
bool skip_nop_node = !context->get_param<bool>(MS_CTX_ENABLE_MINDRT);
for (size_t i = 0; i < input_size; i++) {
auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, i, false);
auto real_input = input_node_with_index.first;
bool abstract_in_cache = DynamicShapeDtypeManager::GetInstance().CheckDeviceType(real_input);
if (depend_list.find(i) != depend_list.end()) {
auto out_tensor = GetDependValueTensor(cnode, i, input_node_with_index, skip_nop_node, args, abstract_in_cache);
auto ret2 = depend_tensor_map->try_emplace(i, out_tensor);
if (!ret2.second) {
MS_LOG(EXCEPTION) << "Insert map failed.";
}
}
}
if (auto kernel_mod_type = kernel_mod->GetKernelModType(); IsCpuGpuKernelMod(kernel_mod_type)) {
auto update = kernel::AbstractArgsFromCNode(cnode, IsDeprecatedCpuOrGpuKernelMod(kernel_mod_type));
update.depend_tensor_map = std::move(kernel_args.depend_tensor_map);
kernel::SetInputsByDependMap(update.depend_tensor_map, &update.inputs, IsCpuKernelMod(kernel_mod_type));
kernel::SetArgsToCNode(cnode, update);
} else {
kernel::SetArgsToCNode(cnode, kernel_args);
}
}
CustomActorNodeManager &CustomActorNodeManager::Instance() {
static CustomActorNodeManager instance{};
return instance;

View File

@ -24,6 +24,8 @@
namespace mindspore::opt::dynamic_shape {
bool IsRealCNode(const BaseRef &n);
BACKEND_EXPORT void InferOp(const CNodePtr &node, void *args = nullptr);
BACKEND_EXPORT void InferOpDynamic(const CNodePtr &node, void *args = nullptr);
BACKEND_EXPORT void SetOpArgs(const CNodePtr &node, void *args = nullptr);
AnfNodePtr GenInferNode(const AnfNodePtr &node);
AnfNodePtr GenInitNode(const AnfNodePtr &node);

View File

@ -486,7 +486,8 @@ void LaunchKernelsDynamic(const KernelGraphPtr &graph, const device::DeviceConte
}
auto inputs = CreateKernelInputAddress(runtime_info);
InferNodeRealShape(node);
opt::dynamic_shape::InferOpDynamic(node);
opt::dynamic_shape::SetOpArgs(node);
runtime::DeviceAddressUtils::CreateKernelOutputDeviceAddress(device_context, graph, is_gradient_out);
runtime::DeviceAddressUtils::UpdateDeviceAddressForInplaceNode(graph);