forked from mindspore-Ecosystem/mindspore
!31752 unified runtime enable the dynamic shape in the heterogeneous
Merge pull request !31752 from limingqi107/bug_fix3
This commit is contained in:
commit
d22da105c8
|
@ -59,16 +59,9 @@ void KernelMod::InferShape() {
|
|||
auto input_size = common::AnfAlgo::GetInputTensorNum(cnode);
|
||||
bool skip_nop_node = !context->get_param<bool>(MS_CTX_ENABLE_MINDRT);
|
||||
for (size_t i = 0; i < input_size; i++) {
|
||||
AnfNodePtr real_input = nullptr;
|
||||
size_t real_input_index = 0;
|
||||
if (real_input_nodes_.count(i) > 0) {
|
||||
real_input = real_input_nodes_[i].first.lock();
|
||||
real_input_index = real_input_nodes_[i].second;
|
||||
} else {
|
||||
auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, i);
|
||||
real_input = input_node_with_index.first;
|
||||
real_input_index = input_node_with_index.second;
|
||||
}
|
||||
auto input_node_with_index = common::AnfAlgo::GetPrevNodeOutput(cnode, i, false);
|
||||
auto real_input = input_node_with_index.first;
|
||||
auto real_input_index = input_node_with_index.second;
|
||||
MS_EXCEPTION_IF_NULL(real_input);
|
||||
if (skip_nop_node) {
|
||||
InferShapeForNopNode(real_input);
|
||||
|
|
|
@ -20,7 +20,6 @@
|
|||
#include <memory>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <utility>
|
||||
#include "nlohmann/json.hpp"
|
||||
#include "ir/anf.h"
|
||||
#include "ir/dtype.h"
|
||||
|
@ -223,10 +222,6 @@ class KernelMod {
|
|||
// set true if need to update output's shape after launch in dynamic_shape, like Unique
|
||||
virtual bool IsNeedUpdateOp() { return is_need_updateop_; }
|
||||
|
||||
void InsertRealInputNode(const AnfNodePtr &pre_node, size_t pre_node_out_index, size_t input_index) {
|
||||
real_input_nodes_[input_index] = {pre_node, pre_node_out_index};
|
||||
}
|
||||
|
||||
protected:
|
||||
void InferShape();
|
||||
void GetDepndLists(const CNodePtr &cnode);
|
||||
|
@ -254,11 +249,6 @@ class KernelMod {
|
|||
std::vector<AddressPtr> inputs_addr_;
|
||||
std::vector<AddressPtr> workspaces_addr_;
|
||||
std::vector<AddressPtr> outputs_addr_;
|
||||
|
||||
// HashMap <input_index, pair<pre_node, pre_node_output_index>> is used to record the real input node to infer the
|
||||
// dynamic shape information of the nodes located at the boundary of the graph partition, such as heterogeneous
|
||||
// scenario and so on.
|
||||
mindspore::HashMap<size_t, std::pair<AnfNodeWeakPtr, size_t>> real_input_nodes_;
|
||||
};
|
||||
using KernelModPtr = std::shared_ptr<KernelMod>;
|
||||
} // namespace kernel
|
||||
|
|
|
@ -111,12 +111,10 @@ void DisableMindRT(const ResourcePtr &res) {
|
|||
auto parallel_mode = parallel_context->parallel_mode();
|
||||
bool is_parallel_mode = parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel;
|
||||
bool enable_old_runtime = (common::GetEnv("MS_DEV_ENABLE_CLOSURE") == "0");
|
||||
bool use_old_vm_for_dynamic_shape = func_graph->exist_multi_target() && IsDynamicShapeGraph(func_graph);
|
||||
bool use_old_vm_for_control_parallel =
|
||||
func_graph->exist_multi_target() && ExistControlFlow(func_graph) && is_parallel_mode;
|
||||
if (enable_old_runtime || use_old_vm_for_dynamic_shape || use_old_vm_for_control_parallel) {
|
||||
// Heterogeneous scenario + dynamic_shape runs in MsBackend.
|
||||
MS_LOG(INFO) << "Disable mindRT in the heterogeneous + dynamic shape scenario.";
|
||||
if (enable_old_runtime || use_old_vm_for_control_parallel) {
|
||||
MS_LOG(INFO) << "Disable mindRT in the heterogeneous + control flow + parallel scenario.";
|
||||
context_ptr->set_param<bool>(MS_CTX_ENABLE_MINDRT, false);
|
||||
// Update the backend.
|
||||
auto new_backend = compile::CreateBackend();
|
||||
|
|
|
@ -22,6 +22,7 @@
|
|||
#include <memory>
|
||||
#include <utility>
|
||||
#include <set>
|
||||
#include <map>
|
||||
#include "mindrt/include/actor/op_actor.h"
|
||||
#include "runtime/graph_scheduler/actor/actor_common.h"
|
||||
#include "runtime/graph_scheduler/device_tensor_store.h"
|
||||
|
@ -64,6 +65,7 @@ class AbstractActor : public OpActor<DeviceTensor> {
|
|||
}
|
||||
const std::vector<AID> &input_data_arrow_aids() const { return input_data_arrow_aids_; }
|
||||
const std::vector<AID> &input_control_arrow_aids() const { return input_control_arrow_aids_; }
|
||||
const std::map<size_t, AnfNodeWeakPtr> &internal_parameters() const { return internal_parameters_; }
|
||||
|
||||
protected:
|
||||
friend class GraphScheduler;
|
||||
|
@ -108,6 +110,11 @@ class AbstractActor : public OpActor<DeviceTensor> {
|
|||
// The device tensor stores which have the auto monad attribute.
|
||||
std::set<AnfNodePtr> auto_monad_device_tensor_stores_;
|
||||
|
||||
// HashMap <output_index, internal_parameter> is used to update the shape of internal parameter node for inferring the
|
||||
// dynamic shape information of the nodes located at the boundary of the graph partition, such as heterogeneous
|
||||
// scenario and so on.
|
||||
std::map<size_t, AnfNodeWeakPtr> internal_parameters_;
|
||||
|
||||
// The dependent input actors.
|
||||
std::vector<AID> input_data_arrow_aids_;
|
||||
std::vector<AID> input_control_arrow_aids_;
|
||||
|
|
|
@ -88,6 +88,16 @@ void DumpAbstractActor(const AbstractActor *actor, std::ofstream &ofs) {
|
|||
ofs << "\t\t\tto_actor_name:" << aid.Name() << "\n";
|
||||
}
|
||||
}
|
||||
|
||||
if (actor->internal_parameters().size() > 0) {
|
||||
ofs << "\t\tinternal_parameters:" << actor->internal_parameters().size() << "\n ";
|
||||
for (auto &internal_parameter_iter : actor->internal_parameters()) {
|
||||
auto internal_parameter = internal_parameter_iter.second.lock();
|
||||
MS_EXCEPTION_IF_NULL(internal_parameter);
|
||||
ofs << "\t\t\toutput_index:" << internal_parameter_iter.first
|
||||
<< "\tinternal_parameter:" << internal_parameter->DebugString() << "\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void DumpDSActor(const DataSourceActor *actor, std::ofstream &ofs) {
|
||||
|
@ -140,7 +150,8 @@ void DumpKernelActor(const KernelActor *actor, std::ofstream &ofs) {
|
|||
MS_EXCEPTION_IF_NULL(kernel);
|
||||
ofs << "\t\tkernel_name:" << kernel->fullname_with_scope()
|
||||
<< "\tinputs_num:" << common::AnfAlgo::GetInputTensorNum(kernel)
|
||||
<< "\toutputs_num:" << common::AnfAlgo::GetOutputTensorNum(kernel) << "\n";
|
||||
<< "\toutputs_num:" << common::AnfAlgo::GetOutputTensorNum(kernel)
|
||||
<< "\tis_dynamic_shape:" << actor->is_dynamic_shape() << "\n";
|
||||
for (size_t i = 0; i < common::AnfAlgo::GetOutputTensorNum(kernel); ++i) {
|
||||
const auto &device_tensor = AnfAlgo::GetMutableOutputAddr(kernel, i, false);
|
||||
MS_EXCEPTION_IF_NULL(device_tensor);
|
||||
|
@ -201,6 +212,8 @@ void DumpCopyActor(const CopyActor *actor, std::ofstream &ofs) {
|
|||
}
|
||||
|
||||
DumpAbstractActor(actor, ofs);
|
||||
|
||||
ofs << "\t\tis_need_update_output_size:" << actor->is_need_update_output_size() << "\n ";
|
||||
ofs << "\n";
|
||||
}
|
||||
|
||||
|
|
|
@ -118,7 +118,13 @@ void CopyActor::FetchDeviceTensor(OpContext<DeviceTensor> *const context) {
|
|||
input_device_tensor_[0] = input_data->data_;
|
||||
|
||||
MS_EXCEPTION_IF_NULL(output_);
|
||||
output_device_tensor_[0] = output_.get();
|
||||
output_device_tensor_[0] = output_;
|
||||
}
|
||||
|
||||
if (is_need_update_output_size_ && (input_device_tensor_[0]->GetSize() != output_device_tensor_[0]->GetSize())) {
|
||||
MS_LOG(INFO) << GetAID().Name() << " update output size from " << output_device_tensor_[0]->GetSize() << " to "
|
||||
<< input_device_tensor_[0]->GetSize();
|
||||
output_device_tensor_[0]->SetSize(input_device_tensor_[0]->GetSize());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -37,7 +37,9 @@ using mindspore::device::DeviceContext;
|
|||
class CopyActor : public MemoryAwareActor {
|
||||
public:
|
||||
CopyActor(const std::string &name, const AID &memory_manager_aid)
|
||||
: MemoryAwareActor(name, KernelTransformType::kCopyActor, nullptr, memory_manager_aid), output_(nullptr) {}
|
||||
: MemoryAwareActor(name, KernelTransformType::kCopyActor, nullptr, memory_manager_aid),
|
||||
output_(nullptr),
|
||||
is_need_update_output_size_(false) {}
|
||||
~CopyActor() override = default;
|
||||
|
||||
// The memory related operation interface.
|
||||
|
@ -46,7 +48,8 @@ class CopyActor : public MemoryAwareActor {
|
|||
// The copy processing after memory alloc finished.
|
||||
void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override;
|
||||
|
||||
const DeviceTensorPtr &output() const { return output_; }
|
||||
const DeviceTensor *output() const { return output_; }
|
||||
bool is_need_update_output_size() const { return is_need_update_output_size_; }
|
||||
|
||||
protected:
|
||||
void Init() override;
|
||||
|
@ -66,8 +69,9 @@ class CopyActor : public MemoryAwareActor {
|
|||
// The output device tensor is saved from the output or fetched by device_tensor_store_keys_.
|
||||
std::vector<DeviceTensor *> output_device_tensor_;
|
||||
|
||||
// The output is created in the copy actor build, so can't be the raw pointer.
|
||||
DeviceTensorPtr output_;
|
||||
DeviceTensor *output_;
|
||||
// The output size needs to be updated in the dynamic shape scene.
|
||||
bool is_need_update_output_size_;
|
||||
};
|
||||
|
||||
using CopyActorPtr = std::shared_ptr<CopyActor>;
|
||||
|
|
|
@ -21,8 +21,6 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace runtime {
|
||||
void CustomActor::Init() {}
|
||||
|
||||
void CustomActor::Run(OpContext<DeviceTensor> *const ctx) {
|
||||
auto node = kernel_.lock();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
|
@ -48,6 +46,14 @@ void CustomActor::Run(OpContext<DeviceTensor> *const ctx) {
|
|||
auto base_node = AnfUtils::GetCustomActorBaseNode(kernel_.lock());
|
||||
auto kernel_info = dynamic_cast<KernelInfo *>(base_node->kernel_info());
|
||||
UpdateOutputAddrSize(kernel_info, base_node);
|
||||
// Update the shape of internal parameter.
|
||||
for (auto &internal_parameter_iter : internal_parameters_) {
|
||||
auto internal_parameter = internal_parameter_iter.second.lock();
|
||||
MS_EXCEPTION_IF_NULL(internal_parameter);
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape(
|
||||
{common::AnfAlgo::GetOutputInferDataType(base_node, internal_parameter_iter.first)},
|
||||
{common::AnfAlgo::GetOutputInferShape(base_node, internal_parameter_iter.first)}, internal_parameter.get());
|
||||
}
|
||||
}
|
||||
|
||||
EraseInput(ctx);
|
||||
|
|
|
@ -42,7 +42,6 @@ class CustomActor : public AbstractActor {
|
|||
const AnfNodeWeakPtr &kernel() const { return kernel_; }
|
||||
|
||||
protected:
|
||||
void Init() override;
|
||||
void Run(OpContext<DeviceTensor> *const context) override;
|
||||
|
||||
private:
|
||||
|
|
|
@ -463,6 +463,14 @@ void KernelActor::PostLaunchKernel(OpContext<DeviceTensor> *const context) {
|
|||
// The size of output address may be changed in dynamic shape scenario.
|
||||
if (is_dynamic_shape_) {
|
||||
UpdateOutputAddrSize(kernel_info_, kernel_);
|
||||
// Update the shape of internal parameter.
|
||||
for (auto &internal_parameter_iter : internal_parameters_) {
|
||||
auto internal_parameter = internal_parameter_iter.second.lock();
|
||||
MS_EXCEPTION_IF_NULL(internal_parameter);
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape(
|
||||
{common::AnfAlgo::GetOutputInferDataType(kernel_, internal_parameter_iter.first)},
|
||||
{common::AnfAlgo::GetOutputInferShape(kernel_, internal_parameter_iter.first)}, internal_parameter.get());
|
||||
}
|
||||
}
|
||||
|
||||
running_dependent_msg_num_ = SizeToInt(input_datas_num_ + input_controls_num_);
|
||||
|
|
|
@ -79,6 +79,7 @@ class KernelActor : public DebugAwareActor {
|
|||
const CNodePtr &kernel() const { return kernel_; }
|
||||
const std::set<size_t> &modifiable_ref_input_indexes() const { return modifiable_ref_input_indexes_; }
|
||||
const std::set<size_t> &modifiable_ref_output_indexes() const { return modifiable_ref_output_indexes_; }
|
||||
bool is_dynamic_shape() const { return is_dynamic_shape_; }
|
||||
|
||||
protected:
|
||||
void Init() override;
|
||||
|
|
|
@ -1296,13 +1296,20 @@ void GraphScheduler::LinkDataArrowForInternalParameter(AbstractActor *const, Abs
|
|||
kernel_type = actor_pair.first->type_;
|
||||
}
|
||||
|
||||
// Update the real input node.
|
||||
MS_EXCEPTION_IF_NULL(to_kernel_with_input_idx.first);
|
||||
if (to_kernel_with_input_idx.first->isa<CNode>()) {
|
||||
auto kernel_mod = AnfAlgo::GetKernelMod(to_kernel_with_input_idx.first->cast<CNodePtr>());
|
||||
MS_EXCEPTION_IF_NULL(kernel_mod);
|
||||
kernel_mod->InsertRealInputNode(real_from_kernel_with_output_idx.first, real_from_kernel_with_output_idx.second,
|
||||
to_kernel_with_input_idx.second);
|
||||
// Record the internal parameter of dynamic shape kernel.
|
||||
if (common::AnfAlgo::IsDynamicShape(real_from_kernel_with_output_idx.first)) {
|
||||
AbstractActor *dynamic_shape_actor = nullptr;
|
||||
auto from_update_node = AnfUtils::GetCustomUpdateopNode(real_from_kernel_with_output_idx.first);
|
||||
auto from_infer_node = AnfUtils::GetCustomInferopNode(real_from_kernel_with_output_idx.first);
|
||||
if (from_update_node != nullptr) {
|
||||
dynamic_shape_actor = FetchActor(AnfUtils::GetCustomActorName(from_update_node));
|
||||
} else if (from_infer_node != nullptr) {
|
||||
dynamic_shape_actor = FetchActor(AnfUtils::GetCustomActorName(from_infer_node));
|
||||
} else {
|
||||
dynamic_shape_actor = real_from_actor;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(dynamic_shape_actor);
|
||||
dynamic_shape_actor->internal_parameters_[real_from_kernel_with_output_idx.second] = internal_parameter;
|
||||
}
|
||||
|
||||
if (kKernelTypeToLinkFunc.count(kernel_type) == 0) {
|
||||
|
@ -1416,27 +1423,36 @@ void GraphScheduler::LinkDataArrowForCopyActor(AbstractActor *const from_actor,
|
|||
|
||||
// Set the member output_ of the copy actor.
|
||||
if (to_actor->type_ == KernelTransformType::kSuperKernelActor) {
|
||||
copy_actor->output_ = AnfAlgo::GetMutableOutputAddr(to_kernel_with_input_idx.first, 0, false);
|
||||
copy_actor->output_ = AnfAlgo::GetMutableOutputAddr(to_kernel_with_input_idx.first, 0, false).get();
|
||||
} else {
|
||||
copy_actor->output_ =
|
||||
AnfAlgo::GetPrevNodeMutableOutputAddr(to_kernel_with_input_idx.first, to_kernel_with_input_idx.second, false);
|
||||
AnfAlgo::GetPrevNodeMutableOutputAddr(to_kernel_with_input_idx.first, to_kernel_with_input_idx.second, false)
|
||||
.get();
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(copy_actor->output_);
|
||||
if (copy_actor->output_->DeviceType() != to_device_context->GetDeviceAddressType()) {
|
||||
MS_LOG(EXCEPTION) << "The device type is not equal, output device type:" << copy_actor->output_->DeviceType()
|
||||
<< ", to device context type:" << to_device_context->GetDeviceAddressType();
|
||||
}
|
||||
copy_actor->is_need_update_output_size_ = common::AnfAlgo::IsDynamicShape(to_kernel_with_input_idx.first);
|
||||
|
||||
// Link between from actor and copy actor.
|
||||
AddDataArrow(from_actor, copy_actor, from_kernel, from_kernel_with_output_idx.second, 0);
|
||||
// Link control arrow between custom update actor and copy actor if the custom update actor exists.
|
||||
auto custom_update_node = AnfUtils::GetCustomUpdateopNode(from_kernel);
|
||||
if (custom_update_node != nullptr) {
|
||||
auto custom_update_actor = FetchActor(AnfUtils::GetCustomActorName(custom_update_node));
|
||||
MS_EXCEPTION_IF_NULL(custom_update_actor);
|
||||
AddControlArrow(custom_update_actor, copy_actor);
|
||||
}
|
||||
}
|
||||
|
||||
// If the copy actor already exists, only need link between copy actor and to actor.
|
||||
AddDataArrow(copy_actor, to_actor, nullptr, 0, to_kernel_with_input_idx.second);
|
||||
if (to_actor->type_ == KernelTransformType::kSuperKernelActor) {
|
||||
UpdateRefCount(copy_actor->output_.get(), true);
|
||||
UpdateRefCount(copy_actor->output_, true);
|
||||
} else {
|
||||
UpdateRefCount(copy_actor->output_.get(), false);
|
||||
UpdateRefCount(copy_actor->output_, false);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1682,6 +1698,11 @@ void GraphScheduler::LinkControlArrowForCustomActor(ActorSet *const actor_set,
|
|||
continue;
|
||||
}
|
||||
|
||||
auto to_kernel_type = FetchKernelTransformType(to_node, graph, graph_compiler_info.origin_parameters_order_,
|
||||
graph_compiler_info.strategy_);
|
||||
auto to_actor = FetchActor(to_kernel_type, graph_compiler_info.name_, to_node, graph);
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
|
||||
AbstractActor *from_actor = nullptr;
|
||||
// InternalParameter --> CustomActor.
|
||||
if (IsInternalParameter(from_node, graph)) {
|
||||
|
@ -1691,20 +1712,28 @@ void GraphScheduler::LinkControlArrowForCustomActor(ActorSet *const actor_set,
|
|||
if (IsSwitchActor(front_output_node) || (graph_output_to_actor_.count(front_output_with_index) == 0)) {
|
||||
continue;
|
||||
}
|
||||
from_actor = graph_output_to_actor_[front_output_with_index].first;
|
||||
auto real_from_node = graph_output_to_actor_[front_output_with_index].second.first;
|
||||
auto from_update_node = AnfUtils::GetCustomUpdateopNode(real_from_node);
|
||||
auto from_infer_node = AnfUtils::GetCustomInferopNode(real_from_node);
|
||||
if (from_update_node != nullptr) {
|
||||
from_actor = FetchActor(AnfUtils::GetCustomActorName(from_update_node));
|
||||
} else if (from_infer_node != nullptr) {
|
||||
from_actor = FetchActor(AnfUtils::GetCustomActorName(from_infer_node));
|
||||
} else {
|
||||
from_actor = graph_output_to_actor_[front_output_with_index].first;
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
MS_LOG(INFO) << "Custom actor link control arrow by internal parameter, front node: "
|
||||
<< front_output_node->fullname_with_scope() << ", from actor: " << from_actor->GetAID().Name()
|
||||
<< ", to actor: " << to_actor->GetAID().Name();
|
||||
} else if (from_node->isa<Parameter>()) {
|
||||
continue;
|
||||
} else {
|
||||
auto from_kernel_type = FetchKernelTransformType(from_node, graph, graph_compiler_info.origin_parameters_order_,
|
||||
graph_compiler_info.strategy_);
|
||||
from_actor = FetchActor(from_kernel_type, graph_compiler_info.name_, from_node, graph);
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
}
|
||||
MS_EXCEPTION_IF_NULL(from_actor);
|
||||
|
||||
auto to_kernel_type = FetchKernelTransformType(to_node, graph, graph_compiler_info.origin_parameters_order_,
|
||||
graph_compiler_info.strategy_);
|
||||
auto to_actor = FetchActor(to_kernel_type, graph_compiler_info.name_, to_node, graph);
|
||||
MS_EXCEPTION_IF_NULL(to_actor);
|
||||
AddControlArrow(from_actor, to_actor);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,70 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
|
||||
|
||||
class Net(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.unique = P.Unique().add_prim_attr("primitive_target", "CPU")
|
||||
|
||||
def construct(self, x):
|
||||
x, y = self.unique(x)
|
||||
return (x, y)
|
||||
|
||||
|
||||
class UniqueSquare(nn.Cell):
|
||||
def __init__(self):
|
||||
super(UniqueSquare, self).__init__()
|
||||
self.unique = P.Unique().add_prim_attr("primitive_target", "CPU")
|
||||
self.square = P.Square()
|
||||
|
||||
def construct(self, x):
|
||||
x, _ = self.unique(x)
|
||||
return self.square(x)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_unique_ascend():
|
||||
x = Tensor(np.array([1, 1, 2, 2, 3, 3]), mstype.int32)
|
||||
unique = Net()
|
||||
output = unique(x)
|
||||
expect1 = np.array([1, 2, 3])
|
||||
expect2 = np.array([0, 0, 1, 1, 2, 2])
|
||||
assert (output[0].asnumpy() == expect1).all()
|
||||
assert (output[1].asnumpy() == expect2).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_unique_square():
|
||||
x = Tensor(np.array([1, 1, 2, 2, 3, 3]), mstype.int32)
|
||||
net = UniqueSquare()
|
||||
output = net(x)
|
||||
expect1 = np.array([1, 4, 9])
|
||||
assert (output.asnumpy() == expect1).all()
|
|
@ -0,0 +1,118 @@
|
|||
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
import numpy as np
|
||||
import pytest
|
||||
import mindspore.context as context
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE)
|
||||
|
||||
|
||||
class Unique(nn.Cell):
|
||||
def __init__(self):
|
||||
super(Unique, self).__init__()
|
||||
self.unique_cpu = P.Unique().add_prim_attr("primitive_target", "CPU")
|
||||
|
||||
def construct(self, x):
|
||||
x, y = self.unique_cpu(x)
|
||||
return (x, y)
|
||||
|
||||
|
||||
class UniqueSquare(nn.Cell):
|
||||
def __init__(self):
|
||||
super(UniqueSquare, self).__init__()
|
||||
self.unique_cpu = P.Unique().add_prim_attr("primitive_target", "CPU")
|
||||
self.square = P.Square()
|
||||
|
||||
def construct(self, x):
|
||||
x, _ = self.unique_cpu(x)
|
||||
return self.square(x)
|
||||
|
||||
|
||||
class UniqueReshapeAdd(nn.Cell):
|
||||
def __init__(self):
|
||||
super(UniqueReshapeAdd, self).__init__()
|
||||
self.unique_cpu = P.Unique().add_prim_attr("primitive_target", "CPU")
|
||||
self.unique = P.Unique()
|
||||
self.reshape_cpu = P.Reshape().add_prim_attr("primitive_target", "CPU")
|
||||
self.reshape = P.Reshape()
|
||||
self.add = P.Add()
|
||||
|
||||
def construct(self, x, y):
|
||||
x, _ = self.unique_cpu(x)
|
||||
x = self.reshape(x, (3, 1))
|
||||
y, _ = self.unique(y)
|
||||
y = self.reshape_cpu(y, (3, 1))
|
||||
return self.add(x, y)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_unique():
|
||||
"""
|
||||
Feature: Dynamic shape with heterogeneity.
|
||||
Description: Test unique kernel in dynamic shape with heterogeneity scenarios.
|
||||
Expectation: The value and shape of output are the expected values.
|
||||
"""
|
||||
x = Tensor(np.array([1, 1, 2, 2, 3, 3]), mstype.float32)
|
||||
net = Unique()
|
||||
output = net(x)
|
||||
expect1 = np.array([1, 2, 3])
|
||||
expect2 = np.array([0, 0, 1, 1, 2, 2])
|
||||
assert (output[0].asnumpy() == expect1).all()
|
||||
assert (output[1].asnumpy() == expect2).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_unique_square():
|
||||
"""
|
||||
Feature: Dynamic shape with heterogeneity.
|
||||
Description: Test unique and square kernels in dynamic shape with heterogeneity scenarios.
|
||||
Expectation: The value and shape of output are the expected values.
|
||||
"""
|
||||
x = Tensor(np.array([1, 1, 2, 2, 3, 3]), mstype.float32)
|
||||
net = UniqueSquare()
|
||||
output = net(x)
|
||||
expect = np.array([1, 4, 9])
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_unique_reshape_add():
|
||||
"""
|
||||
Feature: Dynamic shape with heterogeneity.
|
||||
Description: Test unique, reshape and add kernels in dynamic shape with heterogeneity scenarios.
|
||||
Expectation: The value and shape of output are the expected values.
|
||||
"""
|
||||
x = Tensor(np.array([1, 1, 2, 2, 3, 3]), mstype.int32)
|
||||
y = Tensor(np.array([4, 4, 5, 5, 6, 6]), mstype.int32)
|
||||
net = UniqueReshapeAdd()
|
||||
output = net(x, y)
|
||||
expect = np.array([[5], [7], [9]])
|
||||
assert (output.asnumpy() == expect).all()
|
Loading…
Reference in New Issue