Remove transdata and cast for internal outputs

This commit is contained in:
yujianfeng 2020-07-15 09:38:53 +08:00
parent 11732f0ea2
commit 188d74f15e
15 changed files with 564 additions and 30 deletions

View File

@ -96,6 +96,7 @@
#include "backend/optimizer/ascend/format_type/modify_ops_attrs.h"
#include "backend/optimizer/ascend/format_type/remove_no_use_reshape_op.h"
#include "backend/optimizer/ascend/ir_fusion/add_input_to_output.h"
#include "backend/optimizer/ascend/format_type/remove_internal_output.h"
#include "utils/context/ms_context.h"
#include "utils/config_manager.h"
#include "debug/anf_ir_dump.h"
@ -199,6 +200,7 @@ void AscendDataLayout(const std::shared_ptr<session::KernelGraph> &kernel_graph)
data_layout_pm->AddPass(std::make_shared<OptimizeDependence>());
data_layout_pm->AddPass(std::make_shared<TransDataSplit>());
data_layout_pm->AddPass(std::make_shared<EraseVisitAttr>());
data_layout_pm->AddPass(std::make_shared<RemoveInternalOutputTransOp>());
optimizer->AddPassManager(data_layout_pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();
@ -220,6 +222,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>());
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
mixed_precision_pm->AddPass(std::make_shared<ConvertUnSupportNodeToAICPU>());
mixed_precision_pm->AddPass(std::make_shared<RemoveInternalOutputCast>());
optimizer->AddPassManager(mixed_precision_pm);
(void)optimizer->Optimize(kernel_graph);
kernel_graph->SetExecOrderByDefault();

View File

@ -142,6 +142,7 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
MS_EXCEPTION_IF_NULL(node);
std::vector<AnfNodePtr> make_tuple_inputs;
make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(node); ++output_idx) {
std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx);
if (output_format == kOpFormat_NC1KHKWHWC0) {
@ -151,7 +152,11 @@ AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const
auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx);
std::vector<size_t> origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx);
if (kCommonFormatSet.find(output_format) == kCommonFormatSet.end() && origin_shape.size() > 1) {
make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false));
auto trans_op = AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false);
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) {
kernel_graph->ReplaceInternalOutput(node, trans_op, output_idx, 0);
}
make_tuple_inputs.emplace_back(trans_op);
} else {
// No need insert trans op.
make_tuple_inputs.push_back(tuple_getitem);
@ -249,9 +254,14 @@ AnfNodePtr InsertTransOpForOutput(const FuncGraphPtr &func_graph, const AnfNodeP
if (outputs_num == 0) {
return node;
}
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
// Single output
if (outputs_num == 1 && (!AnfAlgo::IsTupleOutput(node))) {
return InsertTransOpForSingleOutput(func_graph, node, kernel_select);
auto new_node = InsertTransOpForSingleOutput(func_graph, node, kernel_select);
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) {
kernel_graph->ReplaceInternalOutput(node, new_node);
}
return new_node;
}
// Multiple output
return InsertTransOpForMultipleOutput(func_graph, node, kernel_select);

View File

@ -40,6 +40,7 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo
std::vector<AnfNodePtr> make_tuple_inputs;
AbstractBasePtrList abstract_list;
make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple));
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(cnode); ++output_idx) {
AnfNodePtr replace_node = nullptr;
const auto origin_shape = AnfAlgo::GetOutputInferShape(cnode, output_idx);
@ -64,6 +65,9 @@ AnfNodePtr InsertCastForMultipleOutput(const FuncGraphPtr &func_graph, const CNo
MS_EXCEPTION_IF_NULL(replace_node);
replace_node->set_scope(cnode->scope());
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode)) {
kernel_graph->ReplaceInternalOutput(cnode, replace_node, output_idx, 0);
}
} else {
replace_node = getitem;
}
@ -87,6 +91,7 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c
return cnode;
}
MS_EXCEPTION_IF_NULL(cnode->Type());
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
// Single output
if (!cnode->Type()->isa<Tuple>()) {
if (!need_insert_cast[0]) {
@ -109,6 +114,9 @@ AnfNodePtr InsertCastForOutput(const FuncGraphPtr &func_graph, const CNodePtr &c
MS_EXCEPTION_IF_NULL(replace_node);
replace_node->set_scope(cnode->scope());
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), replace_node);
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(cnode)) {
kernel_graph->ReplaceInternalOutput(cnode, replace_node);
}
}
return replace_node;
}
@ -188,6 +196,10 @@ const AnfNodePtr InsertCast::Process(const FuncGraphPtr &func_graph, const AnfNo
CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto new_node = InsertCastForInput(func_graph, cnode);
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) {
kernel_graph->ReplaceInternalOutput(node, new_node);
}
// process output
return InsertCastForOutput(func_graph, new_node, std::vector<bool>(AnfAlgo::GetOutputTensorNum(new_node), true));
}

View File

@ -46,14 +46,13 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
if (node == nullptr || !AnfAlgo::IsRealKernel(node)) {
return nullptr;
}
AnfNodePtr front_node;
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
MS_LOG(DEBUG) << "process op: " << node->DebugString();
AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_);
auto kernel_graph = func_graph->cast<std::shared_ptr<session::KernelGraph>>();
if (kernel_graph != nullptr && kernel_graph->IsInternalOutput(node)) {
front_node = kernel_graph->GetFrontNodeByInternalOutput(node);
kernel_graph->ReplaceInternalOutput(node, new_node);
}
AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), node);
MS_LOG(DEBUG) << "====process op: " << node->DebugString();
AnfNodePtr new_node = InsertTransOpForInput(func_graph, node, kernel_select_);
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
if (ms_context->execution_mode() == kPynativeMode && !ms_context->enable_pynative_hook()) {
@ -61,12 +60,7 @@ const AnfNodePtr InsertTransOp::Process(const FuncGraphPtr &func_graph, const An
return new_node;
}
}
auto final_node = InsertTransOpForOutput(func_graph, new_node, kernel_select_);
if (kernel_graph != nullptr && front_node != nullptr) {
auto old_node = kernel_graph->GetInternalOutputByFrontNode(front_node);
kernel_graph->ReplaceInternalOutput(old_node, final_node);
}
return final_node;
return InsertTransOpForOutput(func_graph, new_node, kernel_select_);
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,83 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/format_type/remove_internal_output.h"
#include <memory>
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace opt {
namespace {
bool UsedForOutputOnly(const FuncGraphPtr &func_graph, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(func_graph);
auto manager = func_graph->manager();
MS_EXCEPTION_IF_NULL(manager);
auto &node_users = manager->node_users();
auto iter = node_users.find(node);
if (iter == node_users.end()) {
return false;
}
const auto &node_set = iter->second;
for (const auto &node_index : node_set) {
if (!AnfAlgo::CheckPrimitiveType(node_index.first, prim::kPrimMakeTuple)) {
return false;
}
}
return true;
}
} // namespace
const BaseRef RemoveInternalOutputTransOp::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
auto prim = std::make_shared<Primitive>(kTransDataOpName);
return VectorRef({prim, X});
}
const BaseRef RemoveInternalOutputCast::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
return VectorRef({prim::kPrimCast, X});
}
const AnfNodePtr RemoveInternalOutput::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
MS_EXCEPTION_IF_NULL(func_graph);
MS_EXCEPTION_IF_NULL(node);
auto kernel_graph = func_graph->cast<KernelGraphPtr>();
if (kernel_graph == nullptr) {
return nullptr;
}
if (!kernel_graph->IsInternalOutput(node)) {
return nullptr;
}
if (!UsedForOutputOnly(func_graph, node)) {
return nullptr;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
CheckCNodeInputSize(cnode, kTransOpInputNum);
auto input_node = cnode->input(1);
if (!AnfAlgo::CheckPrimitiveType(input_node, prim::kPrimTupleGetItem)) {
kernel_graph->ReplaceInternalOutput(node, input_node);
} else {
auto tuple_getitem = input_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(tuple_getitem);
int idx = AnfAlgo::GetTupleGetItemOutIndex(tuple_getitem);
AnfNodePtr real_input_node = AnfAlgo::GetTupleGetItemRealInput(tuple_getitem);
kernel_graph->ReplaceInternalOutput(node, real_input_node, 0, idx);
}
return input_node;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,51 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_INTERNAL_OUTPUT_H_
#define MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_INTERNAL_OUTPUT_H_
#include <string>
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class RemoveInternalOutput : public PatternProcessPass {
public:
explicit RemoveInternalOutput(const std::string &name, bool multigraph = true)
: PatternProcessPass(name, multigraph) {}
~RemoveInternalOutput() override = default;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
};
class RemoveInternalOutputTransOp : public RemoveInternalOutput {
public:
explicit RemoveInternalOutputTransOp(bool multigraph = true)
: RemoveInternalOutput("remove_internal_output_trans_op", multigraph) {}
~RemoveInternalOutputTransOp() override = default;
const BaseRef DefinePattern() const override;
};
class RemoveInternalOutputCast : public RemoveInternalOutput {
public:
explicit RemoveInternalOutputCast(bool multigraph = true)
: RemoveInternalOutput("remove_internal_output_cast", multigraph) {}
~RemoveInternalOutputCast() override = default;
const BaseRef DefinePattern() const override;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_PRE_ACTIVATE_ASCEND_FORMAT_TYPE_REMOVE_INTERNAL_OUTPUT_H_

View File

@ -929,10 +929,15 @@ void KernelGraph::AddInternalOutput(const AnfNodePtr &front_node, const AnfNodeP
}
MS_LOG(INFO) << "Add internal node " << node->DebugString() << " with front node " << front_node->DebugString();
front_to_internal_outputs_map_[front_node] = node;
internal_outputs_to_front_map_[node] = front_node;
int output_idx = 0;
if (AnfAlgo::CheckPrimitiveType(front_node, prim::kPrimTupleGetItem)) {
output_idx = AnfAlgo::GetTupleGetItemOutIndex(front_node->cast<CNodePtr>());
}
internal_outputs_to_front_map_[node][output_idx] = front_node;
}
void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node) {
void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, int src_output_idx,
int dst_output_idx) {
if (new_node == nullptr || node == nullptr) {
MS_LOG(INFO) << "New node or node is nullptr";
return;
@ -947,9 +952,30 @@ void KernelGraph::ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr
return;
}
MS_LOG(INFO) << "Replace internal node " << node->DebugString() << " To " << new_node->DebugString();
internal_outputs_to_front_map_[new_node] = iter->second;
front_to_internal_outputs_map_[iter->second] = new_node;
internal_outputs_to_front_map_.erase(iter);
auto &front_nodes = iter->second;
// Move all front nodes to new node mapping
if (src_output_idx == -1) {
internal_outputs_to_front_map_[new_node] = front_nodes;
for (const auto &front_node_iter : front_nodes) {
front_to_internal_outputs_map_[front_node_iter.second] = new_node;
}
internal_outputs_to_front_map_.erase(iter);
return;
}
// Move specified front node to new node mapping
int index = SizeToInt(src_output_idx);
auto front_node_iter = front_nodes.find(index);
if (front_node_iter == front_nodes.end()) {
MS_LOG(INFO) << "The output " << src_output_idx << " of node " << node->DebugString() << " is not an internal node";
return;
}
auto front_node = front_node_iter->second;
internal_outputs_to_front_map_[new_node][dst_output_idx] = front_node;
front_to_internal_outputs_map_[front_node] = new_node;
front_nodes.erase(index);
if (front_nodes.empty()) {
internal_outputs_to_front_map_.erase(iter);
}
}
AnfNodePtr KernelGraph::GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const {
@ -967,14 +993,6 @@ bool KernelGraph::IsInternalOutput(const AnfNodePtr &node) const {
return false;
}
AnfNodePtr KernelGraph::GetFrontNodeByInternalOutput(const AnfNodePtr &node) const {
auto iter = internal_outputs_to_front_map_.find(node);
if (iter != internal_outputs_to_front_map_.end()) {
return iter->second;
}
return nullptr;
}
void KernelGraph::AddFinalOutputKernel(const AnfNodePtr &node) {
if (node == nullptr) {
return;

View File

@ -148,10 +148,10 @@ class KernelGraph : public FuncGraph {
const std::map<std::string, std::pair<AnfNodePtr, int>> &summary_nodes() const { return summary_nodes_; }
void set_summary_nodes(const std::map<std::string, std::pair<AnfNodePtr, int>> &nodes) { summary_nodes_ = nodes; }
void AddInternalOutput(const AnfNodePtr &front_node, const AnfNodePtr &node);
void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node);
void ReplaceInternalOutput(const AnfNodePtr &node, const AnfNodePtr &new_node, int src_output_idx = -1,
int dst_output_idx = -1);
AnfNodePtr GetInternalOutputByFrontNode(const AnfNodePtr &front_node) const;
bool IsInternalOutput(const AnfNodePtr &node) const;
AnfNodePtr GetFrontNodeByInternalOutput(const AnfNodePtr &node) const;
void AddFinalOutputKernel(const AnfNodePtr &node);
bool IsFinalOutputKernel(const AnfNodePtr &node) const;
uint32_t current_epoch() const { return current_epoch_; }
@ -223,7 +223,7 @@ class KernelGraph : public FuncGraph {
CNodePtr end_goto_;
bool null_output_;
std::unordered_map<AnfNodePtr, AnfNodePtr> front_to_internal_outputs_map_;
std::unordered_map<AnfNodePtr, AnfNodePtr> internal_outputs_to_front_map_;
std::unordered_map<AnfNodePtr, std::unordered_map<int, AnfNodePtr>> internal_outputs_to_front_map_;
std::set<AnfNodePtr> final_output_kernels_;
uint32_t current_epoch_;
};

View File

@ -300,7 +300,11 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
MS_LOG(INFO) << "No corresponding internal output for output node";
return;
}
auto real_kernel = AnfAlgo::VisitKernel(ref_node, 0);
size_t output_idx = 0;
if (AnfAlgo::CheckPrimitiveType(out_node, prim::kPrimTupleGetItem)) {
output_idx = AnfAlgo::GetTupleGetItemOutIndex(out_node->cast<CNodePtr>());
}
auto real_kernel = AnfAlgo::VisitKernel(ref_node, output_idx);
auto ref_real_node = real_kernel.first;
auto ref_real_node_index = real_kernel.second;
if (ref_real_node->isa<CNode>() && node_graph->IsInternalOutput(ref_real_node) &&
@ -325,6 +329,7 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
builder.SetOutputsFormat({format});
d_kernel_info->set_select_kernel_build_info(builder.Build());
AnfAlgo::SetOutputAddr(address, 0, parameter.get());
AnfAlgo::SetOutputInferTypeAndShape({type}, {AnfAlgo::GetOutputInferShape(parameter, 0)}, parameter.get());
}
}

View File

@ -0,0 +1,89 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Momentum
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class LeNet(nn.Cell):
def __init__(self):
super(LeNet, self).__init__()
self.relu = P.ReLU()
self.batch_size = 32
self.conv1 = nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=0, has_bias=False, pad_mode='valid')
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.reshape = P.Reshape()
self.fc1 = nn.Dense(400, 120)
self.fc1.matmul.add_prim_attr("primitive_target", "CPU")
self.fc1.bias_add.add_prim_attr("primitive_target", "CPU")
self.fc2 = nn.Dense(120, 84)
self.fc2.matmul.add_prim_attr("primitive_target", "CPU")
self.fc2.bias_add.add_prim_attr("primitive_target", "CPU")
self.fc3 = nn.Dense(84, 10)
self.fc3.matmul.add_prim_attr("primitive_target", "CPU")
self.fc3.bias_add.add_prim_attr("primitive_target", "CPU")
def construct(self, input_x):
output = self.conv1(input_x)
output = self.relu(output)
output = self.pool(output)
output = self.conv2(output)
output = self.relu(output)
output = self.pool(output)
output = self.reshape(output, (self.batch_size, -1))
output = self.fc1(output)
output = self.relu(output)
output = self.fc2(output)
output = self.relu(output)
output = self.fc3(output)
return output
def train(net, data, label):
learning_rate = 0.01
momentum = 0.9
optimizer = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), learning_rate, momentum)
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(net_with_criterion, optimizer) # optimizer
train_network.set_train()
res = train_network(data, label)
print("+++++++++Loss+++++++++++++")
print(res)
print("+++++++++++++++++++++++++++")
diff = res.asnumpy()[0] - 2.3025851
assert np.all(diff < 1.e-7)
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_lenet():
data = Tensor(np.ones([32, 1, 32, 32]).astype(np.float32) * 0.01)
label = Tensor(np.ones([32]).astype(np.int32))
net = LeNet()
train(net, data, label)

View File

@ -14,6 +14,7 @@
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
@ -43,6 +44,9 @@ class Net(nn.Cell):
return out
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_net():
gradient = Tensor(np.ones([3, 3, 3]).astype(np.float32))
indices = Tensor([0, 1, 2], mstype.int32)

View File

@ -14,6 +14,7 @@
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
@ -35,6 +36,9 @@ class Net(nn.Cell):
return out
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_net():
gradient = Tensor(np.ones([3, 3, 3]).astype(np.float32))
indices = Tensor([0, 1, 2], mstype.int32)

View File

@ -14,6 +14,7 @@
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
@ -37,6 +38,9 @@ class Net(nn.Cell):
return out
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_net():
gradient = Tensor(np.ones([3, 3, 3]).astype(np.float32))
indices = Tensor([0, 1, 2], mstype.int32)

View File

@ -0,0 +1,174 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/backend_common_test.h"
#include "debug/anf_ir_dump.h"
#include "common/py_func_graph_fetcher.h"
#include "backend/optimizer/ascend/format_type/remove_internal_output.h"
#define private public
#define protected public
#include "backend/optimizer/ascend/format_type/insert_trans_op.h"
#undef private
#undef protected
namespace mindspore {
namespace opt {
using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder;
class TestHWRemoveInternalOutput : public BackendCommon {
public:
TestHWRemoveInternalOutput() : getPyFun_("gtest_input.pre_activate.remove_internal_output_test", true) {}
~TestHWRemoveInternalOutput() override = default;
AnfNodePtr GetMakeTuple(const KernelGraphPtr &kg) {
auto ret = kg->get_return();
MS_EXCEPTION_IF_NULL(ret);
auto make_tuple = ret->input(1);
return make_tuple;
}
KernelGraphPtr GetSingleOutputGraph(const std::string &func_name, const std::string &sub_func_name) {
FuncGraphPtr g = getPyFun_.CallAndParseRet(func_name, sub_func_name);
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list{x_abstract, x_abstract};
auto kg = GetKernelGraph(g, args_spec_list);
auto make_tuple = GetMakeTuple(kg);
auto add = make_tuple->cast<CNodePtr>()->input(1);
MS_EXCEPTION_IF_NULL(add);
kg->AddInternalOutput(add, add);
KernelBuildInfoBuilder builder;
builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT});
builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()});
builder.SetOutputsFormat({kOpFormat_NC1HWC0});
builder.SetOutputsDeviceType({kFloat16->type_id()});
add->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), add.get());
return kg;
}
KernelGraphPtr GetMutilpleOutputGraph(const std::string &func_name, const std::string &sub_func_name) {
FuncGraphPtr g = getPyFun_.CallAndParseRet(func_name, sub_func_name);
std::vector<int> shp{2, 32, 224, 224};
auto x_abstract = std::make_shared<abstract::AbstractTensor>(kFloat32, shp);
AbstractBasePtrList args_spec_list{x_abstract};
auto kg = GetKernelGraph(g, args_spec_list);
auto output_make_tuple = GetMakeTuple(kg);
auto make_tuple = output_make_tuple->cast<CNodePtr>()->input(1);
MS_EXCEPTION_IF_NULL(make_tuple);
auto tuple_getitem1 = make_tuple->cast<CNodePtr>()->input(1);
MS_EXCEPTION_IF_NULL(tuple_getitem1);
auto tuple_getitem2 = make_tuple->cast<CNodePtr>()->input(2);
MS_EXCEPTION_IF_NULL(tuple_getitem2);
auto max_pool = tuple_getitem1->cast<CNodePtr>()->input(1);
MS_EXCEPTION_IF_NULL(max_pool);
kg->AddInternalOutput(tuple_getitem1, max_pool);
kg->AddInternalOutput(tuple_getitem2, max_pool);
KernelBuildInfoBuilder builder;
builder.SetInputsFormat({kOpFormat_DEFAULT});
builder.SetInputsDeviceType({kFloat32->type_id()});
builder.SetOutputsFormat({kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
builder.SetOutputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
max_pool->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), max_pool.get());
return kg;
}
UT::PyFuncGraphFetcher getPyFun_;
};
class MockRemoveInternalOutputTransOpKernelSelect : public KernelSelect {
public:
MockRemoveInternalOutputTransOpKernelSelect() = default;
~MockRemoveInternalOutputTransOpKernelSelect() override = default;
void SelectKernel(const CNodePtr &cnode) override {
KernelBuildInfoBuilder builder;
builder.SetInputsFormat({kOpFormat_NC1HWC0});
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({kOpFormat_DEFAULT});
builder.SetOutputsDeviceType({kFloat32->type_id()});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
}
};
TEST_F(TestHWRemoveInternalOutput, test_remove_internal_output_trans_op_for_single_output) {
auto ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
ms_context->set_execution_mode(kGraphMode);
auto kg = GetSingleOutputGraph("test_remove_internal_output_trans_op_for_single_output", "before");
// insert trans op for output
auto graph_optimizer = std::make_shared<opt::GraphOptimizer>();
auto pass_manager = std::make_shared<opt::PassManager>();
auto insert_trans_op_pass = std::make_shared<opt::InsertTransOp>();
insert_trans_op_pass->kernel_select_ = std::make_shared<MockRemoveInternalOutputTransOpKernelSelect>();
pass_manager->AddPass(insert_trans_op_pass);
graph_optimizer->AddPassManager(pass_manager);
auto new_g = graph_optimizer->Optimize(kg);
FuncGraphPtr g_after =
getPyFun_.CallAndParseRet("test_remove_internal_output_trans_op_for_single_output", "after_insert_trans_op");
EXPECT_TRUE(CheckEqualGraph(g_after, new_g));
auto make_tuple = GetMakeTuple(kg);
auto trans_data = make_tuple->cast<CNodePtr>()->input(1);
EXPECT_TRUE(kg->IsInternalOutput(trans_data));
// remove trans op for internal output
auto graph_optimizer1 = std::make_shared<opt::GraphOptimizer>();
auto pass_manager1 = std::make_shared<opt::PassManager>();
auto remove_internal_output_trans_op_pass = std::make_shared<opt::RemoveInternalOutputTransOp>();
pass_manager1->AddPass(remove_internal_output_trans_op_pass);
graph_optimizer1->AddPassManager(pass_manager1);
auto new_g1 = graph_optimizer1->Optimize(new_g);
FuncGraphPtr g_after1 = getPyFun_.CallAndParseRet("test_remove_internal_output_trans_op_for_single_output",
"after_remove_internal_output_trans_op");
EXPECT_TRUE(CheckEqualGraph(g_after1, new_g1));
}
TEST_F(TestHWRemoveInternalOutput, test_remove_internal_output_trans_op_for_multiple_output) {
auto kg = GetMutilpleOutputGraph("test_remove_internal_output_trans_op_for_multiple_output", "before");
// insert trans op for output
auto graph_optimizer = std::make_shared<opt::GraphOptimizer>();
auto pass_manager = std::make_shared<opt::PassManager>();
auto insert_trans_op_pass = std::make_shared<opt::InsertTransOp>();
insert_trans_op_pass->kernel_select_ = std::make_shared<MockRemoveInternalOutputTransOpKernelSelect>();
pass_manager->AddPass(insert_trans_op_pass);
graph_optimizer->AddPassManager(pass_manager);
auto new_g = graph_optimizer->Optimize(kg);
FuncGraphPtr g_after =
getPyFun_.CallAndParseRet("test_remove_internal_output_trans_op_for_multiple_output", "after_insert_trans_op");
EXPECT_TRUE(CheckEqualGraph(g_after, new_g));
auto output_make_tuple = GetMakeTuple(kg);
auto make_tuple = output_make_tuple->cast<CNodePtr>()->input(1);
auto tuple_getitem = make_tuple->cast<CNodePtr>()->input(1);
auto make_tuple1 = tuple_getitem->cast<CNodePtr>()->input(1);
auto trans_data1 = make_tuple1->cast<CNodePtr>()->input(1);
auto trans_data2 = make_tuple1->cast<CNodePtr>()->input(2);
EXPECT_TRUE(kg->IsInternalOutput(trans_data1));
EXPECT_TRUE(kg->IsInternalOutput(trans_data2));
// remove trans op for internal output
auto graph_optimizer1 = std::make_shared<opt::GraphOptimizer>();
auto pass_manager1 = std::make_shared<opt::PassManager>();
auto remove_internal_output_trans_op_pass = std::make_shared<opt::RemoveInternalOutputTransOp>();
pass_manager1->AddPass(remove_internal_output_trans_op_pass);
graph_optimizer1->AddPassManager(pass_manager1);
auto new_g1 = graph_optimizer1->Optimize(new_g);
FuncGraphPtr g_after1 = getPyFun_.CallAndParseRet("test_remove_internal_output_trans_op_for_multiple_output",
"after_remove_internal_output_trans_op");
EXPECT_TRUE(CheckEqualGraph(g_after1, new_g1));
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,83 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
from mindspore.ops import Primitive
from mindspore.ops import operations as P
tuple_getitem = Primitive('tuple_getitem')
add = P.TensorAdd()
max_pool = P.MaxPoolWithArgmax(padding="same", ksize=3, strides=2)
make_tuple = Primitive('make_tuple')
trans_data = Primitive("TransData")
class FnDict:
def __init__(self):
self.fnDict = {}
def __call__(self, fn):
self.fnDict[fn.__name__] = fn
def __getitem__(self, name):
return self.fnDict[name]
def test_remove_internal_output_trans_op_for_single_output(tag):
fns = FnDict()
@fns
def before(x, y):
res = add(x, y)
return res
@fns
def after_insert_trans_op(x, y):
output = add(x, y)
res = trans_data(output)
return make_tuple(res)
@fns
def after_remove_internal_output_trans_op(x, y):
res = add(x, y)
return make_tuple(res)
return fns[tag]
def test_remove_internal_output_trans_op_for_multiple_output(tag):
fns = FnDict()
@fns
def before(x):
max_pool_res = max_pool(x)
res = make_tuple(tuple_getitem(max_pool_res, 0), tuple_getitem(max_pool_res, 1))
return res
@fns
def after_insert_trans_op(x):
output = max_pool(x)
trans_data0 = trans_data(tuple_getitem(output, 0))
trans_data1 = trans_data(tuple_getitem(output, 1))
new_make_tuple = make_tuple(trans_data0, trans_data1)
res = make_tuple(tuple_getitem(new_make_tuple, 0), tuple_getitem(new_make_tuple, 1))
return make_tuple(res)
@fns
def after_remove_internal_output_trans_op(x):
output = max_pool(x)
new_make_tuple = make_tuple(tuple_getitem(output, 0), tuple_getitem(output, 1))
res = make_tuple(tuple_getitem(new_make_tuple, 0), tuple_getitem(new_make_tuple, 1))
return make_tuple(res)
return fns[tag]