forked from mindspore-Ecosystem/mindspore
!23614 310 control getoutput
Merge pull request !23614 from TuDouNi/master
This commit is contained in:
commit
78b6fd17d6
|
@ -20,6 +20,8 @@
|
|||
#include <map>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <functional>
|
||||
#include "backend/session/session_basic.h"
|
||||
#include "backend/session/session_factory.h"
|
||||
#include "cxx_api/factory.h"
|
||||
|
@ -33,6 +35,14 @@ namespace mindspore {
|
|||
API_FACTORY_REG(ModelImpl, Ascend310, AclModelMulti);
|
||||
|
||||
namespace {
|
||||
std::map<DataType, size_t> kDtypeMap = {
|
||||
{DataType::kNumberTypeBool, sizeof(bool)}, {DataType::kNumberTypeInt8, sizeof(int8_t)},
|
||||
{DataType::kNumberTypeInt16, sizeof(int16_t)}, {DataType::kNumberTypeInt32, sizeof(int32_t)},
|
||||
{DataType::kNumberTypeInt64, sizeof(int64_t)}, {DataType::kNumberTypeFloat16, sizeof(float16)},
|
||||
{DataType::kNumberTypeFloat32, sizeof(float)}, {DataType::kNumberTypeFloat64, sizeof(double)},
|
||||
{DataType::kNumberTypeUInt8, sizeof(uint8_t)}, {DataType::kNumberTypeUInt16, sizeof(uint16_t)},
|
||||
{DataType::kNumberTypeUInt32, sizeof(uint32_t)}, {DataType::kNumberTypeUInt64, sizeof(uint64_t)}};
|
||||
|
||||
class MSTensorRef : public BaseRef {
|
||||
public:
|
||||
static VectorRef Convert(const std::vector<MSTensor> &tensors) {
|
||||
|
@ -369,6 +379,8 @@ Status AclModelMulti::Build() {
|
|||
return abstract;
|
||||
});
|
||||
(void)InferMindir(ModelImpl::GetFuncGraph(), broaded_args);
|
||||
// set output
|
||||
SetOutput();
|
||||
// create vm
|
||||
auto backend = CreateBackend(std::make_shared<AclModelOptions>(model_context_));
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
|
@ -449,6 +461,51 @@ void AclModelMulti::SetInputs() {
|
|||
}
|
||||
}
|
||||
|
||||
void AclModelMulti::SetOutput() {
|
||||
if (outputs_.empty()) {
|
||||
auto fg = ModelImpl::GetFuncGraph();
|
||||
MS_EXCEPTION_IF_NULL(fg);
|
||||
const auto output = fg->output();
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
auto abs = output->abstract();
|
||||
MS_EXCEPTION_IF_NULL(abs);
|
||||
|
||||
// DataType
|
||||
DataType type_id;
|
||||
if (abs->isa<abstract::AbstractTensor>()) {
|
||||
auto abs_tensor = abs->cast<abstract::AbstractTensorPtr>();
|
||||
auto ele = abs_tensor->element();
|
||||
MS_EXCEPTION_IF_NULL(ele);
|
||||
MS_EXCEPTION_IF_NULL(ele->GetTypeTrack());
|
||||
type_id = static_cast<DataType>(ele->GetTypeTrack()->type_id());
|
||||
} else {
|
||||
MS_EXCEPTION_IF_NULL(abs->GetTypeTrack());
|
||||
type_id = static_cast<DataType>(abs->GetTypeTrack()->type_id());
|
||||
}
|
||||
// Shape
|
||||
auto shape_track = abs->GetShapeTrack();
|
||||
MS_EXCEPTION_IF_NULL(shape_track);
|
||||
std::vector<int64_t> shape = {};
|
||||
if (shape_track->isa<abstract::Shape>()) {
|
||||
auto shapeptr = shape_track->cast<abstract::ShapePtr>();
|
||||
shape = static_cast<std::vector<int64_t>>(shapeptr->shape());
|
||||
}
|
||||
// Size
|
||||
size_t ato_size = 0;
|
||||
if (kDtypeMap.find(type_id) != kDtypeMap.end()) {
|
||||
ato_size = kDtypeMap[type_id];
|
||||
}
|
||||
int64_t ele_num = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int64_t>());
|
||||
size_t size = ato_size * ele_num;
|
||||
// create tensor
|
||||
auto output_tensor = MSTensor::CreateTensor("", type_id, shape, nullptr, size);
|
||||
outputs_.emplace_back(*output_tensor);
|
||||
MSTensor::DestroyTensorPtr(output_tensor);
|
||||
} else {
|
||||
MS_LOG(DEBUG) << "outputs_ has been set.";
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<MSTensor> AclModelMulti::GetInputs() {
|
||||
if (!is_multi_graph_.has_value()) {
|
||||
is_multi_graph_ = ModelImpl::GetFuncGraph() == nullptr ? false : HasMultiGraph(ModelImpl::GetFuncGraph());
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
#include <optional>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <map>
|
||||
|
||||
namespace mindspore {
|
||||
namespace compile {
|
||||
|
@ -41,6 +42,7 @@ class AclModelMulti : public AclModel {
|
|||
|
||||
private:
|
||||
void SetInputs();
|
||||
void SetOutput();
|
||||
|
||||
std::optional<bool> is_multi_graph_;
|
||||
std::shared_ptr<compile::MsBackend> backend_;
|
||||
|
|
|
@ -72,6 +72,17 @@ TEST_F(TestControl, InferIfbyIf) {
|
|||
EXPECT_EQ(inputs_before[4].Shape()[2], 4);
|
||||
EXPECT_EQ(inputs_before[4].Shape()[3], 5);
|
||||
|
||||
// assert outputs
|
||||
std::vector<MSTensor> outputs_before = control_model.GetOutputs();
|
||||
ASSERT_EQ(1, outputs_before.size());
|
||||
EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeFloat32);
|
||||
ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(float) * input_data.size());
|
||||
ASSERT_EQ(outputs_before[0].Shape().size(), 4);
|
||||
EXPECT_EQ(outputs_before[0].Shape()[0], 2);
|
||||
EXPECT_EQ(outputs_before[0].Shape()[1], 3);
|
||||
EXPECT_EQ(outputs_before[0].Shape()[2], 4);
|
||||
EXPECT_EQ(outputs_before[0].Shape()[3], 5);
|
||||
|
||||
// prepare input
|
||||
std::vector<MSTensor> outputs;
|
||||
std::vector<MSTensor> inputs;
|
||||
|
@ -130,6 +141,17 @@ TEST_F(TestControl, InferSimpleWhile) {
|
|||
EXPECT_EQ(inputs_before[2].Shape()[2], 4);
|
||||
EXPECT_EQ(inputs_before[2].Shape()[3], 5);
|
||||
|
||||
// assert outputs
|
||||
std::vector<MSTensor> outputs_before = control_model.GetOutputs();
|
||||
ASSERT_EQ(1, outputs_before.size());
|
||||
EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeFloat32);
|
||||
ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(float) * input_data.size());
|
||||
ASSERT_EQ(outputs_before[0].Shape().size(), 4);
|
||||
EXPECT_EQ(outputs_before[0].Shape()[0], 2);
|
||||
EXPECT_EQ(outputs_before[0].Shape()[1], 3);
|
||||
EXPECT_EQ(outputs_before[0].Shape()[2], 4);
|
||||
EXPECT_EQ(outputs_before[0].Shape()[3], 5);
|
||||
|
||||
// prepare input
|
||||
std::vector<MSTensor> outputs;
|
||||
std::vector<MSTensor> inputs;
|
||||
|
@ -173,6 +195,15 @@ TEST_F(TestControl, InferRecursive) {
|
|||
ASSERT_EQ(inputs_before[0].Shape().size(), 1);
|
||||
EXPECT_EQ(inputs_before[0].Shape()[0], 1);
|
||||
|
||||
// assert outputs
|
||||
std::vector<MSTensor> outputs_before = control_model.GetOutputs();
|
||||
ASSERT_EQ(1, outputs_before.size());
|
||||
EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeInt32);
|
||||
ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(int32_t));
|
||||
ASSERT_EQ(outputs_before[0].Shape().size(), 1);
|
||||
EXPECT_EQ(outputs_before[0].Shape()[0], 1);
|
||||
|
||||
|
||||
// prepare input
|
||||
std::vector<MSTensor> outputs;
|
||||
std::vector<MSTensor> inputs;
|
||||
|
@ -226,6 +257,14 @@ TEST_F(TestControl, InferMixedWhileIf) {
|
|||
ASSERT_EQ(inputs_before[4].Shape().size(), 1);
|
||||
EXPECT_EQ(inputs_before[4].Shape()[0], 1);
|
||||
|
||||
// assert outputs
|
||||
std::vector<MSTensor> outputs_before = control_model.GetOutputs();
|
||||
ASSERT_EQ(1, outputs_before.size());
|
||||
EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeInt32);
|
||||
ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(int32_t));
|
||||
ASSERT_EQ(outputs_before[0].Shape().size(), 1);
|
||||
EXPECT_EQ(outputs_before[0].Shape()[0], 1);
|
||||
|
||||
// prepare input
|
||||
std::vector<MSTensor> outputs;
|
||||
std::vector<MSTensor> inputs;
|
||||
|
@ -279,6 +318,14 @@ TEST_F(TestControl, InferSingleFor) {
|
|||
ASSERT_EQ(inputs_before[2].Shape().size(), 1);
|
||||
EXPECT_EQ(inputs_before[2].Shape()[0], 1);
|
||||
|
||||
// assert outputs
|
||||
std::vector<MSTensor> outputs_before = control_model.GetOutputs();
|
||||
ASSERT_EQ(1, outputs_before.size());
|
||||
EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeInt32);
|
||||
ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(int32_t));
|
||||
ASSERT_EQ(outputs_before[0].Shape().size(), 1);
|
||||
EXPECT_EQ(outputs_before[0].Shape()[0], 1);
|
||||
|
||||
// prepare input
|
||||
std::vector<MSTensor> outputs;
|
||||
std::vector<MSTensor> inputs;
|
||||
|
@ -324,6 +371,12 @@ TEST_F(TestControl, InferSingleOr) {
|
|||
ASSERT_EQ(inputs_before[1].Shape().size(), 1);
|
||||
EXPECT_EQ(inputs_before[1].Shape()[0], 2);
|
||||
|
||||
// assert outputs
|
||||
std::vector<MSTensor> outputs_before = control_model.GetOutputs();
|
||||
ASSERT_EQ(1, outputs_before.size());
|
||||
EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeFloat32);
|
||||
ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(float));
|
||||
|
||||
// prepare input
|
||||
std::vector<MSTensor> outputs;
|
||||
std::vector<MSTensor> inputs;
|
||||
|
@ -339,6 +392,13 @@ TEST_F(TestControl, InferSingleOr) {
|
|||
// infer
|
||||
ASSERT_TRUE(control_model.Predict(inputs, &outputs) == kSuccess);
|
||||
|
||||
// assert outputs
|
||||
std::vector<MSTensor> outputs_after = control_model.GetOutputs();
|
||||
ASSERT_EQ(1, outputs_after.size());
|
||||
EXPECT_EQ(outputs_after[0].DataType(), DataType::kNumberTypeFloat32);
|
||||
ASSERT_TRUE(outputs_after[0].DataSize() == sizeof(float));
|
||||
EXPECT_EQ(outputs_after[0].Shape().size(), outputs_before[0].Shape().size());
|
||||
|
||||
// assert output
|
||||
ASSERT_TRUE(outputs.size() == 1);
|
||||
auto out = outputs[0];
|
||||
|
@ -375,6 +435,17 @@ TEST_F(TestControl, InferSingleSwitch) {
|
|||
ASSERT_EQ(inputs_before[2].Shape().size(), 1);
|
||||
EXPECT_EQ(inputs_before[2].Shape()[0], 1);
|
||||
|
||||
// assert outputs
|
||||
std::vector<MSTensor> outputs_before = control_model.GetOutputs();
|
||||
ASSERT_EQ(1, outputs_before.size());
|
||||
EXPECT_EQ(outputs_before[0].DataType(), DataType::kNumberTypeFloat32);
|
||||
ASSERT_TRUE(outputs_before[0].DataSize() == sizeof(float) * 224 * 224);
|
||||
ASSERT_EQ(outputs_before[0].Shape().size(), 4);
|
||||
EXPECT_EQ(outputs_before[0].Shape()[0], 1);
|
||||
EXPECT_EQ(outputs_before[0].Shape()[1], 1);
|
||||
EXPECT_EQ(outputs_before[0].Shape()[2], 224);
|
||||
EXPECT_EQ(outputs_before[0].Shape()[3], 224);
|
||||
|
||||
// prepare input
|
||||
std::vector<MSTensor> outputs;
|
||||
std::vector<MSTensor> inputs;
|
||||
|
|
Loading…
Reference in New Issue