diff --git a/mindspore/ccsrc/cxx_api/model/acl/acl_model_multi.cc b/mindspore/ccsrc/cxx_api/model/acl/acl_model_multi.cc index 42c83543b6b..59cdf3b4349 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/acl_model_multi.cc +++ b/mindspore/ccsrc/cxx_api/model/acl/acl_model_multi.cc @@ -20,6 +20,8 @@ #include #include #include +#include +#include #include "backend/session/session_basic.h" #include "backend/session/session_factory.h" #include "cxx_api/factory.h" @@ -33,6 +35,14 @@ namespace mindspore { API_FACTORY_REG(ModelImpl, Ascend310, AclModelMulti); namespace { +std::map kDtypeMap = { + {DataType::kNumberTypeBool, sizeof(bool)}, {DataType::kNumberTypeInt8, sizeof(int8_t)}, + {DataType::kNumberTypeInt16, sizeof(int16_t)}, {DataType::kNumberTypeInt32, sizeof(int32_t)}, + {DataType::kNumberTypeInt64, sizeof(int64_t)}, {DataType::kNumberTypeFloat16, sizeof(float16)}, + {DataType::kNumberTypeFloat32, sizeof(float)}, {DataType::kNumberTypeFloat64, sizeof(double)}, + {DataType::kNumberTypeUInt8, sizeof(uint8_t)}, {DataType::kNumberTypeUInt16, sizeof(uint16_t)}, + {DataType::kNumberTypeUInt32, sizeof(uint32_t)}, {DataType::kNumberTypeUInt64, sizeof(uint64_t)}}; + class MSTensorRef : public BaseRef { public: static VectorRef Convert(const std::vector &tensors) { @@ -369,6 +379,8 @@ Status AclModelMulti::Build() { return abstract; }); (void)InferMindir(ModelImpl::GetFuncGraph(), broaded_args); + // set output + SetOutput(); // create vm auto backend = CreateBackend(std::make_shared(model_context_)); auto context_ptr = MsContext::GetInstance(); @@ -449,6 +461,51 @@ void AclModelMulti::SetInputs() { } } +void AclModelMulti::SetOutput() { + if (outputs_.empty()) { + auto fg = ModelImpl::GetFuncGraph(); + MS_EXCEPTION_IF_NULL(fg); + const auto output = fg->output(); + MS_EXCEPTION_IF_NULL(output); + auto abs = output->abstract(); + MS_EXCEPTION_IF_NULL(abs); + + // DataType + DataType type_id; + if (abs->isa()) { + auto abs_tensor = abs->cast(); + auto ele = abs_tensor->element(); + MS_EXCEPTION_IF_NULL(ele); + MS_EXCEPTION_IF_NULL(ele->GetTypeTrack()); + type_id = static_cast(ele->GetTypeTrack()->type_id()); + } else { + MS_EXCEPTION_IF_NULL(abs->GetTypeTrack()); + type_id = static_cast(abs->GetTypeTrack()->type_id()); + } + // Shape + auto shape_track = abs->GetShapeTrack(); + MS_EXCEPTION_IF_NULL(shape_track); + std::vector shape = {}; + if (shape_track->isa()) { + auto shapeptr = shape_track->cast(); + shape = static_cast>(shapeptr->shape()); + } + // Size + size_t ato_size = 0; + if (kDtypeMap.find(type_id) != kDtypeMap.end()) { + ato_size = kDtypeMap[type_id]; + } + int64_t ele_num = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); + size_t size = ato_size * ele_num; + // create tensor + auto output_tensor = MSTensor::CreateTensor("", type_id, shape, nullptr, size); + outputs_.emplace_back(*output_tensor); + MSTensor::DestroyTensorPtr(output_tensor); + } else { + MS_LOG(DEBUG) << "outputs_ has been set."; + } +} + std::vector AclModelMulti::GetInputs() { if (!is_multi_graph_.has_value()) { is_multi_graph_ = ModelImpl::GetFuncGraph() == nullptr ? false : HasMultiGraph(ModelImpl::GetFuncGraph()); diff --git a/mindspore/ccsrc/cxx_api/model/acl/acl_model_multi.h b/mindspore/ccsrc/cxx_api/model/acl/acl_model_multi.h index 11f177f5b09..56d14070418 100644 --- a/mindspore/ccsrc/cxx_api/model/acl/acl_model_multi.h +++ b/mindspore/ccsrc/cxx_api/model/acl/acl_model_multi.h @@ -21,6 +21,7 @@ #include #include #include +#include namespace mindspore { namespace compile { @@ -41,6 +42,7 @@ class AclModelMulti : public AclModel { private: void SetInputs(); + void SetOutput(); std::optional is_multi_graph_; std::shared_ptr backend_; diff --git a/tests/st/cpp/model/test_control.cc b/tests/st/cpp/model/test_control.cc index 42d86050d41..c083cae248e 100644 --- a/tests/st/cpp/model/test_control.cc +++ b/tests/st/cpp/model/test_control.cc @@ -72,6 +72,17 @@ TEST_F(TestControl, InferIfbyIf) { EXPECT_EQ(inputs_before[4].Shape()[2], 4); EXPECT_EQ(inputs_before[4].Shape()[3], 5); + // assert outputs + std::vector outputs_before = control_model.GetOutputs(); + ASSERT_EQ(1, outputs_before.size()); + EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeFloat32); + ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(float) * input_data.size()); + ASSERT_EQ(outputs_before[0].Shape().size(), 4); + EXPECT_EQ(outputs_before[0].Shape()[0], 2); + EXPECT_EQ(outputs_before[0].Shape()[1], 3); + EXPECT_EQ(outputs_before[0].Shape()[2], 4); + EXPECT_EQ(outputs_before[0].Shape()[3], 5); + // prepare input std::vector outputs; std::vector inputs; @@ -130,6 +141,17 @@ TEST_F(TestControl, InferSimpleWhile) { EXPECT_EQ(inputs_before[2].Shape()[2], 4); EXPECT_EQ(inputs_before[2].Shape()[3], 5); + // assert outputs + std::vector outputs_before = control_model.GetOutputs(); + ASSERT_EQ(1, outputs_before.size()); + EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeFloat32); + ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(float) * input_data.size()); + ASSERT_EQ(outputs_before[0].Shape().size(), 4); + EXPECT_EQ(outputs_before[0].Shape()[0], 2); + EXPECT_EQ(outputs_before[0].Shape()[1], 3); + EXPECT_EQ(outputs_before[0].Shape()[2], 4); + EXPECT_EQ(outputs_before[0].Shape()[3], 5); + // prepare input std::vector outputs; std::vector inputs; @@ -173,6 +195,15 @@ TEST_F(TestControl, InferRecursive) { ASSERT_EQ(inputs_before[0].Shape().size(), 1); EXPECT_EQ(inputs_before[0].Shape()[0], 1); + // assert outputs + std::vector outputs_before = control_model.GetOutputs(); + ASSERT_EQ(1, outputs_before.size()); + EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeInt32); + ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(int32_t)); + ASSERT_EQ(outputs_before[0].Shape().size(), 1); + EXPECT_EQ(outputs_before[0].Shape()[0], 1); + + // prepare input std::vector outputs; std::vector inputs; @@ -226,6 +257,14 @@ TEST_F(TestControl, InferMixedWhileIf) { ASSERT_EQ(inputs_before[4].Shape().size(), 1); EXPECT_EQ(inputs_before[4].Shape()[0], 1); + // assert outputs + std::vector outputs_before = control_model.GetOutputs(); + ASSERT_EQ(1, outputs_before.size()); + EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeInt32); + ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(int32_t)); + ASSERT_EQ(outputs_before[0].Shape().size(), 1); + EXPECT_EQ(outputs_before[0].Shape()[0], 1); + // prepare input std::vector outputs; std::vector inputs; @@ -279,6 +318,14 @@ TEST_F(TestControl, InferSingleFor) { ASSERT_EQ(inputs_before[2].Shape().size(), 1); EXPECT_EQ(inputs_before[2].Shape()[0], 1); + // assert outputs + std::vector outputs_before = control_model.GetOutputs(); + ASSERT_EQ(1, outputs_before.size()); + EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeInt32); + ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(int32_t)); + ASSERT_EQ(outputs_before[0].Shape().size(), 1); + EXPECT_EQ(outputs_before[0].Shape()[0], 1); + // prepare input std::vector outputs; std::vector inputs; @@ -324,6 +371,12 @@ TEST_F(TestControl, InferSingleOr) { ASSERT_EQ(inputs_before[1].Shape().size(), 1); EXPECT_EQ(inputs_before[1].Shape()[0], 2); + // assert outputs + std::vector outputs_before = control_model.GetOutputs(); + ASSERT_EQ(1, outputs_before.size()); + EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeFloat32); + ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(float)); + // prepare input std::vector outputs; std::vector inputs; @@ -339,6 +392,13 @@ TEST_F(TestControl, InferSingleOr) { // infer ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess); + // assert outputs + std::vector outputs_after = control_model.GetOutputs(); + ASSERT_EQ(1, outputs_after.size()); + EXPECT_EQ(outputs_after[0].DataType(), DataType::kNumberTypeFloat32); + ASSERT_TRUE(outputs_after[0].DataSize() == sizeof(float)); + EXPECT_EQ(outputs_after[0].Shape().size(), outputs_before[0].Shape().size()); + // assert output ASSERT_TRUE(outputs.size() == 1); auto out = outputs[0]; @@ -375,6 +435,17 @@ TEST_F(TestControl, InferSingleSwitch) { ASSERT_EQ(inputs_before[2].Shape().size(), 1); EXPECT_EQ(inputs_before[2].Shape()[0], 1); + // assert outputs + std::vector outputs_before = control_model.GetOutputs(); + ASSERT_EQ(1, outputs_before.size()); + EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeFloat32); + ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(float) * 224 * 224); + ASSERT_EQ(outputs_before[0].Shape().size(), 4); + EXPECT_EQ(outputs_before[0].Shape()[0], 1); + EXPECT_EQ(outputs_before[0].Shape()[1], 1); + EXPECT_EQ(outputs_before[0].Shape()[2], 224); + EXPECT_EQ(outputs_before[0].Shape()[3], 224); + // prepare input std::vector outputs; std::vector inputs;