forked from mindspore-Ecosystem/mindspore
!35363 Add testcase for eliminate nopnode.
Merge pull request !35363 from gaoyong10/dynamic_shape_01
This commit is contained in:
commit
764656a010
|
@ -15,7 +15,8 @@
|
|||
import numpy as np
|
||||
import pytest
|
||||
import mindspore
|
||||
from mindspore import context, nn, ops, Tensor, Parameter
|
||||
from mindspore import context, nn, ops, Tensor, Parameter, ms_function
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
|
@ -61,3 +62,26 @@ def test_repeat_control_arrow_for_stack_actor():
|
|||
out = net(x)
|
||||
result = 10
|
||||
assert out == result
|
||||
|
||||
|
||||
@ms_function
|
||||
def switch_op(x, y):
|
||||
z1 = y + 1
|
||||
z2 = Tensor(5, mindspore.int32)
|
||||
return F.switch(x, z1, z2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_switch_op():
|
||||
"""
|
||||
Feature: Runtime.
|
||||
Description: Test switch op.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
x = Tensor(False, mindspore.bool_)
|
||||
y = Tensor(1, mindspore.int32)
|
||||
out = switch_op(x, y)
|
||||
assert out == 5
|
||||
|
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore
|
||||
from mindspore import context, ops, nn, Tensor
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.reshape = ops.Reshape()
|
||||
|
||||
def construct(self, x, y, z):
|
||||
a = x + y
|
||||
b = self.reshape(a, (3, 2))
|
||||
c = self.reshape(z, (3, 2))
|
||||
return b + c
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_eliminate_nopnode():
|
||||
"""
|
||||
Feature: eliminate nopnode.
|
||||
Description: base scene.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
x = Tensor(np.ones([6, 1]), mindspore.float32)
|
||||
y = Tensor(np.ones([6, 1]), mindspore.float32)
|
||||
z = Tensor(np.ones([6, 1]), mindspore.float32)
|
||||
net = Net()
|
||||
out = net(x, y, z)
|
||||
assert out.shape == (3, 2)
|
||||
|
||||
|
||||
class NetWithNopNodeOutput(nn.Cell):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.reshape = ops.Reshape()
|
||||
|
||||
def construct(self, x, y):
|
||||
a = x + y
|
||||
return self.reshape(a, (3, 2))
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_nopnode_output():
|
||||
"""
|
||||
Feature: eliminate nopnode.
|
||||
Description: base scene.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
x = Tensor(np.ones([6, 1]), mindspore.float32)
|
||||
y = Tensor(np.ones([6, 1]), mindspore.float32)
|
||||
net = NetWithNopNodeOutput()
|
||||
out = net(x, y)
|
||||
assert out.shape == (3, 2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_nopnode_dynamic_shape():
|
||||
"""
|
||||
Feature: eliminate nopnode.
|
||||
Description: base scene.
|
||||
Expectation: No exception.
|
||||
"""
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
x_dyn = Tensor(shape=[6, None], dtype=mindspore.float32)
|
||||
y_dyn = Tensor(shape=[6, None], dtype=mindspore.float32)
|
||||
z_dyn = Tensor(shape=[6, None], dtype=mindspore.float32)
|
||||
net = Net()
|
||||
net.set_inputs(x_dyn, y_dyn, z_dyn)
|
||||
x = Tensor(np.ones([6, 1]), mindspore.float32)
|
||||
y = Tensor(np.ones([6, 1]), mindspore.float32)
|
||||
z = Tensor(np.ones([6, 1]), mindspore.float32)
|
||||
out = net(x, y, z)
|
||||
assert out.shape == (3, 2)
|
|
@ -207,7 +207,6 @@ TEST_F(ControlNodeParserTest, Parse) {
|
|||
std::vector<FuncGraphPtr> graphs{func_graph};
|
||||
FuncGraphManagerPtr manager = std::make_shared<FuncGraphManager>(graphs);
|
||||
manager->AddFuncGraph(func_graph);
|
||||
;
|
||||
|
||||
auto parser = std::make_shared<ControlNodeParser>();
|
||||
DeviceContextKey device_context_key{"CPU", 0};
|
||||
|
|
|
@ -0,0 +1,157 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/common_test.h"
|
||||
#include "abstract/abstract_function.h"
|
||||
#include "runtime/graph_scheduler/graph_compiler.h"
|
||||
#include "kernel/kernel.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
using KernelGraph = session::KernelGraph;
|
||||
using FuncGraphAbstractClosure = abstract::FuncGraphAbstractClosure;
|
||||
using AnalysisContext = abstract::AnalysisContext;
|
||||
using DeviceContextKey = device::DeviceContextKey;
|
||||
using DeviceAddress = device::DeviceAddress;
|
||||
using DeviceAddressPtr = device::DeviceAddressPtr;
|
||||
using DeviceType = device::DeviceType;
|
||||
using AddressPtr = kernel::AddressPtr;
|
||||
|
||||
class TestDeviceAddress : public DeviceAddress {
|
||||
public:
|
||||
TestDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {}
|
||||
~TestDeviceAddress() {}
|
||||
virtual bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const {
|
||||
return true;
|
||||
}
|
||||
virtual bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr,
|
||||
const std::string &format) const {
|
||||
return true;
|
||||
}
|
||||
virtual void *GetMutablePtr() const { return nullptr; }
|
||||
virtual void ClearDeviceMemory() {}
|
||||
};
|
||||
|
||||
class TestKernelMod : public kernel::KernelMod {
|
||||
public:
|
||||
TestKernelMod() = default;
|
||||
~TestKernelMod() override = default;
|
||||
virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
|
||||
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
class TestADeviceContext : public DeviceContext {
|
||||
public:
|
||||
explicit TestADeviceContext(const DeviceContextKey &device_context_key) : DeviceContext(device_context_key) {}
|
||||
~TestADeviceContext() override = default;
|
||||
virtual void Initialize() {}
|
||||
virtual bool AllocateMemory(DeviceAddress *const &address, size_t size) const { return true; }
|
||||
virtual void FreeMemory(DeviceAddress *const &address) const {}
|
||||
virtual void *AllocateMemory(size_t size) const { return nullptr; }
|
||||
virtual void FreeMemory(void *const ptr) const {}
|
||||
virtual DeviceAddressPtr CreateDeviceAddress(void *const device_ptr, size_t device_size, const string &format,
|
||||
TypeId type_id, const ShapeVector &shape) const {
|
||||
return std::make_shared<TestDeviceAddress>(nullptr, 0);
|
||||
}
|
||||
virtual DeviceType GetDeviceType() const { return DeviceType::kCPU; }
|
||||
virtual void SetOperatorInfo(const KernelGraphPtr &graph) const {}
|
||||
virtual void CreateKernel(const std::vector<CNodePtr> &nodes) const {
|
||||
for (const auto node : nodes) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
if (node->kernel_info() == nullptr) {
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
|
||||
kernel_info->set_select_kernel_build_info(builder->Build());
|
||||
node->set_kernel_info(kernel_info);
|
||||
} else {
|
||||
const auto &kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
||||
if (kernel_info->select_kernel_build_info() == nullptr) {
|
||||
std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
|
||||
kernel_info->set_select_kernel_build_info(builder->Build());
|
||||
}
|
||||
}
|
||||
auto kernel_mod_ptr = std::make_shared<TestKernelMod>();
|
||||
kernel_mod_ptr->SetInputSizeList({4});
|
||||
kernel_mod_ptr->SetOutputSizeList({4});
|
||||
kernel_mod_ptr->SetWorkspaceSizeList({4});
|
||||
AnfAlgo::SetKernelMod(kernel_mod_ptr, node.get());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class GraphCompilerTest : public UT::Common {
|
||||
public:
|
||||
GraphCompilerTest() {}
|
||||
};
|
||||
|
||||
/// Feature: control flow support dynamic shape.
|
||||
/// Description: Test the parse interface.
|
||||
/// Expectation: As expected.
|
||||
TEST_F(GraphCompilerTest, CompileGraph) {
|
||||
std::vector<int64_t> shp{2, 2};
|
||||
abstract::AbstractTensorPtr abs;
|
||||
|
||||
// Func graph.
|
||||
auto func_graph = std::make_shared<FuncGraph>();
|
||||
|
||||
// Parameter.
|
||||
auto abstract_x = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
auto parameter_x = func_graph->add_parameter();
|
||||
parameter_x->set_abstract(abstract_x);
|
||||
|
||||
auto abstract_y = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
auto parameter_y = func_graph->add_parameter();
|
||||
parameter_y->set_abstract(abstract_y);
|
||||
auto parameters = func_graph->parameters();
|
||||
|
||||
// Add.
|
||||
std::vector<AnfNodePtr> add_inputs{NewValueNode(prim::kPrimAdd), parameters[0], parameters[1]};
|
||||
auto add_node = func_graph->NewCNode(add_inputs);
|
||||
abs = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
add_node->set_abstract(abs);
|
||||
|
||||
// Reshape.
|
||||
std::vector<AnfNodePtr> reshape_inputs{NewValueNode(prim::kPrimReshape), add_node};
|
||||
auto reshape_node = func_graph->NewCNode(reshape_inputs);
|
||||
abs = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
reshape_node->set_abstract(abs);
|
||||
|
||||
// sub.
|
||||
std::vector<AnfNodePtr> sub_inputs{NewValueNode(prim::kPrimSub), reshape_node, parameters[0]};
|
||||
auto sub_node = func_graph->NewCNode(sub_inputs);
|
||||
abs = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
|
||||
sub_node->set_abstract(abs);
|
||||
|
||||
// Return.
|
||||
std::vector<AnfNodePtr> return_inputs{NewValueNode(prim::kPrimReturn), sub_node};
|
||||
auto return_node = func_graph->NewCNode(return_inputs);
|
||||
func_graph->set_return(return_node);
|
||||
|
||||
std::vector<AnfNodePtr> nodes{add_node, reshape_node, sub_node};
|
||||
std::vector<AnfNodePtr> outputs{sub_node};
|
||||
auto segment = std::make_shared<GraphSegment>(nodes, false);
|
||||
|
||||
auto compiler = std::make_shared<GraphCompiler>();
|
||||
DeviceContextKey device_context_key{"CPU", 0};
|
||||
auto device_context = std::make_shared<TestADeviceContext>(device_context_key);
|
||||
auto graph_id = compiler->CompileGraph(segment, outputs, device_context.get(), device::RunMode::kKernelMode, false);
|
||||
const auto &kernel_graph = compiler->Fetch(graph_id);
|
||||
ASSERT_EQ(3, kernel_graph->execution_order().size());
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2022 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "common/common_test.h"
|
||||
#include "abstract/abstract_function.h"
|
||||
#include "runtime/graph_scheduler/scheduler_helper.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
class SchedulerHelperTest : public UT::Common {
|
||||
public:
|
||||
SchedulerHelperTest() {}
|
||||
};
|
||||
|
||||
/// Feature: Add fusion actor.
|
||||
/// Description: Test the common interface.
|
||||
/// Expectation: As expected.
|
||||
TEST_F(SchedulerHelperTest, AddDependency) {
|
||||
auto memory_manager_actor = std::make_shared<MemoryManagerActor>();
|
||||
MS_EXCEPTION_IF_NULL(memory_manager_actor);
|
||||
auto kernel_graph = std::make_shared<KernelGraph>();
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
std::vector<AnfNodePtr> inputs{NewValueNode(prim::kPrimLess)};
|
||||
auto backend_node1 = kernel_graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(backend_node1);
|
||||
auto backend_node2 = kernel_graph->NewCNode(inputs);
|
||||
MS_EXCEPTION_IF_NULL(backend_node2);
|
||||
std::set<size_t> ref_input_indexes;
|
||||
std::set<size_t> ref_output_indexes;
|
||||
|
||||
auto from_actor =
|
||||
std::make_shared<KernelActor>("from_actor", backend_node1, nullptr, memory_manager_actor->GetAID(), nullptr,
|
||||
nullptr, GraphExecutionStrategy::kPipeline, ref_input_indexes, ref_output_indexes);
|
||||
auto to_actor =
|
||||
std::make_shared<KernelActor>("to_actor", backend_node2, nullptr, memory_manager_actor->GetAID(), nullptr, nullptr,
|
||||
GraphExecutionStrategy::kPipeline, ref_input_indexes, ref_output_indexes);
|
||||
SchedulerHelper::AddDependency(from_actor.get(), to_actor.get());
|
||||
ASSERT_EQ(1, from_actor->dependent_actors().size());
|
||||
|
||||
auto fusion_actor = SchedulerHelper::BuildFusionActor({from_actor, to_actor});
|
||||
ASSERT_EQ(2, fusion_actor->sub_actors().size());
|
||||
|
||||
SchedulerHelper::AddArrowForFusionActor(fusion_actor.get());
|
||||
ASSERT_EQ(0, fusion_actor->input_data_arrow_aids().size());
|
||||
|
||||
SchedulerHelper::FuseDataArrowsToBatchDataArrow(fusion_actor.get());
|
||||
ASSERT_EQ(0, fusion_actor->batch_output_data_arrows().size());
|
||||
}
|
||||
} // namespace runtime
|
||||
} // namespace mindspore
|
Loading…
Reference in New Issue