!35363 Add testcase for eliminate nopnode.

Merge pull request !35363 from gaoyong10/dynamic_shape_01
This commit is contained in:
i-robot 2022-06-09 22:09:10 +00:00 committed by Gitee
commit 764656a010
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
5 changed files with 345 additions and 2 deletions

View File

@ -15,7 +15,8 @@
import numpy as np
import pytest
import mindspore
from mindspore import context, nn, ops, Tensor, Parameter
from mindspore import context, nn, ops, Tensor, Parameter, ms_function
from mindspore.ops import functional as F
class Net(nn.Cell):
@ -61,3 +62,26 @@ def test_repeat_control_arrow_for_stack_actor():
out = net(x)
result = 10
assert out == result
@ms_function
def switch_op(x, y):
z1 = y + 1
z2 = Tensor(5, mindspore.int32)
return F.switch(x, z1, z2)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_switch_op():
"""
Feature: Runtime.
Description: Test switch op.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(False, mindspore.bool_)
y = Tensor(1, mindspore.int32)
out = switch_op(x, y)
assert out == 5

View File

@ -0,0 +1,100 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest
import mindspore
from mindspore import context, ops, nn, Tensor
class Net(nn.Cell):
def __init__(self):
super().__init__()
self.reshape = ops.Reshape()
def construct(self, x, y, z):
a = x + y
b = self.reshape(a, (3, 2))
c = self.reshape(z, (3, 2))
return b + c
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_eliminate_nopnode():
"""
Feature: eliminate nopnode.
Description: base scene.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.ones([6, 1]), mindspore.float32)
y = Tensor(np.ones([6, 1]), mindspore.float32)
z = Tensor(np.ones([6, 1]), mindspore.float32)
net = Net()
out = net(x, y, z)
assert out.shape == (3, 2)
class NetWithNopNodeOutput(nn.Cell):
def __init__(self):
super().__init__()
self.reshape = ops.Reshape()
def construct(self, x, y):
a = x + y
return self.reshape(a, (3, 2))
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_nopnode_output():
"""
Feature: eliminate nopnode.
Description: base scene.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE)
x = Tensor(np.ones([6, 1]), mindspore.float32)
y = Tensor(np.ones([6, 1]), mindspore.float32)
net = NetWithNopNodeOutput()
out = net(x, y)
assert out.shape == (3, 2)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_nopnode_dynamic_shape():
"""
Feature: eliminate nopnode.
Description: base scene.
Expectation: No exception.
"""
context.set_context(mode=context.GRAPH_MODE)
x_dyn = Tensor(shape=[6, None], dtype=mindspore.float32)
y_dyn = Tensor(shape=[6, None], dtype=mindspore.float32)
z_dyn = Tensor(shape=[6, None], dtype=mindspore.float32)
net = Net()
net.set_inputs(x_dyn, y_dyn, z_dyn)
x = Tensor(np.ones([6, 1]), mindspore.float32)
y = Tensor(np.ones([6, 1]), mindspore.float32)
z = Tensor(np.ones([6, 1]), mindspore.float32)
out = net(x, y, z)
assert out.shape == (3, 2)

View File

@ -207,7 +207,6 @@ TEST_F(ControlNodeParserTest, Parse) {
std::vector<FuncGraphPtr> graphs{func_graph};
FuncGraphManagerPtr manager = std::make_shared<FuncGraphManager>(graphs);
manager->AddFuncGraph(func_graph);
;
auto parser = std::make_shared<ControlNodeParser>();
DeviceContextKey device_context_key{"CPU", 0};

View File

@ -0,0 +1,157 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common_test.h"
#include "abstract/abstract_function.h"
#include "runtime/graph_scheduler/graph_compiler.h"
#include "kernel/kernel.h"
namespace mindspore {
namespace runtime {
using KernelGraph = session::KernelGraph;
using FuncGraphAbstractClosure = abstract::FuncGraphAbstractClosure;
using AnalysisContext = abstract::AnalysisContext;
using DeviceContextKey = device::DeviceContextKey;
using DeviceAddress = device::DeviceAddress;
using DeviceAddressPtr = device::DeviceAddressPtr;
using DeviceType = device::DeviceType;
using AddressPtr = kernel::AddressPtr;
class TestDeviceAddress : public DeviceAddress {
public:
TestDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {}
~TestDeviceAddress() {}
virtual bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const {
return true;
}
virtual bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr,
const std::string &format) const {
return true;
}
virtual void *GetMutablePtr() const { return nullptr; }
virtual void ClearDeviceMemory() {}
};
class TestKernelMod : public kernel::KernelMod {
public:
TestKernelMod() = default;
~TestKernelMod() override = default;
virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
return true;
}
};
class TestADeviceContext : public DeviceContext {
public:
explicit TestADeviceContext(const DeviceContextKey &device_context_key) : DeviceContext(device_context_key) {}
~TestADeviceContext() override = default;
virtual void Initialize() {}
virtual bool AllocateMemory(DeviceAddress *const &address, size_t size) const { return true; }
virtual void FreeMemory(DeviceAddress *const &address) const {}
virtual void *AllocateMemory(size_t size) const { return nullptr; }
virtual void FreeMemory(void *const ptr) const {}
virtual DeviceAddressPtr CreateDeviceAddress(void *const device_ptr, size_t device_size, const string &format,
TypeId type_id, const ShapeVector &shape) const {
return std::make_shared<TestDeviceAddress>(nullptr, 0);
}
virtual DeviceType GetDeviceType() const { return DeviceType::kCPU; }
virtual void SetOperatorInfo(const KernelGraphPtr &graph) const {}
virtual void CreateKernel(const std::vector<CNodePtr> &nodes) const {
for (const auto node : nodes) {
MS_EXCEPTION_IF_NULL(node);
if (node->kernel_info() == nullptr) {
auto kernel_info = std::make_shared<device::KernelInfo>();
std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
kernel_info->set_select_kernel_build_info(builder->Build());
node->set_kernel_info(kernel_info);
} else {
const auto &kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
if (kernel_info->select_kernel_build_info() == nullptr) {
std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
kernel_info->set_select_kernel_build_info(builder->Build());
}
}
auto kernel_mod_ptr = std::make_shared<TestKernelMod>();
kernel_mod_ptr->SetInputSizeList({4});
kernel_mod_ptr->SetOutputSizeList({4});
kernel_mod_ptr->SetWorkspaceSizeList({4});
AnfAlgo::SetKernelMod(kernel_mod_ptr, node.get());
}
}
};
class GraphCompilerTest : public UT::Common {
public:
GraphCompilerTest() {}
};
/// Feature: control flow support dynamic shape.
/// Description: Test the parse interface.
/// Expectation: As expected.
TEST_F(GraphCompilerTest, CompileGraph) {
std::vector<int64_t> shp{2, 2};
abstract::AbstractTensorPtr abs;
// Func graph.
auto func_graph = std::make_shared<FuncGraph>();
// Parameter.
auto abstract_x = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
auto parameter_x = func_graph->add_parameter();
parameter_x->set_abstract(abstract_x);
auto abstract_y = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
auto parameter_y = func_graph->add_parameter();
parameter_y->set_abstract(abstract_y);
auto parameters = func_graph->parameters();
// Add.
std::vector<AnfNodePtr> add_inputs{NewValueNode(prim::kPrimAdd), parameters[0], parameters[1]};
auto add_node = func_graph->NewCNode(add_inputs);
abs = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
add_node->set_abstract(abs);
// Reshape.
std::vector<AnfNodePtr> reshape_inputs{NewValueNode(prim::kPrimReshape), add_node};
auto reshape_node = func_graph->NewCNode(reshape_inputs);
abs = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
reshape_node->set_abstract(abs);
// sub.
std::vector<AnfNodePtr> sub_inputs{NewValueNode(prim::kPrimSub), reshape_node, parameters[0]};
auto sub_node = func_graph->NewCNode(sub_inputs);
abs = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
sub_node->set_abstract(abs);
// Return.
std::vector<AnfNodePtr> return_inputs{NewValueNode(prim::kPrimReturn), sub_node};
auto return_node = func_graph->NewCNode(return_inputs);
func_graph->set_return(return_node);
std::vector<AnfNodePtr> nodes{add_node, reshape_node, sub_node};
std::vector<AnfNodePtr> outputs{sub_node};
auto segment = std::make_shared<GraphSegment>(nodes, false);
auto compiler = std::make_shared<GraphCompiler>();
DeviceContextKey device_context_key{"CPU", 0};
auto device_context = std::make_shared<TestADeviceContext>(device_context_key);
auto graph_id = compiler->CompileGraph(segment, outputs, device_context.get(), device::RunMode::kKernelMode, false);
const auto &kernel_graph = compiler->Fetch(graph_id);
ASSERT_EQ(3, kernel_graph->execution_order().size());
}
} // namespace runtime
} // namespace mindspore

View File

@ -0,0 +1,63 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/common_test.h"
#include "abstract/abstract_function.h"
#include "runtime/graph_scheduler/scheduler_helper.h"
namespace mindspore {
namespace runtime {
class SchedulerHelperTest : public UT::Common {
public:
SchedulerHelperTest() {}
};
/// Feature: Add fusion actor.
/// Description: Test the common interface.
/// Expectation: As expected.
TEST_F(SchedulerHelperTest, AddDependency) {
auto memory_manager_actor = std::make_shared<MemoryManagerActor>();
MS_EXCEPTION_IF_NULL(memory_manager_actor);
auto kernel_graph = std::make_shared<KernelGraph>();
MS_EXCEPTION_IF_NULL(kernel_graph);
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimLess)};
auto backend_node1 = kernel_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(backend_node1);
auto backend_node2 = kernel_graph->NewCNode(inputs);
MS_EXCEPTION_IF_NULL(backend_node2);
std::set<size_t> ref_input_indexes;
std::set<size_t> ref_output_indexes;
auto from_actor =
std::make_shared<KernelActor>("from_actor", backend_node1, nullptr, memory_manager_actor->GetAID(), nullptr,
nullptr, GraphExecutionStrategy::kPipeline, ref_input_indexes, ref_output_indexes);
auto to_actor =
std::make_shared<KernelActor>("to_actor", backend_node2, nullptr, memory_manager_actor->GetAID(), nullptr, nullptr,
GraphExecutionStrategy::kPipeline, ref_input_indexes, ref_output_indexes);
SchedulerHelper::AddDependency(from_actor.get(), to_actor.get());
ASSERT_EQ(1, from_actor->dependent_actors().size());
auto fusion_actor = SchedulerHelper::BuildFusionActor({from_actor, to_actor});
ASSERT_EQ(2, fusion_actor->sub_actors().size());
SchedulerHelper::AddArrowForFusionActor(fusion_actor.get());
ASSERT_EQ(0, fusion_actor->input_data_arrow_aids().size());
SchedulerHelper::FuseDataArrowsToBatchDataArrow(fusion_actor.get());
ASSERT_EQ(0, fusion_actor->batch_output_data_arrows().size());
}
} // namespace runtime
} // namespace mindspore