!46095 MSLite, cloud opt ascend

Merge pull request !46095 from 徐永飞/opt_ascend
This commit is contained in:
i-robot 2022-11-28 07:38:53 +00:00 committed by Gitee
commit f94cb2c1fe
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
19 changed files with 1889 additions and 1077 deletions

View File

@ -0,0 +1,62 @@
cmake_minimum_required(VERSION 3.14)
project(QuickStartCpp)
if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3.0)
message(FATAL_ERROR "GCC version ${CMAKE_CXX_COMPILER_VERSION} must not be less than 7.3.0")
endif()
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
if(DEFINED ENV{LITE_HOME})
set(LITE_HOME $ENV{LITE_HOME})
endif()
if(DEFINED ENV{EXAMPLE_TARGET})
set(EXAMPLE_TARGET $ENV{EXAMPLE_TARGET})
endif()
# Add directory to include search path
include_directories(${LITE_HOME}/runtime)
# Add directory to linker search path
link_directories(${LITE_HOME}/runtime/lib)
if("${EXAMPLE_TARGET}" STREQUAL "Ascend")
include_directories(/usr/local/Ascend/latest/fwkacllib/include)
link_directories(/usr/local/Ascend/latest/fwkacllib/lib64)
add_definitions(-DENABLE_ASCEND)
else()
set(CUDA_HOME $ENV{CUDA_HOME})
include_directories(${CUDA_HOME}/include)
link_directories(${CUDA_HOME}/lib64)
add_definitions(-DENABLE_GPU)
endif()
file(GLOB_RECURSE QUICK_START_CXX ${CMAKE_CURRENT_SOURCE_DIR}/*.cc)
add_executable(mindspore_quick_start_cpp ${QUICK_START_CXX})
target_link_libraries(
mindspore_quick_start_cpp
-Wl,--whole-archive mindspore-lite -Wl,--no-whole-archive
pthread
)
if("${EXAMPLE_TARGET}" STREQUAL "Ascend")
target_link_libraries(mindspore_quick_start_cpp ascendcl)
else()
target_link_libraries(mindspore_quick_start_cpp cudart cublas)
endif()
# Due to the increased compilation options for stack protection,
# it is necessary to target link ssp library when Use the static library in Windows.
if(WIN32)
target_link_libraries(
mindspore_quick_start_cpp
ssp
)
else()
target_link_libraries(
mindspore_quick_start_cpp
dl
)
endif()

View File

@ -0,0 +1,43 @@
#!/usr/bin/env bash
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 1 ]; then
echo "Usage: bash build.sh [DEVICE_TARGET]
DEVICE_TARGET can choose from ['Ascend', 'GPU']."
exit
fi
device_target=$1
if [ 0"$LITE_HOME" = "0" ]; then
echo "Please set env LITE_HOME to MindSpore Lite tar path"
exit
fi
if [ 0"$device_target" != "0GPU" ] && [ 0"$device_target" != "0Ascend" ]; then
echo "Please set args 1 EXAMPLE_TARGET to Ascend or GPU"
exit
fi
if [ 0"$device_target" = "0GPU" ] && [ 0"$CUDA_HOME" = "0" ]; then
echo "Please set env CUDA_HOME to path of cuda, if env EXAMPLE_TARGET is GPU"
exit
fi
rm -rf build
mkdir build && cd build || exit
cmake ../ -DEXAMPLE_TARGET=$device_target
make

View File

@ -0,0 +1,441 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <algorithm>
#include <random>
#include <iostream>
#include <fstream>
#include <cstring>
#include <memory>
#include "include/api/model.h"
#include "include/api/context.h"
#include "include/api/status.h"
#include "include/api/types.h"
#ifdef ENABLE_ASCEND
#include "./mem_ascend.h"
#else
#include "./mem_gpu.h"
#endif
namespace {
constexpr int kNumPrintOfOutData = 50;
}
static std::string ShapeToString(const std::vector<int64_t> &shape) {
std::string result = "[";
for (size_t i = 0; i < shape.size(); ++i) {
result += std::to_string(shape[i]);
if (i + 1 < shape.size()) {
result += ", ";
}
}
result += "]";
return result;
}
std::vector<char> ReadFile(const std::string &file) {
std::ifstream ifs(file, std::ifstream::in | std::ifstream::binary);
if (!ifs.good()) {
std::cerr << "file: " << file << " is not exist." << std::endl;
return {};
}
if (!ifs.is_open()) {
std::cerr << "file: " << file << " open failed." << std::endl;
return {};
}
ifs.seekg(0, std::ios::end);
auto size = ifs.tellg();
std::vector<char> buf;
buf.resize(size);
ifs.seekg(0, std::ios::beg);
ifs.read(buf.data(), buf.size());
ifs.close();
return buf;
}
template <typename T, typename Distribution>
void GenerateRandomData(int size, void *data, Distribution distribution) {
std::random_device rd{};
std::mt19937 random_engine{rd()};
int elements_num = size / sizeof(T);
(void)std::generate_n(static_cast<T *>(data), elements_num,
[&distribution, &random_engine]() { return static_cast<T>(distribution(random_engine)); });
}
int GenerateRandomInputData(std::vector<mindspore::MSTensor> inputs, std::vector<uint8_t *> *host_data_buffer) {
for (auto tensor : inputs) {
auto data_size = tensor.DataSize();
if (data_size == 0) {
std::cerr << "Data size cannot be 0, tensor shape: " << ShapeToString(tensor.Shape()) << std::endl;
return -1;
}
auto host_data = new uint8_t[data_size];
host_data_buffer->push_back(host_data);
GenerateRandomData<float>(data_size, host_data, std::normal_distribution<float>(0.0f, 1.0f));
}
return 0;
}
int SetHostData(std::vector<mindspore::MSTensor> tensors, const std::vector<uint8_t *> &host_data_buffer) {
for (size_t i = 0; i < tensors.size(); i++) {
tensors[i].SetData(host_data_buffer[i], false);
tensors[i].SetDeviceData(nullptr);
}
return 0;
}
int SetDeviceData(std::vector<mindspore::MSTensor> tensors, const std::vector<uint8_t *> &host_data_buffer,
std::vector<void *> *device_buffers) {
for (size_t i = 0; i < tensors.size(); i++) {
auto &tensor = tensors[i];
auto host_data = host_data_buffer[i];
auto data_size = tensor.DataSize();
if (data_size == 0) {
std::cerr << "Data size cannot be 0, tensor shape: " << ShapeToString(tensor.Shape()) << std::endl;
return -1;
}
auto device_data = MallocDeviceMemory(data_size);
if (device_data == nullptr) {
std::cerr << "Failed to alloc device data, data size " << data_size << std::endl;
return -1;
}
device_buffers->push_back(device_data);
if (CopyMemoryHost2Device(device_data, data_size, host_data, data_size) != 0) {
std::cerr << "Failed to copy data to device, data size " << data_size << std::endl;
return -1;
}
tensor.SetDeviceData(device_data);
tensor.SetData(nullptr, false);
}
return 0;
}
int SetOutputHostData(std::vector<mindspore::MSTensor> tensors, std::vector<uint8_t *> *host_buffers) {
for (size_t i = 0; i < tensors.size(); i++) {
auto &tensor = tensors[i];
auto data_size = tensor.DataSize();
if (data_size == 0) {
std::cerr << "Data size cannot be 0, tensor shape: " << ShapeToString(tensor.Shape()) << std::endl;
return -1;
}
auto host_data = new uint8_t[data_size];
host_buffers->push_back(host_data);
tensor.SetData(host_data, false);
tensor.SetDeviceData(nullptr);
}
return 0;
}
int SetOutputDeviceData(std::vector<mindspore::MSTensor> tensors, std::vector<void *> *device_buffers) {
for (size_t i = 0; i < tensors.size(); i++) {
auto &tensor = tensors[i];
auto data_size = tensor.DataSize();
if (data_size == 0) {
std::cerr << "Data size cannot be 0, tensor shape: " << ShapeToString(tensor.Shape()) << std::endl;
return -1;
}
auto device_data = MallocDeviceMemory(data_size);
if (device_data == nullptr) {
std::cerr << "Failed to alloc device data, data size " << data_size << std::endl;
return -1;
}
device_buffers->push_back(device_data);
tensor.SetDeviceData(device_data);
tensor.SetData(nullptr, false);
}
return 0;
}
template <class T>
void PrintBuffer(const void *buffer, size_t elem_count) {
auto data = reinterpret_cast<const T *>(buffer);
constexpr size_t max_print_count = 50;
for (size_t i = 0; i < elem_count && i <= max_print_count; i++) {
std::cout << data[i] << " ";
}
std::cout << std::endl;
}
void PrintOutputsTensor(std::vector<mindspore::MSTensor> outputs) {
for (auto tensor : outputs) {
auto elem_num = tensor.ElementNum();
auto data_size = tensor.DataSize();
std::vector<uint8_t> host_data;
const void *print_data;
if (tensor.GetDeviceData() != nullptr) {
host_data.resize(data_size);
CopyMemoryDevice2Host(host_data.data(), host_data.size(), tensor.GetDeviceData(), data_size);
print_data = host_data.data();
std::cout << "Device data, tensor name is:" << tensor.Name() << " tensor size is:" << data_size
<< " tensor elements num is:" << elem_num << std::endl;
} else {
print_data = tensor.Data().get();
std::cout << "Host data, tensor name is:" << tensor.Name() << " tensor size is:" << data_size
<< " tensor elements num is:" << elem_num << std::endl;
}
auto data_type = tensor.DataType();
if (data_type == mindspore::DataType::kNumberTypeFloat32) {
PrintBuffer<float>(print_data, elem_num);
} else if (data_type == mindspore::DataType::kNumberTypeFloat64) {
PrintBuffer<double>(print_data, elem_num);
} else if (data_type == mindspore::DataType::kNumberTypeInt64) {
PrintBuffer<int64_t>(print_data, elem_num);
} else if (data_type == mindspore::DataType::kNumberTypeInt32) {
PrintBuffer<int32_t>(print_data, elem_num);
} else if (data_type == mindspore::DataType::kNumberTypeInt16) {
PrintBuffer<int16_t>(print_data, elem_num);
} else if (data_type == mindspore::DataType::kNumberTypeInt8) {
PrintBuffer<int8_t>(print_data, elem_num);
} else if (data_type == mindspore::DataType::kNumberTypeUInt64) {
PrintBuffer<uint64_t>(print_data, elem_num);
} else if (data_type == mindspore::DataType::kNumberTypeUInt32) {
PrintBuffer<uint32_t>(print_data, elem_num);
} else if (data_type == mindspore::DataType::kNumberTypeUInt16) {
PrintBuffer<uint16_t>(print_data, elem_num);
} else if (data_type == mindspore::DataType::kNumberTypeUInt8) {
PrintBuffer<uint8_t>(print_data, elem_num);
} else if (data_type == mindspore::DataType::kNumberTypeBool) {
PrintBuffer<bool>(print_data, elem_num);
} else {
std::cout << "Unsupported data type " << static_cast<int>(tensor.DataType()) << std::endl;
}
}
}
int Predict(mindspore::Model *model, const std::vector<mindspore::MSTensor> &inputs,
std::vector<mindspore::MSTensor> *outputs) {
auto ret = model->Predict(inputs, outputs);
if (ret != mindspore::kSuccess) {
std::cerr << "Predict error " << ret << std::endl;
return -1;
}
PrintOutputsTensor(*outputs);
return 0;
}
class ResourceGuard {
public:
explicit ResourceGuard(std::function<void()> rel_func) : rel_func_(rel_func) {}
~ResourceGuard() {
if (rel_func_) {
rel_func_();
}
}
private:
std::function<void()> rel_func_ = nullptr;
};
int TestHostDeviceInput(mindspore::Model *model, uint32_t batch_size) {
// Get Input
auto inputs = model->GetInputs();
std::vector<std::vector<int64_t>> input_shapes;
std::transform(inputs.begin(), inputs.end(), std::back_inserter(input_shapes), [batch_size](auto &item) {
auto shape = item.Shape();
shape[0] = batch_size;
return shape;
});
if (model->Resize(inputs, input_shapes) != mindspore::kSuccess) {
std::cerr << "Failed to resize model batch size to " << batch_size << std::endl;
return -1;
}
std::cout << "Success resize model batch size to " << batch_size << std::endl;
// Generate random data as input data.
std::vector<uint8_t *> host_buffers;
ResourceGuard host_rel([&host_buffers]() {
for (auto &item : host_buffers) {
delete[] item;
}
});
std::vector<void *> device_buffers;
ResourceGuard device_rel([&device_buffers]() {
for (auto &item : device_buffers) {
FreeDeviceMemory(item);
}
});
auto ret = GenerateRandomInputData(inputs, &host_buffers);
if (ret != 0) {
std::cerr << "Generate Random Input Data failed." << std::endl;
return -1;
}
// empty outputs
std::vector<mindspore::MSTensor> outputs;
// Model Predict, input host memory
SetHostData(inputs, host_buffers);
if (Predict(model, inputs, &outputs) != 0) {
return -1;
}
// Model Predict, input device memory
outputs.clear();
SetDeviceData(inputs, host_buffers, &device_buffers);
if (Predict(model, inputs, &outputs) != 0) {
return -1;
}
return 0;
}
int TestHostDeviceOutput(mindspore::Model *model, uint32_t batch_size) {
// Get Input
auto inputs = model->GetInputs();
std::vector<std::vector<int64_t>> input_shapes;
std::transform(inputs.begin(), inputs.end(), std::back_inserter(input_shapes), [batch_size](auto &item) {
auto shape = item.Shape();
shape[0] = batch_size;
return shape;
});
if (model->Resize(inputs, input_shapes) != mindspore::kSuccess) {
std::cerr << "Failed to resize model batch size to " << batch_size << std::endl;
return -1;
}
std::cout << "Success resize model batch size to " << batch_size << std::endl;
// Generate random data as input data.
std::vector<uint8_t *> host_buffers;
ResourceGuard host_rel([&host_buffers]() {
for (auto &item : host_buffers) {
delete[] item;
}
});
std::vector<void *> device_buffers;
ResourceGuard device_rel([&device_buffers]() {
for (auto &item : device_buffers) {
FreeDeviceMemory(item);
}
});
auto ret = GenerateRandomInputData(inputs, &host_buffers);
if (ret != 0) {
std::cerr << "Generate Random Input Data failed." << std::endl;
return -1;
}
// Get Output from model
auto outputs = model->GetOutputs();
// ---------------------- output host data
std::vector<uint8_t *> output_host_buffers;
ResourceGuard output_host_rel([&output_host_buffers]() {
for (auto &item : output_host_buffers) {
delete[] item;
}
});
if (SetOutputHostData(outputs, &output_host_buffers) != 0) {
std::cerr << "Failed to set output host data" << std::endl;
return -1;
}
// Model Predict, input host memory
if (SetHostData(inputs, host_buffers) != 0) {
std::cerr << "Failed to set input device data" << std::endl;
return -1;
}
if (Predict(model, inputs, &outputs) != 0) {
return -1;
}
// Model Predict, input device memory
if (SetDeviceData(inputs, host_buffers, &device_buffers) != 0) {
std::cerr << "Failed to set input device data" << std::endl;
return -1;
}
if (Predict(model, inputs, &outputs) != 0) {
return -1;
}
// ---------------------- output device data
std::vector<void *> output_device_buffers;
ResourceGuard output_device_rel([&output_device_buffers]() {
for (auto &item : output_device_buffers) {
FreeDeviceMemory(item);
}
});
if (SetOutputDeviceData(outputs, &output_device_buffers) != 0) {
std::cerr << "Failed to set output device data" << std::endl;
return -1;
}
// Model Predict, input host memory
if (SetHostData(inputs, host_buffers) != 0) {
std::cerr << "Failed to set input device data" << std::endl;
return -1;
}
if (Predict(model, inputs, &outputs) != 0) {
return -1;
}
// Model Predict, input device memory
if (SetDeviceData(inputs, host_buffers, &device_buffers) != 0) {
std::cerr << "Failed to set input device data" << std::endl;
return -1;
}
if (Predict(model, inputs, &outputs) != 0) {
return -1;
}
return 0;
}
int QuickStart(int argc, const char **argv) {
if (argc < 2) {
std::cerr << "Model file must be provided.\n";
return -1;
}
// Read model file.
std::string model_path = argv[1];
if (model_path.empty()) {
std::cerr << "Model path " << model_path << " is invalid.";
return -1;
}
auto model_buf = ReadFile(model_path);
if (model_buf.empty()) {
std::cerr << "Read model file failed." << std::endl;
return -1;
}
// Create and init context, add CPU device info
auto context = std::make_shared<mindspore::Context>();
if (context == nullptr) {
std::cerr << "New context failed." << std::endl;
return -1;
}
auto &device_list = context->MutableDeviceInfo();
#ifdef ENABLE_ASCEND
auto device_info = std::make_shared<mindspore::AscendDeviceInfo>();
#else
auto device_info = std::make_shared<mindspore::GPUDeviceInfo>();
#endif
device_info->SetDeviceID(0);
if (device_info == nullptr) {
std::cerr << "New CPUDeviceInfo failed." << std::endl;
return -1;
}
device_list.push_back(device_info);
mindspore::Model model;
// Build model
auto build_ret = model.Build(model_buf.data(), model_buf.size(), mindspore::kMindIR, context);
if (build_ret != mindspore::kSuccess) {
std::cerr << "Build model error " << build_ret << std::endl;
return -1;
}
TestHostDeviceInput(&model, 1);
TestHostDeviceOutput(&model, 1);
return 0;
}
int main(int argc, const char **argv) { return QuickStart(argc, argv); }

View File

@ -0,0 +1,57 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_EXAMPLE_ASCEND_MEM_H
#define MINDSPORE_LITE_EXAMPLE_ASCEND_MEM_H
#include <string>
#include "acl/acl.h"
void *MallocDeviceMemory(size_t data_size) {
void *device_data = nullptr;
auto ret = aclrtMalloc(&device_data, data_size, ACL_MEM_MALLOC_NORMAL_ONLY);
if (ret != ACL_ERROR_NONE) {
std::cerr << "Malloc device buffer failed , buffer size " << data_size;
return nullptr;
}
return device_data;
}
void FreeDeviceMemory(void *device_data) {
if (device_data) {
aclrtFree(device_data);
}
}
int CopyMemoryHost2Device(void *device_data, size_t dst_size, void *host_data, size_t src_size) {
auto ret = aclrtMemcpy(device_data, dst_size, host_data, src_size, ACL_MEMCPY_HOST_TO_DEVICE);
if (ret != ACL_ERROR_NONE) {
std::cerr << "Acl memcpy host data to device failed, src size: " << src_size << ", dst size: " << dst_size
<< std::endl;
return -1;
}
return 0;
}
int CopyMemoryDevice2Host(void *host_data, size_t dst_size, void *device_data, size_t src_size) {
auto ret = aclrtMemcpy(host_data, dst_size, device_data, src_size, ACL_MEMCPY_DEVICE_TO_HOST);
if (ret != ACL_ERROR_NONE) {
std::cerr << "Acl memcpy device data to host failed, src size: " << src_size << ", dst size: " << dst_size
<< std::endl;
return -1;
}
return 0;
}
#endif // MINDSPORE_LITE_EXAMPLE_ASCEND_MEM_H

View File

@ -0,0 +1,57 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_EXAMPLE_GPU_MEM_H
#define MINDSPORE_LITE_EXAMPLE_GPU_MEM_H
#include <cuda_runtime.h>
#include <string>
void *MallocDeviceMemory(size_t data_size) {
void *device_data = nullptr;
auto ret = cudaMalloc(&device_data, data_size);
if (ret != cudaSuccess) {
std::cerr << "Malloc device buffer failed , buffer size " << data_size;
return nullptr;
}
return device_data;
}
void FreeDeviceMemory(void *device_data) {
if (device_data) {
cudaFree(device_data);
}
}
int CopyMemoryHost2Device(void *device_data, size_t dst_size, void *host_data, size_t src_size) {
auto ret = cudaMemcpy(device_data, host_data, src_size, cudaMemcpyHostToDevice);
if (ret != cudaSuccess) {
std::cerr << "Cuda memcpy host data to device failed, src size: " << src_size << ", dst size: " << dst_size
<< std::endl;
return -1;
}
return 0;
}
int CopyMemoryDevice2Host(void *host_data, size_t dst_size, void *device_data, size_t src_size) {
auto ret = cudaMemcpy(host_data, device_data, src_size, cudaMemcpyDeviceToHost);
if (ret != cudaSuccess) {
std::cerr << "Cuda memcpy device data to host failed, src size: " << src_size << ", dst size: " << dst_size
<< std::endl;
return -1;
}
return 0;
}
#endif // MINDSPORE_LITE_EXAMPLE_GPU_MEM_H

View File

@ -52,6 +52,7 @@ if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE)
${CMAKE_CURRENT_SOURCE_DIR}/utils/tensor_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/utils/runtime_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/utils/serialization.cc
${CMAKE_CURRENT_SOURCE_DIR}/utils/func_graph_utils.cc
${CMAKE_CURRENT_SOURCE_DIR}/session/delegate_session.cc
${CMAKE_CURRENT_SOURCE_DIR}/session/factory.cc
${CMAKE_CURRENT_SOURCE_DIR}/delegate/factory.cc

View File

@ -35,32 +35,6 @@ constexpr auto kImageSizeHwNum = 2;
constexpr auto kUnknownDim = -1;
} // namespace
int DynShapeProcess::ProcDynamicInput(std::vector<KernelTensorPtr> *const original_datas,
std::vector<KernelTensorPtr> *const inputs) {
MS_CHECK_TRUE_MSG(acl_options_ != nullptr, lite::RET_ERROR, "Acl options ptr is nullptr.");
if (!acl_options_->batch_size.empty() && !acl_options_->image_size.empty()) {
MS_LOG(ERROR) << "Batch size and image size can't be set at the same time.";
return lite::RET_ERROR;
}
MS_CHECK_TRUE_MSG(original_datas != nullptr, lite::RET_ERROR, "Original Data is nullptr.");
MS_CHECK_TRUE_MSG(inputs != nullptr, lite::RET_ERROR, "Inputs is nullptr.");
MS_CHECK_TRUE_MSG((*original_datas).size() == (*inputs).size(), lite::RET_ERROR,
"The size of Original Data and Input is not equal.");
if (!acl_options_->batch_size.empty()) {
if (AddBatchSizeInput(original_datas, inputs) != lite::RET_OK) {
MS_LOG(ERROR) << "Add batch size input failed.";
return lite::RET_ERROR;
}
}
if (!acl_options_->image_size.empty()) {
if (AddImageSizeInput(original_datas, inputs) != lite::RET_OK) {
MS_LOG(ERROR) << "Add Image size input failed.";
return lite::RET_ERROR;
}
}
return lite::RET_OK;
}
std::string GenResultStr(const std::vector<int64_t> &input_vec) {
std::string res;
for (size_t i = 0; i < input_vec.size(); ++i) {
@ -72,181 +46,155 @@ std::string GenResultStr(const std::vector<int64_t> &input_vec) {
return res;
}
int DynShapeProcess::CheckBatchSize(std::vector<KernelTensorPtr> *const original_datas,
std::vector<KernelTensorPtr> *const inputs) {
if (input_data_idx_ >= inputs->size()) {
MS_LOG(ERROR) << " Input data index " << input_data_idx_ << " is larger than input size " << inputs->size();
return lite::RET_ERROR;
bool DynShapeProcess::Init(const AclDynamicShapeOptions &options, size_t input_data_idx) {
acl_options_ = options;
input_data_idx_ = input_data_idx;
if (input_data_idx >= acl_options_.input_shapes.size()) {
MS_LOG(ERROR) << "Input data index " << input_data_idx
<< " is invalid, inputs count: " << acl_options_.input_shapes.size();
return false;
}
auto original_tensor = (*original_datas)[input_data_idx_];
auto cur_tensor = (*inputs)[input_data_idx_];
std::vector<int64_t> original_shape = original_tensor->GetShapeVector();
std::vector<int64_t> cur_shape = cur_tensor->GetShapeVector();
return true;
}
bool DynShapeProcess::CheckAndGetBatchSize(const std::vector<ShapeVector> &new_shapes, int32_t *batch_size) {
if (acl_options_.batch_size.empty()) {
MS_LOG(ERROR) << "Not support dynamic batch size";
return false;
}
if (batch_size == nullptr) {
MS_LOG(ERROR) << "Input parameter batch size cannot be nullptr";
return false;
}
if (!CheckBatchSize(new_shapes)) {
return false;
}
return GetRealBatchSize(new_shapes, batch_size);
}
bool DynShapeProcess::CheckAndGetImageSize(const std::vector<ShapeVector> &new_shapes, int32_t *height,
int32_t *width) {
if (acl_options_.image_size.empty()) {
MS_LOG(ERROR) << "Not support image batch size";
return false;
}
if (height == nullptr || width == nullptr) {
MS_LOG(ERROR) << "Input parameter image size cannot be nullptr";
return false;
}
if (!CheckImageSize(new_shapes)) {
return false;
}
return GetRealImageSize(new_shapes, height, width);
}
bool DynShapeProcess::CheckBatchSize(const std::vector<ShapeVector> &new_shapes) {
if (input_data_idx_ >= new_shapes.size()) {
MS_LOG(ERROR) << " Input data index " << input_data_idx_ << " is larger than input size " << new_shapes.size();
return false;
}
std::vector<int64_t> original_shape = acl_options_.input_shapes[input_data_idx_];
std::vector<int64_t> cur_shape = new_shapes[input_data_idx_];
if (cur_shape.empty() || original_shape.empty()) {
MS_LOG(ERROR) << "Shape is empty, input index = " << input_data_idx_;
return lite::RET_ERROR;
return false;
}
for (uint32_t i = 1; i < cur_shape.size(); ++i) {
for (size_t i = 1; i < cur_shape.size(); ++i) {
if (cur_shape[i] <= 0) {
MS_LOG(ERROR) << "Invalid new shape " << cur_shape << " for input " << i;
return false;
}
if (original_shape[i] != kUnknownDim && (original_shape[i] != cur_shape[i])) {
MS_LOG(ERROR) << "Shape Conflict: Original Shape:[" << GenResultStr(original_shape) << "], Current Shape:["
<< GenResultStr(cur_shape) << "]";
return lite::RET_ERROR;
return false;
}
}
return lite::RET_OK;
return true;
}
int DynShapeProcess::CheckImageSize(std::vector<KernelTensorPtr> *const original_datas,
std::vector<KernelTensorPtr> *const inputs) {
if (input_data_idx_ >= inputs->size() || input_data_idx_ >= acl_options_->input_format.size()) {
MS_LOG(ERROR) << "Input data index " << input_data_idx_ << " is invalid, inputs size " << inputs->size()
<< " input formats size " << acl_options_->input_format.size();
return lite::RET_ERROR;
bool DynShapeProcess::CheckImageSize(const std::vector<ShapeVector> &new_shapes) {
if (input_data_idx_ >= new_shapes.size() || input_data_idx_ >= acl_options_.input_format.size()) {
MS_LOG(ERROR) << "Input data index " << input_data_idx_ << " is invalid, inputs size " << new_shapes.size()
<< " input formats size " << acl_options_.input_format.size();
return false;
}
auto original_tensor = (*original_datas)[input_data_idx_];
auto cur_tensor = (*inputs)[input_data_idx_];
std::vector<int64_t> original_shape = original_tensor->GetShapeVector();
std::vector<int64_t> cur_shape = cur_tensor->GetShapeVector();
std::vector<int64_t> original_shape = acl_options_.input_shapes[input_data_idx_];
std::vector<int64_t> cur_shape = new_shapes[input_data_idx_];
if (original_shape.size() != kInputDimNum) {
MS_LOG(ERROR) << "Shape size " << original_shape.size() << " is invalid, input index = " << input_data_idx_;
return lite::RET_ERROR;
return false;
}
if (cur_shape.size() != kInputDimNum) {
MS_LOG(ERROR) << "Shape size " << cur_shape.size() << " is invalid, input index = " << input_data_idx_;
return lite::RET_ERROR;
return false;
}
auto format = acl_options_->input_format[input_data_idx_];
for (size_t i = 1; i < cur_shape.size(); ++i) {
if (cur_shape[i] <= 0) {
MS_LOG(ERROR) << "Invalid new shape " << cur_shape << " for input " << i;
return false;
}
if (original_shape[i] != kUnknownDim && (original_shape[i] != cur_shape[i])) {
MS_LOG(ERROR) << "Shape Conflict: Original Shape:[" << GenResultStr(original_shape) << "], Current Shape:["
<< GenResultStr(cur_shape) << "]";
return false;
}
}
auto format = acl_options_.input_format[input_data_idx_];
if (format == mindspore::Format::NHWC) {
if ((original_shape[kNHWCCIdx] != kUnknownDim && (original_shape[kNHWCCIdx] != cur_shape[kNHWCCIdx])) ||
(original_shape[kNHWCNIdx] != kUnknownDim && (original_shape[kNHWCNIdx] != cur_shape[kNHWCNIdx]))) {
MS_LOG(ERROR) << "Shape Conflict: Original Shape:[" << GenResultStr(original_shape) << "], Current Shape:["
<< GenResultStr(cur_shape) << "]";
return lite::RET_ERROR;
return false;
}
} else {
if ((original_shape[kNCHWCIdx] != kUnknownDim && (original_shape[kNCHWCIdx] != cur_shape[kNCHWCIdx])) ||
(original_shape[kNCHWNIdx] != kUnknownDim && (original_shape[kNCHWNIdx] != cur_shape[kNCHWNIdx]))) {
MS_LOG(ERROR) << "Shape Conflict: Original Shape:[" << GenResultStr(original_shape) << "], Current Shape:["
<< GenResultStr(cur_shape) << "]";
return lite::RET_ERROR;
return false;
}
}
return lite::RET_OK;
return true;
}
int DynShapeProcess::AddBatchSizeInput(std::vector<KernelTensorPtr> *const original_datas,
std::vector<KernelTensorPtr> *const inputs) {
int32_t *batch_size_addr = reinterpret_cast<int32_t *>(malloc(sizeof(int32_t)));
if (batch_size_addr == nullptr) {
MS_LOG(ERROR) << "Malloc batch size failed.";
return lite::RET_ERROR;
bool DynShapeProcess::GetRealBatchSize(const std::vector<ShapeVector> &new_shapes, int32_t *batch_size) {
if (input_data_idx_ >= new_shapes.size()) {
MS_LOG(ERROR) << " Input data index " << input_data_idx_ << " is larger than input size " << new_shapes.size();
return false;
}
if (CheckBatchSize(original_datas, inputs) != lite::RET_OK) {
MS_LOG(ERROR) << "Check dynamic batch size failed.";
free(batch_size_addr);
return lite::RET_ERROR;
}
if (GetRealBatchSize(inputs, batch_size_addr) != lite::RET_OK) {
MS_LOG(ERROR) << "Get real batch size failed.";
free(batch_size_addr);
return lite::RET_ERROR;
}
batch_size_ptr_ = std::make_shared<Address>(batch_size_addr, sizeof(int32_t));
if (batch_size_ptr_ == nullptr) {
MS_LOG(ERROR) << "Create Address failed.";
free(batch_size_addr);
return lite::RET_ERROR;
}
auto tensor_ptr = std::make_shared<KernelTensor>();
if (tensor_ptr == nullptr) {
MS_LOG(ERROR) << "Create KernelTensor failed.";
free(batch_size_addr);
return lite::RET_ERROR;
}
tensor_ptr->SetData(batch_size_ptr_);
auto abstract = std::make_shared<abstract::AbstractTensor>(kInt32, std::vector<int64_t>());
tensor_ptr->SetAbstract(abstract);
inputs->emplace_back(tensor_ptr);
return lite::RET_OK;
}
int DynShapeProcess::AddImageSizeInput(std::vector<KernelTensorPtr> *const original_datas,
std::vector<KernelTensorPtr> *const inputs) {
int32_t *image_size_addr = reinterpret_cast<int32_t *>(malloc(kImageSizeHwNum * sizeof(int32_t)));
if (image_size_addr == nullptr) {
MS_LOG(ERROR) << "Malloc image size failed.";
return lite::RET_ERROR;
}
if (CheckImageSize(original_datas, inputs) != lite::RET_OK) {
MS_LOG(ERROR) << "Check dynamic image size failed.";
free(image_size_addr);
return lite::RET_ERROR;
}
if (GetRealImageSize(inputs, image_size_addr, kImageSizeHwNum) != lite::RET_OK) {
MS_LOG(ERROR) << "Get real image size failed.";
free(image_size_addr);
return lite::RET_ERROR;
}
image_size_ptr_ = std::make_shared<Address>(image_size_addr, kImageSizeHwNum * sizeof(int32_t));
if (image_size_ptr_ == nullptr) {
MS_LOG(ERROR) << "Create Address failed.";
free(image_size_addr);
return lite::RET_ERROR;
}
auto tensor_ptr = std::make_shared<KernelTensor>();
if (tensor_ptr == nullptr) {
MS_LOG(ERROR) << "Create KernelTensor failed.";
free(image_size_addr);
return lite::RET_ERROR;
}
tensor_ptr->SetData(image_size_ptr_);
auto abstract = std::make_shared<abstract::AbstractTensor>(kInt32, std::vector<int64_t>());
tensor_ptr->SetAbstract(abstract);
inputs->emplace_back(tensor_ptr);
return lite::RET_OK;
}
int DynShapeProcess::GetRealBatchSize(std::vector<KernelTensorPtr> *const inputs, int32_t *batch_size) {
MS_CHECK_TRUE_MSG(batch_size != nullptr, lite::RET_ERROR, "Batch size ptr is nullptr.");
if (input_data_idx_ >= inputs->size()) {
MS_LOG(ERROR) << " Input data index " << input_data_idx_ << " is larger than input size " << inputs->size();
return lite::RET_ERROR;
}
auto tensor = (*inputs)[input_data_idx_];
std::vector<int64_t> shape = tensor->GetShapeVector();
std::vector<int64_t> shape = new_shapes[input_data_idx_];
if (shape.empty()) {
MS_LOG(ERROR) << "Shape is empty, input index = " << input_data_idx_;
return lite::RET_ERROR;
return false;
}
int32_t cur_batch_size = static_cast<uint64_t>(shape[0]);
auto iter = acl_options_->batch_size.find(cur_batch_size);
if (iter == acl_options_->batch_size.end()) {
auto iter = acl_options_.batch_size.find(cur_batch_size);
if (iter == acl_options_.batch_size.end()) {
MS_LOG(ERROR) << "Current batch size " << cur_batch_size << " is invalid, please check device info of context";
return lite::RET_ERROR;
return false;
}
*batch_size = cur_batch_size;
MS_LOG(DEBUG) << "Current batch size " << cur_batch_size;
return lite::RET_OK;
return true;
}
int DynShapeProcess::GetRealImageSize(std::vector<KernelTensorPtr> *const inputs, int32_t *image_size, int32_t num) {
MS_CHECK_TRUE_MSG(image_size != nullptr, lite::RET_ERROR, "Image size ptr is nullptr.");
if (input_data_idx_ >= inputs->size() || input_data_idx_ >= acl_options_->input_format.size()) {
MS_LOG(ERROR) << "Input data index " << input_data_idx_ << " is invalid, inputs size " << inputs->size()
<< " input formats size " << acl_options_->input_format.size();
return lite::RET_ERROR;
bool DynShapeProcess::GetRealImageSize(const std::vector<ShapeVector> &new_shapes, int32_t *height_p,
int32_t *width_p) {
if (input_data_idx_ >= new_shapes.size() || input_data_idx_ >= acl_options_.input_format.size()) {
MS_LOG(ERROR) << "Input data index " << input_data_idx_ << " is invalid, inputs size " << new_shapes.size()
<< " input formats size " << acl_options_.input_format.size();
return false;
}
auto tensor = (*inputs)[input_data_idx_];
std::vector<int64_t> shape = tensor->GetShapeVector();
std::vector<int64_t> shape = new_shapes[input_data_idx_];
if (shape.size() != kInputDimNum) {
MS_LOG(ERROR) << "Shape size " << shape.size() << " is invalid, input index = " << input_data_idx_;
return lite::RET_ERROR;
return false;
}
auto format = acl_options_->input_format[input_data_idx_];
uint64_t height;
uint64_t width;
auto format = acl_options_.input_format[input_data_idx_];
int64_t height;
int64_t width;
if (format == mindspore::Format::NHWC) {
height = shape[kNHWCHeightIdx];
width = shape[kNHWCWidthIdx];
@ -254,41 +202,17 @@ int DynShapeProcess::GetRealImageSize(std::vector<KernelTensorPtr> *const inputs
height = shape[kNCHWHeightIdx];
width = shape[kNCHWWidthIdx];
}
auto cur_image_size = std::pair<int32_t, int32_t>(static_cast<uint64_t>(height), static_cast<uint64_t>(width));
auto iter = acl_options_->image_size.find(cur_image_size);
if (iter == acl_options_->image_size.end()) {
auto cur_image_size = std::pair<int32_t, int32_t>(static_cast<int32_t>(height), static_cast<int32_t>(width));
auto iter = acl_options_.image_size.find(cur_image_size);
if (iter == acl_options_.image_size.end()) {
MS_LOG(ERROR) << "Image size height " << height << ",weight " << width
<< " is invalid, please check device info of context.";
return lite::RET_ERROR;
return false;
}
if (num != kImageSizeHwNum) {
MS_LOG(ERROR) << "The hw num should be " << kImageSizeHwNum << ",real num " << num;
return lite::RET_ERROR;
}
image_size[0] = height;
image_size[1] = width;
*height_p = LongToInt(height);
*width_p = LongToInt(width);
MS_LOG(DEBUG) << "Current height " << height << " width " << width;
return lite::RET_OK;
}
void DynShapeProcess::DestroyDynamicInput(std::vector<KernelTensorPtr> *const inputs) {
if (inputs == nullptr) {
MS_LOG(ERROR) << "Inputs ptr is nullptr.";
return;
}
if (batch_size_ptr_ != nullptr && batch_size_ptr_->addr != nullptr) {
free(batch_size_ptr_->addr);
batch_size_ptr_->addr = nullptr;
batch_size_ptr_->size = 0;
}
if (image_size_ptr_ != nullptr && image_size_ptr_->addr != nullptr) {
free(image_size_ptr_->addr);
image_size_ptr_->addr = nullptr;
image_size_ptr_->size = 0;
}
if (!inputs->empty()) {
(*inputs).pop_back();
}
return true;
}
} // namespace acl
} // namespace mindspore::kernel

View File

@ -28,24 +28,18 @@ namespace mindspore::kernel {
namespace acl {
class DynShapeProcess {
public:
explicit DynShapeProcess(const AclModelOptionsPtr &options, size_t input_data_idx)
: acl_options_(options), input_data_idx_(input_data_idx), batch_size_ptr_(nullptr), image_size_ptr_(nullptr) {}
int ProcDynamicInput(std::vector<KernelTensorPtr> *const original_datas, std::vector<KernelTensorPtr> *const inputs);
void DestroyDynamicInput(std::vector<KernelTensorPtr> *const inputs);
bool Init(const AclDynamicShapeOptions &options, size_t input_data_idx);
bool CheckAndGetBatchSize(const std::vector<ShapeVector> &new_shapes, int32_t *batch_size);
bool CheckAndGetImageSize(const std::vector<ShapeVector> &new_shapes, int32_t *height, int32_t *width);
private:
int CheckBatchSize(std::vector<KernelTensorPtr> *const original_datas, std::vector<KernelTensorPtr> *const inputs);
int CheckImageSize(std::vector<KernelTensorPtr> *const original_datas, std::vector<KernelTensorPtr> *const inputs);
int AddBatchSizeInput(std::vector<KernelTensorPtr> *const original_datas, std::vector<KernelTensorPtr> *const inputs);
int AddImageSizeInput(std::vector<KernelTensorPtr> *const original_datas, std::vector<KernelTensorPtr> *const inputs);
int GetRealBatchSize(std::vector<KernelTensorPtr> *const inputs, int32_t *batch_size);
int GetRealImageSize(std::vector<KernelTensorPtr> *const inputs, int32_t *image_size, int32_t num);
bool CheckBatchSize(const std::vector<ShapeVector> &new_shapes);
bool CheckImageSize(const std::vector<ShapeVector> &new_shapes);
bool GetRealBatchSize(const std::vector<ShapeVector> &new_shapes, int32_t *batch_size);
bool GetRealImageSize(const std::vector<ShapeVector> &new_shapes, int32_t *height, int32_t *width);
AclModelOptionsPtr acl_options_;
AclDynamicShapeOptions acl_options_;
size_t input_data_idx_;
AddressPtr batch_size_ptr_;
AddressPtr image_size_ptr_;
};
using DynShapeProcPtr = std::shared_ptr<DynShapeProcess>;

View File

@ -30,32 +30,32 @@ ModelInfer::ModelInfer(const Buffer &om_data, const AclModelOptionsPtr &options)
model_process_(options),
acl_env_(nullptr) {}
STATUS ModelInfer::Init() {
bool ModelInfer::Init() {
if (init_flag_) {
MS_LOG(INFO) << "Acl has been initialized, skip.";
return lite::RET_OK;
return true;
}
if (options_ == nullptr) {
MS_LOG(ERROR) << "Acl options is nullptr.";
return lite::RET_ERROR;
return false;
}
acl_env_ = AclEnvGuard::GetAclEnv(options_->dump_cfg_path);
if (acl_env_ == nullptr) {
MS_LOG(ERROR) << "Acl init failed.";
return lite::RET_ERROR;
return false;
}
int32_t device_id = options_->device_id;
aclError ret = aclrtSetDevice(device_id);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Acl open device " << device_id << " failed.";
return lite::RET_ERROR;
return false;
}
MS_LOG(INFO) << "Open device " << device_id << " success.";
ret = aclrtCreateContext(&context_, device_id);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Acl create context failed.";
return lite::RET_ERROR;
return false;
}
MS_LOG(INFO) << "Create context success.";
@ -63,7 +63,7 @@ STATUS ModelInfer::Init() {
ret = aclrtGetRunMode(&run_mode);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Acl get run mode failed.";
return lite::RET_ERROR;
return false;
}
bool is_device = (run_mode == ACL_DEVICE);
model_process_.SetIsDevice(is_device);
@ -71,25 +71,24 @@ STATUS ModelInfer::Init() {
MS_LOG(INFO) << "Init model success, device id " << device_id;
init_flag_ = true;
return lite::RET_OK;
return true;
}
STATUS ModelInfer::Finalize() {
bool ModelInfer::Finalize() {
if (!init_flag_) {
MS_LOG(WARNING) << "Init is not ok, no need to finalize.";
return lite::RET_OK;
MS_LOG(INFO) << "Init is not ok, no need to finalize.";
return true;
}
aclError rt_ret = aclrtSetCurrentContext(context_);
if (rt_ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Set the ascend device context failed.";
return lite::RET_ERROR;
return false;
}
if (load_flag_) {
auto ret = model_process_.UnLoad();
if (ret != lite::RET_OK) {
if (!model_process_.UnLoad()) {
MS_LOG(ERROR) << "Unload model inner failed.";
return ret;
return false;
}
}
if (context_ != nullptr) {
@ -108,15 +107,14 @@ STATUS ModelInfer::Finalize() {
MS_LOG(INFO) << "End to reset device " << options_->device_id;
init_flag_ = false;
load_flag_ = false;
return lite::RET_OK;
return true;
}
STATUS ModelInfer::Load() {
bool ModelInfer::Load() {
if (!load_flag_) {
int ret = LoadAclModel(om_data_);
if (ret != lite::RET_OK) {
if (!model_process_.Load(om_data_)) {
MS_LOG(ERROR) << "Load model model failed.";
return ret;
return false;
}
load_flag_ = true;
}
@ -124,51 +122,25 @@ STATUS ModelInfer::Load() {
aclError rt_ret = aclrtSetCurrentContext(context_);
if (rt_ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Set the ascend device context failed, ret = " << rt_ret;
return lite::RET_ERROR;
return false;
}
return lite::RET_OK;
return true;
}
STATUS ModelInfer::LoadAclModel(const Buffer &om_data) {
MS_LOG(INFO) << "Start load model model.";
// model load model
uint32_t acl_model_id;
auto acl_ret = aclmdlLoadFromMem(om_data.Data(), om_data.DataSize(), &acl_model_id);
if (acl_ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Call aclmdlLoadFromMem failed, ret = " << acl_ret;
return lite::RET_ERROR;
}
// model init model resource
model_process_.set_model_id(acl_model_id);
int ret = model_process_.PreInitModelResource();
if (ret != lite::RET_OK) {
(void)aclmdlUnload(acl_model_id);
MS_LOG(ERROR) << "Pre init model resource failed.";
return ret;
}
MS_LOG(INFO) << "Load model model success.";
return lite::RET_OK;
}
STATUS ModelInfer::Inference(const std::vector<KernelTensorPtr> &inputs, const std::vector<KernelTensorPtr> &outputs) {
if (Load() != lite::RET_OK) {
bool ModelInfer::Inference(const std::vector<KernelTensorPtr> &inputs, const std::vector<KernelTensorPtr> &outputs) {
if (!Load()) {
MS_LOG(ERROR) << "Prepare model resource failed.";
return lite::RET_ERROR;
return false;
}
return model_process_.PredictFromHost(inputs, outputs);
}
std::set<uint64_t> ModelInfer::GetDynamicBatch() { return model_process_.GetDynamicBatch(); }
// need to be called after model load;
std::set<std::pair<uint64_t, uint64_t>> ModelInfer::GetDynamicImage() { return model_process_.GetDynamicImage(); }
std::vector<Format> ModelInfer::GetInputFormat() { return model_process_.GetInputFormat(); }
const std::vector<ShapeVector> ModelInfer::GetOutputShape() { return model_process_.GetOutputShape(); }
const std::vector<ShapeVector> ModelInfer::GetInputShape() { return model_process_.GetInputShape(); }
const std::vector<TypeId> ModelInfer::GetInputDataType() { return model_process_.GetInputDataType(); }
bool ModelInfer::Resize(const std::vector<ShapeVector> &new_shapes) { return model_process_.Resize(new_shapes); }
} // namespace acl
} // namespace mindspore::kernel

View File

@ -37,22 +37,18 @@ class ModelInfer {
ModelInfer(const Buffer &om_data, const AclModelOptionsPtr &options);
~ModelInfer() = default;
STATUS Init();
STATUS Finalize();
STATUS Load();
STATUS Inference(const std::vector<KernelTensorPtr> &inputs, const std::vector<KernelTensorPtr> &outputs);
// need to be called after model load
std::set<uint64_t> GetDynamicBatch();
// need to be called after model load
std::set<std::pair<uint64_t, uint64_t>> GetDynamicImage();
bool Init();
bool Finalize();
bool Load();
bool Inference(const std::vector<KernelTensorPtr> &inputs, const std::vector<KernelTensorPtr> &outputs);
std::vector<Format> GetInputFormat();
const std::vector<ShapeVector> GetOutputShape();
const std::vector<ShapeVector> GetInputShape();
const std::vector<TypeId> GetInputDataType();
private:
STATUS LoadAclModel(const Buffer &om_data);
bool Resize(const std::vector<ShapeVector> &new_shapes);
private:
bool init_flag_;
bool load_flag_;
std::string device_type_;

View File

@ -22,6 +22,8 @@
#include <map>
#include <set>
#include <utility>
#include <functional>
#include <memory>
#include "acl/acl.h"
#include "acl/acl_mdl.h"
#include "acl/acl_rt.h"
@ -29,14 +31,15 @@
#include "include/errorcode.h"
#include "kernel/kernel.h"
#include "extendrt/kernel/ascend/options/acl_model_options.h"
#include "extendrt/kernel/ascend/model/dyn_shape_process.h"
namespace mindspore::kernel {
namespace acl {
using mindspore::lite::STATUS;
struct AclTensorInfo {
void *cur_device_data;
void *device_data;
size_t buffer_size;
size_t malloc_buffer_size;
aclDataType data_type;
std::vector<int64_t> dims;
std::string name;
@ -44,26 +47,16 @@ struct AclTensorInfo {
class ModelProcess {
public:
explicit ModelProcess(const AclModelOptionsPtr &options)
: options_(options),
model_id_(0xffffffff),
is_run_on_device_(false),
model_desc_(nullptr),
inputs_(nullptr),
outputs_(nullptr),
input_infos_(),
output_infos_() {}
explicit ModelProcess(const AclModelOptionsPtr &options) : options_(options) {}
~ModelProcess() {}
STATUS UnLoad();
STATUS PredictFromHost(const std::vector<KernelTensorPtr> &inputs, const std::vector<KernelTensorPtr> &outputs);
STATUS PreInitModelResource();
bool Load(const Buffer &om_data);
bool UnLoad();
bool PredictFromHost(const std::vector<KernelTensorPtr> &inputs, const std::vector<KernelTensorPtr> &outputs);
// override this method to avoid request/reply data copy
void SetIsDevice(bool is_device) { is_run_on_device_ = is_device; }
void set_model_id(uint32_t model_id) { model_id_ = model_id; }
uint32_t model_id() const { return model_id_; }
std::set<uint64_t> GetDynamicBatch();
std::set<std::pair<uint64_t, uint64_t>> GetDynamicImage();
std::vector<Format> GetInputFormat();
@ -71,39 +64,45 @@ class ModelProcess {
const std::vector<ShapeVector> GetInputShape();
const std::vector<TypeId> GetInputDataType();
bool Resize(const std::vector<ShapeVector> &new_shapes);
private:
STATUS CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset);
STATUS CheckAndInitInput(const std::vector<KernelTensorPtr> &inputs);
STATUS CheckTensorByTensorInfo(const std::vector<KernelTensorPtr> &tensor,
const std::vector<AclTensorInfo> &tensor_info);
STATUS GetOutputs(const std::vector<KernelTensorPtr> &outputs);
void UpdateOutputInfo(const std::vector<KernelTensorPtr> &outputs);
STATUS ConstructTensor(const std::vector<KernelTensorPtr> &outputs);
STATUS SetBatchSize(const std::vector<KernelTensorPtr> &inputs);
STATUS SetImageSize(const std::vector<KernelTensorPtr> &inputs);
STATUS InitInputsBuffer();
STATUS InitOutputsBuffer();
STATUS ResetOutputSize();
STATUS ProcDynamicShape(const std::vector<KernelTensorPtr> &inputs);
std::string VectorToString(const std::vector<int64_t> &);
bool PreInitModelResource();
bool InitInputsBuffer();
bool InitOutputsBuffer();
void DestroyInputsBuffer();
void DestroyOutputsBuffer();
bool CreateDataBuffer(void **data_mem_buffer, size_t buffer_size, aclmdlDataset *dataset);
bool CheckAndInitInput(const std::vector<KernelTensorPtr> &inputs);
bool CheckAndInitOutput(const std::vector<KernelTensorPtr> &outputs);
bool CheckInputTensors(const std::vector<KernelTensorPtr> &inputs);
bool CheckOutputTensors(const std::vector<KernelTensorPtr> &outputs);
bool GetOutputs(const std::vector<KernelTensorPtr> &outputs);
bool ResetInputSize(const std::vector<ShapeVector> &new_shapes);
bool ResetOutputSize();
bool IsDynamicShape();
bool IsDynamicBatchSize();
bool IsDynamicImageSize();
void DestroyInputsDataset();
void DestroyInputsDataMem();
void DestroyInputsBuffer();
void DestroyOutputsBuffer();
void UpdateBufferSize(const std::vector<KernelTensorPtr> &inputs);
AclModelOptionsPtr options_;
uint32_t model_id_;
uint32_t model_id_ = UINT32_MAX;
// if run one device(AICPU), there is no need to alloc device memory and copy inputs to(/outputs from) device
bool is_run_on_device_;
aclmdlDesc *model_desc_;
aclmdlDataset *inputs_;
aclmdlDataset *outputs_;
bool is_run_on_device_ = false;
aclmdlDesc *model_desc_ = nullptr;
aclmdlDataset *inputs_ = nullptr;
aclmdlDataset *outputs_ = nullptr;
bool loaded_ = false;
size_t data_input_num_ = 0;
std::vector<AclTensorInfo> input_infos_;
std::vector<AclTensorInfo> output_infos_;
AclDynamicShapeOptions dynamic_shape_options_;
DynShapeProcess dyn_shape_proc_;
std::vector<ShapeVector> cur_input_shapes_;
};
} // namespace acl
} // namespace mindspore::kernel

View File

@ -29,11 +29,15 @@ namespace acl {
struct AclModelOptions {
int32_t device_id;
std::string dump_cfg_path;
AclModelOptions() : device_id(0) {}
};
struct AclDynamicShapeOptions {
std::set<uint64_t> batch_size;
std::set<std::pair<uint64_t, uint64_t>> image_size;
std::vector<Format> input_format;
AclModelOptions() : device_id(0) {}
std::vector<std::vector<int64_t>> input_shapes;
};
using AclModelOptionsPtr = std::shared_ptr<AclModelOptions>;

View File

@ -16,6 +16,7 @@
#include "extendrt/kernel/ascend/src/custom_ascend_kernel.h"
#include <utility>
#include <algorithm>
#include "acl/acl_base.h"
#include "acl/acl_rt.h"
#include "include/registry/register_kernel.h"
@ -30,12 +31,11 @@
namespace mindspore::kernel {
namespace acl {
CustomAscendKernelMod::CustomAscendKernelMod()
: load_model_(false), acl_options_(nullptr), dyn_shape_proc_(nullptr), model_infer_(nullptr), input_data_idx_(0) {}
: load_model_(false), acl_options_(nullptr), model_infer_(nullptr), input_data_idx_(0) {}
CustomAscendKernelMod::~CustomAscendKernelMod() {
if (load_model_) {
int ret = model_infer_->Finalize();
if (ret != lite::RET_OK) {
if (!model_infer_->Finalize()) {
MS_LOG(ERROR) << "Model finalize failed.";
}
}
@ -95,17 +95,6 @@ bool CustomAscendKernelMod::InitParam(const std::vector<KernelTensorPtr> &inputs
MS_LOG(ERROR) << "Create ModelInfer failed.";
return false;
}
RecordInputDataIndex(inputs);
dyn_shape_proc_ = std::make_shared<DynShapeProcess>(acl_options_, input_data_idx_);
if (dyn_shape_proc_ == nullptr) {
MS_LOG(ERROR) << "Create DynShapeProcess failed.";
return false;
}
if (inputs[idx]->GetData()->addr != nullptr) {
free(inputs[idx]->GetData()->addr);
inputs[idx]->GetData()->addr = nullptr;
inputs[idx]->GetData()->size = 0;
}
return true;
}
@ -113,7 +102,7 @@ bool CustomAscendKernelMod::Init(const BaseOperatorPtr &base_operator, const std
const std::vector<KernelTensorPtr> &outputs) {
if (load_model_) {
MS_LOG(INFO) << "Om has been loaded in custom kernel.";
return lite::RET_OK;
return true;
}
auto kernel_ptr = std::dynamic_pointer_cast<ops::Custom>(base_operator);
@ -125,32 +114,28 @@ bool CustomAscendKernelMod::Init(const BaseOperatorPtr &base_operator, const std
MS_LOG(ERROR) << "Init param failed.";
return false;
}
if (LoadModel() != lite::RET_OK) {
if (!LoadModel()) {
MS_LOG(ERROR) << "Load model failed.";
return false;
}
load_model_ = true;
return true;
}
int CustomAscendKernelMod::LoadModel() {
int ret = model_infer_->Init();
if (ret != lite::RET_OK) {
bool CustomAscendKernelMod::LoadModel() {
if (!model_infer_->Init()) {
MS_LOG(ERROR) << "Model infer init failed.";
return lite::RET_ERROR;
return false;
}
ret = model_infer_->Load();
if (ret != lite::RET_OK) {
if (!model_infer_->Load()) {
MS_LOG(ERROR) << "Load om data failed.";
return lite::RET_ERROR;
return false;
}
acl_options_->batch_size = model_infer_->GetDynamicBatch();
acl_options_->image_size = model_infer_->GetDynamicImage();
acl_options_->input_format = model_infer_->GetInputFormat();
UpdateInputKernelTensorInfo();
(void)RetrieveOutputShape();
MS_LOG(INFO) << "Load om data success.";
return lite::RET_OK;
return true;
}
int CustomAscendKernelMod::Resize(const BaseOperatorPtr &base_operator, const std::vector<KernelTensorPtr> &inputs,
@ -167,24 +152,26 @@ int CustomAscendKernelMod::Resize(const BaseOperatorPtr &base_operator, const st
MS_LOG(ERROR) << "inputs size is less than one.";
return lite::RET_ERROR;
}
original_data_ = inputs_;
inputs_.assign(inputs.begin(), inputs.end() - 1);
if (!OnNewInputShapes(inputs)) {
MS_LOG(ERROR) << "Failed to resize inputs";
return lite::RET_ERROR;
}
return lite::RET_OK;
}
template <typename T, typename U>
static int UpdateCheckInputNums(const std::vector<T> &update_info, const std::vector<U> &inputs,
size_t input_weight = 0) {
static bool UpdateCheckInputNums(const std::vector<T> &update_info, const std::vector<U> &inputs,
size_t input_weight = 0) {
if (update_info.empty()) {
MS_LOG(ERROR) << "check update info size empty";
return lite::RET_ERROR;
return false;
}
if (update_info.size() + input_weight != inputs.size()) {
MS_LOG(ERROR) << "update info size and inputs size check failed. update info size: " << update_info.size()
<< ". inputs' size: " << inputs.size() << ". input weight: " << input_weight;
return lite::RET_ERROR;
return false;
}
return lite::RET_OK;
return true;
}
// In DVPP, model input shape and data type get modified
@ -197,9 +184,8 @@ void CustomAscendKernelMod::UpdateInputKernelTensorInfo() {
const std::vector<TypeId> types = model_infer_->GetInputDataType();
const std::vector<Format> formats = model_infer_->GetInputFormat();
MS_LOG(INFO) << "check input kernel tensor info nums";
if ((UpdateCheckInputNums(shapes, inputs_) != lite::RET_OK) ||
(UpdateCheckInputNums(types, inputs_) != lite::RET_OK) ||
(UpdateCheckInputNums(formats, inputs_) != lite::RET_OK)) {
if (!UpdateCheckInputNums(shapes, inputs_) || !UpdateCheckInputNums(types, inputs_) ||
!UpdateCheckInputNums(formats, inputs_)) {
return;
}
@ -218,80 +204,64 @@ std::vector<KernelTensorPtr> CustomAscendKernelMod::GetInputKernelTensor() {
return inputs_;
}
int CustomAscendKernelMod::SetInputAndOutputAddr(const std::vector<AddressPtr> &inputs,
const std::vector<AddressPtr> &outputs) {
if ((inputs_.size() + 1) != inputs.size()) {
MS_LOG(ERROR) << "Size of inputs in init [" << (inputs_.size() + 1) << "] and "
<< "size of inputs in launch [" << inputs.size() << "] are not equal.";
return lite::RET_ERROR;
}
if (outputs_.size() != outputs.size()) {
MS_LOG(ERROR) << "Size of outputs in init (" << outputs_.size() << ") and "
<< "size of outputs in launch (" << outputs.size() << ") are not equal.";
return lite::RET_ERROR;
bool CustomAscendKernelMod::ResetInputOutputShapes() {
auto input_shapes = model_infer_->GetInputShape();
if (input_shapes.size() != inputs_.size()) {
MS_LOG(ERROR) << "The number of input shapes size " << input_shapes.size() << " != the number of inputs "
<< inputs_.size();
return false;
}
for (size_t i = 0; i < inputs_.size(); ++i) {
if (inputs[i] == nullptr || inputs_[i] == nullptr) {
MS_LOG(ERROR) << "Input " << i << " is nullptr.";
return lite::RET_ERROR;
}
if (inputs[i]->addr == nullptr || inputs[i]->size == 0) {
MS_LOG(ERROR) << "Input " << i << " addr is invalid.";
return lite::RET_ERROR;
}
inputs_[i]->SetData(inputs[i]);
inputs_[i]->SetShapeVector(input_shapes[i]);
}
auto output_shapes = model_infer_->GetOutputShape();
if (output_shapes.size() != outputs_.size()) {
MS_LOG(ERROR) << "The number of output shapes size " << output_shapes.size() << " != the number of outputs "
<< outputs_.size();
return false;
}
for (size_t i = 0; i < outputs_.size(); ++i) {
if (outputs[i] == nullptr || outputs_[i] == nullptr) {
MS_LOG(ERROR) << "Output " << i << " is nullptr.";
return lite::RET_ERROR;
}
outputs_[i]->SetData(outputs[i]);
}
return lite::RET_OK;
}
bool CustomAscendKernelMod::IsDynamicInput() {
if (acl_options_->batch_size.empty() && acl_options_->image_size.empty()) {
MS_LOG(INFO) << "Inputs are not dynamic mode.";
return false;
outputs_[i]->SetShapeVector(output_shapes[i]);
}
return true;
}
void CustomAscendKernelMod::UpdateOutputAddr(const std::vector<AddressPtr> &outputs) {
for (size_t i = 0; i < outputs.size(); ++i) {
if ((outputs[i]->addr != outputs_[i]->GetData()->addr) || (outputs[i]->size != outputs_[i]->GetData()->size)) {
outputs[i]->addr = outputs_[i]->GetData()->addr;
outputs[i]->size = outputs_[i]->GetData()->size;
bool CustomAscendKernelMod::OnNewInputShapes(const std::vector<KernelTensorPtr> &new_inputs) {
auto input_shapes = model_infer_->GetInputShape();
if (input_shapes.size() != new_inputs.size()) {
MS_LOG(ERROR) << "Invalid new input size " << new_inputs.size() << ", expect input size " << input_shapes.size();
return false;
}
bool input_shape_changed = false;
for (size_t i = 0; i < new_inputs.size(); i++) {
auto new_shape = new_inputs[i]->GetShapeVector();
if (input_shapes[i] != new_shape) {
input_shape_changed = true;
}
}
if (!input_shape_changed) {
return true;
}
std::vector<ShapeVector> new_shapes;
std::transform(new_inputs.begin(), new_inputs.end(), std::back_inserter(new_shapes),
[](auto &t) { return t->GetShapeVector(); });
if (!model_infer_->Resize(new_shapes)) {
MS_LOG(ERROR) << "Failed to Resize";
return false;
}
return ResetInputOutputShapes();
}
bool CustomAscendKernelMod::Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
const std::vector<AddressPtr> &outputs, void *stream_ptr) {
bool CustomAscendKernelMod::Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &, void *) {
if (!load_model_) {
MS_LOG(ERROR) << "Init custom ascend kernel has been not ready.";
return false;
}
if (SetInputAndOutputAddr(inputs, outputs) != lite::RET_OK) {
MS_LOG(ERROR) << "Check input and output param failed.";
return false;
}
if (IsDynamicInput()) {
if (dyn_shape_proc_->ProcDynamicInput(&original_data_, &inputs_) != lite::RET_OK) {
MS_LOG(ERROR) << "Proc dynamic batch size input failed.";
return false;
}
}
if (model_infer_->Inference(inputs_, outputs_) != lite::RET_OK) {
if (!model_infer_->Inference(inputs_, outputs_)) {
MS_LOG(ERROR) << "Custom kernel execute failed.";
return false;
}
if (IsDynamicInput()) {
dyn_shape_proc_->DestroyDynamicInput(&inputs_);
}
UpdateOutputAddr(outputs);
return true;
}

View File

@ -23,7 +23,6 @@
#include <map>
#include "extendrt/kernel/ascend/options/acl_model_options.h"
#include "extendrt/kernel/ascend/model/model_infer.h"
#include "extendrt/kernel/ascend/model/dyn_shape_process.h"
#include "include/api/types.h"
#include "include/api/context.h"
#include "kernel/kernel.h"
@ -56,18 +55,18 @@ class CustomAscendKernelMod : public kernel::KernelMod {
void RecordInputDataIndex(const std::vector<KernelTensorPtr> &inputs);
void SetDeviceId();
bool InitParam(const std::vector<KernelTensorPtr> &inputs, const std::vector<KernelTensorPtr> &outputs);
int SetInputAndOutputAddr(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &outputs);
int LoadModel();
bool IsDynamicInput();
bool LoadModel();
void UpdateOutputAddr(const std::vector<AddressPtr> &outputs);
void UpdateInputKernelTensorInfo();
bool ResetInputOutputShapes();
bool OnNewInputShapes(const std::vector<KernelTensorPtr> &new_shapes);
bool load_model_;
std::vector<KernelTensorPtr> original_data_;
std::vector<KernelTensorPtr> inputs_;
std::vector<KernelTensorPtr> outputs_;
AclModelOptionsPtr acl_options_;
DynShapeProcPtr dyn_shape_proc_;
ModelInferPtr model_infer_;
size_t input_data_idx_;
};

View File

@ -18,6 +18,7 @@
#include <functional>
#include <string>
#include <vector>
#include <map>
#include "src/extendrt/session/single_op_session.h"
#include "src/extendrt/infer_device_address.h"
@ -32,6 +33,8 @@
#include "extendrt/session/factory.h"
#include "extendrt/utils/runtime_utils.h"
#include "extendrt/utils/tensor_default_impl.h"
#include "extendrt/utils/func_graph_utils.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore {
const size_t tensor_max_size = 0x1000000;
@ -51,11 +54,6 @@ Status SingleOpInferSession::AscendInit(const std::shared_ptr<Context> &context)
auto ascend_device_info = device_info->Cast<mindspore::AscendDeviceInfo>();
MS_EXCEPTION_IF_NULL(ascend_device_info);
device_id_ = ascend_device_info->GetDeviceID();
// AIPP config path is specified, DVPP mode
if (ascend_device_info->GetInsertOpConfigPath() != "") {
is_dvpp_ = true;
}
return kSuccess;
}
}
@ -66,7 +64,6 @@ Status SingleOpInferSession::AscendInit(const std::shared_ptr<Context> &context)
Status SingleOpInferSession::Init(const std::shared_ptr<Context> &context) {
MS_LOG(INFO) << "SingleOpInferSession::Init";
MS_EXCEPTION_IF_NULL(context);
kernel_graph_utils_ = std::make_shared<mindspore::KernelGraphUtils>();
if (AscendInit(context) != kSuccess) {
MS_LOG(ERROR) << "Init ascend failed.";
return kLiteError;
@ -74,200 +71,149 @@ Status SingleOpInferSession::Init(const std::shared_ptr<Context> &context) {
return kSuccess;
}
void InitInputSizeList(const std::shared_ptr<CNode> &kernel_node, std::vector<size_t> *input_size_list) {
MS_EXCEPTION_IF_NULL(input_size_list);
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t input_index = 0; input_index < input_num; ++input_index) {
TypeId type_id = AnfAlgo::GetInputDeviceDataType(kernel_node, input_index);
size_t type_size = GetTypeByte(TypeIdToType(type_id));
auto shape = AnfAlgo::GetInputDeviceShape(kernel_node, input_index);
size_t tensor_size;
if (std::any_of(shape.begin(), shape.end(), [](int64_t tmp) { return tmp < 0; })) {
tensor_size = type_size;
} else {
tensor_size =
shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
}
tensor_size = std::max(tensor_size, type_size);
(void)input_size_list->emplace_back(tensor_size);
Status SingleOpInferSession::BuildCustomAscendKernel(const CNodePtr &cnode) {
auto kernel_name = kNameCustomAscend;
std::shared_ptr<kernel::KernelMod> kernel_mod = kernel::Factory<kernel::KernelMod>::Instance().Create(kernel_name);
if (kernel_mod == nullptr) {
MS_LOG(ERROR) << "Kernel mod is nullptr, kernel name: " << kernel_name;
return mindspore::kLiteError;
}
}
MS_LOG(INFO) << "SingleOpInferSession::Kernels " << kernel_name;
kernel_mod->SetDevicedId(device_id_);
Status SingleOpInferSession::UpdateKernelGraphInputs(const std::vector<std::vector<int64_t>> &dims,
const std::vector<TypeId> &type_ids, bool use_type_from_graph) {
auto graph_inputs = RuntimeUtils::GetGraphDataInputs(kernel_graph_);
if (graph_inputs.size() != dims.size()) {
MS_LOG(ERROR) << "Number of graph inputs [" << graph_inputs.size() << "] is not equal to the given dims num ["
<< dims.size() << "]";
return kLiteError;
auto make_kernel_tensor = [](TypeId type_id, const ShapeVector &shape) {
auto kernel_tensor = std::make_shared<kernel::KernelTensor>();
auto abstract = abstract::MakeAbstract(std::make_shared<abstract::Shape>(shape),
std::make_shared<TensorType>(TypeIdToType(type_id)));
kernel::TensorInfo tensor_info;
tensor_info.abstract_base = abstract;
tensor_info.device_shape_adaptively = shape;
kernel_tensor->SetAbstract(abstract);
return kernel_tensor;
};
kernel::KernelArgs args;
if (!FuncGraphUtils::GetCNodeOperator(cnode, &args.op)) {
MS_LOG(ERROR) << "Failed to create operator for cnode " << cnode->fullname_with_scope();
return mindspore::kLiteError;
}
if (!use_type_from_graph && (graph_inputs.size() != type_ids.size())) {
MS_LOG(ERROR) << "Number of graph inputs [" << graph_inputs.size() << "] is not equal to the given type ids num ["
<< type_ids.size() << "]";
return kLiteError;
std::vector<tensor::TensorPtr> tensor_cache;
std::map<AnfWithOutIndex, kernel::KernelTensorPtr> kernel_tensor_map;
std::vector<AnfWithOutIndex> inputs;
std::vector<AnfWithOutIndex> outputs;
FuncGraphUtils::GetCNodeInputsOutputs(cnode, &inputs, &outputs);
for (size_t i = 0; i < inputs.size(); i++) {
auto &input = inputs[i];
auto data_type = FuncGraphUtils::GetTensorDataType(input);
auto shape = FuncGraphUtils::GetTensorShape(input);
auto kernel_tensor = make_kernel_tensor(static_cast<TypeId>(data_type), shape);
auto tensor_data = FuncGraphUtils::GetConstNodeValue(input.first);
if (tensor_data) {
tensor_cache.push_back(tensor_data);
kernel_tensor->SetData(std::make_shared<kernel::Address>(tensor_data->data_c(), tensor_data->Size()));
}
args.inputs.push_back(kernel_tensor);
kernel_tensor_map[input] = kernel_tensor;
}
for (size_t i = 0; i < graph_inputs.size(); ++i) {
auto &graph_input = graph_inputs[i];
if (utils::isa<mindspore::abstract::AbstractTuplePtr>(graph_input->abstract())) {
MS_LOG(ERROR) << "The abstract of input does not support abstract tuple.";
return kLiteError;
for (size_t i = 0; i < outputs.size(); i++) {
auto &output = outputs[i];
kernel::KernelTensorPtr kernel_tensor;
auto it = kernel_tensor_map.find(output);
if (it != kernel_tensor_map.end()) { // use input as output
kernel_tensor = it->second;
} else {
auto data_type = FuncGraphUtils::GetTensorDataType(output);
auto shape = FuncGraphUtils::GetTensorShape(output);
kernel_tensor = make_kernel_tensor(static_cast<TypeId>(data_type), shape);
}
auto graph_input_addr = AnfAlgo::GetMutableOutputAddr(graph_input, 0);
if (graph_input_addr == nullptr) {
MS_LOG(ERROR) << "Graph input addr is nullptr.";
return kLiteError;
}
TypeId type_id = graph_input_addr->type_id();
if (!use_type_from_graph) {
type_id = type_ids[i];
}
size_t type_size = GetTypeByte(TypeIdToType(type_id));
const std::vector<int64_t> &dim = dims[i];
size_t tensor_size =
dim.empty() ? type_size : std::accumulate(dim.begin(), dim.end(), type_size, std::multiplies<size_t>());
// update input size
if (graph_input_addr->ptr_ != nullptr) {
free(graph_input_addr->ptr_);
auto new_addr = malloc(tensor_size);
if (new_addr == nullptr) {
MS_LOG(ERROR) << " malloc memory of input " << i << " failed, memory size " << tensor_size;
return kLiteError;
}
graph_input_addr->set_ptr(new_addr);
graph_input_addr->SetSize(tensor_size);
}
// update input shape
auto abstract = std::make_shared<abstract::AbstractTensor>(TypeIdToType(type_id), dim);
if (abstract == nullptr) {
MS_LOG(ERROR) << "Abstract is nullptr.";
return kLiteError;
}
graph_input->set_abstract(abstract);
args.outputs.push_back(kernel_tensor);
}
auto ret = kernel_mod->Init(args.op, args.inputs, args.outputs);
MS_LOG(INFO) << "SingleOpInferSession::Kernels ret " << ret;
if (!ret) {
MS_LOG(ERROR) << "kernel init failed " << kernel_name;
return mindspore::kLiteError;
}
// remove const input, OM graph data input
args.inputs = kernel_mod->GetInputKernelTensor();
args.outputs = kernel_mod->RetrieveOutputShape();
kernel_mod_ = kernel_mod;
kernel_args_ = args;
return kSuccess;
}
Status SingleOpInferSession::UpdateGraphInputsForDVPP(const std::vector<kernel::KernelTensorPtr> &inputs) {
std::vector<std::vector<int64_t>> dims = {};
std::vector<TypeId> type_ids = {};
for (auto &input : inputs) {
dims.push_back(input->GetShapeVector());
type_ids.push_back(input->GetDtype());
Status SingleOpInferSession::InitInputOutputInfos(const FuncGraphPtr &graph) {
std::vector<AnfWithOutIndex> input_tensors;
std::vector<AnfWithOutIndex> output_tensors;
FuncGraphUtils::GetFuncGraphInputs(graph, &input_tensors);
FuncGraphUtils::GetFuncGraphOutputs(graph, &output_tensors);
if (kernel_args_.inputs.size() != input_tensors.size()) {
MS_LOG(ERROR) << "Graph inputs size " << input_tensors.size() << " != custom inputs size "
<< kernel_args_.inputs.size();
return kCoreFailed;
}
auto ret = UpdateKernelGraphInputs(dims, type_ids, false);
if (ret != kSuccess) {
return ret;
if (kernel_args_.outputs.size() != output_tensors.size()) {
MS_LOG(ERROR) << "Graph outputs size " << output_tensors.size() << " != custom inputs size "
<< kernel_args_.outputs.size();
return kCoreFailed;
}
for (size_t i = 0; i < inputs.size(); i++) {
// update session inputs_
auto data_type = static_cast<enum DataType>(type_ids[i]);
auto impl = std::make_shared<TensorDefaultImpl>(input_names_[i], data_type, dims[i]);
impl->SetFormat(inputs[i]->GetFormat());
inputs_.push_back(impl);
for (size_t i = 0; i < input_tensors.size(); i++) {
auto &tensor = input_tensors[i];
auto &kernel_tensor = kernel_args_.inputs[i];
auto tensor_name = FuncGraphUtils::GetTensorName(tensor);
auto data_type = static_cast<DataType>(kernel_tensor->GetDtype());
auto shape = kernel_tensor->GetShapeVector();
inputs_.push_back(std::make_shared<TensorDefaultImpl>(tensor_name, data_type, shape));
input_names_.push_back(FuncGraphUtils::GetTensorName(tensor));
}
for (size_t i = 0; i < output_tensors.size(); i++) {
auto &tensor = output_tensors[i];
auto &kernel_tensor = kernel_args_.outputs[i];
auto tensor_name = FuncGraphUtils::GetTensorName(tensor);
auto data_type = static_cast<DataType>(kernel_tensor->GetDtype());
auto shape = kernel_tensor->GetShapeVector();
outputs_.push_back(std::make_shared<TensorDefaultImpl>(tensor_name, data_type, shape));
output_names_.push_back(FuncGraphUtils::GetTensorName(tensor));
}
return kSuccess;
}
Status SingleOpInferSession::CompileGraph(FuncGraphPtr graph, const void *data, size_t size) {
MS_LOG(INFO) << "SingleOpInferSession::CompileGraph";
std::vector<KernelGraphPtr> all_out_graph;
kernel_graph_ = kernel_graph_utils_->ConstructKernelGraph(graph, &all_out_graph, mindspore::device::DeviceType::kCPU);
MS_EXCEPTION_IF_NULL(kernel_graph_);
auto &nodes = kernel_graph_->nodes();
auto nodes = graph->TopoSort(graph->get_return());
if (nodes.empty()) {
MS_LOG(ERROR) << "There are no nodes in the graph";
return mindspore::kLiteNullptr;
}
size_t cnode_count = 0;
for (const auto &node : nodes) {
std::string node_name = common::AnfAlgo::GetCNodeName(node);
MS_LOG(INFO) << "SingleOpInferSession::Nodes " << node_name;
}
auto &kernel_nodes = kernel_graph_->execution_order();
bool update_flag = false;
std::vector<kernel::KernelTensorPtr> update_inputs;
std::vector<kernel::KernelTensorPtr> update_outputs;
for (const auto &kernel_node : kernel_nodes) {
mindspore::infer::SetKernelInfo(kernel_node);
std::string kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
std::shared_ptr<kernel::KernelMod> kernel_mod = kernel::Factory<kernel::KernelMod>::Instance().Create(kernel_name);
if (kernel_mod == nullptr) {
MS_LOG(EXCEPTION) << "Kernel mod is nullptr, kernel name: " << kernel_name;
auto cnode = node->cast<CNodePtr>();
if (!cnode || !AnfUtils::IsRealKernel(cnode)) {
continue;
}
MS_LOG(INFO) << "SingleOpInferSession::Kernels " << kernel_name;
if (kernel_name == kNameCustomAscend) {
kernel_mod->SetDevicedId(device_id_);
std::string kernel_name = common::AnfAlgo::GetCNodeName(cnode);
if (kernel_name != kNameCustomAscend) {
MS_LOG(ERROR) << "Only support " << kNameCustomAscend << ", but got " << kernel_name << ", node "
<< cnode->fullname_with_scope();
return kLiteError;
}
auto args = kernel::AbstractArgsFromCNode(kernel_node);
mindspore::infer::CopyInputWeights(kernel_node, args.inputs);
auto ret = kernel_mod->Init(args.op, args.inputs, args.outputs);
MS_LOG(INFO) << "SingleOpInferSession::Kernels ret " << ret;
if (!ret) {
MS_LOG(EXCEPTION) << "kernel init failed " << kernel_name;
cnode_count += 1;
if (cnode_count > 1) {
MS_LOG(ERROR) << "Only support one " << kNameCustomAscend << " node, but got " << kernel_name << ", node "
<< cnode->fullname_with_scope();
return kLiteError;
}
std::vector<size_t> input_size_list;
std::vector<size_t> output_size_list;
InitInputSizeList(kernel_node, &input_size_list);
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t output_index = 0; output_index < output_num; ++output_index) {
TypeId type_id = AnfAlgo::GetOutputDeviceDataType(kernel_node, output_index);
size_t type_size = GetTypeByte(TypeIdToType(type_id));
auto shape = AnfAlgo::GetOutputDeviceShape(kernel_node, output_index);
size_t tensor_size =
shape.empty() ? type_size : std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
tensor_size = std::max(tensor_size, type_size);
(void)output_size_list.emplace_back(tensor_size);
}
kernel_mod->SetInputSizeList(input_size_list);
kernel_mod->SetOutputSizeList(output_size_list);
AnfAlgo::SetKernelMod(kernel_mod, kernel_node.get());
if (kernel_name == kNameCustomAscend) {
if (is_dvpp_) {
update_inputs = kernel_mod->GetInputKernelTensor();
}
update_flag = true;
update_outputs = kernel_mod->RetrieveOutputShape();
}
}
RuntimeUtils::AssignKernelGraphAddress(kernel_graph_);
std::vector<tensor::TensorPtr> graph_inputs, graph_outputs;
kernel_graph_utils_->GetModelInputsInfo(kernel_graph_->graph_id(), &graph_inputs, &input_names_);
kernel_graph_utils_->GetModelOutputsInfo(kernel_graph_->graph_id(), &graph_outputs, &output_names_);
if (graph_inputs.size() != input_names_.size()) {
MS_LOG(ERROR) << "Graph input size " << graph_inputs.size() << " != input names size " << input_names_.size();
return kCoreFailed;
}
if (graph_outputs.size() != output_names_.size()) {
MS_LOG(ERROR) << "Graph output size " << graph_outputs.size() << " != output names size " << output_names_.size();
return kCoreFailed;
}
if (is_dvpp_) {
MS_LOG(INFO) << "Update input kernel tensor shape, data type, and format for CustomAscend DVPP";
auto ret = UpdateGraphInputsForDVPP(update_inputs);
auto ret = BuildCustomAscendKernel(cnode);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Failed to Build custom ascend kernel";
return ret;
}
} else {
for (size_t i = 0; i < input_names_.size(); i++) {
auto &input = graph_inputs[i];
auto data_type = static_cast<enum DataType>(input->data_type());
auto impl = std::make_shared<TensorDefaultImpl>(input_names_[i], data_type, input->shape_c());
inputs_.push_back(impl);
}
}
for (size_t i = 0; i < output_names_.size(); i++) {
auto &output = graph_outputs[i];
auto data_type = static_cast<enum DataType>(output->data_type());
auto impl = std::make_shared<TensorDefaultImpl>(output_names_[i], data_type, output->shape_c());
outputs_.push_back(impl);
}
if (update_flag) {
for (size_t i = 0; i < update_outputs.size(); ++i) {
outputs_.at(i)->SetShape(update_outputs.at(i)->GetShapeVector());
}
auto ret = InitInputOutputInfos(graph);
if (ret != kSuccess) {
MS_LOG(ERROR) << "Failed to init graph input and output infos";
return ret;
}
return kSuccess;
}
@ -278,87 +224,129 @@ Status SingleOpInferSession::RunGraph(const std::vector<tensor::Tensor> &inputs,
}
Status SingleOpInferSession::RunGraph(const std::vector<tensor::Tensor> &inputs, std::vector<tensor::Tensor> *outputs) {
MS_LOG(INFO) << "SingleOpInferSession::RunGraph with input and outputs";
MS_EXCEPTION_IF_NULL(kernel_graph_);
RuntimeUtils::CopyInputTensorsToKernelGraph(inputs, kernel_graph_);
auto &kernel_nodes = kernel_graph_->execution_order();
for (const auto &kernel_node : kernel_nodes) {
std::string kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
MS_LOG(INFO) << "SingleOpInferSession::RunGraph " << kernel_name;
auto kernel_mod = AnfAlgo::GetKernelMod(kernel_node);
MS_EXCEPTION_IF_NULL(kernel_mod);
std::vector<kernel::AddressPtr> kernel_inputs;
size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t i = 0; i < input_num; ++i) {
auto device_address = AnfAlgo::GetPrevNodeMutableOutputAddr(kernel_node, i);
auto input = RuntimeUtils::GetAddressFromDevice(device_address);
kernel_inputs.push_back(input);
}
std::vector<kernel::AddressPtr> kernel_outputs;
size_t output_num = common::AnfAlgo::GetOutputTensorNum(kernel_node);
for (size_t i = 0; i < output_num; ++i) {
auto device_address = AnfAlgo::GetMutableOutputAddr(kernel_node, i);
auto output = RuntimeUtils::GetAddressFromDevice(device_address);
kernel_outputs.push_back(output);
}
std::vector<kernel::AddressPtr> kernel_workspaces;
bool ret = true;
try {
ret = kernel_mod->Launch(kernel_inputs, kernel_workspaces, kernel_outputs, 0);
RuntimeUtils::UpdateKernelNodeOutputInfo(kernel_node, kernel_outputs);
} catch (std::exception &e) {
MS_LOG(EXCEPTION) << e.what();
}
if (!ret) {
MS_LOG(EXCEPTION) << "Launch kernel failed.";
}
}
RuntimeUtils::CopyOutputTensorsFromKernelGraph(outputs, kernel_graph_);
return kSuccess;
}
Status SingleOpInferSession::ResizeGraphInputs(const std::vector<tensor::Tensor> &inputs,
const std::vector<std::vector<int64_t>> &dims) {
if (inputs_.size() != inputs.size()) {
MS_LOG(ERROR) << "Graph inputs tensor size[" << inputs_.size() << " is not equal with user input tensor size["
<< inputs.size() << "]";
if (outputs == nullptr) {
MS_LOG(ERROR) << "outputs cannot be nullptr";
return kLiteError;
}
auto ret = UpdateKernelGraphInputs(dims, {}, true);
MS_LOG(INFO) << "SingleOpInferSession::RunGraph with input and outputs";
std::vector<ShapeVector> new_shapes;
std::transform(inputs.begin(), inputs.end(), std::back_inserter(new_shapes), [](auto &t) { return t.shape_c(); });
auto ret = OnNewInputShapes(new_shapes);
if (ret != kSuccess) {
return ret;
}
if (inputs.size() != kernel_args_.inputs.size()) {
MS_LOG(ERROR) << "Given inputs size " << inputs.size() << " != graph inputs size " << kernel_args_.inputs.size();
return kLiteError;
}
for (size_t i = 0; i < inputs.size(); i++) {
// update session inputs_
inputs_[i]->SetShape(dims[i]);
auto &input = inputs[i];
auto &kernel_input = kernel_args_.inputs[i];
if (input.Size() != kernel_input->GetSizeInBytes()) {
MS_LOG(ERROR) << "Byte size of input " << i << " != the size expected, given size " << input.Size()
<< ", expected size " << kernel_input->GetSizeInBytes()
<< ", input shape: " << kernel_input->GetShapeVector();
return kLiteError;
}
auto input_device_address = input.device_address();
if (input_device_address != nullptr && input_device_address->GetMutablePtr() != nullptr) {
auto device_ptr = input_device_address->GetMutablePtr();
kernel_args_.inputs[i]->SetData(std::make_shared<kernel::Address>(device_ptr, input.Size()));
kernel_args_.inputs[i]->SetHostData(nullptr);
} else {
kernel_args_.inputs[i]->SetHostData(std::make_shared<kernel::Address>(input.data_c(), input.Size()));
kernel_args_.inputs[i]->SetData(nullptr);
}
}
if (outputs->empty()) {
std::transform(kernel_args_.outputs.begin(), kernel_args_.outputs.end(), std::back_inserter(*outputs),
[](auto &item) { return tensor::Tensor(item->GetDtype(), item->GetShapeVector()); });
}
if (outputs->size() != kernel_args_.outputs.size()) {
MS_LOG(ERROR) << "Given outputs size " << outputs->size() << " != graph inputs size "
<< kernel_args_.outputs.size();
return kLiteError;
}
for (size_t i = 0; i < outputs->size(); i++) {
auto &output = (*outputs)[i];
auto &kernel_output = kernel_args_.outputs[i];
if (output.Size() != kernel_output->GetSizeInBytes()) {
MS_LOG(ERROR) << "Byte size of output " << i << " != the size expected, given size " << output.Size()
<< ", expected size " << kernel_output->GetSizeInBytes()
<< ", output shape: " << kernel_output->GetShapeVector();
return kLiteError;
}
auto output_device_address = output.device_address();
if (output_device_address != nullptr && output_device_address->GetMutablePtr() != nullptr) {
auto device_ptr = output_device_address->GetMutablePtr();
kernel_args_.outputs[i]->SetData(std::make_shared<kernel::Address>(device_ptr, output.Size()));
kernel_args_.outputs[i]->SetHostData(nullptr);
} else {
kernel_args_.outputs[i]->SetHostData(std::make_shared<kernel::Address>(output.data_c(), output.Size()));
kernel_args_.outputs[i]->SetData(nullptr);
}
}
if (kernel_mod_ == nullptr) {
MS_LOG(ERROR) << "Model has not been built";
return kLiteError;
}
try {
std::vector<kernel::AddressPtr> ignore_datas;
if (!kernel_mod_->Launch(ignore_datas, ignore_datas, ignore_datas, nullptr)) {
MS_LOG(ERROR) << "Failed to launch kernel";
return kLiteError;
}
} catch (std::exception &e) {
MS_LOG(ERROR) << "Failed to launch kernel, exception: " << e.what();
return kLiteError;
}
return kSuccess;
}
Status SingleOpInferSession::Resize(const std::vector<tensor::Tensor> &inputs,
const std::vector<std::vector<int64_t>> &dims) {
if (ResizeGraphInputs(inputs, dims) != kSuccess) {
MS_LOG(EXCEPTION) << "Resize graph input error. ";
Status SingleOpInferSession::OnNewInputShapes(const std::vector<ShapeVector> &new_shapes) {
if (inputs_.size() != new_shapes.size()) {
MS_LOG(ERROR) << "Graph inputs size " << inputs_.size() << " != resize input size " << new_shapes.size();
return kLiteError;
}
auto &kernel_nodes = kernel_graph_->execution_order();
for (const auto &kernel_node : kernel_nodes) {
std::string kernel_name = common::AnfAlgo::GetCNodeName(kernel_node);
MS_LOG(INFO) << "SingleOpInferSession::Resize " << kernel_name;
auto kernel_mod = AnfAlgo::GetKernelMod(kernel_node);
if (kernel_mod == nullptr) {
MS_LOG(EXCEPTION) << "Kernel mod is nullptr, kernel name: " << kernel_name;
auto input_changed = false;
for (size_t i = 0; i < inputs_.size(); i++) {
auto new_shape = new_shapes[i];
if (std::any_of(new_shape.begin(), new_shape.end(), [](auto dim) { return dim < 0; })) {
MS_LOG(ERROR) << "New shape of input " << i << " cannot be dynamic, new shape: " << new_shape;
return kLiteError;
}
auto args = kernel::AbstractArgsFromCNode(kernel_node);
if (kernel_mod->Resize(args.op, args.inputs, args.outputs) != kSuccess) {
MS_LOG(EXCEPTION) << "Kernel mod resize failed, kernel name: " << kernel_name;
if (inputs_[i]->Shape() != new_shapes[i]) {
input_changed = true;
kernel_args_.inputs[i]->SetShapeVector(new_shapes[i]);
}
}
if (!input_changed) {
return kSuccess;
}
MS_LOG(INFO) << "SingleOpInferSession::Resize";
if (kernel_mod_ == nullptr) {
MS_LOG(ERROR) << "Model has not been built";
return kLiteError;
}
if (kernel_mod_->Resize(kernel_args_.op, kernel_args_.inputs, kernel_args_.outputs) != kSuccess) {
MS_LOG(ERROR) << "Failed to resize custom ascend kernel";
return kLiteError;
}
// shapes of inputs and outputs should be updated in CustomAscendKernelMod::Resize
for (size_t i = 0; i < inputs_.size(); i++) {
inputs_[i]->SetShape(kernel_args_.inputs[i]->GetShapeVector());
}
for (size_t i = 0; i < outputs_.size(); i++) {
outputs_[i]->SetShape(kernel_args_.outputs[i]->GetShapeVector());
}
return kSuccess;
}
Status SingleOpInferSession::Resize(const std::vector<tensor::Tensor> &,
const std::vector<std::vector<int64_t>> &dims) {
return OnNewInputShapes(dims);
}
std::vector<MutableTensorImplPtr> SingleOpInferSession::GetOutputs() { return outputs_; }
std::vector<MutableTensorImplPtr> SingleOpInferSession::GetInputs() { return inputs_; }
std::vector<std::string> SingleOpInferSession::GetOutputNames() { return output_names_; }

View File

@ -21,6 +21,7 @@
#include <vector>
#include "src/extendrt/infer_session.h"
#include "extendrt/utils/kernel_graph_utils.h"
#include "mindspore/ccsrc/kernel/common_utils.h"
namespace mindspore {
/// \brief Single Op Session implementation, used in Ascend Device Context.
@ -43,19 +44,18 @@ class SingleOpInferSession : public InferSession {
MutableTensorImplPtr GetInputByTensorName(const std::string &name) override;
private:
Status UpdateKernelGraphInputs(const std::vector<std::vector<int64_t>> &dims, const std::vector<TypeId> &type_ids,
bool use_type_from_graph);
Status UpdateGraphInputsForDVPP(const std::vector<kernel::KernelTensorPtr> &inputs);
Status ResizeGraphInputs(const std::vector<tensor::Tensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
Status OnNewInputShapes(const std::vector<ShapeVector> &new_shapes);
Status BuildCustomAscendKernel(const CNodePtr &node);
Status InitInputOutputInfos(const FuncGraphPtr &graph);
KernelGraphUtilsPtr kernel_graph_utils_;
KernelGraphPtr kernel_graph_;
std::vector<MutableTensorImplPtr> inputs_;
std::vector<std::string> input_names_;
std::vector<MutableTensorImplPtr> outputs_;
std::vector<std::string> output_names_;
uint32_t device_id_ = 0;
bool is_dvpp_ = false;
kernel::KernelModPtr kernel_mod_ = nullptr;
kernel::KernelArgs kernel_args_;
};
} // namespace mindspore

View File

@ -0,0 +1,223 @@
/**
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <string>
#include <algorithm>
#include <utility>
#include <vector>
#include <map>
#include <memory>
#include "src/extendrt/utils/func_graph_utils.h"
#include "include/common/utils/anfalgo.h"
#include "include/common/utils/convert_utils.h"
#include "mindspore/ccsrc/backend/common/optimizer/helper.h"
#include "tools/optimizer/common/gllo_utils.h"
namespace mindspore {
ValuePtr FuncGraphUtils::GetNodeValuePtr(AnfNodePtr input_node) {
if (input_node == nullptr) {
return nullptr;
}
if (IsPrimitiveCNode(input_node, prim::kPrimDepend)) {
input_node = AnfUtils::VisitKernel(input_node, 0).first;
}
ValuePtr value = nullptr;
if (input_node->isa<ValueNode>() && !HasAbstractMonad(input_node)) {
auto value_node = input_node->cast<ValueNodePtr>();
if (value_node) {
value = value_node->value();
}
} else if (input_node->isa<Parameter>()) {
auto parameter = input_node->cast<ParameterPtr>();
if (parameter->has_default()) {
value = parameter->default_param();
}
}
return value;
}
tensor::TensorPtr FuncGraphUtils::GetConstNodeValue(AnfNodePtr input_node) {
ValuePtr value = GetNodeValuePtr(input_node);
if (value == nullptr) {
return nullptr;
}
if (value->isa<tensor::Tensor>()) {
auto tensor = value->cast<tensor::TensorPtr>();
if (tensor == nullptr || tensor->data().const_data() == nullptr) {
return nullptr;
}
return tensor;
}
if (value->isa<Scalar>()) {
return ScalarToTensor(value->cast<ScalarPtr>());
}
if (value->isa<ValueTuple>()) {
return opt::CreateTupleTensor(value->cast<ValueTuplePtr>());
}
if (value->isa<Type>()) {
auto type_ptr = value->cast<TypePtr>();
if (type_ptr == nullptr) {
return nullptr;
}
return std::make_shared<tensor::Tensor>(static_cast<int64_t>(type_ptr->type_id()), type_ptr->type());
}
MS_LOG(WARNING) << "Unexpected value type " << value->type_name() << " for " << input_node->fullname_with_scope();
return nullptr;
}
bool FuncGraphUtils::GetCNodeOperator(const mindspore::CNodePtr &cnode,
mindspore::kernel::BaseOperatorPtr *base_operator) {
if (!cnode || !base_operator) {
MS_LOG(ERROR) << "Input cnode or base_operator cannot be nullptr";
return false;
}
auto prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(prim);
if (!prim) {
MS_LOG(ERROR) << "Primitive of cnode " << cnode->fullname_with_scope() << " cannot be nullptr";
return false;
}
auto kernel_name = prim->name();
ops::PrimitiveCPtr primc_ptr = nullptr;
static auto primc_fns = ops::OpPrimCRegister::GetInstance().GetPrimCMap();
if (primc_fns.find(kernel_name) != primc_fns.end()) {
primc_ptr = primc_fns[kernel_name]();
(void)primc_ptr->SetAttrs(prim->attrs());
}
if (primc_ptr == nullptr) {
MS_LOG(ERROR) << "OpPrimCRegister can not find " << kernel_name;
return false;
}
*base_operator = nullptr;
static auto operator_fns = ops::OperatorRegister::GetInstance().GetOperatorMap();
if (operator_fns.find(kernel_name) != operator_fns.end()) {
*base_operator = operator_fns[kernel_name](primc_ptr);
}
if (*base_operator == nullptr) {
MS_LOG(ERROR) << "Failed to create operator of type " << kernel_name;
return false;
}
return true;
}
bool FuncGraphUtils::GetCNodeInputsOutputs(const mindspore::CNodePtr &cnode,
std::vector<AnfWithOutIndex> *input_tensors,
std::vector<AnfWithOutIndex> *output_tensors) {
if (!cnode || !input_tensors || !output_tensors) {
MS_LOG(ERROR) << "Input cnode, input_tensors or output_tensors cannot be nullptr";
return false;
}
// Makeup input tensors.
*input_tensors = opt::GetNodeInputs(cnode);
// Makeup output tensors.
output_tensors->clear();
auto output_num = common::AnfAlgo::GetOutputTensorNum(cnode);
for (size_t output_idx = 0; output_idx < output_num; ++output_idx) {
session::KernelWithIndex tensor_id = {cnode, output_idx};
output_tensors->push_back(tensor_id);
}
return true;
}
bool FuncGraphUtils::GetFuncGraphInputs(const FuncGraphPtr &func_graph, std::vector<AnfWithOutIndex> *inputs) {
if (!func_graph || !inputs) {
MS_LOG(ERROR) << "Input func_graph or inputs cannot be nullptr";
return false;
}
auto graph_inputs = func_graph->get_inputs();
// find parameters of graph inputs
for (size_t i = 0; i < graph_inputs.size(); ++i) {
auto input = graph_inputs[i];
auto parameter = input->cast<ParameterPtr>();
if (!parameter) {
MS_LOG(ERROR) << "Input " << parameter->fullname_with_scope() << " of FuncGraph is not type of Parameter.";
return false;
}
if (common::AnfAlgo::IsParameterWeight(parameter)) {
continue;
}
inputs->push_back(std::make_pair(input, 0));
}
return true;
}
bool FuncGraphUtils::GetFuncGraphOutputs(const FuncGraphPtr &func_graph, std::vector<AnfWithOutIndex> *outputs) {
if (!func_graph || !outputs) {
MS_LOG(ERROR) << "Input func_graph or outputs cannot be nullptr";
return false;
}
*outputs = opt::GetNodeInputs(func_graph->get_return());
return true;
}
DataType FuncGraphUtils::GetTensorDataType(const AnfWithOutIndex &tensor) {
auto node = tensor.first;
auto output_idx = tensor.second;
auto tensor_val = GetConstNodeValue(node);
TypeId type_id;
if (tensor_val) {
type_id = tensor_val->Dtype()->type_id();
} else {
type_id = common::AnfAlgo::GetOutputInferDataType(node, output_idx);
}
return static_cast<enum DataType>(type_id);
}
ShapeVector FuncGraphUtils::GetTensorShape(const AnfWithOutIndex &tensor) {
auto node = tensor.first;
auto output_idx = tensor.second;
auto tensor_val = GetConstNodeValue(node);
ShapeVector shape;
if (tensor_val) {
shape = tensor_val->shape_c();
} else {
shape = common::AnfAlgo::GetOutputInferShape(node, output_idx);
}
return shape;
}
std::string FuncGraphUtils::GetTensorName(const AnfWithOutIndex &tensor) {
auto node = tensor.first;
auto idx = tensor.second;
MS_EXCEPTION_IF_NULL(node);
AbstractBasePtr abstract = node->abstract();
MS_EXCEPTION_IF_NULL(abstract);
if (utils::isa<abstract::AbstractTuplePtr>(abstract)) {
auto abstract_tuple = utils::cast<abstract::AbstractTuplePtr>(abstract);
MS_EXCEPTION_IF_NULL(abstract_tuple);
auto abstract_list = abstract_tuple->elements();
if (abstract_list.size() <= idx) {
MS_LOG(ERROR) << "AbstractTuple's size[" << abstract_list.size() << "] is smaller than expect size[" << idx
<< "]";
return "";
}
abstract = abstract_list[idx];
MS_EXCEPTION_IF_NULL(abstract);
}
MS_EXCEPTION_IF_NULL(abstract);
std::string output_name;
if (!abstract->name().empty()) {
output_name = abstract->name();
} else if (idx > 0) {
output_name = node->fullname_with_scope() + ":" + std::to_string(idx);
} else {
output_name = node->fullname_with_scope();
}
return output_name;
}
} // namespace mindspore

View File

@ -0,0 +1,56 @@
/**
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
*
* Copyright 2022 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_UTILS_FUNC_GRAPH_UTILS_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_UTILS_FUNC_GRAPH_UTILS_H_
#include <utility>
#include <string>
#include <vector>
#include "ir/anf.h"
#include "ir/dtype/type.h"
#include "ir/func_graph.h"
#include "include/api/data_type.h"
#include "mindspore/ccsrc/kernel/kernel.h"
namespace mindspore {
using AnfWithOutIndex = std::pair<AnfNodePtr, size_t>;
using kernel::BaseOperatorPtr;
class FuncGraphUtils {
public:
static tensor::TensorPtr GetConstNodeValue(AnfNodePtr input_node);
static bool GetCNodeOperator(const CNodePtr &cnode, BaseOperatorPtr *base_operator);
static bool GetCNodeInputsOutputs(const CNodePtr &cnode, std::vector<AnfWithOutIndex> *input_tensors,
std::vector<AnfWithOutIndex> *output_tensors);
static bool GetFuncGraphInputs(const FuncGraphPtr &func_graph, std::vector<AnfWithOutIndex> *inputs);
static bool GetFuncGraphOutputs(const FuncGraphPtr &func_graph, std::vector<AnfWithOutIndex> *outputs);
static DataType GetTensorDataType(const AnfWithOutIndex &tensor);
static ShapeVector GetTensorShape(const AnfWithOutIndex &tensor);
static std::string GetTensorName(const AnfWithOutIndex &tensor);
private:
static ValuePtr GetNodeValuePtr(AnfNodePtr input_node);
};
} // namespace mindspore
#endif // MINDSPORE_LITE_SRC_EXTENDRT_UTILS_FUNC_GRAPH_UTILS_H_