enhancement reorder_ops pass to support reordering cast and type insensitive operators

support castup, type-insensitive to type-insensitive, castup

refactor reorder_ops

fix compiling

move reorder_ops pass to later

fix abstract

refactor

fix node input num bug
This commit is contained in:
looop5 2021-04-17 09:15:05 +08:00
parent 6801ef61e0
commit e88cdc84ec
4 changed files with 395 additions and 91 deletions

View File

@ -53,9 +53,6 @@ PassManagerPtr GraphKernelOptimizer::PreProcess() {
if (is_ascend) { if (is_ascend) {
// Remove redundant Cast(bias, fp16) for Matmul input // Remove redundant Cast(bias, fp16) for Matmul input
pm->AddPass(std::make_shared<CastMatmulFusion>()); pm->AddPass(std::make_shared<CastMatmulFusion>());
// Reorder TransData-Cast to Cast-TransData
pm->AddPass(std::make_shared<ReorderOps>());
} }
// Spread the MakeTuple input of UpdateState // Spread the MakeTuple input of UpdateState
@ -78,6 +75,9 @@ PassManagerPtr GraphKernelOptimizer::Cluster() {
PassManagerPtr GraphKernelOptimizer::HighLevelOpt1() { PassManagerPtr GraphKernelOptimizer::HighLevelOpt1() {
auto pm = std::make_shared<PassManager>("graphkernel_stage3_highlevelopt1"); auto pm = std::make_shared<PassManager>("graphkernel_stage3_highlevelopt1");
// Reorder Cast and Type-insensitive node
pm->AddPass(std::make_shared<ReorderOps>());
// normalize the Reduce axis // normalize the Reduce axis
pm->AddPass(std::make_shared<AxisNormalizer>()); pm->AddPass(std::make_shared<AxisNormalizer>());

View File

@ -1,5 +1,5 @@
/** /**
* Copyright 2020 Huawei Technologies Co., Ltd * Copyright 2020-2021 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
@ -18,6 +18,8 @@
#include <memory> #include <memory>
#include <vector> #include <vector>
#include <string> #include <string>
#include <algorithm>
#include <unordered_set>
#include "base/core_ops.h" #include "base/core_ops.h"
#include "utils/utils.h" #include "utils/utils.h"
#include "utils/log_adapter.h" #include "utils/log_adapter.h"
@ -27,78 +29,111 @@
namespace mindspore { namespace mindspore {
namespace opt { namespace opt {
namespace { namespace {
bool CanReorder(const FuncGraphManagerPtr &mng, const CNodePtr &transdata_node, const CNodePtr &cast_node) { bool IsTypeInsensitive(const CNodePtr &node) {
auto transdata_input_type = AnfAlgo::GetInputDeviceDataType(transdata_node, 0); // Nodes that will change the input data type will not seen as type insensitive nodes.
auto transdata_output_type = AnfAlgo::GetOutputDeviceDataType(transdata_node, 0); static std::unordered_set<PrimitivePtr> type_insensitive_op_list{
auto cast_input_type = AnfAlgo::GetInputDeviceDataType(cast_node, 0); prim::KPrimTransData, prim::kPrimTranspose, prim::kPrimExpandDims, prim::kPrimReshape,
auto cast_output_type = AnfAlgo::GetOutputDeviceDataType(cast_node, 0); prim::kPrimSqueeze, prim::kPrimTile, prim::kPrimNeg, prim::kPrimRelu,
// Conditions of reordering transdata_cast to cast_transdata: prim::kPrimMaximum, prim::kPrimMinimum, prim::kPrimSelect};
// 1) current transdata is only used by cast
// 2) transdata works on float32 (transdata supports float16/float32; return std::any_of(type_insensitive_op_list.begin(), type_insensitive_op_list.end(),
// transdata performances better on float16 due to less data to process) [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); });
// 3) cast works on float32 -> float16
if (mng->node_users()[transdata_node].size() == 1 && transdata_input_type == kNumberTypeFloat32 &&
transdata_output_type == transdata_input_type && cast_input_type == transdata_output_type &&
cast_output_type == kNumberTypeFloat16) {
return true;
}
return false;
} }
void SetNodeInfo(const CNodePtr &transdata_node, const CNodePtr &cast_node, const CNodePtr &node) { enum CastType { CAST_UP, CAST_DOWN, CAST_OTHER };
// Initial CastType GetCastType(const CNodePtr &node) {
// TransData: (type0, format0) -> (type0, format1) MS_EXCEPTION_IF_NULL(node);
// Cast: (type0, format1) -> (type1, format1) if (!IsPrimitiveCNode(node, prim::kPrimCast)) {
// After reorder MS_LOG(EXCEPTION) << "Only process for Cast!";
// Cast: (type0, format0) -> (type1, format0) }
// TransData: (type1, format0) -> (type1, format1) TypeId input_type = AnfAlgo::GetInputDeviceDataType(node, 0);
auto type0 = AnfAlgo::GetInputDeviceDataType(transdata_node, 0); TypeId output_type = AnfAlgo::GetOutputDeviceDataType(node, 0);
auto type1 = AnfAlgo::GetOutputDeviceDataType(cast_node, 0);
auto format0 = AnfAlgo::GetInputFormat(transdata_node, 0);
auto format1 = AnfAlgo::GetOutputFormat(transdata_node, 0);
auto abstract = transdata_node->abstract(); if (input_type == kNumberTypeFloat16 && output_type == kNumberTypeFloat32) {
auto scope = cast_node->scope(); return CAST_UP;
}
if (input_type == kNumberTypeFloat32 && output_type == kNumberTypeFloat16) {
return CAST_DOWN;
}
return CAST_OTHER;
}
std::vector<size_t> GetOpDataInputIndexes(const CNodePtr &node) {
std::vector<size_t> op_input_indexes;
if (node == nullptr || !IsTypeInsensitive(node)) {
return op_input_indexes;
}
// Data input index starts from 0.
if (IsPrimitiveCNode(node, prim::kPrimMaximum) || IsPrimitiveCNode(node, prim::kPrimMinimum)) {
op_input_indexes = {0, 1};
} else if (IsPrimitiveCNode(node, prim::kPrimSelect)) {
op_input_indexes = {1, 2};
} else {
op_input_indexes = {0};
}
return op_input_indexes;
}
bool CheckInputTypeConsistent(const CNodePtr &node, const std::vector<size_t> &check_indexes, const TypeId &base_type) {
MS_EXCEPTION_IF_NULL(node);
// node's inputs at check_indexes should be of type base_type
for (const auto &index : check_indexes) {
if (AnfAlgo::GetInputDeviceDataType(node, index) != base_type) {
return false;
}
}
return true;
}
void SetNodeInfo(const CNodePtr &orig_node, const CNodePtr &new_node, const TypeId &node_type) {
MS_EXCEPTION_IF_NULL(orig_node);
MS_EXCEPTION_IF_NULL(new_node);
auto node_name = AnfAlgo::GetCNodeName(new_node);
auto orig_node_name = AnfAlgo::GetCNodeName(orig_node);
if (orig_node_name != node_name) {
MS_LOG(EXCEPTION) << "Can not process on different nodes " << orig_node_name << " and " << node_name;
}
AbstractBasePtr new_abstract{nullptr};
std::vector<std::string> inputs_format; std::vector<std::string> inputs_format;
std::vector<std::string> outputs_format; std::vector<std::string> outputs_format;
std::vector<TypeId> inputs_device_type; std::vector<TypeId> inputs_device_type;
std::vector<TypeId> outputs_device_type; std::vector<TypeId> outputs_device_type{node_type};
auto kernel_type = AnfAlgo::GetKernelType(cast_node); KernelType kernel_type{AnfAlgo::GetKernelType(orig_node)};
auto op_pattern = AnfAlgo::GetOpPattern(cast_node); kernel::OpPattern op_pattern{AnfAlgo::GetOpPattern(orig_node)};
auto fusion_type = AnfAlgo::GetFusionType(cast_node); kernel::FusionType fusion_type{AnfAlgo::GetFusionType(orig_node)};
auto processor = AnfAlgo::GetProcessor(cast_node); kernel::Processor processor{AnfAlgo::GetProcessor(orig_node)};
auto node_name = AnfAlgo::GetCNodeName(node); auto node_data_inputs_num = AnfAlgo::GetInputNum(new_node);
for (size_t i = 0; i < node_data_inputs_num; ++i) {
auto node_input = AnfAlgo::GetInputNode(new_node, i);
auto node_input_format = AnfAlgo::GetOutputFormat(node_input, 0);
auto node_input_type = AnfAlgo::GetOutputDeviceDataType(node_input, 0);
inputs_format.push_back(node_input_format);
inputs_device_type.push_back(node_input_type);
}
if (node_name == "Cast") { if (node_name == "Cast") {
inputs_format.push_back(format0); auto node_input = AnfAlgo::GetInputNode(new_node, 0);
outputs_format.push_back(format0); new_abstract =
inputs_device_type.push_back(type0); std::make_shared<abstract::AbstractTensor>(TypeIdToType(node_type), node_input->abstract()->BuildShape());
outputs_device_type.push_back(type1); outputs_format.push_back(AnfAlgo::GetOutputFormat(node_input, 0));
// Set attrs
AnfAlgo::CopyNodeAttrs(cast_node, node);
} else if (node_name == "TransData") {
abstract = cast_node->abstract();
scope = transdata_node->scope();
inputs_format.push_back(format0);
outputs_format.push_back(format1);
inputs_device_type.push_back(type1);
outputs_device_type.push_back(type1);
kernel_type = AnfAlgo::GetKernelType(transdata_node);
op_pattern = AnfAlgo::GetOpPattern(transdata_node);
fusion_type = AnfAlgo::GetFusionType(transdata_node);
processor = AnfAlgo::GetProcessor(transdata_node);
// Set attrs
AnfAlgo::CopyNodeAttrs(transdata_node, node);
} else { } else {
MS_LOG(EXCEPTION) << "Node must be Cast or TransData"; new_abstract =
std::make_shared<abstract::AbstractTensor>(TypeIdToType(node_type), orig_node->abstract()->BuildShape());
outputs_format.push_back(AnfAlgo::GetOutputFormat(orig_node, 0));
} }
// Set abstract info // Set abstract info
node->set_abstract(abstract); new_node->set_abstract(new_abstract);
// Set scope info // Set attrs
node->set_scope(scope); AnfAlgo::CopyNodeAttrs(orig_node, new_node);
// Set kernel build info // Set kernel build info
node->set_kernel_info(std::make_shared<device::KernelInfo>()); new_node->set_kernel_info(std::make_shared<device::KernelInfo>());
kernel::KernelBuildInfo::KernelBuildInfoBuilder info_builder; kernel::KernelBuildInfo::KernelBuildInfoBuilder info_builder;
info_builder.SetInputsFormat(inputs_format); info_builder.SetInputsFormat(inputs_format);
info_builder.SetInputsDeviceType(inputs_device_type); info_builder.SetInputsDeviceType(inputs_device_type);
@ -108,10 +143,141 @@ void SetNodeInfo(const CNodePtr &transdata_node, const CNodePtr &cast_node, cons
info_builder.SetOpPattern(op_pattern); info_builder.SetOpPattern(op_pattern);
info_builder.SetFusionType(fusion_type); info_builder.SetFusionType(fusion_type);
info_builder.SetProcessor(processor); info_builder.SetProcessor(processor);
AnfAlgo::SetSelectKernelBuildInfo(info_builder.Build(), node.get()); AnfAlgo::SetSelectKernelBuildInfo(info_builder.Build(), new_node.get());
}
} // namespace
void ReorderOps::SetTypeInsensitiveNodeInputs(const CNodePtr &node, const std::vector<size_t> &indexes,
const std::vector<AnfNodePtr> &new_input_at_indexes,
std::vector<AnfNodePtr> *new_inputs) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(new_inputs);
if (indexes.size() != new_input_at_indexes.size()) {
MS_LOG(EXCEPTION) << "indexes size " << indexes.size() << " is not equal to new_input_at_indexes size "
<< new_input_at_indexes.size();
}
if (!new_inputs->empty()) {
new_inputs->resize(0);
}
// node's inputs at indexes change to new_input_at_indexes
std::unordered_set<size_t> indexes_set(indexes.begin(), indexes.end());
auto node_inputs_num = node->size();
size_t idx = 0;
for (size_t i = 0; i < node_inputs_num; ++i) {
if (indexes_set.find(i) == indexes_set.end()) {
new_inputs->push_back(node->input(i));
} else {
new_inputs->push_back(new_input_at_indexes[idx++]);
}
}
} }
bool ReorderTransDataCast(const FuncGraphPtr &func_graph) { bool ReorderOps::ReorderTypeInsensitiveCastDown(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng,
const CNodePtr &node) {
// Limitation: Current cast node is CAST_DOWN.
if (!IsPrimitiveCNode(node, prim::kPrimCast) || GetCastType(node) != CAST_DOWN) {
return false;
}
auto node_input = AnfAlgo::GetInputNode(node, 0);
auto type_insens_node = node_input->cast<CNodePtr>();
// Limitation:
// Find type insensitive node before cast node.
// Type insensitive node is only used by current cast node.
if (type_insens_node == nullptr || !IsTypeInsensitive(type_insens_node) ||
mng->node_users()[type_insens_node].size() > 1) {
return false;
}
auto cast_input_type = AnfAlgo::GetInputDeviceDataType(node, 0);
auto cast_out_type = AnfAlgo::GetOutputDeviceDataType(node, 0);
auto op_input_indexes = GetOpDataInputIndexes(type_insens_node);
// Limitation: Type insensitive node's inputs have same data type.
if (op_input_indexes.empty() || !CheckInputTypeConsistent(type_insens_node, op_input_indexes, cast_input_type)) {
return false;
}
std::vector<AnfNodePtr> new_cast_nodes;
for (const auto &index : op_input_indexes) {
auto new_cast_node =
func_graph->NewCNode({NewValueNode(prim::kPrimCast), AnfAlgo::GetInputNode(type_insens_node, index)});
SetNodeInfo(node, new_cast_node, cast_out_type);
new_cast_nodes.push_back(new_cast_node);
}
std::transform(op_input_indexes.begin(), op_input_indexes.end(), op_input_indexes.begin(),
[](const size_t &idx) { return idx + 1; });
std::vector<AnfNodePtr> type_insens_node_new_inputs;
SetTypeInsensitiveNodeInputs(type_insens_node, op_input_indexes, new_cast_nodes, &type_insens_node_new_inputs);
auto new_type_insens_node = func_graph->NewCNode(type_insens_node_new_inputs);
SetNodeInfo(type_insens_node, new_type_insens_node, cast_out_type);
(void)mng->Replace(node, new_type_insens_node);
return true;
}
bool ReorderOps::ReorderCastUpTypeInsensitive(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng,
const CNodePtr &node) {
if (!IsTypeInsensitive(node)) {
return false;
}
// Limitation:
// Certain inputs of type insensitive node are cast node.
// Cast nodes are CAST_UP.
// All these cast nodes are only used by current type insensitive node.
std::vector<CNodePtr> cast_nodes;
std::vector<AnfNodePtr> cast_input_nodes;
auto op_input_indexes = GetOpDataInputIndexes(node);
for (const auto &index : op_input_indexes) {
auto node_input = AnfAlgo::GetInputNode(node, index);
auto cast_node = node_input->cast<CNodePtr>();
if (cast_node != nullptr && IsPrimitiveCNode(cast_node, prim::kPrimCast) && GetCastType(cast_node) == CAST_UP &&
mng->node_users()[cast_node].size() == 1) {
cast_nodes.push_back(cast_node);
cast_input_nodes.push_back(AnfAlgo::GetInputNode(cast_node, 0));
}
}
if (cast_nodes.empty() || cast_nodes.size() != op_input_indexes.size()) {
return false;
}
auto cast_input_type = AnfAlgo::GetInputDeviceDataType(cast_nodes[0], 0);
auto cast_out_type = AnfAlgo::GetOutputDeviceDataType(cast_nodes[0], 0);
// Limitation: All these cast nodes cast same type to another type.
if (!std::all_of(cast_nodes.begin(), cast_nodes.end(), [&cast_input_type](const CNodePtr &cast_node) {
return AnfAlgo::GetInputDeviceDataType(cast_node, 0) == cast_input_type;
})) {
return false;
}
// Limitation: Type insensitive node's inputs have same data type.
if (!CheckInputTypeConsistent(node, op_input_indexes, cast_out_type)) {
return false;
}
std::transform(op_input_indexes.begin(), op_input_indexes.end(), op_input_indexes.begin(),
[](const size_t &idx) { return idx + 1; });
std::vector<AnfNodePtr> type_insens_node_new_inputs;
SetTypeInsensitiveNodeInputs(node, op_input_indexes, cast_input_nodes, &type_insens_node_new_inputs);
auto new_type_insens_node = func_graph->NewCNode(type_insens_node_new_inputs);
SetNodeInfo(node, new_type_insens_node, cast_input_type);
auto new_cast_node = func_graph->NewCNode({NewValueNode(prim::kPrimCast), new_type_insens_node});
SetNodeInfo(cast_nodes[0], new_cast_node, cast_out_type);
(void)mng->Replace(node, new_cast_node);
return true;
}
bool ReorderOps::ReorderCastTypeInsensitive(const FuncGraphPtr &func_graph) {
// Reorder cast node and type insensitive node in graph kernel sub-graph, this function has several limitations,
// see the comments that start will "Limitation:" in this file.
// Limitation: Assuming the type insensitive node will not change the type of input nodes, otherwise it can be seen
// as another cast node in some sense, such as LessEqual operator, which performs on two inputs and output a
// a boolean result.
MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(func_graph);
auto mng = func_graph->manager(); auto mng = func_graph->manager();
if (mng == nullptr) { if (mng == nullptr) {
@ -121,40 +287,52 @@ bool ReorderTransDataCast(const FuncGraphPtr &func_graph) {
bool changed = false; bool changed = false;
auto todos = TopoSort(func_graph->get_return()); auto todos = TopoSort(func_graph->get_return());
for (const auto &anf_node : todos) { for (const auto &anf_node : todos) {
// Find cast node. auto node = anf_node->cast<CNodePtr>();
auto cast_node = anf_node->cast<CNodePtr>(); if (node == nullptr) {
if (cast_node == nullptr || !AnfAlgo::CheckPrimitiveType(cast_node, prim::kPrimCast)) {
continue; continue;
} }
// Find transdata node before cast node. if (IsTypeInsensitive(node)) {
auto cast_input = AnfAlgo::GetInputNode(cast_node, 0); // Reorder pattern 1: CastUp-TypeInsensitive --> TypeInsensitive-CastUp
auto transdata_node = cast_input->cast<CNodePtr>(); changed = ReorderCastUpTypeInsensitive(func_graph, mng, node) || changed;
if (transdata_node == nullptr || !AnfAlgo::CheckPrimitiveType(transdata_node, prim::KPrimTransData)) { } else if (IsPrimitiveCNode(node, prim::kPrimCast)) {
continue; // Reorder pattern 2: TypeInsensitive-CastDown --> CastDown-TypeInsensitive
changed = ReorderTypeInsensitiveCastDown(func_graph, mng, node) || changed;
} }
// Reorder transdata_cast to cast_transdata if possible.
if (!CanReorder(mng, transdata_node, cast_node)) {
continue;
}
MS_LOG(INFO) << "Reorder " << transdata_node->fullname_with_scope() << ", " << cast_node->fullname_with_scope();
auto new_cast_node = func_graph->NewCNode({NewValueNode(prim::kPrimCast), transdata_node->inputs()[1]});
SetNodeInfo(transdata_node, cast_node, new_cast_node);
auto new_transdata_node = func_graph->NewCNode({NewValueNode(prim::KPrimTransData), new_cast_node});
SetNodeInfo(transdata_node, cast_node, new_transdata_node);
(void)mng->Replace(cast_node, new_transdata_node);
changed = true;
} }
return changed; return changed;
} }
} // namespace
bool ReorderOps::Run(const FuncGraphPtr &func_graph) { return ReorderTransDataCast(func_graph); } bool ReorderOps::Run(const FuncGraphPtr &func_graph) {
MS_EXCEPTION_IF_NULL(func_graph);
auto mng = func_graph->manager();
if (mng == nullptr) {
mng = Manage(func_graph, true);
func_graph->set_manager(mng);
}
bool changed = false;
auto todos = TopoSort(func_graph->get_return());
for (const auto &anf_node : todos) {
auto node = anf_node->cast<CNodePtr>();
if (node == nullptr) {
continue;
}
if (AnfAlgo::IsGraphKernel(node)) {
auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node);
bool need_traverse = true;
while (need_traverse) {
need_traverse = ReorderCastTypeInsensitive(sub_func_graph);
if (need_traverse) {
changed = true;
}
}
}
}
return changed;
}
} // namespace opt } // namespace opt
} // namespace mindspore } // namespace mindspore

View File

@ -18,6 +18,7 @@
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_REORDER_OPS_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_REORDER_OPS_H_
#include <memory> #include <memory>
#include <vector>
#include "backend/optimizer/common/pass.h" #include "backend/optimizer/common/pass.h"
namespace mindspore { namespace mindspore {
@ -27,6 +28,16 @@ class ReorderOps : public Pass {
ReorderOps() : Pass("reorder_ops") {} ReorderOps() : Pass("reorder_ops") {}
~ReorderOps() override = default; ~ReorderOps() override = default;
bool Run(const FuncGraphPtr &func_graph) override; bool Run(const FuncGraphPtr &func_graph) override;
private:
void SetTypeInsensitiveNodeInputs(const CNodePtr &node, const std::vector<size_t> &indexes,
const std::vector<AnfNodePtr> &new_input_in_indexes,
std::vector<AnfNodePtr> *new_inputs);
bool ReorderTypeInsensitiveCastDown(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng,
const CNodePtr &node);
bool ReorderCastUpTypeInsensitive(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng,
const CNodePtr &node);
bool ReorderCastTypeInsensitive(const FuncGraphPtr &func_graph);
}; };
using ReorderOpsPtr = std::shared_ptr<ReorderOps>; using ReorderOpsPtr = std::shared_ptr<ReorderOps>;
} // namespace opt } // namespace opt

View File

@ -0,0 +1,115 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import pytest
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore import Tensor
from mindspore.ops import operations as P
class CastUpNet(nn.Cell):
def __init__(self):
super(CastUpNet, self).__init__()
self.cast = P.Cast()
self.transpose = P.Transpose()
self.neg = P.Neg()
def construct(self, i0):
res = self.cast(i0, mstype.float32)
res = self.transpose(res, (1, 0))
res = self.neg(res)
return res
def get_castup_output(x0, enable_graph_kernel=False):
context.set_context(enable_graph_kernel=enable_graph_kernel)
net = CastUpNet()
output = net(x0)
return output
def test_castup():
x0 = Tensor(np.random.normal(0, 1, (16, 16)).astype(np.float16))
expect = get_castup_output(x0, False)
output = get_castup_output(x0, True)
expect_np = expect.asnumpy().copy()
output_np = output.asnumpy().copy()
assert np.allclose(expect_np, output_np, 1e-4, 1e-4)
class CastDownNet(nn.Cell):
def __init__(self):
super(CastDownNet, self).__init__()
self.cast = P.Cast()
self.transpose = P.Transpose()
self.neg = P.Neg()
def construct(self, i0):
res = self.transpose(i0, (1, 0))
res = self.neg(res)
res = self.cast(res, mstype.float16)
return res
def get_castdown_output(x0, enable_graph_kernel=False):
context.set_context(enable_graph_kernel=enable_graph_kernel)
net = CastDownNet()
output = net(x0)
return output
def test_castdown():
x0 = Tensor(np.random.normal(0, 1, (16, 16)).astype(np.float32))
expect = get_castdown_output(x0, False)
output = get_castdown_output(x0, True)
expect_np = expect.asnumpy().copy()
output_np = output.asnumpy().copy()
assert np.allclose(expect_np, output_np, 1e-3, 1e-3)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_castup_gpu():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
test_castup()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_castup_ascend():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_castup()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_castdown_gpu():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
test_castdown()
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_castdown_ascend():
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
test_castdown()