!45635 解耦SetArgs和InferOp
Merge pull request !45635 from nomindcarry/master
This commit is contained in:
commit
d5b972e903
|
@ -244,6 +244,65 @@ void InferShape(const CNodePtr &cnode, std::map<uint32_t, tensor::TensorPtr> *de
|
|||
cnode->set_abstract(new_abs);
|
||||
}
|
||||
|
||||
void InferShapeDynamic(const CNodePtr &cnode, std::map<uint32_t, tensor::TensorPtr> *depend_tensor_map, void *args) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(depend_tensor_map);
|
||||
MS_LOG(DEBUG) << "InferShape start, node:" << cnode->fullname_with_scope();
|
||||
std::set<int64_t> depend_list = abstract::GetValueDependArgIndices(cnode);
|
||||
auto ret = InferShapeForDefiniteOutputNode(cnode);
|
||||
if (ret) {
|
||||
return;
|
||||
}
|
||||
|
||||
depend_tensor_map->clear();
|
||||
auto &inputs = cnode->inputs();
|
||||
if (inputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid inputs.";
|
||||
}
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
AbstractBasePtrList args_spec_list;
|
||||
auto primitive = GetValueNode<PrimitivePtr>(inputs[0]);
|
||||
auto input_size = common::AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t i = 0; i < input_size; i++) {
|
||||
auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, i, false);
|
||||
auto real_input = input_node_with_index.first;
|
||||
auto real_input_index = input_node_with_index.second;
|
||||
|
||||
AbstractBasePtr cached_abstract;
|
||||
AbstractBasePtr real_input_abs = real_input->abstract();
|
||||
|
||||
MS_EXCEPTION_IF_NULL(real_input);
|
||||
if (depend_list.find(i) != depend_list.end()) {
|
||||
auto out_tensor = GetDependValueTensor(cnode, i, input_node_with_index, false, args, false);
|
||||
|
||||
// cppcheck-suppress unreadVariable
|
||||
auto lock = AnfUtils::GetAbstractLock(real_input.get());
|
||||
AbstractBasePtr real_abs = real_input->abstract();
|
||||
if (real_abs->isa<abstract::AbstractTensor>()) {
|
||||
real_abs->set_value(out_tensor);
|
||||
} else if (real_abs->isa<abstract::AbstractTuple>()) {
|
||||
auto abstract_tuple = real_abs->cast<abstract::AbstractTuplePtr>();
|
||||
MS_EXCEPTION_IF_NULL(abstract_tuple);
|
||||
MS_EXCEPTION_IF_CHECK_FAIL((real_input_index < abstract_tuple->elements().size()), "Index is out of range.");
|
||||
auto tuple_elements = abstract_tuple->elements()[real_input_index];
|
||||
tuple_elements->set_value(out_tensor);
|
||||
}
|
||||
}
|
||||
common::AnfAlgo::AddArgList(&args_spec_list, real_input, real_input_index);
|
||||
}
|
||||
|
||||
// Pynative mode is rely on the origin abstract of cnode, so cannot modify the abstract inplace, clone from old
|
||||
// abstract instead.
|
||||
auto old_abs = cnode->abstract();
|
||||
MS_EXCEPTION_IF_NULL(old_abs);
|
||||
auto new_abs = old_abs->Clone();
|
||||
opt::CppInferShape(primitive, args_spec_list, new_abs);
|
||||
MS_LOG(DEBUG) << "The abstract of " << cnode->fullname_with_scope() << " changes from " << old_abs << " to "
|
||||
<< new_abs;
|
||||
cnode->set_abstract(new_abs);
|
||||
}
|
||||
|
||||
inline bool IsDeprecatedCpuOrGpuKernelMod(kernel::KernelModType kernel_mod_type) {
|
||||
return kernel_mod_type == kernel::KernelModType::DeprecatedNativeGpuKernelMod ||
|
||||
kernel_mod_type == kernel::KernelModType::DeprecatedNativeCpuKernelMod;
|
||||
|
@ -319,6 +378,59 @@ void InferOp(const CNodePtr &cnode, void *args) {
|
|||
}
|
||||
}
|
||||
|
||||
void InferOpDynamic(const CNodePtr &cnode, void *args) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
kernel::KernelArgs kernel_args;
|
||||
if (AnfAlgo::IsDynamicShapeSkipExecute(cnode)) {
|
||||
std::vector<TypeId> dtypes{common::AnfAlgo::GetOutputInferDataType(cnode, 0)};
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape(dtypes, {AnfAlgo::GetInputDeviceShape(cnode, 0)}, cnode.get());
|
||||
} else {
|
||||
InferShapeDynamic(cnode, &kernel_args.depend_tensor_map, args);
|
||||
}
|
||||
}
|
||||
|
||||
void SetOpArgs(const CNodePtr &cnode, void *args) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(cnode);
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
kernel::KernelArgs kernel_args;
|
||||
|
||||
std::set<int64_t> depend_list = abstract::GetValueDependArgIndices(cnode);
|
||||
auto *depend_tensor_map = &kernel_args.depend_tensor_map;
|
||||
|
||||
auto &inputs = cnode->inputs();
|
||||
if (inputs.empty()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid inputs.";
|
||||
}
|
||||
auto context = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context);
|
||||
auto input_size = common::AnfAlgo::GetInputTensorNum(cnode);
|
||||
bool skip_nop_node = !context->get_param<bool>(MS_CTX_ENABLE_MINDRT);
|
||||
for (size_t i = 0; i < input_size; i++) {
|
||||
auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, i, false);
|
||||
auto real_input = input_node_with_index.first;
|
||||
bool abstract_in_cache = DynamicShapeDtypeManager::GetInstance().CheckDeviceType(real_input);
|
||||
if (depend_list.find(i) != depend_list.end()) {
|
||||
auto out_tensor = GetDependValueTensor(cnode, i, input_node_with_index, skip_nop_node, args, abstract_in_cache);
|
||||
auto ret2 = depend_tensor_map->try_emplace(i, out_tensor);
|
||||
if (!ret2.second) {
|
||||
MS_LOG(EXCEPTION) << "Insert map failed.";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (auto kernel_mod_type = kernel_mod->GetKernelModType(); IsCpuGpuKernelMod(kernel_mod_type)) {
|
||||
auto update = kernel::AbstractArgsFromCNode(cnode, IsDeprecatedCpuOrGpuKernelMod(kernel_mod_type));
|
||||
update.depend_tensor_map = std::move(kernel_args.depend_tensor_map);
|
||||
kernel::SetInputsByDependMap(update.depend_tensor_map, &update.inputs, IsCpuKernelMod(kernel_mod_type));
|
||||
kernel::SetArgsToCNode(cnode, update);
|
||||
} else {
|
||||
kernel::SetArgsToCNode(cnode, kernel_args);
|
||||
}
|
||||
}
|
||||
|
||||
CustomActorNodeManager &CustomActorNodeManager::Instance() {
|
||||
static CustomActorNodeManager instance{};
|
||||
return instance;
|
||||
|
|
|
@ -24,6 +24,8 @@
|
|||
namespace mindspore::opt::dynamic_shape {
|
||||
bool IsRealCNode(const BaseRef &n);
|
||||
BACKEND_EXPORT void InferOp(const CNodePtr &node, void *args = nullptr);
|
||||
BACKEND_EXPORT void InferOpDynamic(const CNodePtr &node, void *args = nullptr);
|
||||
BACKEND_EXPORT void SetOpArgs(const CNodePtr &node, void *args = nullptr);
|
||||
AnfNodePtr GenInferNode(const AnfNodePtr &node);
|
||||
AnfNodePtr GenInitNode(const AnfNodePtr &node);
|
||||
|
||||
|
|
|
@ -486,7 +486,8 @@ void LaunchKernelsDynamic(const KernelGraphPtr &graph, const device::DeviceConte
|
|||
}
|
||||
auto inputs = CreateKernelInputAddress(runtime_info);
|
||||
|
||||
InferNodeRealShape(node);
|
||||
opt::dynamic_shape::InferOpDynamic(node);
|
||||
opt::dynamic_shape::SetOpArgs(node);
|
||||
|
||||
runtime::DeviceAddressUtils::CreateKernelOutputDeviceAddress(device_context, graph, is_gradient_out);
|
||||
runtime::DeviceAddressUtils::UpdateDeviceAddressForInplaceNode(graph);
|
||||
|
|
Loading…
Reference in New Issue