From f50834e8e0bde10acc2dc204060d0f1d5fe5be8f Mon Sep 17 00:00:00 2001 From: limingqi107 Date: Wed, 23 Mar 2022 09:20:13 +0800 Subject: [PATCH] unified runtime enable the dynamic shape in the heterogeneous --- mindspore/ccsrc/kernel/kernel.cc | 13 +- mindspore/ccsrc/kernel/kernel.h | 10 -- mindspore/ccsrc/pipeline/jit/action.cc | 6 +- .../graph_scheduler/actor/abstract_actor.h | 7 ++ .../graph_scheduler/actor/actor_dump.cc | 15 ++- .../graph_scheduler/actor/copy_actor.cc | 8 +- .../graph_scheduler/actor/copy_actor.h | 12 +- .../graph_scheduler/actor/custom_actor.cc | 10 +- .../graph_scheduler/actor/custom_actor.h | 1 - .../graph_scheduler/actor/kernel_actor.cc | 8 ++ .../graph_scheduler/actor/kernel_actor.h | 1 + .../graph_scheduler/graph_scheduler.cc | 65 +++++++--- tests/st/dynamic_shape/test_ascend_cpu.py | 70 ----------- .../test_dynamic_shape_with_heterogeneity.py | 118 ++++++++++++++++++ 14 files changed, 223 insertions(+), 121 deletions(-) delete mode 100644 tests/st/dynamic_shape/test_ascend_cpu.py create mode 100644 tests/st/dynamic_shape/test_dynamic_shape_with_heterogeneity.py diff --git a/mindspore/ccsrc/kernel/kernel.cc b/mindspore/ccsrc/kernel/kernel.cc index c196d33c31b..e3793af09a9 100644 --- a/mindspore/ccsrc/kernel/kernel.cc +++ b/mindspore/ccsrc/kernel/kernel.cc @@ -59,16 +59,9 @@ void KernelMod::InferShape() { auto input_size = common::AnfAlgo::GetInputTensorNum(cnode); bool skip_nop_node = !context->get_param(MS_CTX_ENABLE_MINDRT); for (size_t i = 0; i < input_size; i++) { - AnfNodePtr real_input = nullptr; - size_t real_input_index = 0; - if (real_input_nodes_.count(i) > 0) { - real_input = real_input_nodes_[i].first.lock(); - real_input_index = real_input_nodes_[i].second; - } else { - auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, i); - real_input = input_node_with_index.first; - real_input_index = input_node_with_index.second; - } + auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, i, false); + auto real_input = input_node_with_index.first; + auto real_input_index = input_node_with_index.second; MS_EXCEPTION_IF_NULL(real_input); if (skip_nop_node) { InferShapeForNopNode(real_input); diff --git a/mindspore/ccsrc/kernel/kernel.h b/mindspore/ccsrc/kernel/kernel.h index 4e9d426404e..1b8877aa0a1 100644 --- a/mindspore/ccsrc/kernel/kernel.h +++ b/mindspore/ccsrc/kernel/kernel.h @@ -20,7 +20,6 @@ #include #include #include -#include #include "nlohmann/json.hpp" #include "ir/anf.h" #include "ir/dtype.h" @@ -223,10 +222,6 @@ class KernelMod { // set true if need to update output's shape after launch in dynamic_shape, like Unique virtual bool IsNeedUpdateOp() { return is_need_updateop_; } - void InsertRealInputNode(const AnfNodePtr &pre_node, size_t pre_node_out_index, size_t input_index) { - real_input_nodes_[input_index] = {pre_node, pre_node_out_index}; - } - protected: void InferShape(); void GetDepndLists(const CNodePtr &cnode); @@ -254,11 +249,6 @@ class KernelMod { std::vector inputs_addr_; std::vector workspaces_addr_; std::vector outputs_addr_; - - // HashMap > is used to record the real input node to infer the - // dynamic shape information of the nodes located at the boundary of the graph partition, such as heterogeneous - // scenario and so on. - mindspore::HashMap> real_input_nodes_; }; using KernelModPtr = std::shared_ptr; } // namespace kernel diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index 93a883eb4c5..8315c7774f5 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -111,12 +111,10 @@ void DisableMindRT(const ResourcePtr &res) { auto parallel_mode = parallel_context->parallel_mode(); bool is_parallel_mode = parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel; bool enable_old_runtime = (common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "0"); - bool use_old_vm_for_dynamic_shape = func_graph->exist_multi_target() && IsDynamicShapeGraph(func_graph); bool use_old_vm_for_control_parallel = func_graph->exist_multi_target() && ExistControlFlow(func_graph) && is_parallel_mode; - if (enable_old_runtime || use_old_vm_for_dynamic_shape || use_old_vm_for_control_parallel) { - // Heterogeneous scenario + dynamic_shape runs in MsBackend. - MS_LOG(INFO) << "Disable mindRT in the heterogeneous + dynamic shape scenario."; + if (enable_old_runtime || use_old_vm_for_control_parallel) { + MS_LOG(INFO) << "Disable mindRT in the heterogeneous + control flow + parallel scenario."; context_ptr->set_param(MS_CTX_ENABLE_MINDRT, false); // Update the backend. auto new_backend = compile::CreateBackend(); diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/abstract_actor.h b/mindspore/ccsrc/runtime/graph_scheduler/actor/abstract_actor.h index a528f27f726..1206ca30503 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/abstract_actor.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/abstract_actor.h @@ -22,6 +22,7 @@ #include #include #include +#include #include "mindrt/include/actor/op_actor.h" #include "runtime/graph_scheduler/actor/actor_common.h" #include "runtime/graph_scheduler/device_tensor_store.h" @@ -64,6 +65,7 @@ class AbstractActor : public OpActor { } const std::vector &input_data_arrow_aids() const { return input_data_arrow_aids_; } const std::vector &input_control_arrow_aids() const { return input_control_arrow_aids_; } + const std::map &internal_parameters() const { return internal_parameters_; } protected: friend class GraphScheduler; @@ -108,6 +110,11 @@ class AbstractActor : public OpActor { // The device tensor stores which have the auto monad attribute. std::set auto_monad_device_tensor_stores_; + // HashMap is used to update the shape of internal parameter node for inferring the + // dynamic shape information of the nodes located at the boundary of the graph partition, such as heterogeneous + // scenario and so on. + std::map internal_parameters_; + // The dependent input actors. std::vector input_data_arrow_aids_; std::vector input_control_arrow_aids_; diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_dump.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_dump.cc index 28baf524336..f00311eca5f 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_dump.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/actor_dump.cc @@ -88,6 +88,16 @@ void DumpAbstractActor(const AbstractActor *actor, std::ofstream &ofs) { ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n"; } } + + if (actor->internal_parameters().size() > 0) { + ofs << "\t\tinternal_parameters:" << actor->internal_parameters().size() << "\n "; + for (auto &internal_parameter_iter : actor->internal_parameters()) { + auto internal_parameter = internal_parameter_iter.second.lock(); + MS_EXCEPTION_IF_NULL(internal_parameter); + ofs << "\t\t\toutput_index:" << internal_parameter_iter.first + << "\tinternal_parameter:" << internal_parameter->DebugString() << "\n"; + } + } } void DumpDSActor(const DataSourceActor *actor, std::ofstream &ofs) { @@ -140,7 +150,8 @@ void DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) { MS_EXCEPTION_IF_NULL(kernel); ofs << "\t\tkernel_name:" << kernel->fullname_with_scope() << "\tinputs_num:" << common::AnfAlgo::GetInputTensorNum(kernel) - << "\toutputs_num:" << common::AnfAlgo::GetOutputTensorNum(kernel) << "\n"; + << "\toutputs_num:" << common::AnfAlgo::GetOutputTensorNum(kernel) + << "\tis_dynamic_shape:" << actor->is_dynamic_shape() << "\n"; for (size_t i = 0; i < common::AnfAlgo::GetOutputTensorNum(kernel); ++i) { const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(kernel, i, false); MS_EXCEPTION_IF_NULL(device_tensor); @@ -201,6 +212,8 @@ void DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) { } DumpAbstractActor(actor, ofs); + + ofs << "\t\tis_need_update_output_size:" << actor->is_need_update_output_size() << "\n "; ofs << "\n"; } diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/copy_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/copy_actor.cc index 08397ffd99a..ffed48c35ae 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/copy_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/copy_actor.cc @@ -118,7 +118,13 @@ void CopyActor::FetchDeviceTensor(OpContext *const context) { input_device_tensor_[0] = input_data->data_; MS_EXCEPTION_IF_NULL(output_); - output_device_tensor_[0] = output_.get(); + output_device_tensor_[0] = output_; + } + + if (is_need_update_output_size_ && (input_device_tensor_[0]->GetSize() != output_device_tensor_[0]->GetSize())) { + MS_LOG(INFO) << GetAID().Name() << " update output size from " << output_device_tensor_[0]->GetSize() << " to " + << input_device_tensor_[0]->GetSize(); + output_device_tensor_[0]->SetSize(input_device_tensor_[0]->GetSize()); } } diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/copy_actor.h b/mindspore/ccsrc/runtime/graph_scheduler/actor/copy_actor.h index b515a79f6ca..7813c721f9a 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/copy_actor.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/copy_actor.h @@ -37,7 +37,9 @@ using mindspore::device::DeviceContext; class CopyActor : public MemoryAwareActor { public: CopyActor(const std::string &name, const AID &memory_manager_aid) - : MemoryAwareActor(name, KernelTransformType::kCopyActor, nullptr, memory_manager_aid), output_(nullptr) {} + : MemoryAwareActor(name, KernelTransformType::kCopyActor, nullptr, memory_manager_aid), + output_(nullptr), + is_need_update_output_size_(false) {} ~CopyActor() override = default; // The memory related operation interface. @@ -46,7 +48,8 @@ class CopyActor : public MemoryAwareActor { // The copy processing after memory alloc finished. void OnMemoryAllocFinish(OpContext *const context) override; - const DeviceTensorPtr &output() const { return output_; } + const DeviceTensor *output() const { return output_; } + bool is_need_update_output_size() const { return is_need_update_output_size_; } protected: void Init() override; @@ -66,8 +69,9 @@ class CopyActor : public MemoryAwareActor { // The output device tensor is saved from the output or fetched by device_tensor_store_keys_. std::vector output_device_tensor_; - // The output is created in the copy actor build, so can't be the raw pointer. - DeviceTensorPtr output_; + DeviceTensor *output_; + // The output size needs to be updated in the dynamic shape scene. + bool is_need_update_output_size_; }; using CopyActorPtr = std::shared_ptr; diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/custom_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/custom_actor.cc index d5b0ee4a91c..26f0e70e887 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/custom_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/custom_actor.cc @@ -21,8 +21,6 @@ namespace mindspore { namespace runtime { -void CustomActor::Init() {} - void CustomActor::Run(OpContext *const ctx) { auto node = kernel_.lock(); MS_EXCEPTION_IF_NULL(node); @@ -48,6 +46,14 @@ void CustomActor::Run(OpContext *const ctx) { auto base_node = AnfUtils::GetCustomActorBaseNode(kernel_.lock()); auto kernel_info = dynamic_cast(base_node->kernel_info()); UpdateOutputAddrSize(kernel_info, base_node); + // Update the shape of internal parameter. + for (auto &internal_parameter_iter : internal_parameters_) { + auto internal_parameter = internal_parameter_iter.second.lock(); + MS_EXCEPTION_IF_NULL(internal_parameter); + common::AnfAlgo::SetOutputInferTypeAndShape( + {common::AnfAlgo::GetOutputInferDataType(base_node, internal_parameter_iter.first)}, + {common::AnfAlgo::GetOutputInferShape(base_node, internal_parameter_iter.first)}, internal_parameter.get()); + } } EraseInput(ctx); diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/custom_actor.h b/mindspore/ccsrc/runtime/graph_scheduler/actor/custom_actor.h index 97265032869..2d88ad0146e 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/custom_actor.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/custom_actor.h @@ -42,7 +42,6 @@ class CustomActor : public AbstractActor { const AnfNodeWeakPtr &kernel() const { return kernel_; } protected: - void Init() override; void Run(OpContext *const context) override; private: diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.cc b/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.cc index 0d5f7d899af..4f1abc6950a 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.cc @@ -463,6 +463,14 @@ void KernelActor::PostLaunchKernel(OpContext *const context) { // The size of output address may be changed in dynamic shape scenario. if (is_dynamic_shape_) { UpdateOutputAddrSize(kernel_info_, kernel_); + // Update the shape of internal parameter. + for (auto &internal_parameter_iter : internal_parameters_) { + auto internal_parameter = internal_parameter_iter.second.lock(); + MS_EXCEPTION_IF_NULL(internal_parameter); + common::AnfAlgo::SetOutputInferTypeAndShape( + {common::AnfAlgo::GetOutputInferDataType(kernel_, internal_parameter_iter.first)}, + {common::AnfAlgo::GetOutputInferShape(kernel_, internal_parameter_iter.first)}, internal_parameter.get()); + } } running_dependent_msg_num_ = SizeToInt(input_datas_num_ + input_controls_num_); diff --git a/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.h b/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.h index ac116899a3f..e4257f602b5 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.h +++ b/mindspore/ccsrc/runtime/graph_scheduler/actor/kernel_actor.h @@ -79,6 +79,7 @@ class KernelActor : public DebugAwareActor { const CNodePtr &kernel() const { return kernel_; } const std::set &modifiable_ref_input_indexes() const { return modifiable_ref_input_indexes_; } const std::set &modifiable_ref_output_indexes() const { return modifiable_ref_output_indexes_; } + bool is_dynamic_shape() const { return is_dynamic_shape_; } protected: void Init() override; diff --git a/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc b/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc index 690fa32530e..aaaa443bbfc 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/graph_scheduler.cc @@ -1296,13 +1296,20 @@ void GraphScheduler::LinkDataArrowForInternalParameter(AbstractActor *const, Abs kernel_type = actor_pair.first->type_; } - // Update the real input node. - MS_EXCEPTION_IF_NULL(to_kernel_with_input_idx.first); - if (to_kernel_with_input_idx.first->isa()) { - auto kernel_mod = AnfAlgo::GetKernelMod(to_kernel_with_input_idx.first->cast()); - MS_EXCEPTION_IF_NULL(kernel_mod); - kernel_mod->InsertRealInputNode(real_from_kernel_with_output_idx.first, real_from_kernel_with_output_idx.second, - to_kernel_with_input_idx.second); + // Record the internal parameter of dynamic shape kernel. + if (common::AnfAlgo::IsDynamicShape(real_from_kernel_with_output_idx.first)) { + AbstractActor *dynamic_shape_actor = nullptr; + auto from_update_node = AnfUtils::GetCustomUpdateopNode(real_from_kernel_with_output_idx.first); + auto from_infer_node = AnfUtils::GetCustomInferopNode(real_from_kernel_with_output_idx.first); + if (from_update_node != nullptr) { + dynamic_shape_actor = FetchActor(AnfUtils::GetCustomActorName(from_update_node)); + } else if (from_infer_node != nullptr) { + dynamic_shape_actor = FetchActor(AnfUtils::GetCustomActorName(from_infer_node)); + } else { + dynamic_shape_actor = real_from_actor; + } + MS_EXCEPTION_IF_NULL(dynamic_shape_actor); + dynamic_shape_actor->internal_parameters_[real_from_kernel_with_output_idx.second] = internal_parameter; } if (kKernelTypeToLinkFunc.count(kernel_type) == 0) { @@ -1416,27 +1423,36 @@ void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor, // Set the member output_ of the copy actor. if (to_actor->type_ == KernelTransformType::kSuperKernelActor) { - copy_actor->output_ = AnfAlgo::GetMutableOutputAddr(to_kernel_with_input_idx.first, 0, false); + copy_actor->output_ = AnfAlgo::GetMutableOutputAddr(to_kernel_with_input_idx.first, 0, false).get(); } else { copy_actor->output_ = - AnfAlgo::GetPrevNodeMutableOutputAddr(to_kernel_with_input_idx.first, to_kernel_with_input_idx.second, false); + AnfAlgo::GetPrevNodeMutableOutputAddr(to_kernel_with_input_idx.first, to_kernel_with_input_idx.second, false) + .get(); } MS_EXCEPTION_IF_NULL(copy_actor->output_); if (copy_actor->output_->DeviceType() != to_device_context->GetDeviceAddressType()) { MS_LOG(EXCEPTION) << "The device type is not equal, output device type:" << copy_actor->output_->DeviceType() << ", to device context type:" << to_device_context->GetDeviceAddressType(); } + copy_actor->is_need_update_output_size_ = common::AnfAlgo::IsDynamicShape(to_kernel_with_input_idx.first); // Link between from actor and copy actor. AddDataArrow(from_actor, copy_actor, from_kernel, from_kernel_with_output_idx.second, 0); + // Link control arrow between custom update actor and copy actor if the custom update actor exists. + auto custom_update_node = AnfUtils::GetCustomUpdateopNode(from_kernel); + if (custom_update_node != nullptr) { + auto custom_update_actor = FetchActor(AnfUtils::GetCustomActorName(custom_update_node)); + MS_EXCEPTION_IF_NULL(custom_update_actor); + AddControlArrow(custom_update_actor, copy_actor); + } } // If the copy actor already exists, only need link between copy actor and to actor. AddDataArrow(copy_actor, to_actor, nullptr, 0, to_kernel_with_input_idx.second); if (to_actor->type_ == KernelTransformType::kSuperKernelActor) { - UpdateRefCount(copy_actor->output_.get(), true); + UpdateRefCount(copy_actor->output_, true); } else { - UpdateRefCount(copy_actor->output_.get(), false); + UpdateRefCount(copy_actor->output_, false); } } @@ -1682,6 +1698,11 @@ void GraphScheduler::LinkControlArrowForCustomActor(ActorSet *const actor_set, continue; } + auto to_kernel_type = FetchKernelTransformType(to_node, graph, graph_compiler_info.origin_parameters_order_, + graph_compiler_info.strategy_); + auto to_actor = FetchActor(to_kernel_type, graph_compiler_info.name_, to_node, graph); + MS_EXCEPTION_IF_NULL(to_actor); + AbstractActor *from_actor = nullptr; // InternalParameter --> CustomActor. if (IsInternalParameter(from_node, graph)) { @@ -1691,20 +1712,28 @@ void GraphScheduler::LinkControlArrowForCustomActor(ActorSet *const actor_set, if (IsSwitchActor(front_output_node) || (graph_output_to_actor_.count(front_output_with_index) == 0)) { continue; } - from_actor = graph_output_to_actor_[front_output_with_index].first; + auto real_from_node = graph_output_to_actor_[front_output_with_index].second.first; + auto from_update_node = AnfUtils::GetCustomUpdateopNode(real_from_node); + auto from_infer_node = AnfUtils::GetCustomInferopNode(real_from_node); + if (from_update_node != nullptr) { + from_actor = FetchActor(AnfUtils::GetCustomActorName(from_update_node)); + } else if (from_infer_node != nullptr) { + from_actor = FetchActor(AnfUtils::GetCustomActorName(from_infer_node)); + } else { + from_actor = graph_output_to_actor_[front_output_with_index].first; + } + MS_EXCEPTION_IF_NULL(from_actor); + MS_LOG(INFO) << "Custom actor link control arrow by internal parameter, front node: " + << front_output_node->fullname_with_scope() << ", from actor: " << from_actor->GetAID().Name() + << ", to actor: " << to_actor->GetAID().Name(); } else if (from_node->isa()) { continue; } else { auto from_kernel_type = FetchKernelTransformType(from_node, graph, graph_compiler_info.origin_parameters_order_, graph_compiler_info.strategy_); from_actor = FetchActor(from_kernel_type, graph_compiler_info.name_, from_node, graph); + MS_EXCEPTION_IF_NULL(from_actor); } - MS_EXCEPTION_IF_NULL(from_actor); - - auto to_kernel_type = FetchKernelTransformType(to_node, graph, graph_compiler_info.origin_parameters_order_, - graph_compiler_info.strategy_); - auto to_actor = FetchActor(to_kernel_type, graph_compiler_info.name_, to_node, graph); - MS_EXCEPTION_IF_NULL(to_actor); AddControlArrow(from_actor, to_actor); } } diff --git a/tests/st/dynamic_shape/test_ascend_cpu.py b/tests/st/dynamic_shape/test_ascend_cpu.py deleted file mode 100644 index e3c4938d603..00000000000 --- a/tests/st/dynamic_shape/test_ascend_cpu.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np -import pytest -import mindspore.context as context -import mindspore.nn as nn -from mindspore import Tensor -import mindspore.common.dtype as mstype -from mindspore.ops import operations as P - -context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") - - -class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.unique = P.Unique().add_prim_attr("primitive_target", "CPU") - - def construct(self, x): - x, y = self.unique(x) - return (x, y) - - -class UniqueSquare(nn.Cell): - def __init__(self): - super(UniqueSquare, self).__init__() - self.unique = P.Unique().add_prim_attr("primitive_target", "CPU") - self.square = P.Square() - - def construct(self, x): - x, _ = self.unique(x) - return self.square(x) - - -@pytest.mark.level0 -@pytest.mark.platform_arm_ascend_training -@pytest.mark.platform_x86_ascend_training -@pytest.mark.env_onecard -def test_unique_ascend(): - x = Tensor(np.array([1, 1, 2, 2, 3, 3]), mstype.int32) - unique = Net() - output = unique(x) - expect1 = np.array([1, 2, 3]) - expect2 = np.array([0, 0, 1, 1, 2, 2]) - assert (output[0].asnumpy() == expect1).all() - assert (output[1].asnumpy() == expect2).all() - - -@pytest.mark.level0 -@pytest.mark.platform_arm_ascend_training -@pytest.mark.platform_x86_ascend_training -@pytest.mark.env_onecard -def test_unique_square(): - x = Tensor(np.array([1, 1, 2, 2, 3, 3]), mstype.int32) - net = UniqueSquare() - output = net(x) - expect1 = np.array([1, 4, 9]) - assert (output.asnumpy() == expect1).all() diff --git a/tests/st/dynamic_shape/test_dynamic_shape_with_heterogeneity.py b/tests/st/dynamic_shape/test_dynamic_shape_with_heterogeneity.py new file mode 100644 index 00000000000..28e0adae23d --- /dev/null +++ b/tests/st/dynamic_shape/test_dynamic_shape_with_heterogeneity.py @@ -0,0 +1,118 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import numpy as np +import pytest +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +import mindspore.common.dtype as mstype +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE) + + +class Unique(nn.Cell): + def __init__(self): + super(Unique, self).__init__() + self.unique_cpu = P.Unique().add_prim_attr("primitive_target", "CPU") + + def construct(self, x): + x, y = self.unique_cpu(x) + return (x, y) + + +class UniqueSquare(nn.Cell): + def __init__(self): + super(UniqueSquare, self).__init__() + self.unique_cpu = P.Unique().add_prim_attr("primitive_target", "CPU") + self.square = P.Square() + + def construct(self, x): + x, _ = self.unique_cpu(x) + return self.square(x) + + +class UniqueReshapeAdd(nn.Cell): + def __init__(self): + super(UniqueReshapeAdd, self).__init__() + self.unique_cpu = P.Unique().add_prim_attr("primitive_target", "CPU") + self.unique = P.Unique() + self.reshape_cpu = P.Reshape().add_prim_attr("primitive_target", "CPU") + self.reshape = P.Reshape() + self.add = P.Add() + + def construct(self, x, y): + x, _ = self.unique_cpu(x) + x = self.reshape(x, (3, 1)) + y, _ = self.unique(y) + y = self.reshape_cpu(y, (3, 1)) + return self.add(x, y) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_unique(): + """ + Feature: Dynamic shape with heterogeneity. + Description: Test unique kernel in dynamic shape with heterogeneity scenarios. + Expectation: The value and shape of output are the expected values. + """ + x = Tensor(np.array([1, 1, 2, 2, 3, 3]), mstype.float32) + net = Unique() + output = net(x) + expect1 = np.array([1, 2, 3]) + expect2 = np.array([0, 0, 1, 1, 2, 2]) + assert (output[0].asnumpy() == expect1).all() + assert (output[1].asnumpy() == expect2).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_unique_square(): + """ + Feature: Dynamic shape with heterogeneity. + Description: Test unique and square kernels in dynamic shape with heterogeneity scenarios. + Expectation: The value and shape of output are the expected values. + """ + x = Tensor(np.array([1, 1, 2, 2, 3, 3]), mstype.float32) + net = UniqueSquare() + output = net(x) + expect = np.array([1, 4, 9]) + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_unique_reshape_add(): + """ + Feature: Dynamic shape with heterogeneity. + Description: Test unique, reshape and add kernels in dynamic shape with heterogeneity scenarios. + Expectation: The value and shape of output are the expected values. + """ + x = Tensor(np.array([1, 1, 2, 2, 3, 3]), mstype.int32) + y = Tensor(np.array([4, 4, 5, 5, 6, 6]), mstype.int32) + net = UniqueReshapeAdd() + output = net(x, y) + expect = np.array([[5], [7], [9]]) + assert (output.asnumpy() == expect).all()