mindspore/tests/ut/cpp/transform/convert_test.cc

945 lines
32 KiB
C++

/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <iostream>
#include <unordered_map>
#include "pybind11/pybind11.h"
#include "transform/transform_base_test.h"
#include "common/py_func_graph_fetcher.h"
#include "pipeline/parse/parse.h"
#include "debug/draw.h"
#include "debug/anf_ir_dump.h"
#include "pipeline/static_analysis/prim.h"
#include "operator/ops.h"
#include "common/common_test.h"
#define private public
#include "transform/types.h"
#include "transform/convert.h"
#include "securec/include/securec.h"
#include "utils/utils.h"
using std::cout;
using std::endl;
using std::string;
using std::unordered_map;
namespace mindspore {
namespace transform {
using AbstractScalar = abstract::AbstractScalar;
using mindspore::parse::ResolveAll;
class TestConvert : public UT::Common {
public:
TestConvert() {}
virtual void SetUp();
virtual void TearDown();
static const std::shared_ptr<Float> kF32;
};
void TestConvert::SetUp() { UT::InitPythonPath(); }
void TestConvert::TearDown() {}
const std::shared_ptr<Float> TestConvert::kF32 = std::make_shared<Float>(32);
AnfGraphPtr createAnfGraph() { return std::make_shared<AnfGraph>(); }
TEST_F(TestConvert, TestConstruct) {
AnfGraphPtr func_graph = std::make_shared<AnfGraph>();
DfGraphConvertor convertor(func_graph);
convertor.ConvertAllNode().GetComputeGraph();
ASSERT_NE(convertor.ErrCode(), SUCCESS);
}
#if (!defined ENABLE_GE)
namespace {
bool MakeDfGraph(PrimitivePtr prim, unsigned int nparam) {
std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, nparam);
std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
draw::Draw("ut_prim_" + prim->name() + ".dot", anf_graph);
DumpIR("ut_prim_" + prim->name() + ".ir", anf_graph);
DfGraphConvertor convertor(anf_graph);
auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
convertor.DrawComputeGraph(prim->name() + ".dot");
if (convertor.ErrCode() != 0) {
MS_LOG(ERROR) << "DfGraphConvertor convert " << prim->name() << " error, error code is: " << convertor.ErrCode();
return false;
}
if (df_graph == nullptr) {
MS_LOG(ERROR) << "DfGraphConvertor get " << prim->name() << " compute func_graph failed";
return false;
}
return true;
}
} // namespace
TEST_F(TestConvert, TestConvertConv2d) {
PrimitivePtr conv2d = prim::kPrimConv2D;
conv2d->AddAttr("stride", MakeValue(2));
conv2d->AddAttr("pad", MakeValue(0));
conv2d->AddAttr("dilation", MakeValue(0));
FuncGraphPtr anf_graph = MakeFuncGraph(conv2d, 2);
std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
draw::Draw("ut_prim_conv2d1.dot", anf_graph);
DumpIR("ut_prim_conv2d1.ir", anf_graph);
DfGraphConvertor convertor(anf_graph);
auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
convertor.DrawComputeGraph("conv2d.dot");
ASSERT_EQ(convertor.ErrCode(), 0);
ASSERT_NE(df_graph, nullptr);
}
TEST_F(TestConvert, TestConvertMaxpooling) {
auto prim = std::make_shared<Primitive>("MaxPool");
FuncGraphPtr anf_graph = MakeFuncGraph(prim, 5); // ary, ksize, stride, padding, data_format
std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
draw::Draw("ut_prim_maxpooling.dot", anf_graph);
DumpIR("ut_prim_maxpooling.ir", anf_graph);
DfGraphConvertor convertor(anf_graph);
auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
convertor.DrawComputeGraph("maxpooling.dot");
ASSERT_EQ(convertor.ErrCode(), 0);
ASSERT_NE(df_graph, nullptr);
}
TEST_F(TestConvert, TestReluOps) {
auto prim = prim::kPrimRelu;
prim->AddAttr("T", MakeValue(0));
auto func_graph = MakeFuncGraph(prim, 1);
ASSERT_TRUE(nullptr != func_graph);
// save the func_graph to manager
std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
// call resolve
bool ret_ = ResolveAll(manager);
ASSERT_TRUE(ret_);
// draw graph
auto anfGraph = *(manager->func_graphs().begin());
DfGraphConvertor convertor(anfGraph);
convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
ASSERT_EQ(convertor.ErrCode(), 0);
}
TEST_F(TestConvert, TestConvertBatchNorm) {
PrimitivePtr batch_norm = prim::kPrimBatchNorm;
batch_norm->AddAttr("epsilon", MakeValue(0.001f));
batch_norm->AddAttr("momentum", MakeValue(0.1f));
FuncGraphPtr anf_graph = std::make_shared<FuncGraph>();
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(batch_norm));
for (unsigned int i = 0; i < 5; i++) {
inputs.push_back(anf_graph->add_parameter());
}
CNodePtr cnode_prim = anf_graph->NewCNode(inputs);
inputs.clear();
inputs.push_back(NewValueNode(prim::kPrimTupleGetItem));
inputs.push_back(cnode_prim);
inputs.push_back(NewValueNode(2));
CNodePtr cnode_getitem = anf_graph->NewCNode(inputs);
inputs.clear();
inputs.push_back(NewValueNode(prim::kPrimRelu));
inputs.push_back(cnode_getitem);
CNodePtr cnode_relu = anf_graph->NewCNode(inputs);
inputs.clear();
inputs.push_back(NewValueNode(std::make_shared<Primitive>("return")));
inputs.push_back(cnode_relu);
CNodePtr cnode_return = anf_graph->NewCNode(inputs);
anf_graph->set_return(cnode_return);
std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
draw::Draw("ut_prim_batchnorm.dot", anf_graph);
DumpIR("ut_prim_batchnorm.ir", anf_graph);
DfGraphConvertor convertor(anf_graph);
auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
convertor.DrawComputeGraph("batchnrom.dot");
ASSERT_EQ(convertor.ErrCode(), 0);
ASSERT_NE(df_graph, nullptr);
}
TEST_F(TestConvert, TestConvertConvBackpropInput) {
auto prim = prim::kPrimConv2DBackpropInput;
const std::vector<int> list{1,1};
prim->AddAttr("stride", MakeValue(list));
prim->AddAttr("pad", MakeValue(0));
prim->AddAttr("pad_mode", MakeValue(std::string("pad")));
prim->AddAttr("dilation", MakeValue(1));
prim->AddAttr("group", MakeValue(1));
prim->AddAttr("mode", MakeValue(1));
prim->AddAttr("dilation", MakeValue(1));
auto func_graph = MakeFuncGraph(prim, 3);
ASSERT_NE(func_graph, nullptr);
// save the func_graph to manager
std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
// call resolve
bool ret_ = ResolveAll(manager);
ASSERT_TRUE(ret_);
// draw graph
auto anf_graph = *(manager->func_graphs().begin());
DfGraphConvertor convertor(anf_graph);
auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
convertor.DrawComputeGraph("Conv2DBackpropInput.dot");
ASSERT_EQ(convertor.ErrCode(), 0);
ASSERT_NE(df_graph, nullptr);
}
TEST_F(TestConvert, TestConvertConvBackpropFilter) {
auto prim = prim::kPrimConv2DBackpropFilter;
const std::vector<int> list{1,1};
prim->AddAttr("stride", MakeValue(list));
prim->AddAttr("pad", MakeValue(0));
prim->AddAttr("pad_mode", MakeValue(std::string("pad")));
prim->AddAttr("dilation", MakeValue(1));
prim->AddAttr("group", MakeValue(1));
prim->AddAttr("mode", MakeValue(1));
prim->AddAttr("dilation", MakeValue(1));
auto func_graph = MakeFuncGraph(prim, 3);
ASSERT_NE(func_graph, nullptr);
// save the func_graph to manager
std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
// call resolve
bool ret_ = ResolveAll(manager);
ASSERT_TRUE(ret_);
// draw graph
auto anf_graph = *(manager->func_graphs().begin());
DfGraphConvertor convertor(anf_graph);
auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
convertor.DrawComputeGraph("Conv2DBackpropFilter.dot");
ASSERT_EQ(convertor.ErrCode(), 0);
ASSERT_NE(df_graph, nullptr);
}
TEST_F(TestConvert, TestConvertReluGrad) {
auto prim = prim::kPrimReluGrad;
prim->AddAttr("alpha", MakeValue(0.1f));
prim->AddAttr("beta", MakeValue(0.1f));
prim->AddAttr("mode", MakeValue(1));
auto func_graph = MakeFuncGraph(prim, 2);
ASSERT_NE(func_graph, nullptr);
// save the func_graph to manager
std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
// call resolve
bool ret_ = ResolveAll(manager);
ASSERT_TRUE(ret_);
// draw graph
auto anf_graph = *(manager->func_graphs().begin());
DfGraphConvertor convertor(anf_graph);
auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
convertor.DrawComputeGraph("ReluGrad.dot");
ASSERT_EQ(convertor.ErrCode(), 0);
ASSERT_NE(df_graph, nullptr);
}
TEST_F(TestConvert, TestConvertBiasAdd) {
auto prim = std::make_shared<Primitive>("BiasAdd");
prim->AddAttr("alpha", MakeValue(0.0f));
prim->AddAttr("beta", MakeValue(1.0f));
prim->AddAttr("format", MakeValue(1));
auto func_graph = MakeFuncGraph(prim, 2);
ASSERT_NE(func_graph, nullptr);
// save the func_graph to manager
std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
// call resolve
bool ret_ = ResolveAll(manager);
ASSERT_TRUE(ret_);
// draw graph
auto anf_graph = *(manager->func_graphs().begin());
DfGraphConvertor convertor(anf_graph);
auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
convertor.DrawComputeGraph("BiasAdd.dot");
ASSERT_EQ(convertor.ErrCode(), 0);
ASSERT_NE(df_graph, nullptr);
}
TEST_F(TestConvert, TestConvertBiasAddGrad) {
auto prim = prim::kPrimBiasAddGrad;
prim->AddAttr("alpha", MakeValue(0.0f));
prim->AddAttr("beta", MakeValue(1.0f));
prim->AddAttr("format", MakeValue(1));
auto func_graph = MakeFuncGraph(prim, 2);
ASSERT_NE(func_graph, nullptr);
// save the func_graph to manager
std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
// call resolve
bool ret_ = ResolveAll(manager);
ASSERT_TRUE(ret_);
// draw graph
auto anf_graph = *(manager->func_graphs().begin());
DfGraphConvertor convertor(anf_graph);
auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
convertor.DrawComputeGraph("BiasAddGrad.dot");
ASSERT_EQ(convertor.ErrCode(), 0);
ASSERT_NE(df_graph, nullptr);
}
TEST_F(TestConvert, TestConvertMaxPoolGradWithArgmax) {
auto prim = std::make_shared<Primitive>("MaxPoolGradWithArgmax");
prim->AddAttr("alpha", MakeValue(0.0f));
prim->AddAttr("beta", MakeValue(1.0f));
prim->AddAttr("window", MakeValue(2));
prim->AddAttr("stride", MakeValue(1));
prim->AddAttr("ceil_mode", MakeValue(0));
prim->AddAttr("data_mode", MakeValue(0));
prim->AddAttr("alpha", MakeValue(0.1f));
prim->AddAttr("beta", MakeValue(1.0f));
auto func_graph = MakeFuncGraph(prim, 2);
ASSERT_NE(func_graph, nullptr);
// save the func_graph to manager
std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
// call resolve
bool ret_ = ResolveAll(manager);
ASSERT_TRUE(ret_);
// draw graph
auto anf_graph = *(manager->func_graphs().begin());
DfGraphConvertor convertor(anf_graph);
auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
convertor.DrawComputeGraph("MaxPoolGradWithArgmax.dot");
ASSERT_EQ(convertor.ErrCode(), 0);
ASSERT_NE(df_graph, nullptr);
}
TEST_F(TestConvert, TestConcat) {
auto prim = prim::kPrimConcat;
std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
draw::Draw("ut_prim_concat.dot", anf_graph);
DumpIR("ut_prim_concat.ir", anf_graph);
DfGraphConvertor convertor(anf_graph);
auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
convertor.DrawComputeGraph("concat.dot");
ASSERT_EQ(convertor.ErrCode(), 0);
ASSERT_NE(df_graph, nullptr);
}
TEST_F(TestConvert, TestGatherV2) {
auto prim = prim::kPrimGatherV2;
std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 3);
std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
draw::Draw("ut_prim_gatherv2.dot", anf_graph);
DumpIR("ut_prim_gatherv2.ir", anf_graph);
DfGraphConvertor convertor(anf_graph);
auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
convertor.DrawComputeGraph("gatherv2.dot");
ASSERT_EQ(convertor.ErrCode(), 0);
ASSERT_NE(df_graph, nullptr);
}
TEST_F(TestConvert, TestCast) {
auto prim = prim::kPrimCast;
std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
draw::Draw("ut_prim_cast.dot", anf_graph);
DumpIR("ut_prim_cast.ir", anf_graph);
DfGraphConvertor convertor(anf_graph);
auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
convertor.DrawComputeGraph("cast.dot");
ASSERT_EQ(convertor.ErrCode(), 0);
ASSERT_NE(df_graph, nullptr);
}
TEST_F(TestConvert, TestExp) {
auto prim = std::make_shared<Primitive>("Exp");
std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
draw::Draw("ut_prim_exp.dot", anf_graph);
DumpIR("ut_prim_exp.ir", anf_graph);
DfGraphConvertor convertor(anf_graph);
auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
convertor.DrawComputeGraph("exp.dot");
ASSERT_EQ(convertor.ErrCode(), 0);
ASSERT_NE(df_graph, nullptr);
}
TEST_F(TestConvert, TestFloor) {
auto prim = std::make_shared<Primitive>("Floor");
std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
draw::Draw("ut_prim_floor.dot", anf_graph);
DumpIR("ut_prim_floor.ir", anf_graph);
DfGraphConvertor convertor(anf_graph);
auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
convertor.DrawComputeGraph("floor.dot");
ASSERT_EQ(convertor.ErrCode(), 0);
ASSERT_NE(df_graph, nullptr);
}
TEST_F(TestConvert, TestGreaterEqual) {
auto prim = std::make_shared<Primitive>("GreaterEqual");
std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
draw::Draw("ut_prim_greater_equal.dot", anf_graph);
DumpIR("ut_prim_greater_equal.ir", anf_graph);
DfGraphConvertor convertor(anf_graph);
auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
convertor.DrawComputeGraph("greater_equal.dot");
ASSERT_EQ(convertor.ErrCode(), 0);
ASSERT_NE(df_graph, nullptr);
}
TEST_F(TestConvert, TestLess) {
auto prim = std::make_shared<Primitive>("Less");
prim->AddAttr("T", MakeValue(kFloat32));
std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
draw::Draw("ut_prim_less.dot", anf_graph);
DumpIR("ut_prim_less.ir", anf_graph);
DfGraphConvertor convertor(anf_graph);
auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
convertor.DrawComputeGraph("less.dot");
ASSERT_EQ(convertor.ErrCode(), 0);
ASSERT_NE(df_graph, nullptr);
}
TEST_F(TestConvert, TestLessEqual) {
auto prim = std::make_shared<Primitive>("LessEqual");
std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
draw::Draw("ut_prim_less_equal.dot", anf_graph);
DumpIR("ut_prim_less_equal.ir", anf_graph);
DfGraphConvertor convertor(anf_graph);
auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
convertor.DrawComputeGraph("less_equal.dot");
ASSERT_EQ(convertor.ErrCode(), 0);
ASSERT_NE(df_graph, nullptr);
}
TEST_F(TestConvert, TestLogicalNot) {
auto prim = std::make_shared<Primitive>("LogicalNot");
std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
draw::Draw("ut_prim_logical_not.dot", anf_graph);
DumpIR("ut_prim_logical_not.ir", anf_graph);
DfGraphConvertor convertor(anf_graph);
auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
convertor.DrawComputeGraph("logical_not.dot");
ASSERT_EQ(convertor.ErrCode(), 0);
ASSERT_NE(df_graph, nullptr);
}
TEST_F(TestConvert, TestAssignAdd) {
auto prim = prim::kPrimAssignAdd;
prim->AddAttr("use_locking", MakeValue(true));
std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 2);
std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
draw::Draw("ut_prim_assign_add.dot", anf_graph);
DumpIR("ut_prim_assign_add.ir", anf_graph);
DfGraphConvertor convertor(anf_graph);
auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
convertor.DrawComputeGraph("assign_add.dot");
ASSERT_EQ(convertor.ErrCode(), 0);
ASSERT_NE(df_graph, nullptr);
}
TEST_F(TestConvert, LogSoftmax) {
auto prim = prim::kPrimLogSoftmax;
prim->AddAttr("axis", MakeValue(0));
std::shared_ptr<FuncGraph> anf_graph = MakeFuncGraph(prim, 1);
std::shared_ptr<FuncGraphManager> graph_manager = MakeManager({anf_graph});
draw::Draw("ut_prim_log_softmax.dot", anf_graph);
DumpIR("ut_prim_log_softmax.ir", anf_graph);
DfGraphConvertor convertor(anf_graph);
auto df_graph = convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
convertor.DrawComputeGraph("log_softmax.dot");
ASSERT_EQ(convertor.ErrCode(), 0);
ASSERT_NE(df_graph, nullptr);
}
TEST_F(TestConvert, TestMaximumOps) {
auto prim = prim::kPrimMaximum;
bool ret = MakeDfGraph(prim, 2);
ASSERT_TRUE(ret);
}
TEST_F(TestConvert, TestReduceMeanOps) {
auto prim = prim::kPrimReduceMean;
prim->AddAttr("keepdims", MakeValue(true));
bool ret = MakeDfGraph(prim, 2);
ASSERT_TRUE(ret);
}
TEST_F(TestConvert, TestMinimumOps) {
auto prim = prim::kPrimMinimum;
bool ret = MakeDfGraph(prim, 2);
ASSERT_TRUE(ret);
}
TEST_F(TestConvert, TestFusedMinOrMaxGradOps) {
// Add infer step to this test case
ASSERT_TRUE(true);
}
TEST_F(TestConvert, TestSqueezeOps) {
auto prim = prim::kPrimSqueeze;
bool ret = MakeDfGraph(prim, 2);
ASSERT_TRUE(ret);
}
TEST_F(TestConvert, TestMulOps) {
auto prim = prim::kPrimMul;
bool ret = MakeDfGraph(prim, 2);
ASSERT_TRUE(ret);
}
TEST_F(TestConvert, TestNegOps) {
auto prim = prim::kPrimNeg;
bool ret = MakeDfGraph(prim, 1);
ASSERT_TRUE(ret);
}
TEST_F(TestConvert, TestOneHotOps) {
auto prim = prim::kPrimOneHot;
prim->AddAttr("axis", MakeValue(0));
bool ret = MakeDfGraph(prim, 4);
ASSERT_TRUE(ret);
}
TEST_F(TestConvert, TestPowOps) {
auto prim = std::make_shared<Primitive>("Pow");
bool ret = MakeDfGraph(prim, 2);
ASSERT_TRUE(ret);
}
TEST_F(TestConvert, TestReciprocalOps) {
auto prim = std::make_shared<Primitive>("Reciprocal");
bool ret = MakeDfGraph(prim, 1);
ASSERT_TRUE(ret);
}
TEST_F(TestConvert, TestSelectOps) {
auto prim = prim::kPrimSelect;
bool ret = MakeDfGraph(prim, 3);
ASSERT_TRUE(ret);
}
TEST_F(TestConvert, TestSqrtOps) {
auto prim = std::make_shared<Primitive>("Sqrt");
bool ret = MakeDfGraph(prim, 1);
ASSERT_TRUE(ret);
}
TEST_F(TestConvert, TestSquareOps) {
auto prim = std::make_shared<Primitive>("Square");
bool ret = MakeDfGraph(prim, 1);
ASSERT_TRUE(ret);
}
TEST_F(TestConvert, TestScalarSummaryOps) {
auto prim = prim::kPrimScalarSummary;
// should have only 1 input.
bool ret = MakeDfGraph(prim, 2);
ASSERT_TRUE(ret);
}
TEST_F(TestConvert, TestTensorSummaryOps) {
auto prim = prim::kPrimTensorSummary;
bool ret = MakeDfGraph(prim, 2);
ASSERT_TRUE(ret);
}
TEST_F(TestConvert, TestHistogramSummaryOps) {
auto prim = prim::kPrimHistogramSummary;
bool ret = MakeDfGraph(prim, 2);
ASSERT_TRUE(ret);
}
TEST_F(TestConvert, TestGreaterOps) {
auto prim = std::make_shared<Primitive>("Greater");
bool ret = MakeDfGraph(prim, 2);
ASSERT_TRUE(ret);
}
TEST_F(TestConvert, TestEqualOps) {
auto prim = std::make_shared<Primitive>("Equal");
bool ret = MakeDfGraph(prim, 2);
ASSERT_TRUE(ret);
}
TEST_F(TestConvert, TestArgMaxiOps) {
auto prim = std::make_shared<Primitive>("Argmax");
bool ret = MakeDfGraph(prim, 2);
ASSERT_TRUE(ret);
}
TEST_F(TestConvert, TestResizeNearestNeighborOps) {
auto prim = std::make_shared<Primitive>("ResizeNearestNeighbor");
bool ret = MakeDfGraph(prim, 1);
ASSERT_TRUE(ret);
}
TEST_F(TestConvert, TestApplyMomentumOps) {
auto prim = std::make_shared<Primitive>("ApplyMomentum");
bool ret = MakeDfGraph(prim, 5);
ASSERT_TRUE(ret);
}
TEST_F(TestConvert, TestNPUGetFloatStatusOps) {
auto prim = std::make_shared<Primitive>("NPUGetFloatStatus");
bool ret = MakeDfGraph(prim, 1);
ASSERT_TRUE(ret);
}
TEST_F(TestConvert, TestNPUAllocFloatStatusOps) {
auto prim = std::make_shared<Primitive>("NPUAllocFloatStatus");
bool ret = MakeDfGraph(prim, 0);
ASSERT_TRUE(ret);
}
TEST_F(TestConvert, TestNPUClearFloatStatusOps) {
auto prim = std::make_shared<Primitive>("NPUClearFloatStatus");
bool ret = MakeDfGraph(prim, 1);
ASSERT_TRUE(ret);
}
#endif
TEST_F(TestConvert, TestAddOps) {
auto prim = std::make_shared<Primitive>("TensorAdd");
auto func_graph = MakeFuncGraph(prim, 2);
ASSERT_TRUE(nullptr != func_graph);
// save the func_graph to manager
std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
// call resolve
bool ret_ = ResolveAll(manager);
ASSERT_TRUE(ret_);
// draw graph
auto anfGraph = *(manager->func_graphs().begin());
DfGraphConvertor convertor(anfGraph);
convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
ASSERT_EQ(convertor.ErrCode(), 0);
}
TEST_F(TestConvert, TestConvertTensor) {
float data[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11};
// Create a tensor with wanted data type and shape
std::vector<int> dims{2, 2, 3};
std::vector<int64_t> ge_dims{2, 2, 3};
auto type_id = kNumberTypeFloat32;
MeTensor me_tensor(type_id, dims);
// Get the writable data pointer of the tensor and cast it to its data type
uint8_t* me_data_ptr = reinterpret_cast<uint8_t*>(me_tensor.data_c(true));
// Copy or use the writable data pointer of the ME tensor
memcpy_s(me_data_ptr, me_tensor.data().nbytes(), data, 12 * sizeof(float));
auto me_tensor_ptr = std::make_shared<MeTensor>(me_tensor);
auto ge_tensor_ptr = TransformUtil::ConvertTensor(me_tensor_ptr, kOpFormat_NCHW);
ASSERT_EQ(ge_tensor_ptr->GetTensorDesc().GetFormat(), GeFormat::FORMAT_NCHW);
ASSERT_EQ(ge_tensor_ptr->GetTensorDesc().GetDataType(), GeDataType::DT_FLOAT);
// ASSERT_EQ(ge_tensor_ptr->GetTensorDesc().array().GetDims(), ge_dims);
int i = 0;
for (i = 0; i < ge_dims.size(); i++) {
ASSERT_EQ(ge_dims[i], ge_tensor_ptr->GetTensorDesc().GetShape().GetDims()[i]);
}
for (i = 0; i < ge_tensor_ptr->GetTensorDesc().GetShape().GetShapeSize(); i++) {
ASSERT_EQ(data[i], (reinterpret_cast<float*>(ge_tensor_ptr->GetData()))[i]);
}
}
TEST_F(TestConvert, TestConvertTensor0Dims) {
// shape with 0 dims is also valid
std::vector<int> dims{};
auto type_id = kNumberTypeFloat32;
auto me_tensor_ptr = std::make_shared<MeTensor>(type_id, dims);
ASSERT_NE(TransformUtil::ConvertTensor(me_tensor_ptr, kOpFormat_NCHW), nullptr);
}
TEST_F(TestConvert, TestConvertTensorError) {
std::vector<int> dims2{2, 3, 4};
auto type_id_2 = kNumberTypeFloat32;
auto me_tensor_ptr_2 = std::make_shared<MeTensor>(type_id_2, dims2);
ASSERT_EQ(TransformUtil::ConvertTensor(me_tensor_ptr_2, "xyz"), nullptr);
}
TEST_F(TestConvert, TestUtilsConvertDataType) {
ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeFloat16), GeDataType::DT_FLOAT16);
ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeFloat32), GeDataType::DT_FLOAT);
ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeFloat64), GeDataType::DT_DOUBLE);
ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt8), GeDataType::DT_INT8);
ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt16), GeDataType::DT_INT16);
ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt32), GeDataType::DT_INT32);
ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeInt64), GeDataType::DT_INT64);
ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeUInt32), GeDataType::DT_UINT32);
ASSERT_EQ(TransformUtil::ConvertDataType(MeDataType::kNumberTypeBool), GeDataType::DT_BOOL);
}
TEST_F(TestConvert, TestUtilsConvertFormat) {
ASSERT_EQ(TransformUtil::ConvertFormat(kOpFormat_NCHW), GeFormat::FORMAT_NCHW);
ASSERT_EQ(TransformUtil::ConvertFormat(kOpFormat_NC1HWC0), GeFormat::FORMAT_NC1HWC0);
ASSERT_EQ(TransformUtil::ConvertFormat(kOpFormat_NHWC), GeFormat::FORMAT_NHWC);
ASSERT_EQ(TransformUtil::ConvertFormat("xyz"), GeFormat::FORMAT_ND);
}
TEST_F(TestConvert, TestUtilsDataSize) {
ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeFloat32), 4);
ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeFloat16), 2);
ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeFloat64), 8);
ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt8), 1);
ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt16), 2);
ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt32), 4);
ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeInt64), 8);
ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeUInt32), 4);
ASSERT_EQ(TransformUtil::GetDataTypeSize(MeDataType::kNumberTypeBool), 1);
}
TEST_F(TestConvert, TestConvertGeTensor) {
#define DTYPE float
ge::DataType dt = ge::DataType::DT_FLOAT;
std::vector<float> data1 = {1.1, 2.2, 3.3, 4.4, 6.6, 7.7, 8.8, 9.9};
std::vector<DTYPE> data2 = {1, 2, 3, 4, 6, 7, 8, 9};
auto data = data1;
ge::Shape shape({2, 2, 2});
ge::Format format = ge::Format::FORMAT_NCHW;
ge::TensorDesc desc(shape, format, dt);
GeTensorPtr ge_tensor_ptr =
std::make_shared<GeTensor>(desc, reinterpret_cast<uint8_t*>(data.data()), data.size() * sizeof(DTYPE));
GeTensor& ge_tensor = *ge_tensor_ptr;
const DTYPE* ge_data = reinterpret_cast<DTYPE*>(ge_tensor.GetData());
// make sure GetData()'s return is a reference
assert(ge_data == reinterpret_cast<DTYPE*>(ge_tensor.GetData()));
cout << "ge data size is: " << std::dec << ge_tensor.GetSize() << " bytes" << endl;
for (int i = 0; i < ge_tensor.GetSize() / sizeof(DTYPE); i++) {
cout << "ge data is: " << static_cast<DTYPE>(*(ge_data + i)) << endl;
}
MeTensorPtr me_tensor_ptr = TransformUtil::ConvertGeTensor(ge_tensor_ptr);
MeTensor& me_tensor = *me_tensor_ptr;
cout << "after convert ge tensor to me tensor" << endl;
DTYPE* me_data = reinterpret_cast<DTYPE*>(me_tensor.data_c());
PrintMeTensor(&me_tensor);
assert(ge_tensor.GetSize() == me_tensor.data().nbytes());
assert(memcmp(ge_data, me_data, ge_tensor.GetSize()) == 0);
}
TEST_F(TestConvert, TestConvertMakeTuple) {
FuncGraphPtr func_graph = std::make_shared<FuncGraph>();
std::vector<AnfNodePtr> inputs;
inputs.push_back(NewValueNode(std::make_shared<Primitive>("make_tuple")));
for (int i = 0; i < 3; i++) {
auto input = func_graph->add_parameter();
input->set_name("x" + std::to_string(i));
inputs.push_back(input);
}
CNodePtr cnode_prim = func_graph->NewCNode(inputs);
inputs.clear();
inputs.push_back(NewValueNode(std::make_shared<Primitive>("return")));
inputs.push_back(cnode_prim);
CNodePtr cnode_return = func_graph->NewCNode(inputs);
func_graph->set_return(cnode_return);
// save the func_graph to manager
std::shared_ptr<FuncGraphManager> manager = Manage(func_graph);
// call resolve
bool ret_ = ResolveAll(manager);
ASSERT_TRUE(ret_);
// draw graph
auto anfGraph = *(manager->func_graphs().begin());
DfGraphConvertor convertor(anfGraph);
convertor.ConvertAllNode().BuildGraph().GetComputeGraph();
ASSERT_EQ(convertor.ErrCode(), 0);
}
TEST_F(TestConvert, TestConvertInputTensors) {
#define DTYPE float
MeTensorPtr input_ptr1 = MakeTensor(kF32, {1, 1, 4, 4});
MeTensorPtr input_ptr2 = MakeTensor(kF32, {2, 3, 4, 5});
MeTensorPtr input_ptr3 = MakeTensor(kF32, {9, 9, 1, 1});
std::vector<MeTensorPtr> me_inputs;
me_inputs.emplace_back(input_ptr1);
me_inputs.emplace_back(input_ptr2);
me_inputs.emplace_back(input_ptr3);
std::vector<GeTensorPtr> ge_tensors = TransformUtil::ConvertInputTensors(me_inputs, kOpFormat_NCHW);
for (int i = 0; i < ge_tensors.size(); i++) {
DTYPE* me_data = reinterpret_cast<DTYPE*>(me_inputs[i]->data_c());
const DTYPE* ge_data = reinterpret_cast<DTYPE*>(ge_tensors[i]->GetData());
ASSERT_TRUE(ge_tensors[i]->GetSize() == me_inputs[i]->data().nbytes());
ASSERT_EQ(memcmp(ge_data, me_data, ge_tensors[i]->GetSize()), 0);
ASSERT_TRUE(ge_tensors[i]->GetTensorDesc().GetShape().GetDims() ==
TransformUtil::ConvertMeShape(me_inputs[i]->shape_c()).GetDims());
}
}
TEST_F(TestConvert, TestConvertGeTensors) {
#define DTYPE float
ge::DataType dt = ge::DataType::DT_FLOAT;
std::vector<float> data1(16);
std::vector<float> data2(120);
std::vector<float> data3(81);
ge::Shape shape1({1, 1, 4, 4});
ge::Shape shape2({2, 3, 4, 5});
ge::Shape shape3({9, 9, 1, 1});
ge::Format format = ge::Format::FORMAT_NCHW;
ge::TensorDesc desc1(shape1, format, dt);
ge::TensorDesc desc2(shape2, format, dt);
ge::TensorDesc desc3(shape3, format, dt);
GeTensorPtr ge_tensor_ptr1 =
std::make_shared<GeTensor>(desc1, reinterpret_cast<uint8_t*>(data1.data()), data1.size() * sizeof(DTYPE));
GeTensorPtr ge_tensor_ptr2 =
std::make_shared<GeTensor>(desc2, reinterpret_cast<uint8_t*>(data2.data()), data2.size() * sizeof(DTYPE));
GeTensorPtr ge_tensor_ptr3 =
std::make_shared<GeTensor>(desc3, reinterpret_cast<uint8_t*>(data3.data()), data3.size() * sizeof(DTYPE));
std::vector<GeTensorPtr> ge_tensors;
ge_tensors.emplace_back(ge_tensor_ptr1);
ge_tensors.emplace_back(ge_tensor_ptr2);
ge_tensors.emplace_back(ge_tensor_ptr3);
std::vector<std::vector<int>> request_dims;
std::vector<int> dims1 = {1, 1, 4, 4};
std::vector<int> dims2 = {2, 3, 4, 5};
std::vector<int> dims3 = {9, 9, 1, 1};
request_dims.emplace_back(dims1);
request_dims.emplace_back(dims2);
request_dims.emplace_back(dims3);
std::vector<MeTensorPtr> me_outputs = TransformUtil::ConvertGeTensors(ge_tensors, request_dims);
for (int i = 0; i < ge_tensors.size(); i++) {
DTYPE* me_data = reinterpret_cast<DTYPE*>(me_outputs[i]->data_c());
const DTYPE* ge_data = reinterpret_cast<DTYPE*>(ge_tensors[i]->GetData());
ASSERT_TRUE(ge_tensors[i]->GetSize() == me_outputs[i]->data().nbytes());
ASSERT_EQ(memcmp(ge_data, me_data, ge_tensors[i]->GetSize()), 0);
ASSERT_TRUE(request_dims[i] == me_outputs[i]->shape_c());
}
}
TEST_F(TestConvert, TestConvertGeShape1) {
GeShape ge_shape({10, 1, 1, 1});
std::vector<int> request_dims{10};
ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
}
TEST_F(TestConvert, TestConvertGeShape2) {
GeShape ge_shape({10, 15, 1, 1});
std::vector<int> request_dims{10, 15};
ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
}
TEST_F(TestConvert, TestConvertGeShape3) {
GeShape ge_shape({10, 13, 18, 1});
std::vector<int> request_dims{10, 13, 18};
ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
}
TEST_F(TestConvert, TestConvertGeShape4) {
GeShape ge_shape({1, 10, 1, 1});
std::vector<int> request_dims{10};
ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == request_dims);
}
TEST_F(TestConvert, TestConvertGeShape5) {
GeShape ge_shape({10, 1, 1, 2});
std::vector<int> request_dims{10};
ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == TransformUtil::ConvertGeShape(ge_shape));
}
TEST_F(TestConvert, TestConvertGeShape6) {
GeShape ge_shape({5, 2, 1, 1});
std::vector<int> request_dims{10};
ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == TransformUtil::ConvertGeShape(ge_shape));
}
TEST_F(TestConvert, TestConvertGeShape7) {
GeShape ge_shape({10});
std::vector<int> request_dims{10, 1};
ASSERT_TRUE(TransformUtil::ConvertGeShape(ge_shape, request_dims) == TransformUtil::ConvertGeShape(ge_shape));
}
} // namespace transform
} // namespace mindspore