From e88cdc84ec49eabcb242e2fbb03597a95db4afe4 Mon Sep 17 00:00:00 2001 From: looop5 Date: Sat, 17 Apr 2021 09:15:05 +0800 Subject: [PATCH] enhancement reorder_ops pass to support reordering cast and type insensitive operators support castup, type-insensitive to type-insensitive, castup refactor reorder_ops fix compiling move reorder_ops pass to later fix abstract refactor fix node input num bug --- .../graph_kernel/graph_kernel_optimization.cc | 6 +- .../optimizer/graph_kernel/reorder_ops.cc | 354 +++++++++++++----- .../optimizer/graph_kernel/reorder_ops.h | 11 + tests/st/ops/graph_kernel/test_reorder_ops.py | 115 ++++++ 4 files changed, 395 insertions(+), 91 deletions(-) create mode 100644 tests/st/ops/graph_kernel/test_reorder_ops.py diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc index 31f7d8ec1cd..f1808911ad6 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_optimization.cc @@ -53,9 +53,6 @@ PassManagerPtr GraphKernelOptimizer::PreProcess() { if (is_ascend) { // Remove redundant Cast(bias, fp16) for Matmul input pm->AddPass(std::make_shared()); - - // Reorder TransData-Cast to Cast-TransData - pm->AddPass(std::make_shared()); } // Spread the MakeTuple input of UpdateState @@ -78,6 +75,9 @@ PassManagerPtr GraphKernelOptimizer::Cluster() { PassManagerPtr GraphKernelOptimizer::HighLevelOpt1() { auto pm = std::make_shared("graphkernel_stage3_highlevelopt1"); + // Reorder Cast and Type-insensitive node + pm->AddPass(std::make_shared()); + // normalize the Reduce axis pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/reorder_ops.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/reorder_ops.cc index 1edbbdd4281..55d05bc6f9e 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/reorder_ops.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/reorder_ops.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,8 @@ #include #include #include +#include +#include #include "base/core_ops.h" #include "utils/utils.h" #include "utils/log_adapter.h" @@ -27,78 +29,111 @@ namespace mindspore { namespace opt { namespace { -bool CanReorder(const FuncGraphManagerPtr &mng, const CNodePtr &transdata_node, const CNodePtr &cast_node) { - auto transdata_input_type = AnfAlgo::GetInputDeviceDataType(transdata_node, 0); - auto transdata_output_type = AnfAlgo::GetOutputDeviceDataType(transdata_node, 0); - auto cast_input_type = AnfAlgo::GetInputDeviceDataType(cast_node, 0); - auto cast_output_type = AnfAlgo::GetOutputDeviceDataType(cast_node, 0); - // Conditions of reordering transdata_cast to cast_transdata: - // 1) current transdata is only used by cast - // 2) transdata works on float32 (transdata supports float16/float32; - // transdata performances better on float16 due to less data to process) - // 3) cast works on float32 -> float16 - if (mng->node_users()[transdata_node].size() == 1 && transdata_input_type == kNumberTypeFloat32 && - transdata_output_type == transdata_input_type && cast_input_type == transdata_output_type && - cast_output_type == kNumberTypeFloat16) { - return true; - } - return false; +bool IsTypeInsensitive(const CNodePtr &node) { + // Nodes that will change the input data type will not seen as type insensitive nodes. + static std::unordered_set type_insensitive_op_list{ + prim::KPrimTransData, prim::kPrimTranspose, prim::kPrimExpandDims, prim::kPrimReshape, + prim::kPrimSqueeze, prim::kPrimTile, prim::kPrimNeg, prim::kPrimRelu, + prim::kPrimMaximum, prim::kPrimMinimum, prim::kPrimSelect}; + + return std::any_of(type_insensitive_op_list.begin(), type_insensitive_op_list.end(), + [&node](const PrimitivePtr &prim) { return IsPrimitiveCNode(node, prim); }); } -void SetNodeInfo(const CNodePtr &transdata_node, const CNodePtr &cast_node, const CNodePtr &node) { - // Initial - // TransData: (type0, format0) -> (type0, format1) - // Cast: (type0, format1) -> (type1, format1) - // After reorder - // Cast: (type0, format0) -> (type1, format0) - // TransData: (type1, format0) -> (type1, format1) - auto type0 = AnfAlgo::GetInputDeviceDataType(transdata_node, 0); - auto type1 = AnfAlgo::GetOutputDeviceDataType(cast_node, 0); - auto format0 = AnfAlgo::GetInputFormat(transdata_node, 0); - auto format1 = AnfAlgo::GetOutputFormat(transdata_node, 0); +enum CastType { CAST_UP, CAST_DOWN, CAST_OTHER }; +CastType GetCastType(const CNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + if (!IsPrimitiveCNode(node, prim::kPrimCast)) { + MS_LOG(EXCEPTION) << "Only process for Cast!"; + } + TypeId input_type = AnfAlgo::GetInputDeviceDataType(node, 0); + TypeId output_type = AnfAlgo::GetOutputDeviceDataType(node, 0); - auto abstract = transdata_node->abstract(); - auto scope = cast_node->scope(); + if (input_type == kNumberTypeFloat16 && output_type == kNumberTypeFloat32) { + return CAST_UP; + } + + if (input_type == kNumberTypeFloat32 && output_type == kNumberTypeFloat16) { + return CAST_DOWN; + } + + return CAST_OTHER; +} + +std::vector GetOpDataInputIndexes(const CNodePtr &node) { + std::vector op_input_indexes; + if (node == nullptr || !IsTypeInsensitive(node)) { + return op_input_indexes; + } + + // Data input index starts from 0. + if (IsPrimitiveCNode(node, prim::kPrimMaximum) || IsPrimitiveCNode(node, prim::kPrimMinimum)) { + op_input_indexes = {0, 1}; + } else if (IsPrimitiveCNode(node, prim::kPrimSelect)) { + op_input_indexes = {1, 2}; + } else { + op_input_indexes = {0}; + } + return op_input_indexes; +} + +bool CheckInputTypeConsistent(const CNodePtr &node, const std::vector &check_indexes, const TypeId &base_type) { + MS_EXCEPTION_IF_NULL(node); + + // node's inputs at check_indexes should be of type base_type + for (const auto &index : check_indexes) { + if (AnfAlgo::GetInputDeviceDataType(node, index) != base_type) { + return false; + } + } + return true; +} + +void SetNodeInfo(const CNodePtr &orig_node, const CNodePtr &new_node, const TypeId &node_type) { + MS_EXCEPTION_IF_NULL(orig_node); + MS_EXCEPTION_IF_NULL(new_node); + + auto node_name = AnfAlgo::GetCNodeName(new_node); + auto orig_node_name = AnfAlgo::GetCNodeName(orig_node); + if (orig_node_name != node_name) { + MS_LOG(EXCEPTION) << "Can not process on different nodes " << orig_node_name << " and " << node_name; + } + + AbstractBasePtr new_abstract{nullptr}; std::vector inputs_format; std::vector outputs_format; std::vector inputs_device_type; - std::vector outputs_device_type; - auto kernel_type = AnfAlgo::GetKernelType(cast_node); - auto op_pattern = AnfAlgo::GetOpPattern(cast_node); - auto fusion_type = AnfAlgo::GetFusionType(cast_node); - auto processor = AnfAlgo::GetProcessor(cast_node); + std::vector outputs_device_type{node_type}; + KernelType kernel_type{AnfAlgo::GetKernelType(orig_node)}; + kernel::OpPattern op_pattern{AnfAlgo::GetOpPattern(orig_node)}; + kernel::FusionType fusion_type{AnfAlgo::GetFusionType(orig_node)}; + kernel::Processor processor{AnfAlgo::GetProcessor(orig_node)}; - auto node_name = AnfAlgo::GetCNodeName(node); + auto node_data_inputs_num = AnfAlgo::GetInputNum(new_node); + for (size_t i = 0; i < node_data_inputs_num; ++i) { + auto node_input = AnfAlgo::GetInputNode(new_node, i); + auto node_input_format = AnfAlgo::GetOutputFormat(node_input, 0); + auto node_input_type = AnfAlgo::GetOutputDeviceDataType(node_input, 0); + inputs_format.push_back(node_input_format); + inputs_device_type.push_back(node_input_type); + } if (node_name == "Cast") { - inputs_format.push_back(format0); - outputs_format.push_back(format0); - inputs_device_type.push_back(type0); - outputs_device_type.push_back(type1); - // Set attrs - AnfAlgo::CopyNodeAttrs(cast_node, node); - } else if (node_name == "TransData") { - abstract = cast_node->abstract(); - scope = transdata_node->scope(); - inputs_format.push_back(format0); - outputs_format.push_back(format1); - inputs_device_type.push_back(type1); - outputs_device_type.push_back(type1); - kernel_type = AnfAlgo::GetKernelType(transdata_node); - op_pattern = AnfAlgo::GetOpPattern(transdata_node); - fusion_type = AnfAlgo::GetFusionType(transdata_node); - processor = AnfAlgo::GetProcessor(transdata_node); - // Set attrs - AnfAlgo::CopyNodeAttrs(transdata_node, node); + auto node_input = AnfAlgo::GetInputNode(new_node, 0); + new_abstract = + std::make_shared(TypeIdToType(node_type), node_input->abstract()->BuildShape()); + outputs_format.push_back(AnfAlgo::GetOutputFormat(node_input, 0)); } else { - MS_LOG(EXCEPTION) << "Node must be Cast or TransData"; + new_abstract = + std::make_shared(TypeIdToType(node_type), orig_node->abstract()->BuildShape()); + outputs_format.push_back(AnfAlgo::GetOutputFormat(orig_node, 0)); } // Set abstract info - node->set_abstract(abstract); - // Set scope info - node->set_scope(scope); + new_node->set_abstract(new_abstract); + // Set attrs + AnfAlgo::CopyNodeAttrs(orig_node, new_node); // Set kernel build info - node->set_kernel_info(std::make_shared()); + new_node->set_kernel_info(std::make_shared()); kernel::KernelBuildInfo::KernelBuildInfoBuilder info_builder; info_builder.SetInputsFormat(inputs_format); info_builder.SetInputsDeviceType(inputs_device_type); @@ -108,10 +143,141 @@ void SetNodeInfo(const CNodePtr &transdata_node, const CNodePtr &cast_node, cons info_builder.SetOpPattern(op_pattern); info_builder.SetFusionType(fusion_type); info_builder.SetProcessor(processor); - AnfAlgo::SetSelectKernelBuildInfo(info_builder.Build(), node.get()); + AnfAlgo::SetSelectKernelBuildInfo(info_builder.Build(), new_node.get()); +} +} // namespace + +void ReorderOps::SetTypeInsensitiveNodeInputs(const CNodePtr &node, const std::vector &indexes, + const std::vector &new_input_at_indexes, + std::vector *new_inputs) { + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(new_inputs); + if (indexes.size() != new_input_at_indexes.size()) { + MS_LOG(EXCEPTION) << "indexes size " << indexes.size() << " is not equal to new_input_at_indexes size " + << new_input_at_indexes.size(); + } + if (!new_inputs->empty()) { + new_inputs->resize(0); + } + + // node's inputs at indexes change to new_input_at_indexes + std::unordered_set indexes_set(indexes.begin(), indexes.end()); + auto node_inputs_num = node->size(); + size_t idx = 0; + for (size_t i = 0; i < node_inputs_num; ++i) { + if (indexes_set.find(i) == indexes_set.end()) { + new_inputs->push_back(node->input(i)); + } else { + new_inputs->push_back(new_input_at_indexes[idx++]); + } + } } -bool ReorderTransDataCast(const FuncGraphPtr &func_graph) { +bool ReorderOps::ReorderTypeInsensitiveCastDown(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng, + const CNodePtr &node) { + // Limitation: Current cast node is CAST_DOWN. + if (!IsPrimitiveCNode(node, prim::kPrimCast) || GetCastType(node) != CAST_DOWN) { + return false; + } + + auto node_input = AnfAlgo::GetInputNode(node, 0); + auto type_insens_node = node_input->cast(); + // Limitation: + // Find type insensitive node before cast node. + // Type insensitive node is only used by current cast node. + if (type_insens_node == nullptr || !IsTypeInsensitive(type_insens_node) || + mng->node_users()[type_insens_node].size() > 1) { + return false; + } + + auto cast_input_type = AnfAlgo::GetInputDeviceDataType(node, 0); + auto cast_out_type = AnfAlgo::GetOutputDeviceDataType(node, 0); + auto op_input_indexes = GetOpDataInputIndexes(type_insens_node); + // Limitation: Type insensitive node's inputs have same data type. + if (op_input_indexes.empty() || !CheckInputTypeConsistent(type_insens_node, op_input_indexes, cast_input_type)) { + return false; + } + + std::vector new_cast_nodes; + for (const auto &index : op_input_indexes) { + auto new_cast_node = + func_graph->NewCNode({NewValueNode(prim::kPrimCast), AnfAlgo::GetInputNode(type_insens_node, index)}); + SetNodeInfo(node, new_cast_node, cast_out_type); + new_cast_nodes.push_back(new_cast_node); + } + + std::transform(op_input_indexes.begin(), op_input_indexes.end(), op_input_indexes.begin(), + [](const size_t &idx) { return idx + 1; }); + + std::vector type_insens_node_new_inputs; + SetTypeInsensitiveNodeInputs(type_insens_node, op_input_indexes, new_cast_nodes, &type_insens_node_new_inputs); + auto new_type_insens_node = func_graph->NewCNode(type_insens_node_new_inputs); + SetNodeInfo(type_insens_node, new_type_insens_node, cast_out_type); + + (void)mng->Replace(node, new_type_insens_node); + return true; +} + +bool ReorderOps::ReorderCastUpTypeInsensitive(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng, + const CNodePtr &node) { + if (!IsTypeInsensitive(node)) { + return false; + } + + // Limitation: + // Certain inputs of type insensitive node are cast node. + // Cast nodes are CAST_UP. + // All these cast nodes are only used by current type insensitive node. + std::vector cast_nodes; + std::vector cast_input_nodes; + auto op_input_indexes = GetOpDataInputIndexes(node); + for (const auto &index : op_input_indexes) { + auto node_input = AnfAlgo::GetInputNode(node, index); + auto cast_node = node_input->cast(); + if (cast_node != nullptr && IsPrimitiveCNode(cast_node, prim::kPrimCast) && GetCastType(cast_node) == CAST_UP && + mng->node_users()[cast_node].size() == 1) { + cast_nodes.push_back(cast_node); + cast_input_nodes.push_back(AnfAlgo::GetInputNode(cast_node, 0)); + } + } + if (cast_nodes.empty() || cast_nodes.size() != op_input_indexes.size()) { + return false; + } + + auto cast_input_type = AnfAlgo::GetInputDeviceDataType(cast_nodes[0], 0); + auto cast_out_type = AnfAlgo::GetOutputDeviceDataType(cast_nodes[0], 0); + // Limitation: All these cast nodes cast same type to another type. + if (!std::all_of(cast_nodes.begin(), cast_nodes.end(), [&cast_input_type](const CNodePtr &cast_node) { + return AnfAlgo::GetInputDeviceDataType(cast_node, 0) == cast_input_type; + })) { + return false; + } + // Limitation: Type insensitive node's inputs have same data type. + if (!CheckInputTypeConsistent(node, op_input_indexes, cast_out_type)) { + return false; + } + + std::transform(op_input_indexes.begin(), op_input_indexes.end(), op_input_indexes.begin(), + [](const size_t &idx) { return idx + 1; }); + + std::vector type_insens_node_new_inputs; + SetTypeInsensitiveNodeInputs(node, op_input_indexes, cast_input_nodes, &type_insens_node_new_inputs); + auto new_type_insens_node = func_graph->NewCNode(type_insens_node_new_inputs); + SetNodeInfo(node, new_type_insens_node, cast_input_type); + + auto new_cast_node = func_graph->NewCNode({NewValueNode(prim::kPrimCast), new_type_insens_node}); + SetNodeInfo(cast_nodes[0], new_cast_node, cast_out_type); + + (void)mng->Replace(node, new_cast_node); + return true; +} + +bool ReorderOps::ReorderCastTypeInsensitive(const FuncGraphPtr &func_graph) { + // Reorder cast node and type insensitive node in graph kernel sub-graph, this function has several limitations, + // see the comments that start will "Limitation:" in this file. + // Limitation: Assuming the type insensitive node will not change the type of input nodes, otherwise it can be seen + // as another cast node in some sense, such as LessEqual operator, which performs on two inputs and output a + // a boolean result. MS_EXCEPTION_IF_NULL(func_graph); auto mng = func_graph->manager(); if (mng == nullptr) { @@ -121,40 +287,52 @@ bool ReorderTransDataCast(const FuncGraphPtr &func_graph) { bool changed = false; auto todos = TopoSort(func_graph->get_return()); for (const auto &anf_node : todos) { - // Find cast node. - auto cast_node = anf_node->cast(); - if (cast_node == nullptr || !AnfAlgo::CheckPrimitiveType(cast_node, prim::kPrimCast)) { + auto node = anf_node->cast(); + if (node == nullptr) { continue; } - // Find transdata node before cast node. - auto cast_input = AnfAlgo::GetInputNode(cast_node, 0); - auto transdata_node = cast_input->cast(); - if (transdata_node == nullptr || !AnfAlgo::CheckPrimitiveType(transdata_node, prim::KPrimTransData)) { - continue; + if (IsTypeInsensitive(node)) { + // Reorder pattern 1: CastUp-TypeInsensitive --> TypeInsensitive-CastUp + changed = ReorderCastUpTypeInsensitive(func_graph, mng, node) || changed; + } else if (IsPrimitiveCNode(node, prim::kPrimCast)) { + // Reorder pattern 2: TypeInsensitive-CastDown --> CastDown-TypeInsensitive + changed = ReorderTypeInsensitiveCastDown(func_graph, mng, node) || changed; } - - // Reorder transdata_cast to cast_transdata if possible. - if (!CanReorder(mng, transdata_node, cast_node)) { - continue; - } - - MS_LOG(INFO) << "Reorder " << transdata_node->fullname_with_scope() << ", " << cast_node->fullname_with_scope(); - - auto new_cast_node = func_graph->NewCNode({NewValueNode(prim::kPrimCast), transdata_node->inputs()[1]}); - SetNodeInfo(transdata_node, cast_node, new_cast_node); - - auto new_transdata_node = func_graph->NewCNode({NewValueNode(prim::KPrimTransData), new_cast_node}); - SetNodeInfo(transdata_node, cast_node, new_transdata_node); - - (void)mng->Replace(cast_node, new_transdata_node); - changed = true; } return changed; } -} // namespace -bool ReorderOps::Run(const FuncGraphPtr &func_graph) { return ReorderTransDataCast(func_graph); } +bool ReorderOps::Run(const FuncGraphPtr &func_graph) { + MS_EXCEPTION_IF_NULL(func_graph); + auto mng = func_graph->manager(); + if (mng == nullptr) { + mng = Manage(func_graph, true); + func_graph->set_manager(mng); + } + + bool changed = false; + auto todos = TopoSort(func_graph->get_return()); + for (const auto &anf_node : todos) { + auto node = anf_node->cast(); + if (node == nullptr) { + continue; + } + + if (AnfAlgo::IsGraphKernel(node)) { + auto sub_func_graph = AnfAlgo::GetCNodeFuncGraphPtr(node); + bool need_traverse = true; + while (need_traverse) { + need_traverse = ReorderCastTypeInsensitive(sub_func_graph); + if (need_traverse) { + changed = true; + } + } + } + } + + return changed; +} } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/reorder_ops.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/reorder_ops.h index 119dbc2c320..9164844628e 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/reorder_ops.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/reorder_ops.h @@ -18,6 +18,7 @@ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_REORDER_OPS_H_ #include +#include #include "backend/optimizer/common/pass.h" namespace mindspore { @@ -27,6 +28,16 @@ class ReorderOps : public Pass { ReorderOps() : Pass("reorder_ops") {} ~ReorderOps() override = default; bool Run(const FuncGraphPtr &func_graph) override; + + private: + void SetTypeInsensitiveNodeInputs(const CNodePtr &node, const std::vector &indexes, + const std::vector &new_input_in_indexes, + std::vector *new_inputs); + bool ReorderTypeInsensitiveCastDown(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng, + const CNodePtr &node); + bool ReorderCastUpTypeInsensitive(const FuncGraphPtr &func_graph, const FuncGraphManagerPtr &mng, + const CNodePtr &node); + bool ReorderCastTypeInsensitive(const FuncGraphPtr &func_graph); }; using ReorderOpsPtr = std::shared_ptr; } // namespace opt diff --git a/tests/st/ops/graph_kernel/test_reorder_ops.py b/tests/st/ops/graph_kernel/test_reorder_ops.py new file mode 100644 index 00000000000..6652a8ee0e1 --- /dev/null +++ b/tests/st/ops/graph_kernel/test_reorder_ops.py @@ -0,0 +1,115 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import pytest +import numpy as np +import mindspore.context as context +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import Tensor +from mindspore.ops import operations as P + + +class CastUpNet(nn.Cell): + def __init__(self): + super(CastUpNet, self).__init__() + self.cast = P.Cast() + self.transpose = P.Transpose() + self.neg = P.Neg() + + def construct(self, i0): + res = self.cast(i0, mstype.float32) + res = self.transpose(res, (1, 0)) + res = self.neg(res) + return res + + +def get_castup_output(x0, enable_graph_kernel=False): + context.set_context(enable_graph_kernel=enable_graph_kernel) + net = CastUpNet() + output = net(x0) + return output + + +def test_castup(): + x0 = Tensor(np.random.normal(0, 1, (16, 16)).astype(np.float16)) + expect = get_castup_output(x0, False) + output = get_castup_output(x0, True) + expect_np = expect.asnumpy().copy() + output_np = output.asnumpy().copy() + assert np.allclose(expect_np, output_np, 1e-4, 1e-4) + + +class CastDownNet(nn.Cell): + def __init__(self): + super(CastDownNet, self).__init__() + self.cast = P.Cast() + self.transpose = P.Transpose() + self.neg = P.Neg() + + def construct(self, i0): + res = self.transpose(i0, (1, 0)) + res = self.neg(res) + res = self.cast(res, mstype.float16) + return res + + +def get_castdown_output(x0, enable_graph_kernel=False): + context.set_context(enable_graph_kernel=enable_graph_kernel) + net = CastDownNet() + output = net(x0) + return output + + +def test_castdown(): + x0 = Tensor(np.random.normal(0, 1, (16, 16)).astype(np.float32)) + expect = get_castdown_output(x0, False) + output = get_castdown_output(x0, True) + expect_np = expect.asnumpy().copy() + output_np = output.asnumpy().copy() + assert np.allclose(expect_np, output_np, 1e-3, 1e-3) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_castup_gpu(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_castup() + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_castup_ascend(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + test_castup() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_castdown_gpu(): + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + test_castdown() + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_castdown_ascend(): + context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + test_castdown()