forked from mindspore-Ecosystem/mindspore
Add ut for dynamic shape and control flow.
This commit is contained in:
parent
0c4ca0fefa
commit
89c98f8b73
|
@ -938,6 +938,7 @@ void ControlNodeParser::ParseDeviceContextForFuncGraph(const std::vector<KernelG
|
|||
}
|
||||
|
||||
// If there is no kernel in funcgraph, the parameter uses the default device context type.
|
||||
MS_EXCEPTION_IF_NULL(root_func_graph_->manager());
|
||||
FuncGraphSet sub_graphs = root_func_graph_->manager()->func_graphs();
|
||||
for (auto sub_graph : sub_graphs) {
|
||||
if (func_graph_to_device_contexts_.find(sub_graph) == func_graph_to_device_contexts_.end()) {
|
||||
|
@ -1081,6 +1082,9 @@ void ControlNodeParser::ParseDeviceContextForReturnNode(const DeviceContext *def
|
|||
const auto &cnode = return_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
const auto &inputs = cnode->inputs();
|
||||
if (inputs.size() <= kReturnInputPos) {
|
||||
MS_LOG(EXCEPTION) << "Invalid return node:" << cnode->DebugString();
|
||||
}
|
||||
const auto output_nodes = FetchInputNodeByNode(inputs[kReturnInputPos]);
|
||||
std::vector<const DeviceContext *> return_device_contexts;
|
||||
|
||||
|
|
|
@ -75,6 +75,7 @@ if(ENABLE_MINDDATA)
|
|||
./cxx_api/*.cc
|
||||
./tbe/*.cc
|
||||
./mindapi/*.cc
|
||||
./runtime/graph_scheduler/*.cc
|
||||
)
|
||||
if(NOT ENABLE_SECURITY)
|
||||
file(GLOB_RECURSE UT_SRCS_DEBUG RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
|
@ -124,6 +125,7 @@ file(GLOB_RECURSE MINDSPORE_SRC_LIST RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}
|
|||
"../../../mindspore/ccsrc/runtime/device/kernel_info.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/bucket.cc"
|
||||
"../../../mindspore/ccsrc/runtime/device/launch_kernel.cc"
|
||||
"../../../mindspore/ccsrc/runtime/graph_scheduler/*.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/profiling/*.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/ge_runtime/*.cc"
|
||||
"../../../mindspore/ccsrc/plugin/device/ascend/hal/device/kernel_select_ascend.cc"
|
||||
|
@ -230,6 +232,7 @@ list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/common/graph_kerne
|
|||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_kernel_compile.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/plugin/device/cpu/kernel/akg/akg_cpu_kernel_mod.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/plugin/device/cpu/kernel/akg/akg_cpu_kernel_build.cc")
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/runtime/graph_scheduler/rpc_node_scheduler.cc")
|
||||
|
||||
if(ENABLE_SECURITY)
|
||||
list(REMOVE_ITEM MINDSPORE_SRC_LIST "../../../mindspore/ccsrc/profiler/device/profiling.cc")
|
||||
|
|
|
@ -0,0 +1,222 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/common_test.h"
|
||||
#include "abstract/abstract_function.h"
|
||||
#include "runtime/graph_scheduler/control_node_parser.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using KernelGraph = session::KernelGraph;
|
||||
using FuncGraphAbstractClosure = abstract::FuncGraphAbstractClosure;
|
||||
using AnalysisContext = abstract::AnalysisContext;
|
||||
using DeviceContextKey = device::DeviceContextKey;
|
||||
using DeviceAddress = device::DeviceAddress;
|
||||
using DeviceAddressPtr = device::DeviceAddressPtr;
|
||||
using DeviceAddressType = device::DeviceAddressType;
|
||||
class ControlNodeParserTest : public UT::Common {
|
||||
public:
|
||||
ControlNodeParserTest() {}
|
||||
};
|
||||
|
||||
FuncGraphPtr BuildFuncGraph() {
|
||||
std::vector<int64_t> shp{2, 2};
|
||||
auto func_graph = std::make_shared<FuncGraph>();
|
||||
auto abstract_x = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
auto parameter_x = func_graph->add_parameter();
|
||||
parameter_x->set_abstract(abstract_x);
|
||||
|
||||
auto abstract_y = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
auto parameter_y = func_graph->add_parameter();
|
||||
parameter_y->set_abstract(abstract_y);
|
||||
return func_graph;
|
||||
}
|
||||
|
||||
class TestDeviceAddress : public DeviceAddress {
|
||||
public:
|
||||
TestDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {}
|
||||
~TestDeviceAddress() {}
|
||||
virtual bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const {
|
||||
return true;
|
||||
}
|
||||
virtual bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr,
|
||||
const std::string &format) const {
|
||||
return true;
|
||||
}
|
||||
virtual void *GetMutablePtr() const { return nullptr; }
|
||||
virtual void ClearDeviceMemory() {}
|
||||
};
|
||||
|
||||
class TestDeviceContext : public DeviceContext {
|
||||
public:
|
||||
explicit TestDeviceContext(const DeviceContextKey &device_context_key) : DeviceContext(device_context_key) {}
|
||||
~TestDeviceContext() override = default;
|
||||
virtual void Initialize() {}
|
||||
virtual bool AllocateMemory(DeviceAddress *const &address, size_t size) const { return true; }
|
||||
virtual void FreeMemory(DeviceAddress *const &address) const {}
|
||||
virtual void *AllocateMemory(size_t size) const { return nullptr; }
|
||||
virtual void FreeMemory(void *const ptr) const {}
|
||||
virtual DeviceAddressPtr CreateDeviceAddress(void *const device_ptr, size_t device_size, const string &format,
|
||||
TypeId type_id, const ShapeVector &shape) const {
|
||||
return std::make_shared<TestDeviceAddress>(nullptr, 0);
|
||||
}
|
||||
virtual DeviceAddressType GetDeviceAddressType() const { return DeviceAddressType::kCPU; }
|
||||
virtual void SetOperatorInfo(const std::vector<CNodePtr> &nodes) const {}
|
||||
virtual void CreateKernel(const std::vector<CNodePtr> &nodes) const {}
|
||||
};
|
||||
|
||||
KernelGraphPtr BuildKernelGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &front_node,
|
||||
const ValueNodePtr &prim) {
|
||||
auto kernel_graph = std::make_shared<KernelGraph>();
|
||||
auto front_parameter = func_graph->parameters();
|
||||
|
||||
// Build kernel.
|
||||
std::vector<AnfNodePtr> inputs{prim};
|
||||
for (const auto ¶meter : front_parameter) {
|
||||
inputs.emplace_back(kernel_graph->NewParameter(parameter->cast<ParameterPtr>()));
|
||||
}
|
||||
auto backend_node = kernel_graph->NewCNode(inputs);
|
||||
std::vector<int64_t> shp{2, 2};
|
||||
abstract::AbstractTensorPtr abs = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
backend_node->set_abstract(abs);
|
||||
// build return.
|
||||
std::vector<AnfNodePtr> return_inputs{NewValueNode(prim::kPrimReturn), backend_node};
|
||||
auto return_node = kernel_graph->NewCNode(return_inputs);
|
||||
|
||||
kernel_graph->set_return(return_node);
|
||||
kernel_graph->set_execution_order({backend_node});
|
||||
kernel_graph->CacheGraphOutputToFrontNodeWithIndex({backend_node}, {front_node});
|
||||
return kernel_graph;
|
||||
}
|
||||
|
||||
void BuildGraphs(std::vector<AnfNodePtr> *control_nodes, FuncGraphPtr *func_graph,
|
||||
std::vector<KernelGraphPtr> *kernel_graphs, FuncGraphToKernelGraphGroup *func_graph_to_kernel_graphs) {
|
||||
auto root_func_graph = BuildFuncGraph();
|
||||
auto true_func_graph = BuildFuncGraph();
|
||||
auto false_func_graph = BuildFuncGraph();
|
||||
std::vector<int64_t> shp{2, 2};
|
||||
abstract::AbstractTensorPtr abs;
|
||||
// root graph.
|
||||
auto parameters = root_func_graph->parameters();
|
||||
// Less.
|
||||
std::vector<AnfNodePtr> less_inputs{NewValueNode(prim::kPrimLess), parameters[0], parameters[1]};
|
||||
auto less = root_func_graph->NewCNode(less_inputs);
|
||||
abs = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
less->set_abstract(abs);
|
||||
// True partial.
|
||||
std::vector<AnfNodePtr> true_partial_inputs{NewValueNode(prim::kPrimPartial), NewValueNode(true_func_graph),
|
||||
parameters[0], parameters[1]};
|
||||
auto true_partial = root_func_graph->NewCNode(true_partial_inputs);
|
||||
control_nodes->emplace_back(true_partial);
|
||||
// False partial.
|
||||
std::vector<AnfNodePtr> false_partial_inputs{NewValueNode(prim::kPrimPartial), NewValueNode(false_func_graph),
|
||||
parameters[0], parameters[1]};
|
||||
auto false_partial = root_func_graph->NewCNode(false_partial_inputs);
|
||||
control_nodes->emplace_back(false_partial);
|
||||
// Switch.
|
||||
std::vector<AnfNodePtr> switch_inputs{NewValueNode(prim::kPrimSwitch), less, true_partial, false_partial};
|
||||
auto switch_node = root_func_graph->NewCNode(switch_inputs);
|
||||
auto switch_abs = std::make_shared<FuncGraphAbstractClosure>(false_func_graph, AnalysisContext::DummyContext());
|
||||
switch_node->set_abstract(switch_abs);
|
||||
control_nodes->emplace_back(switch_node);
|
||||
// Call.
|
||||
std::vector<AnfNodePtr> call_inputs{switch_node};
|
||||
auto root_call_node = root_func_graph->NewCNode(call_inputs);
|
||||
control_nodes->emplace_back(root_call_node);
|
||||
// Return.
|
||||
std::vector<AnfNodePtr> return_inputs{NewValueNode(prim::kPrimReturn), root_call_node};
|
||||
auto return_node = root_func_graph->NewCNode(return_inputs);
|
||||
control_nodes->emplace_back(return_node);
|
||||
root_func_graph->set_return(return_node);
|
||||
|
||||
// true graph.
|
||||
auto true_parameters = true_func_graph->parameters();
|
||||
// Call.
|
||||
std::vector<AnfNodePtr> true_call_inputs{NewValueNode(root_func_graph), true_parameters[0], true_parameters[1]};
|
||||
auto true_call_node = true_func_graph->NewCNode(true_call_inputs);
|
||||
auto true_call_abs = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
true_call_node->set_abstract(true_call_abs);
|
||||
control_nodes->emplace_back(true_call_node);
|
||||
// Add.
|
||||
std::vector<AnfNodePtr> true_add_inputs{NewValueNode(prim::kPrimAdd), true_parameters[0], true_call_node};
|
||||
auto true_add = true_func_graph->NewCNode(true_add_inputs);
|
||||
abs = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
true_add->set_abstract(abs);
|
||||
// Return.
|
||||
std::vector<AnfNodePtr> true_return_inputs{NewValueNode(prim::kPrimReturn), true_add};
|
||||
auto true_return_node = true_func_graph->NewCNode(true_return_inputs);
|
||||
control_nodes->emplace_back(true_return_node);
|
||||
true_func_graph->set_return(true_return_node);
|
||||
|
||||
// false graph.
|
||||
// Add.
|
||||
auto false_parameters = false_func_graph->parameters();
|
||||
std::vector<AnfNodePtr> false_add_inputs{NewValueNode(prim::kPrimAdd), false_parameters[0], false_parameters[1]};
|
||||
auto false_add = false_func_graph->NewCNode(false_add_inputs);
|
||||
abs = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
false_add->set_abstract(abs);
|
||||
// Return.
|
||||
std::vector<AnfNodePtr> false_return_inputs{NewValueNode(prim::kPrimReturn), false_add};
|
||||
auto false_return_node = false_func_graph->NewCNode(false_return_inputs);
|
||||
control_nodes->emplace_back(false_return_node);
|
||||
false_func_graph->set_return(false_return_node);
|
||||
|
||||
// Build kernel graph.
|
||||
// Root kernel graph.
|
||||
auto root_kernel_graph = BuildKernelGraph(root_func_graph, less, NewValueNode(prim::kPrimLess));
|
||||
kernel_graphs->emplace_back(root_kernel_graph);
|
||||
std::vector<KernelGraphPtr> graphs{root_kernel_graph};
|
||||
(*func_graph_to_kernel_graphs)[root_func_graph].emplace_back(graphs);
|
||||
// True kernel graph.
|
||||
auto true_kernel_graph = BuildKernelGraph(true_func_graph, true_add, NewValueNode(prim::kPrimAdd));
|
||||
kernel_graphs->emplace_back(true_kernel_graph);
|
||||
graphs[0] = true_kernel_graph;
|
||||
(*func_graph_to_kernel_graphs)[true_func_graph].emplace_back(graphs);
|
||||
// False kernel graph.
|
||||
auto false_kernel_graph = BuildKernelGraph(false_func_graph, false_add, NewValueNode(prim::kPrimAdd));
|
||||
kernel_graphs->emplace_back(false_kernel_graph);
|
||||
graphs[0] = false_kernel_graph;
|
||||
(*func_graph_to_kernel_graphs)[false_func_graph].emplace_back(graphs);
|
||||
|
||||
(*func_graph) = root_func_graph;
|
||||
}
|
||||
|
||||
|
||||
/// Feature: control flow support dynamic shape.
|
||||
/// Description: Test the parse interface.
|
||||
/// Expectation: As expected.
|
||||
TEST_F(ControlNodeParserTest, Parse) {
|
||||
std::vector<AnfNodePtr> control_nodes;
|
||||
FuncGraphPtr func_graph;
|
||||
std::vector<KernelGraphPtr> kernel_graphs;
|
||||
FuncGraphToKernelGraphGroup func_graph_to_kernel_graphs;
|
||||
BuildGraphs(&control_nodes, &func_graph, &kernel_graphs, &func_graph_to_kernel_graphs);
|
||||
|
||||
std::vector<FuncGraphPtr> graphs{func_graph};
|
||||
FuncGraphManagerPtr manager = std::make_shared<FuncGraphManager>(graphs);
|
||||
manager->AddFuncGraph(func_graph);
|
||||
;
|
||||
|
||||
auto parser = std::make_shared<ControlNodeParser>();
|
||||
DeviceContextKey device_context_key{"CPU", 0};
|
||||
auto device_context = std::make_shared<TestDeviceContext>(device_context_key);
|
||||
std::vector<DeviceContext *> device_contexts(kernel_graphs.size(), device_context.get());
|
||||
|
||||
parser->Parse(control_nodes, kernel_graphs, device_contexts, func_graph, func_graph_to_kernel_graphs);
|
||||
ASSERT_EQ(4, parser->control_node_parameters().size());
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue