Merge pull request !31520 from gaoyong10/dynamic_shape_01
This commit is contained in:
i-robot 2022-03-20 04:24:23 +00:00 committed by Gitee
commit 048d089f9a
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
3 changed files with 229 additions and 0 deletions

View File

@ -935,6 +935,7 @@ void ControlNodeParser::ParseDeviceContextForFuncGraph(const std::vector<KernelG
}
// If there is no kernel in funcgraph, the parameter uses the default device context type.
MS_EXCEPTION_IF_NULL(root_func_graph_->manager());
FuncGraphSet sub_graphs = root_func_graph_->manager()->func_graphs();
for (auto sub_graph : sub_graphs) {
if (func_graph_to_device_contexts_.find(sub_graph) == func_graph_to_device_contexts_.end()) {
@ -1078,6 +1079,9 @@ void ControlNodeParser::ParseDeviceContextForReturnNode(const DeviceContext *def
const auto &cnode = return_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
const auto &inputs = cnode->inputs();
if (inputs.size() <= kReturnInputPos) {
MS_LOG(EXCEPTION) << "Invalid return node:" << cnode->DebugString();
}
const auto output_nodes = FetchInputNodeByNode(inputs[kReturnInputPos]);
std::vector<const DeviceContext *> return_device_contexts;

View File

@ -75,6 +75,7 @@ if(ENABLE_MINDDATA)
./cxx_api/*.cc
./tbe/*.cc
./mindapi/*.cc
./runtime/graph_scheduler/*.cc
)
if(NOT ENABLE_SECURITY)
file(GLOB_RECURSE UT_SRCS_DEBUG RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
@ -124,6 +125,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
"../../../mindspore/ccsrc/runtime/device/kernel_info.cc"
"../../../mindspore/ccsrc/runtime/device/bucket.cc"
"../../../mindspore/ccsrc/runtime/device/launch_kernel.cc"
"../../../mindspore/ccsrc/runtime/graph_scheduler/*.cc"
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/profiling/*.cc"
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/ge_runtime/*.cc"
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/kernel_select_ascend.cc"
@ -229,6 +231,7 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/common/graph_kerne
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_kernel_compile.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/plugin/device/cpu/kernel/akg/akg_cpu_kernel_mod.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/plugin/device/cpu/kernel/akg/akg_cpu_kernel_build.cc")
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/runtime/graph_scheduler/rpc_node_scheduler.cc")
if(ENABLE_SECURITY)
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/profiler/device/profiling.cc")

View File

@ -0,0 +1,222 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common_test.h"
#include "abstract/abstract_function.h"
#include "runtime/graph_scheduler/control_node_parser.h"
namespace mindspore {
namespace runtime {
using KernelGraph = session::KernelGraph;
using FuncGraphAbstractClosure = abstract::FuncGraphAbstractClosure;
using AnalysisContext = abstract::AnalysisContext;
using DeviceContextKey = device::DeviceContextKey;
using DeviceAddress = device::DeviceAddress;
using DeviceAddressPtr = device::DeviceAddressPtr;
using DeviceAddressType = device::DeviceAddressType;
class ControlNodeParserTest : public UT::Common {
public:
ControlNodeParserTest() {}
};
FuncGraphPtr BuildFuncGraph() {
std::vector<int64_t> shp{2, 2};
auto func_graph = std::make_shared<FuncGraph>();
auto abstract_x = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
auto parameter_x = func_graph->add_parameter();
parameter_x->set_abstract(abstract_x);
auto abstract_y = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
auto parameter_y = func_graph->add_parameter();
parameter_y->set_abstract(abstract_y);
return func_graph;
}
class TestDeviceAddress : public DeviceAddress {
public:
TestDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {}
~TestDeviceAddress() {}
virtual bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const {
return true;
}
virtual bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr,
const std::string &format) const {
return true;
}
virtual void *GetMutablePtr() const { return nullptr; }
virtual void ClearDeviceMemory() {}
};
class TestDeviceContext : public DeviceContext {
public:
explicit TestDeviceContext(const DeviceContextKey &device_context_key) : DeviceContext(device_context_key) {}
~TestDeviceContext() override = default;
virtual void Initialize() {}
virtual bool AllocateMemory(DeviceAddress *const &address, size_t size) const { return true; }
virtual void FreeMemory(DeviceAddress *const &address) const {}
virtual void *AllocateMemory(size_t size) const { return nullptr; }
virtual void FreeMemory(void *const ptr) const {}
virtual DeviceAddressPtr CreateDeviceAddress(void *const device_ptr, size_t device_size, const string &format,
TypeId type_id, const ShapeVector &shape) const {
return std::make_shared<TestDeviceAddress>(nullptr, 0);
}
virtual DeviceAddressType GetDeviceAddressType() const { return DeviceAddressType::kCPU; }
virtual void SetOperatorInfo(const std::vector<CNodePtr> &nodes) const {}
virtual void CreateKernel(const std::vector<CNodePtr> &nodes) const {}
};
KernelGraphPtr BuildKernelGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &front_node,
const ValueNodePtr &prim) {
auto kernel_graph = std::make_shared<KernelGraph>();
auto front_parameter = func_graph->parameters();
// Build kernel.
std::vector<AnfNodePtr> inputs{prim};
for (const auto &parameter : front_parameter) {
inputs.emplace_back(kernel_graph->NewParameter(parameter->cast<ParameterPtr>()));
}
auto backend_node = kernel_graph->NewCNode(inputs);
std::vector<int64_t> shp{2, 2};
abstract::AbstractTensorPtr abs = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
backend_node->set_abstract(abs);
// build return.
std::vector<AnfNodePtr> return_inputs{NewValueNode(prim::kPrimReturn), backend_node};
auto return_node = kernel_graph->NewCNode(return_inputs);
kernel_graph->set_return(return_node);
kernel_graph->set_execution_order({backend_node});
kernel_graph->CacheGraphOutputToFrontNodeWithIndex({backend_node}, {front_node});
return kernel_graph;
}
void BuildGraphs(std::vector<AnfNodePtr> *control_nodes, FuncGraphPtr *func_graph,
std::vector<KernelGraphPtr> *kernel_graphs, FuncGraphToKernelGraphGroup *func_graph_to_kernel_graphs) {
auto root_func_graph = BuildFuncGraph();
auto true_func_graph = BuildFuncGraph();
auto false_func_graph = BuildFuncGraph();
std::vector<int64_t> shp{2, 2};
abstract::AbstractTensorPtr abs;
// root graph.
auto parameters = root_func_graph->parameters();
// Less.
std::vector<AnfNodePtr> less_inputs{NewValueNode(prim::kPrimLess), parameters[0], parameters[1]};
auto less = root_func_graph->NewCNode(less_inputs);
abs = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
less->set_abstract(abs);
// True partial.
std::vector<AnfNodePtr> true_partial_inputs{NewValueNode(prim::kPrimPartial), NewValueNode(true_func_graph),
parameters[0], parameters[1]};
auto true_partial = root_func_graph->NewCNode(true_partial_inputs);
control_nodes->emplace_back(true_partial);
// False partial.
std::vector<AnfNodePtr> false_partial_inputs{NewValueNode(prim::kPrimPartial), NewValueNode(false_func_graph),
parameters[0], parameters[1]};
auto false_partial = root_func_graph->NewCNode(false_partial_inputs);
control_nodes->emplace_back(false_partial);
// Switch.
std::vector<AnfNodePtr> switch_inputs{NewValueNode(prim::kPrimSwitch), less, true_partial, false_partial};
auto switch_node = root_func_graph->NewCNode(switch_inputs);
auto switch_abs = std::make_shared<FuncGraphAbstractClosure>(false_func_graph, AnalysisContext::DummyContext());
switch_node->set_abstract(switch_abs);
control_nodes->emplace_back(switch_node);
// Call.
std::vector<AnfNodePtr> call_inputs{switch_node};
auto root_call_node = root_func_graph->NewCNode(call_inputs);
control_nodes->emplace_back(root_call_node);
// Return.
std::vector<AnfNodePtr> return_inputs{NewValueNode(prim::kPrimReturn), root_call_node};
auto return_node = root_func_graph->NewCNode(return_inputs);
control_nodes->emplace_back(return_node);
root_func_graph->set_return(return_node);
// true graph.
auto true_parameters = true_func_graph->parameters();
// Call.
std::vector<AnfNodePtr> true_call_inputs{NewValueNode(root_func_graph), true_parameters[0], true_parameters[1]};
auto true_call_node = true_func_graph->NewCNode(true_call_inputs);
auto true_call_abs = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
true_call_node->set_abstract(true_call_abs);
control_nodes->emplace_back(true_call_node);
// Add.
std::vector<AnfNodePtr> true_add_inputs{NewValueNode(prim::kPrimAdd), true_parameters[0], true_call_node};
auto true_add = true_func_graph->NewCNode(true_add_inputs);
abs = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
true_add->set_abstract(abs);
// Return.
std::vector<AnfNodePtr> true_return_inputs{NewValueNode(prim::kPrimReturn), true_add};
auto true_return_node = true_func_graph->NewCNode(true_return_inputs);
control_nodes->emplace_back(true_return_node);
true_func_graph->set_return(true_return_node);
// false graph.
// Add.
auto false_parameters = false_func_graph->parameters();
std::vector<AnfNodePtr> false_add_inputs{NewValueNode(prim::kPrimAdd), false_parameters[0], false_parameters[1]};
auto false_add = false_func_graph->NewCNode(false_add_inputs);
abs = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
false_add->set_abstract(abs);
// Return.
std::vector<AnfNodePtr> false_return_inputs{NewValueNode(prim::kPrimReturn), false_add};
auto false_return_node = false_func_graph->NewCNode(false_return_inputs);
control_nodes->emplace_back(false_return_node);
false_func_graph->set_return(false_return_node);
// Build kernel graph.
// Root kernel graph.
auto root_kernel_graph = BuildKernelGraph(root_func_graph, less, NewValueNode(prim::kPrimLess));
kernel_graphs->emplace_back(root_kernel_graph);
std::vector<KernelGraphPtr> graphs{root_kernel_graph};
(*func_graph_to_kernel_graphs)[root_func_graph].emplace_back(graphs);
// True kernel graph.
auto true_kernel_graph = BuildKernelGraph(true_func_graph, true_add, NewValueNode(prim::kPrimAdd));
kernel_graphs->emplace_back(true_kernel_graph);
graphs[0] = true_kernel_graph;
(*func_graph_to_kernel_graphs)[true_func_graph].emplace_back(graphs);
// False kernel graph.
auto false_kernel_graph = BuildKernelGraph(false_func_graph, false_add, NewValueNode(prim::kPrimAdd));
kernel_graphs->emplace_back(false_kernel_graph);
graphs[0] = false_kernel_graph;
(*func_graph_to_kernel_graphs)[false_func_graph].emplace_back(graphs);
(*func_graph) = root_func_graph;
}
/// Feature: control flow support dynamic shape.
/// Description: Test the parse interface.
/// Expectation: As expected.
TEST_F(ControlNodeParserTest, Parse) {
std::vector<AnfNodePtr> control_nodes;
FuncGraphPtr func_graph;
std::vector<KernelGraphPtr> kernel_graphs;
FuncGraphToKernelGraphGroup func_graph_to_kernel_graphs;
BuildGraphs(&control_nodes, &func_graph, &kernel_graphs, &func_graph_to_kernel_graphs);
std::vector<FuncGraphPtr> graphs{func_graph};
FuncGraphManagerPtr manager = std::make_shared<FuncGraphManager>(graphs);
manager->AddFuncGraph(func_graph);
;
auto parser = std::make_shared<ControlNodeParser>();
DeviceContextKey device_context_key{"CPU", 0};
auto device_context = std::make_shared<TestDeviceContext>(device_context_key);
std::vector<DeviceContext *> device_contexts(kernel_graphs.size(), device_context.get());
parser->Parse(control_nodes, kernel_graphs, device_contexts, func_graph, func_graph_to_kernel_graphs);
ASSERT_EQ(4, parser->control_node_parameters().size());
}
} // namespace runtime
} // namespace mindspore