!58302 [MS][LITE][ASCEND] device data for ascend (python api and C++ api)

Merge pull request !58302 from yefeng/625-device_data_for_master
This commit is contained in:
i-robot 2023-08-30 17:01:49 +00:00 committed by Gitee
commit 5dc7431fcb
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
39 changed files with 1237 additions and 82 deletions

View File

@ -223,6 +223,8 @@
"mindspore/mindspore/lite/test/st/python/import_ms_and_mslite/" "unused-import"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/ascendc/cmake/util/insert_simplified_keys.py" "duplicate-key"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/ascendc/cmake/util/replay_codegen.py" "bad-continuation"
"mindspore/mindspore/lite/python/api/_checkparam.py" "len-as-condition"
"mindspore/mindspore/lite/python/api/base_model.py" "len-as-condition"
# ascend samples
"mindspore/mindspore/lite/tools/kernel_builder/ascend/tbe_dsl/sample/" "wrong-import-order"
"mindspore/mindspore/lite/tools/kernel_builder/ascend/tbe_tik/sample/" "wrong-import-order"

View File

@ -73,10 +73,25 @@ class MS_API MSTensor {
/// \param[in] shape The shape of the MSTensor.
/// \param[in] data The data pointer that points to allocated memory.
/// \param[in] data_len The length of the memory, in bytes.
/// \param[in] device The tensor of device type.
/// \param[in] device_id The tensor of device id.
///
/// \return A pointer of MSTensor.
static inline MSTensor *CreateTensor(const std::string &name, DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
const void *data, size_t data_len, const std::string &device = "",
int device_id = -1) noexcept;
/// \brief Creates a MSTensor object, whose data need to be copied before accessed by Model, must be used in pairs
/// with DestroyTensorPtr.
///
/// \param[in] name The name of the MSTensor.
/// \param[in] tensor The src tensor.
/// \param[in] device The tensor of device type.
/// \param[in] device_id The tensor of device id.
///
/// \return A pointer of MSTensor.
static inline MSTensor *CreateTensor(const std::string &name, const MSTensor &tensor, const std::string &device = "",
int device_id = -1) noexcept;
/// \brief Creates a MSTensor object, whose data can be directly accessed by Model, must be used in pairs with
/// DestroyTensorPtr.
@ -181,6 +196,16 @@ class MS_API MSTensor {
/// \return The length of the data of the MSTensor, in bytes.
size_t DataSize() const;
/// \brief Get the MSTensor device id
///
/// \return device id of MSTensor
int GetDeviceId() const;
/// \brief Get the MSTensor device type
///
/// \return device type of MSTensor
std::string GetDevice() const;
/// \brief Get whether the MSTensor data is const data
///
/// \return Const flag of MSTensor
@ -294,7 +319,10 @@ class MS_API MSTensor {
private:
// api without std::string
static MSTensor *CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept;
const void *data, size_t data_len, const std::vector<char> &device = {},
int device_id = -1) noexcept;
static MSTensor *CreateTensor(const std::vector<char> &name, const MSTensor &tensor, const std::vector<char> &device,
int device_id = -1) noexcept;
static MSTensor *CreateRefTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len, bool own_data) noexcept;
static MSTensor CreateDeviceTensor(const std::vector<char> &name, enum DataType type,
@ -334,8 +362,13 @@ class MS_API Buffer {
};
MSTensor *MSTensor::CreateTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept {
return CreateTensor(StringToChar(name), type, shape, data, data_len);
const void *data, size_t data_len, const std::string &device, int device_id) noexcept {
return CreateTensor(StringToChar(name), type, shape, data, data_len, StringToChar(device), device_id);
}
MSTensor *MSTensor::CreateTensor(const std::string &name, const MSTensor &tensor, const std::string &device,
int device_id) noexcept {
return CreateTensor(StringToChar(name), tensor, StringToChar(device), device_id);
}
MSTensor *MSTensor::CreateRefTensor(const std::string &name, enum DataType type, const std::vector<int64_t> &shape,

View File

@ -141,7 +141,8 @@ class TensorReferenceImpl : public MSTensor::Impl {
};
MSTensor *MSTensor::CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept {
const void *data, size_t data_len, const std::vector<char> &device,
int device_id) noexcept {
std::string name_str = CharToString(name);
try {
std::shared_ptr<Impl> impl = std::make_shared<TensorDefaultImpl>(name_str, type, shape, data, data_len);
@ -156,6 +157,12 @@ MSTensor *MSTensor::CreateTensor(const std::vector<char> &name, enum DataType ty
}
}
MSTensor *MSTensor::CreateTensor(const std::vector<char> &name, const MSTensor &tensor, const std::vector<char> &device,
int device_id) noexcept {
MS_LOG(ERROR) << "Invalid implement.";
return nullptr;
}
MSTensor *MSTensor::CreateRefTensor(const std::vector<char> &name, enum DataType type,
const std::vector<int64_t> &shape, const void *data, size_t data_len,
bool) noexcept {

View File

@ -121,6 +121,7 @@ class BACKEND_EXPORT KernelTensor {
meta_ = copy_tensor.meta_;
data_ = copy_tensor.data_;
host_data_ = copy_tensor.host_data_;
device_id_ = copy_tensor.device_id_;
}
KernelTensor &operator=(const KernelTensor &copy_tensor) {
if (&copy_tensor == this) {
@ -130,6 +131,7 @@ class BACKEND_EXPORT KernelTensor {
meta_ = copy_tensor.meta_;
data_ = copy_tensor.data_;
host_data_ = copy_tensor.host_data_;
device_id_ = copy_tensor.device_id_;
dyn_output_data_ = nullptr;
return *this;
}
@ -191,6 +193,8 @@ class BACKEND_EXPORT KernelTensor {
// deprecated field for dynamic shape
const ShapeVector &GetDeviceShapeAdaptively() const;
void SetDeviceShapeAdaptively(const ShapeVector &device_shape_adaptively);
int32_t GetDeviceId() const { return device_id_; }
void SetDeviceId(int32_t device_id) { device_id_ = device_id; }
private:
TypeId meta_type_{kObjectTypeTensorType};
@ -200,6 +204,7 @@ class BACKEND_EXPORT KernelTensor {
AddressPtr host_data_{nullptr}; // Host data address.
std::unique_ptr<uint8_t[]> dyn_output_data_{nullptr}; // Create new output memory buffer for dynamic output
string GetAbstractName() const;
int32_t device_id_{0};
};
using KernelTensorPtr = std::shared_ptr<KernelTensor>;

View File

@ -32,7 +32,8 @@ int SoftMaxInferShape(const TensorC *const *inputs, size_t inputs_size, TensorC
if (!InferFlag(inputs, inputs_size)) {
return NNACL_INFER_INVALID;
}
if (input->shape_size_ > 5) {
// there is a model with an 8-dim input, which runs on ascend910.
if (input->shape_size_ > DIMENSION_8D) {
return NNACL_ERR;
}

View File

@ -43,11 +43,15 @@ namespace tensor {
// Includes the format, data type and host format of a tensor.
struct DeviceInfo {
explicit DeviceInfo(std::string format = "DefaultFormat", TypePtr data_type = nullptr,
std::string host_format = "DefaultFormat")
: format_(std::move(format)), data_type_(std::move(data_type)), host_format_(std::move(host_format)) {}
std::string host_format = "DefaultFormat", int32_t device_id = 0)
: format_(std::move(format)),
data_type_(std::move(data_type)),
host_format_(std::move(host_format)),
device_id_(device_id) {}
std::string format_ = "DefaultFormat";
TypePtr data_type_ = nullptr;
std::string host_format_ = "DefaultFormat";
int32_t device_id_ = 0;
};
// brief Metadata of Tensor

View File

@ -689,6 +689,7 @@ Tensor::Tensor(const Tensor &tensor)
tensor_name_(tensor.tensor_name_),
address_future_(tensor.address_future_) {
user_data_ = tensor.user_data_;
set_device_info(tensor.device_info());
}
Tensor::Tensor(const Tensor &tensor, TypeId data_type)
@ -715,6 +716,7 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type)
tensor_name_(tensor.tensor_name_),
address_future_(tensor.address_future_) {
user_data_ = tensor.user_data_;
set_device_info(tensor.device_info());
}
Tensor &Tensor::operator=(const Tensor &tensor) {

View File

@ -15,6 +15,7 @@ if(Python3_FOUND)
include_directories(${Python3_NumPy_INCLUDE_DIRS})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../../)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../core/)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../ccsrc/plugin/device/cpu/kernel/)
if(MSLITE_ENABLE_CLOUD_FUSION_INFERENCE OR MSLITE_ENABLE_CLOUD_INFERENCE)
add_compile_definitions(MSLITE_ENABLE_CLOUD_INFERENCE)

View File

@ -75,3 +75,24 @@ def check_config_info(config_info_name, config_info, enable_none=False):
raise TypeError(f"{config_info_name} val must be str, but got "
f"{type(config_info[key])} at key {key}.")
return config_info
def check_tensor_input_param(shape=None, device=None):
"""Check tensor input param"""
if shape is not None:
if not isinstance(shape, (list, tuple)):
raise TypeError(f"shape must be list or tuple, but got {type(shape)}.")
for i, element in enumerate(shape):
if not isinstance(element, int):
raise TypeError(f"shape element must be int, but got {type(element)} at index {i}.")
if device is None:
return
if device is not None and not isinstance(device, str):
raise TypeError(f"device must be str, but got {type(device)}.")
split_device = device.split(":")
if len(split_device) > 2:
raise TypeError(f"device must be 'ascend:index', eg: 'ascend:0'")
if len(split_device) > 0 and split_device[0] != "ascend":
raise TypeError(f"now only support ascend device.")
if len(split_device) == 2 and not split_device[1].isdigit():
raise TypeError(f"device id should >= 0.")

View File

@ -15,7 +15,7 @@
"""
BaseModel.
"""
from mindspore_lite.tensor import Tensor
from mindspore_lite.tensor import Tensor, TensorMeta
class BaseModel:
@ -35,25 +35,52 @@ class BaseModel:
inputs.append(Tensor(_tensor))
return inputs
def predict(self, inputs):
def get_outputs(self):
"""
Obtains all output TensorMeta of the model.
"""
outputs_metadata = []
for _tensor in self._model.get_outputs():
out_tensor = Tensor(_tensor)
output_meta = TensorMeta()
output_meta.name = out_tensor.name
output_meta.dtype = out_tensor.dtype
output_meta.shape = out_tensor.shape
output_meta.format = out_tensor.format
output_meta.element_num = out_tensor.element_num
output_meta.data_size = out_tensor.data_size
outputs_metadata.append(output_meta)
return tuple(outputs_metadata)
def predict(self, inputs, outputs=None):
"""
Inference model.
"""
if not isinstance(inputs, list):
raise TypeError("inputs must be list, but got {}.".format(type(inputs)))
_inputs = []
_outputs = []
for i, element in enumerate(inputs):
if not isinstance(element, Tensor):
raise TypeError(f"inputs element must be Tensor, but got "
f"{type(element)} at index {i}.")
# pylint: disable=protected-access
_inputs.append(element._tensor)
outputs = self._model.predict(_inputs)
if not outputs:
if outputs is not None:
if not isinstance(outputs, list):
raise TypeError("inputs must be list, but got {}.".format(type(inputs)))
for i, element in enumerate(outputs):
if not isinstance(element, Tensor):
raise TypeError(f"outputs element must be Tensor, but got "
f"{type(element)} at index {i}.")
# pylint: disable=protected-access
_outputs.append(element._tensor)
predict_result = self._model.predict(_inputs, _outputs)
if predict_result is None or len(predict_result) == 0:
raise RuntimeError(f"predict failed!")
predict_outputs = []
for output in outputs:
predict_outputs.append(Tensor(output))
for output_tensor in predict_result:
predict_outputs.append(Tensor(output_tensor))
return predict_outputs
def resize(self, inputs, dims):

View File

@ -19,6 +19,7 @@ from __future__ import absolute_import
import os
import logging
from enum import Enum
import numpy
from mindspore_lite._checkparam import check_isinstance
from mindspore_lite.context import Context
@ -233,6 +234,16 @@ class Model(BaseModel):
if not ret.IsOk():
raise RuntimeError(f"build_from_file failed! Error is {ret.ToString()}")
def get_outputs(self):
"""
Obtains all output information Tensors of the model.
Returns:
list[TensorMeta], the output TensorMeta list of the model.
"""
# pylint: disable=useless-super-delegation
return super(Model, self).get_outputs()
def get_inputs(self):
"""
Obtains all input Tensors of the model.
@ -249,12 +260,14 @@ class Model(BaseModel):
# pylint: disable=useless-super-delegation
return super(Model, self).get_inputs()
def predict(self, inputs):
def predict(self, inputs, outputs=None):
"""
Inference model.
Args:
inputs (list[Tensor]): A list that includes all input Tensors in order.
inputs (list[Tensor], optional): A list that includes all output Tensors in order,
this tensor include output data buffer.
Returns:
list[Tensor], the output Tensor list of the model.
@ -321,7 +334,21 @@ class Model(BaseModel):
outputs' shape: (1,1001)
"""
# pylint: disable=useless-super-delegation
return super(Model, self).predict(inputs)
if not isinstance(inputs, (list, tuple)):
raise TypeError("inputs must be list or tuple, but got {}.".format(type(inputs)))
model_input_tensors = self.get_inputs()
if len(model_input_tensors) != len(inputs):
raise RuntimeError(f"inputs size is wrong.")
inputs_tensor = []
for i, in_tensor in enumerate(inputs):
if isinstance(in_tensor, numpy.ndarray):
model_input_tensors[i].set_data_from_numpy(in_tensor)
inputs_tensor.append(model_input_tensors[i])
elif isinstance(in_tensor, Tensor):
inputs_tensor.append(in_tensor)
else:
raise TypeError("inputs element must be Tensor, of numpy.")
return super(Model, self).predict(inputs_tensor, outputs)
def resize(self, inputs, dims):
"""

View File

@ -21,8 +21,35 @@ from enum import Enum
import numpy
from mindspore_lite.lib import _c_lite_wrapper
from mindspore_lite._checkparam import check_tensor_input_param, check_isinstance
__all__ = ['DataType', 'Format', 'Tensor']
__all__ = ['TensorMeta', 'DataType', 'Format', 'Tensor']
class TensorMeta:
"""
The `TensorMeta` class defines a TensorInfo in MindSpore Lite.
Args:
tensor(): The info to be stored in a new TensorMeta.
"""
def __init__(self):
self.name = ""
self.dtype = DataType.UNKNOWN
self.shape = []
self.format = Format.DEFAULT
self.element_num = 0
self.data_size = 0
def __str__(self):
res = f"name: {self.name},\n" \
f"dtype: {self.dtype},\n" \
f"shape: {self.shape},\n" \
f"format: {self.format},\n" \
f"element_num: {self.element_num},\n" \
f"data_size: {self.data_size}."
return res
class DataType(Enum):
@ -190,6 +217,21 @@ numpy_data_type_map = {
numpy.float64: DataType.FLOAT64,
}
ms_to_numpy_data_type_map = {
DataType.BOOL: numpy.bool_,
DataType.INT8: numpy.int8,
DataType.INT16: numpy.int16,
DataType.INT32: numpy.int32,
DataType.INT64: numpy.int64,
DataType.UINT8: numpy.uint8,
DataType.UINT16: numpy.uint16,
DataType.UINT32: numpy.uint32,
DataType.UINT64: numpy.uint64,
DataType.FLOAT16: numpy.float16,
DataType.FLOAT32: numpy.float32,
DataType.FLOAT64: numpy.float64,
}
format_py_cxx_map = {
Format.DEFAULT: _c_lite_wrapper.Format.DEFAULT_FORMAT,
Format.NCHW: _c_lite_wrapper.Format.NCHW,
@ -244,7 +286,12 @@ class Tensor:
Args:
tensor(Tensor, optional): The data to be stored in a new Tensor. It can be from another Tensor.
Default: ``None``.
shape(list, optional): The shape of the Tensor.
Default: ``None``.
dtype(DataType, optional): The dtype of the Tensor.
Default: ``None``.
device(str, optional): The device type of the Tensor.
Default: ``None``.
Raises:
TypeError: `tensor` is neither a Tensor nor ``None``.
@ -276,23 +323,58 @@ class Tensor:
data_size: 48.
"""
def __init__(self, tensor=None):
def __init__(self, tensor=None, shape=None, dtype=None, device=None):
# check shape, dtype and device
check_tensor_input_param(shape, device)
device_type = ""
device_id = -1
if device is not None:
device_type = device.split(":")[0]
if len(device.split(":")) == 2:
device_id = int(device.split(":")[1])
check_isinstance("dtype", dtype, DataType, True)
if tensor is not None:
# use tensor to init tensor
if isinstance(tensor, _c_lite_wrapper.TensorBind):
self._tensor = tensor
elif isinstance(tensor, Tensor):
tensor_shape = tensor.shape
if shape is not None and list(shape) != list(tensor_shape):
raise TypeError(
f"user set shape is not equal numpy shape, user's shape: {shape}, "
f"tensor shape is: {tensor_shape}.")
tensor_dtype = tensor.dtype
if dtype is not None and tensor_dtype != dtype:
raise TypeError(
f"user set dtype is not equal tensor dtype, user's dtype: {dtype}, "
f"tensor dtype is: {tensor_dtype}.")
numpy_data = tensor.get_data_to_numpy()
self._tensor = _c_lite_wrapper.create_tensor_by_numpy(numpy_data, device_type, device_id)
# use numpy to init tensor
elif isinstance(tensor, numpy.ndarray):
shape = tensor.shape
dtype = tensor.dtype
if dtype.type not in numpy_data_type_map:
raise TypeError(f"Unsupported numpy dtype value {dtype}")
dtype = numpy_data_type_map.get(dtype.type)
self._tensor = _c_lite_wrapper.create_tensor(data_type_py_cxx_map.get(dtype), shape)
self.set_data_from_numpy(tensor)
numpy_shape = tensor.shape
numpy_dtype = tensor.dtype
if numpy_dtype.type not in numpy_data_type_map:
raise TypeError(f"Unsupported numpy dtype value {numpy_dtype}")
ms_dtype = numpy_data_type_map.get(numpy_dtype.type)
if shape is not None and list(shape) != list(numpy_shape):
raise TypeError(
f"user set shape is not equal numpy shape, user shape: {shape}, "
f"numpy shape is: {numpy_shape}.")
if dtype is not None and ms_dtype != dtype:
raise TypeError(
f"user set dtype is not equal numpy dtype, user dtype: {dtype}, "
f"numpy dtype is: {numpy_dtype}.")
self._tensor = _c_lite_wrapper.create_tensor_by_numpy(tensor, device_type, device_id)
else:
raise TypeError(
f"tensor must be MindSpore Lite's Tensor._tensor or numpy ndarray, but got {type(tensor)}.")
else:
self._tensor = _c_lite_wrapper.create_tensor(data_type_py_cxx_map.get(DataType.FLOAT32), ())
if dtype is not None and shape is not None:
self._tensor = _c_lite_wrapper.create_tensor(data_type_py_cxx_map.get(dtype), shape, device_type,
device_id)
else:
self._tensor = _c_lite_wrapper.create_tensor(data_type_py_cxx_map.get(DataType.FLOAT32), (), "", -1)
def __str__(self):
res = f"name: {self.name},\n" \
@ -300,7 +382,8 @@ class Tensor:
f"shape: {self.shape},\n" \
f"format: {self.format},\n" \
f"element_num: {self.element_num},\n" \
f"data_size: {self.data_size}."
f"data_size: {self.data_size}.\n" \
f"device: {self.device}."
return res
@property
@ -514,3 +597,13 @@ class Tensor:
raise RuntimeError(
f"data size not equal! Numpy size: {numpy_obj.nbytes}, Tensor size: {self.data_size}")
self._tensor.set_data_from_numpy(numpy_obj)
@property
def device(self):
"""
Get the device type of the Tensor.
Returns:
str, the device type of the Tensor.
"""
return self._tensor.get_tensor_device_type()

View File

@ -44,13 +44,17 @@ std::vector<MSTensor> MSTensorPtrToMSTensor(const std::vector<MSTensorPtr> &tens
return tensors;
}
std::vector<MSTensorPtr> PyModelPredict(Model *model, const std::vector<MSTensorPtr> &inputs_ptr) {
std::vector<MSTensorPtr> PyModelPredict(Model *model, const std::vector<MSTensorPtr> &inputs_ptr,
const std::vector<MSTensorPtr> &outputs_ptr) {
if (model == nullptr) {
MS_LOG(ERROR) << "Model object cannot be nullptr";
return {};
}
std::vector<MSTensor> inputs = MSTensorPtrToMSTensor(inputs_ptr);
std::vector<MSTensor> outputs;
if (!outputs_ptr.empty()) {
outputs = MSTensorPtrToMSTensor(outputs_ptr);
}
if (!model->Predict(inputs, &outputs).IsOk()) {
return {};
}

View File

@ -13,9 +13,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "include/api/types.h"
#include "pybind11/pybind11.h"
#include "pybind11/numpy.h"
#include "pybind11/stl.h"
namespace mindspore::lite {
namespace py = pybind11;
@ -29,8 +30,12 @@ void LiteInferPyBind(const py::module &m);
void ModelParallelRunnerPyBind(const py::module &m);
void ModelGroupPyBind(const py::module &m);
void TensorPyBind(const py::module &m);
std::shared_ptr<MSTensor> create_tensor(DataType data_type, const std::vector<int64_t> &shape);
std::shared_ptr<MSTensor> create_tensor(DataType data_type, const std::vector<int64_t> &shape,
const std::string &device_type, int device_id);
std::shared_ptr<MSTensor> create_tensor_by_tensor(const MSTensor &tensor, const std::string &device_type,
int device_id);
std::shared_ptr<MSTensor> create_tensor_by_numpy(const py::array &input, const std::string &device_type,
int32_t device_id);
PYBIND11_MODULE(_c_lite_wrapper, m) {
m.doc() = "MindSpore Lite";
ContextPyBind(m);
@ -45,5 +50,7 @@ PYBIND11_MODULE(_c_lite_wrapper, m) {
ModelGroupPyBind(m);
TensorPyBind(m);
m.def("create_tensor", &create_tensor);
m.def("create_tensor_by_tensor", &create_tensor_by_tensor);
m.def("create_tensor_by_numpy", &create_tensor_by_numpy);
}
} // namespace mindspore::lite

View File

@ -28,26 +28,32 @@
#include "common/mutable_tensor_impl.h"
#include "pybind11/pybind11.h"
#include "pybind11/numpy.h"
#ifdef ENABLE_CLOUD_INFERENCE
#include "extendrt/kernel/ascend/plugin/ascend_allocator_plugin.h"
#endif
namespace py = pybind11;
namespace mindspore {
class TensorNumpyImpl : public MutableTensorImpl {
public:
TensorNumpyImpl(const std::string &name, py::buffer_info &&buffer, const std::vector<int64_t> &ms_shape,
mindspore::DataType data_type)
: name_(name), buffer_(std::move(buffer)), ms_shape_(ms_shape), data_type_(data_type) {}
TensorNumpyImpl(const std::string &name, py::buffer_info &&buffer, const std::vector<int64_t> &ms_shape)
: name_(name), buffer_(std::move(buffer)), ms_shape_(ms_shape) {}
~TensorNumpyImpl() {
{
py::gil_scoped_acquire acquire;
{ buffer_ = py::buffer_info(); }
}
if (device_data_ != nullptr) {
MS_LOG(INFO) << "free device data in tensor numpy impl.";
kernel::AscendAllocatorPlugin::GetInstance().Free(device_data_);
}
}
const std::vector<int64_t> &Shape() const override { return ms_shape_; }
void SetShape(const std::vector<int64_t> &shape) override {
MS_LOG(ERROR) << "Cannot call SetShape for numpy tensor";
}
enum DataType DataType() const override { return data_type_; }
enum DataType DataType() const override { return GetDataType(buffer_); }
void SetDataType(mindspore::DataType data_type) override {
MS_LOG(ERROR) << "Cannot call SetDataType for numpy tensor";
}
@ -71,8 +77,19 @@ class TensorNumpyImpl : public MutableTensorImpl {
int64_t ElementNum() const override { return buffer_.size; }
size_t DataSize() const override { return buffer_.size * buffer_.itemsize; }
void SetDeviceData(void *data) override { MS_LOG(ERROR) << "Cannot call SetDeviceData for numpy tensor"; }
void *GetDeviceData() override { return nullptr; }
void SetDeviceData(void *data) override {
#ifdef ENABLE_CLOUD_INFERENCE
if (device_data_ != nullptr) {
kernel::AscendAllocatorPlugin::GetInstance().Free(device_data_);
}
device_data_ = data;
return;
#endif
MS_LOG(ERROR) << "not support.";
}
void *GetDeviceData() override { return device_data_; }
bool IsConst() const override { return false; }
void SetIsConst(bool is_const) { MS_LOG(ERROR) << "Cannot call SetIsConst for numpy tensor"; }
@ -85,6 +102,14 @@ class TensorNumpyImpl : public MutableTensorImpl {
void SetData(void *data, bool own_data) override { MS_LOG(ERROR) << "Cannot call SetData for numpy tensor"; }
int GetDeviceId() const override { return device_id_; }
void SetDeviceId(int device_id) override { device_id_ = device_id; }
std::string GetDevice() const override { return device_; }
void SetDevice(const std::string &device) override { device_ = device; }
void *MutableData() override { return buffer_.ptr; }
std::shared_ptr<Impl> Clone() const override {
@ -143,9 +168,9 @@ class TensorNumpyImpl : public MutableTensorImpl {
py::buffer_info buffer_;
std::vector<int64_t> ms_shape_;
private:
mindspore::DataType data_type_;
void *device_data_ = nullptr;
std::string device_ = "";
int device_id_ = -1;
};
} // namespace mindspore

View File

@ -17,18 +17,27 @@
#include "include/api/data_type.h"
#include "include/api/format.h"
#include "src/common/log_adapter.h"
#include "src/litert/cxx_api/tensor_utils.h"
#include "third_party/securec/include/securec.h"
#include "mindspore/lite/src/common/mutable_tensor_impl.h"
#include "mindspore/lite/python/src/tensor_numpy_impl.h"
#include "mindspore/core/ir/api_tensor_impl.h"
#include "pybind11/pybind11.h"
#include "pybind11/numpy.h"
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include "numpy/arrayobject.h"
#include "pybind11/stl.h"
#ifdef ENABLE_CLOUD_INFERENCE
#include "extendrt/kernel/ascend/plugin/ascend_allocator_plugin.h"
#endif
namespace mindspore::lite {
namespace {
bool IsCContiguous(const py::array &input) {
auto flags = static_cast<unsigned int>(input.flags());
return (flags & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_) != 0;
}
} // namespace
namespace py = pybind11;
using MSTensorPtr = std::shared_ptr<MSTensor>;
@ -94,6 +103,14 @@ void TensorPyBind(const py::module &m) {
.def("set_data", &MSTensor::SetData)
.def("get_data", &MSTensor::MutableData)
.def("is_null", [](const MSTensorPtr &tensor) { return tensor == nullptr; })
.def("get_tensor_device_type",
[](const MSTensorPtr &tensor) {
std::string device = "None";
if (!tensor->GetDevice().empty()) {
device = tensor->GetDevice();
}
return device + ":" + std::to_string(tensor->GetDeviceId());
})
.def("set_data_from_numpy",
[](const MSTensorPtr &tensor, const py::array &input) { return SetTensorNumpyData(tensor, input); })
.def("get_data_to_numpy", [](const MSTensorPtr &tensor) -> py::array {
@ -107,17 +124,72 @@ void TensorPyBind(const py::module &m) {
});
}
MSTensorPtr create_tensor(DataType data_type, const std::vector<int64_t> &shape) {
auto tensor = mindspore::MSTensor::CreateTensor("", data_type, shape, nullptr, 0);
MSTensorPtr create_tensor(DataType data_type, const std::vector<int64_t> &shape, const std::string &device_type,
int device_id) {
auto tensor = mindspore::MSTensor::CreateTensor("", data_type, shape, nullptr, 0, device_type, device_id);
if (tensor == nullptr) {
MS_LOG(ERROR) << "create tensor failed.";
return {};
return nullptr;
}
mindspore::Format data_format = NCHW;
tensor->SetFormat(data_format);
return MSTensorPtr(tensor);
}
MSTensorPtr create_tensor_by_tensor(const MSTensor &tensor, const std::string &device_type, int device_id) {
auto new_tensor = mindspore::MSTensor::CreateTensor("", tensor, device_type, device_id);
if (new_tensor == nullptr) {
MS_LOG(ERROR) << "create tensor failed.";
return nullptr;
}
new_tensor->SetFormat(tensor.format());
return MSTensorPtr(new_tensor);
}
MSTensorPtr create_tensor_by_numpy(const py::array &input, const std::string &device_type, int32_t device_id) {
// Check format.
if (!IsCContiguous(input)) {
MS_LOG(ERROR) << "Numpy array is not C Contiguous";
return nullptr;
}
auto py_buffer_info = input.request();
auto py_data_type = TensorNumpyImpl::GetDataType(py_buffer_info);
auto py_data_size = py_buffer_info.size * py_buffer_info.itemsize;
auto py_shape = py_buffer_info.shape;
auto data_size = mindspore::CalTensorDataSize(py_shape, py_data_type);
if (py_data_size != static_cast<int64_t>(data_size)) {
MS_LOG(ERROR) << "Expect data size " << data_size << ", but got " << py_data_size;
return nullptr;
}
auto tensor_impl = std::make_shared<TensorNumpyImpl>("", std::move(py_buffer_info), py_shape);
tensor_impl->SetDevice(device_type);
tensor_impl->SetDeviceId(device_id);
auto numpy_tensor = std::make_shared<MSTensor>(tensor_impl);
if (numpy_tensor == nullptr) {
MS_LOG(ERROR) << "Create numpy tensor failed.";
return nullptr;
}
#ifdef ENABLE_CLOUD_INFERENCE
if (device_type == "ascend") {
kernel::AscendAllocatorPlugin::GetInstance().Register();
device_id = device_id == -1 ? kernel::AscendAllocatorPlugin::GetInstance().GetCurrentDeviceId() : device_id;
auto device_data = kernel::AscendAllocatorPlugin::GetInstance().Malloc(data_size, device_id);
if (device_data == nullptr) {
MS_LOG(ERROR) << "Malloc device data for numpy tensor failed.";
return nullptr;
}
auto status = kernel::AscendAllocatorPlugin::GetInstance().CopyHostDataToDevice(numpy_tensor->MutableData(),
device_data, data_size);
if (status != kSuccess) {
MS_LOG(ERROR) << "tensor has device data, then copy host data to device failed.";
return nullptr;
}
numpy_tensor->SetDeviceData(device_data);
}
#endif
return numpy_tensor;
}
std::string GetPyTypeFormat(DataType data_type) {
switch (data_type) {
case DataType::kNumberTypeFloat32:
@ -152,11 +224,6 @@ std::string GetPyTypeFormat(DataType data_type) {
}
}
bool IsCContiguous(const py::array &input) {
auto flags = static_cast<unsigned int>(input.flags());
return (flags & pybind11::detail::npy_api::NPY_ARRAY_C_CONTIGUOUS_) != 0;
}
bool SetTensorNumpyData(const MSTensorPtr &tensor_ptr, const py::array &input) {
auto &tensor = *tensor_ptr;
// Check format.
@ -177,9 +244,23 @@ bool SetTensorNumpyData(const MSTensorPtr &tensor_ptr, const py::array &input) {
<< tensor.Shape() << ", got shape " << py_buffer_info.shape;
return false;
}
auto tensor_impl = std::make_shared<TensorNumpyImpl>(tensor.Name(), std::move(py_buffer_info), tensor.Shape(),
static_cast<mindspore::DataType>(py_data_type));
tensor = MSTensor(tensor_impl);
auto tensor_impl = std::make_shared<TensorNumpyImpl>(tensor.Name(), std::move(py_buffer_info), tensor.Shape());
tensor_impl->SetDevice(tensor.GetDevice());
tensor_impl->SetDeviceId(tensor.GetDeviceId());
auto numpy_tensor = MSTensor(tensor_impl);
#ifdef ENABLE_CLOUD_INFERENCE
if (tensor.GetDeviceData() != nullptr) {
MS_LOG(INFO) << "device tensor data ptr is not nullptr, need copy host data to device data.";
auto status = kernel::AscendAllocatorPlugin::GetInstance().CopyHostDataToDevice(
numpy_tensor.MutableData(), tensor.GetDeviceData(), tensor.DataSize());
if (status != kSuccess) {
MS_LOG(ERROR) << "tensor has device data, then copy host data to device failed.";
return false;
}
numpy_tensor.SetDeviceData(tensor.GetDeviceData());
}
#endif
tensor = numpy_tensor;
return true;
}
@ -195,6 +276,19 @@ py::buffer_info GetPyBufferInfo(const MSTensorPtr &tensor) {
strides[i] = element_num * item_size;
element_num *= shape[i];
}
#ifdef ENABLE_CLOUD_INFERENCE
auto device_data = tensor->GetDeviceData();
if (device_data != nullptr) {
MS_LOG(INFO) << "need copy host data to device.";
// device data is not nullptr, data in device, need copy device data to host.
auto status = kernel::AscendAllocatorPlugin::GetInstance().CopyDeviceDataToHost(device_data, tensor->MutableData(),
tensor->DataSize());
if (status != kSuccess) {
MS_LOG(ERROR) << "tensor has device data, then copy device data to host failed.";
return py::buffer_info{nullptr, 0, format, 0, {}, {}};
}
}
#endif
return py::buffer_info{tensor->MutableData(), item_size, format, ndim, shape, strides};
}
} // namespace mindspore::lite

View File

@ -38,6 +38,10 @@ class MutableTensorImpl : public MSTensor::Impl {
virtual void SetQuantParams(const std::vector<QuantParam> &quant_param) = 0;
virtual void SetDeviceData(void *data) = 0;
virtual void *GetDeviceData() = 0;
virtual std::string GetDevice() const = 0;
virtual int GetDeviceId() const = 0;
virtual void SetDeviceId(int device_id) = 0;
virtual void SetDevice(const std::string &device) = 0;
virtual int64_t ElementNum() const {
const auto &shape = Shape();
int64_t ele_num = 1;

View File

@ -13,8 +13,9 @@ file(GLOB_RECURSE ASCEND_SRC ${CMAKE_CURRENT_SOURCE_DIR}
"model/*.cc"
"profiling/*.cc"
)
set(ASCEND_SRC ${ASCEND_SRC} ${TOP_DIR}/mindspore/lite/src/litert/kernel/ascend/src/acl_mem_manager.cc)
set(ASCEND_SRC ${ASCEND_SRC} ${TOP_DIR}/mindspore/lite/src/litert/kernel/ascend/src/acl_mem_manager.cc
${CMAKE_CURRENT_SOURCE_DIR}/model/acl_allocator.cc
)
file(GLOB_RECURSE ACL_SRC ${CMAKE_CURRENT_SOURCE_DIR}
"../acl/*.cc")

View File

@ -0,0 +1,209 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/extendrt/kernel/ascend/model/acl_allocator.h"
#include "src/common/log_adapter.h"
#include "acl/acl.h"
namespace mindspore::kernel {
namespace acl {
AclAllocator *CreateAclAllocator() {
MS_LOG(INFO) << "CreateAclAllocator..";
return new AclAllocator();
}
uint32_t AclAllocator::GetDeviceCount() {
std::unique_lock<std::mutex> l(acl_allocator_mutex_);
if (device_count_ != 0) {
return device_count_;
}
auto ret = aclrtGetDeviceCount(&device_count_);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "GetDeviceCount failed.";
return 0;
}
return device_count_;
}
void AclAllocator::ResetDeviceId(int device_id) {
auto ret = aclrtSetDevice(device_id);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "aclrt Set device failed.";
return;
}
return;
}
int AclAllocator::GetCurrentDeviceId() {
int32_t current_device_id;
auto ret = aclrtGetDevice(&current_device_id);
if (ret != ACL_ERROR_NONE) {
MS_LOG(INFO) << "not init device id, need set device id before get device id.";
ret = aclrtSetDevice(0);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "aclrtSetDevice failed.";
return -1;
}
return 0;
}
return current_device_id;
}
void *AclAllocator::Malloc(size_t size, int device_id) {
if (size == 0) {
MS_LOG(WARNING) << "malloc device data size is zero.";
return nullptr;
}
auto current_device_id = GetCurrentDeviceId();
if (current_device_id == -1) {
MS_LOG(ERROR) << "get current device id failed.";
return nullptr;
}
if (device_id == -1) {
device_id = current_device_id;
}
auto device_count = GetDeviceCount();
if (device_id > static_cast<int>(device_count)) {
MS_LOG(ERROR) << "device id is wrong, device id: " << device_id << ", device count: " << device_count;
return nullptr;
}
if (current_device_id != device_id) {
auto ret = aclrtSetDevice(device_id);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "aclrtSetDevice failed.";
return nullptr;
}
}
void *device_data = nullptr;
auto acl_ret = aclrtMalloc(&device_data, size, ACL_MEM_MALLOC_HUGE_FIRST);
if (acl_ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Call aclrtMalloc failed, err_code = " << acl_ret;
return nullptr;
}
if (GetCurrentDeviceId() != current_device_id) {
ResetDeviceId(current_device_id);
}
return device_data;
}
void AclAllocator::Free(void *device_data) {
if (device_data != nullptr) {
aclrtFree(device_data);
device_data = nullptr;
}
}
Status AclAllocator::CopyDeviceDataToHost(void *device_data, void *host_data, size_t data_size) {
if (device_data == nullptr || host_data == nullptr) {
MS_LOG(ERROR) << "device data or host data ptr is nullptr.";
return kLiteMemoryFailed;
}
auto ret = aclrtMemcpy(host_data, data_size, device_data, data_size, ACL_MEMCPY_DEVICE_TO_HOST);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "copy device data to host failed, data size: " << data_size;
return kLiteMemoryFailed;
}
return kSuccess;
}
Status AclAllocator::CopyHostDataToDevice(void *host_data, void *device_data, size_t data_size) {
if (device_data == nullptr || host_data == nullptr) {
MS_LOG(ERROR) << "device data or host data ptr is nullptr.";
return kLiteMemoryFailed;
}
auto ret = aclrtMemcpy(device_data, data_size, host_data, data_size, ACL_MEMCPY_HOST_TO_DEVICE);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "copy host data to device failed, data size: " << data_size;
return kLiteMemoryFailed;
}
return kSuccess;
}
Status AclAllocator::CopyDeviceDataToDevice(void *src_device_data, void *dst_device_data, size_t data_size,
int src_device_id, int dst_device_id) {
MS_LOG(INFO) << "src device id: " << src_device_id << ", dst device id: " << dst_device_id;
if (dst_device_id == -1 || src_device_id == -1) {
MS_LOG(ERROR) << "device data copy device data, need set src device id and dst device id, now src device id: "
<< src_device_id << ", dst device id: " << dst_device_id;
return kLiteError;
}
auto device_count = GetDeviceCount();
if (dst_device_id >= static_cast<int>(device_count) || src_device_id >= static_cast<int>(device_count)) {
MS_LOG(ERROR) << "device id is more than device count, src device id: " << src_device_id
<< ", dst device id: " << dst_device_id << ", device count: " << device_count;
return kLiteError;
}
if (src_device_id == dst_device_id) {
auto ret = aclrtMemcpy(dst_device_data, data_size, src_device_data, data_size, ACL_MEMCPY_DEVICE_TO_DEVICE);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "aclrtMemcpy failed.";
return kLiteError;
}
return kSuccess;
}
aclrtContext curr_context;
auto ret = aclrtGetCurrentContext(&curr_context);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Get current runtime context failed.";
return kLiteError;
}
int32_t can_access_peer;
ret = aclrtDeviceCanAccessPeer(&can_access_peer, src_device_id, dst_device_id);
if (ret != ACL_ERROR_NONE || can_access_peer != 1) {
MS_LOG(ERROR) << "ret: " << ret << ", can_access_peer: " << can_access_peer;
return kLiteError;
}
auto current_device_id = GetCurrentDeviceId();
if (current_device_id != dst_device_id) {
ret = aclrtSetDevice(dst_device_id);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "aclrtSetDevice failed.";
return kLiteError;
}
}
ret = aclrtDeviceEnablePeerAccess(src_device_id, 0);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "aclrtDeviceEnablePeerAccess failed.";
return kLiteError;
}
ret = aclrtSetDevice(src_device_id);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "aclrtSetDevice failed.";
return kLiteError;
}
ret = aclrtDeviceEnablePeerAccess(dst_device_id, 0);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "aclrtDeviceEnablePeerAccess failed.";
return kLiteError;
}
ret = aclrtMemcpy(dst_device_data, data_size, src_device_data, data_size, ACL_MEMCPY_DEVICE_TO_DEVICE);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "aclrtMemcpy failed.";
return kLiteError;
}
if (current_device_id != GetCurrentDeviceId()) {
ResetDeviceId(current_device_id);
}
ret = aclrtSetCurrentContext(curr_context);
if (ret != ACL_ERROR_NONE) {
MS_LOG(ERROR) << "Set runtime context failed.";
return kLiteError;
}
return kSuccess;
}
} // namespace acl
} // namespace mindspore::kernel

View File

@ -0,0 +1,49 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ACL_ALLOCATOR_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ACL_ALLOCATOR_H_
#include <mutex>
#include "include/api/status.h"
#include "src/extendrt/kernel/ascend/plugin/ascend_allocator_plugin.h"
namespace mindspore::kernel {
namespace acl {
class AclAllocator : public AscendAllocatorPluginImpl {
public:
AclAllocator() = default;
~AclAllocator() = default;
int GetCurrentDeviceId() override;
void *Malloc(size_t size, int device_id = -1) override;
void Free(void *device_data) override;
Status CopyDeviceDataToHost(void *device_data, void *host_data, size_t data_size) override;
Status CopyHostDataToDevice(void *host_data, void *device_data, size_t data_size) override;
Status CopyDeviceDataToDevice(void *src_device, void *dst_device, size_t data_size, int src_device_id,
int dst_device_id) override;
private:
uint32_t GetDeviceCount();
void ResetDeviceId(int device_id);
uint32_t device_count_ = 0;
std::mutex acl_allocator_mutex_;
};
extern "C" MS_API AclAllocator *CreateAclAllocator();
} // namespace acl
} // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ACL_ALLOCATOR_H_

View File

@ -23,6 +23,7 @@
#include "src/common/utils.h"
#include "src/common/log_util.h"
#include "src/litert/kernel/ascend/src/acl_mem_manager.h"
#include "src/extendrt/kernel/ascend/model/acl_allocator.h"
namespace mindspore::kernel {
namespace acl {
@ -914,7 +915,18 @@ bool ModelProcess::CheckAndInitInput(const std::vector<KernelTensorPtr> &inputs)
auto device_data = input->GetData();
auto host_data = input->GetHostData();
if (device_data && device_data->addr) {
auto input_device_id = input->GetDeviceId();
if (input_device_id == device_id_) {
input_buffer = device_data->addr;
} else {
// memcpy device data from src device to current device.
if (AscendAllocatorPlugin::GetInstance().CopyDeviceDataToDevice(
device_data->addr, info.device_data, info.buffer_size, input_device_id, device_id_) != kSuccess) {
MS_LOG(ERROR) << "Copy input data from device to current device failed.";
return false;
}
input_buffer = info.device_data;
}
} else {
auto data = host_data->addr;
auto size = host_data->size;
@ -959,7 +971,8 @@ bool ModelProcess::CheckAndInitOutput(const std::vector<KernelTensorPtr> &output
void *output_buffer = nullptr;
auto device_data = output->GetData();
auto host_data = output->GetHostData();
if (device_data && device_data->addr) {
auto output_device_id = output->GetDeviceId();
if (device_data && device_data->addr && output_device_id == device_id_) {
output_buffer = device_data->addr;
} else if (host_data && host_data->addr && is_run_on_device_) {
output_buffer = host_data->addr;
@ -1099,6 +1112,7 @@ bool ModelProcess::GetOutputs(const std::vector<KernelTensorPtr> &outputs) {
continue;
}
auto host_data = output->GetHostData();
auto output_device_id = output->GetDeviceId();
if (host_data && host_data->addr && !is_run_on_device_) {
if (host_data->size != output_info.buffer_size) {
MS_LOG(ERROR) << "Specified output host data size " << host_data->size << " != execute output data size "
@ -1112,6 +1126,14 @@ bool ModelProcess::GetOutputs(const std::vector<KernelTensorPtr> &outputs) {
<< " to host failed, memory size " << output_info.buffer_size << ",ret: " << ret;
return false;
}
} else if (output_device_id != device_id_) {
// memcpy output data from current device to output device.
if (AscendAllocatorPlugin::GetInstance().CopyDeviceDataToDevice(output_info.cur_device_data,
output->GetData()->addr, output_info.buffer_size,
device_id_, output_device_id) != kSuccess) {
MS_LOG(ERROR) << "Copy output data from device to current device failed.";
return false;
}
}
}
return true;

View File

@ -47,7 +47,7 @@ struct AclTensorInfo {
class ModelProcess {
public:
explicit ModelProcess(const AclModelOptionsPtr &options) : options_(options) {}
explicit ModelProcess(const AclModelOptionsPtr &options) : options_(options), device_id_(options->device_id) {}
~ModelProcess();
bool Load(const void *om_data, size_t om_data_size);
@ -121,6 +121,7 @@ class ModelProcess {
aclmdlIODims *dynamic_dims_ = nullptr;
void *weight_ptr_ = nullptr;
bool is_sharing_workspace_ = false;
int32_t device_id_ = 0;
};
} // namespace acl
} // namespace mindspore::kernel

View File

@ -0,0 +1,203 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "src/extendrt/kernel/ascend/plugin/ascend_allocator_plugin.h"
#include <memory>
#if !defined(_WIN32)
#include "src/extendrt/cxx_api/dlutils.h"
#endif
namespace mindspore::kernel {
namespace {
constexpr auto kAscendkernelPluginSoNmae = "libascend_kernel_plugin.so";
constexpr auto kFunCreateAscendAllocatorPluginImpl = "CreateAclAllocator";
#if !defined(_WIN32)
std::mutex mutex_;
#endif
} // namespace
AscendAllocatorPlugin::AscendAllocatorPlugin() = default;
AscendAllocatorPlugin::~AscendAllocatorPlugin() {
#if !defined(_WIN32)
std::lock_guard<std::mutex> l(mutex_);
MS_LOG(INFO) << "AscendAllocatorPlugin::~AscendAllocatorPlugin() begin.";
ascend_allocator_plugin_impl_ = nullptr;
DLSoClose(handle_);
handle_ = nullptr;
MS_LOG(INFO) << "AscendAllocatorPlugin::~AscendAllocatorPlugin() end.";
#endif
}
AscendAllocatorPlugin &AscendAllocatorPlugin::GetInstance() {
#if !defined(_WIN32)
std::lock_guard<std::mutex> l(mutex_);
#endif
static AscendAllocatorPlugin instance;
return instance;
}
bool AscendAllocatorPlugin::Register() {
#if !defined(_WIN32)
std::lock_guard<std::mutex> l(mutex_);
if (is_registered_) {
return true;
}
MS_LOG(INFO) << "AscendAllocatorPlugin Register.";
auto ret =
DLSoPath({"libmindspore-lite.so", "_c_lite", "tools/converter/lib"}, kAscendkernelPluginSoNmae, &plugin_path_);
if (ret != kSuccess) {
MS_LOG(ERROR) << "get real path of " << kAscendkernelPluginSoNmae << " failed.";
return false;
}
MS_LOG(INFO) << "find ascend allocator plugin so path: " << plugin_path_;
void *func = nullptr;
ret = DLSoOpen(plugin_path_, kFunCreateAscendAllocatorPluginImpl, &handle_, &func);
if (ret != kSuccess) {
MS_LOG(ERROR) << "DLSoOpen failed, so path: " << plugin_path_
<< " , func name: " << kFunCreateAscendAllocatorPluginImpl << ", err: " << ret.ToString();
return false;
}
auto create_plugin_impl_func = reinterpret_cast<AscendAllocatorPluginImpl *(*)(void)>(func);
if (create_plugin_impl_func == nullptr) {
MS_LOG(ERROR) << "cast " << kFunCreateAscendAllocatorPluginImpl << " failed.";
return false;
}
ascend_allocator_plugin_impl_ = std::shared_ptr<AscendAllocatorPluginImpl>(create_plugin_impl_func());
if (ascend_allocator_plugin_impl_ == nullptr) {
MS_LOG(ERROR) << "create ascend allocator plugin impl failed.";
return false;
}
is_registered_ = true;
MS_LOG(INFO) << "register ascend allocator success.";
#endif
return true;
}
int AscendAllocatorPlugin::GetCurrentDeviceId() {
#if !defined(_WIN32)
if (!is_registered_) {
MS_LOG(ERROR) << "AscendAllocatorPlugin is not registered.";
return -1;
}
if (ascend_allocator_plugin_impl_ == nullptr) {
return -1;
}
auto device_data = ascend_allocator_plugin_impl_->GetCurrentDeviceId();
return device_data;
#endif
return -1;
}
void *AscendAllocatorPlugin::Malloc(size_t size, int device_id) {
#if !defined(_WIN32)
if (!is_registered_) {
MS_LOG(ERROR) << "AscendAllocatorPlugin is not registered.";
return nullptr;
}
if (device_id < -1) {
MS_LOG(ERROR) << "device id must more than 0";
return nullptr;
}
if (ascend_allocator_plugin_impl_ == nullptr) {
MS_LOG(ERROR) << "ascend_allocator_plugin_impl_ is nullptr.";
return nullptr;
}
auto device_data = ascend_allocator_plugin_impl_->Malloc(size, device_id);
return device_data;
#endif
return nullptr;
}
void AscendAllocatorPlugin::Free(void *device_data) {
#if !defined(_WIN32)
if (!is_registered_) {
MS_LOG(ERROR) << "AscendAllocatorPlugin is not registered.";
return;
}
if (ascend_allocator_plugin_impl_ == nullptr) {
MS_LOG(ERROR) << "ascend_allocator_plugin_impl_ is nullptr.";
return;
}
if (device_data == nullptr) {
MS_LOG(ERROR) << "device data is nullptr.";
return;
}
ascend_allocator_plugin_impl_->Free(device_data);
#endif
return;
}
Status AscendAllocatorPlugin::CopyDeviceDataToHost(void *device_data, void *host_data, size_t data_size) {
#if !defined(_WIN32)
if (!is_registered_) {
MS_LOG(ERROR) << "AscendAllocatorPlugin is not registered.";
return kLiteMemoryFailed;
}
if (device_data == nullptr) {
MS_LOG(INFO) << "device data is nullptr.";
return kLiteMemoryFailed;
}
if (ascend_allocator_plugin_impl_ == nullptr) {
return kLiteMemoryFailed;
}
return ascend_allocator_plugin_impl_->CopyDeviceDataToHost(device_data, host_data, data_size);
#endif
return kSuccess;
}
Status AscendAllocatorPlugin::CopyDeviceDataToDevice(void *src_device, void *dst_device, size_t data_size,
int src_device_id, int dst_device_id) {
#if !defined(_WIN32)
if (!is_registered_) {
MS_LOG(ERROR) << "AscendAllocatorPlugin is not registered.";
return kLiteMemoryFailed;
}
if (src_device_id < -1 || dst_device_id < -1) {
MS_LOG(ERROR) << "device id is wrong, src device id: " << src_device_id << ", dst device id: " << dst_device_id;
return kLiteError;
}
if (dst_device == nullptr || src_device == nullptr) {
MS_LOG(INFO) << "device data is nullptr.";
return kLiteMemoryFailed;
}
if (ascend_allocator_plugin_impl_ == nullptr) {
return kLiteMemoryFailed;
}
return ascend_allocator_plugin_impl_->CopyDeviceDataToDevice(src_device, dst_device, data_size, src_device_id,
dst_device_id);
#endif
return kSuccess;
}
Status AscendAllocatorPlugin::CopyHostDataToDevice(void *host_data, void *device_data, size_t data_size) {
#if !defined(_WIN32)
if (!is_registered_) {
MS_LOG(ERROR) << "AscendAllocatorPlugin is not registered.";
return kLiteMemoryFailed;
}
if (device_data == nullptr) {
MS_LOG(INFO) << "device data is nullptr.";
return kLiteMemoryFailed;
}
if (ascend_allocator_plugin_impl_ == nullptr) {
return kLiteMemoryFailed;
}
return ascend_allocator_plugin_impl_->CopyHostDataToDevice(host_data, device_data, data_size);
#endif
return kSuccess;
}
} // namespace mindspore::kernel

View File

@ -0,0 +1,60 @@
/**
* Copyright 2023 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ASCEND_ALLOCATOR_PLUGIN_H_
#define MINDSPORE_LITE_SRC_EXTENDRT_KERNEL_ASCEND_ASCEND_ALLOCATOR_PLUGIN_H_
#include <string>
#include <memory>
#include "include/api/status.h"
namespace mindspore::kernel {
class AscendAllocatorPluginImpl {
public:
AscendAllocatorPluginImpl() = default;
virtual ~AscendAllocatorPluginImpl() = default;
virtual int GetCurrentDeviceId() = 0;
virtual void *Malloc(size_t size, int device_id = -1) = 0;
virtual void Free(void *device_data) = 0;
virtual Status CopyDeviceDataToHost(void *device_data, void *host_data, size_t data_size) = 0;
virtual Status CopyHostDataToDevice(void *host_data, void *device_data, size_t data_size) = 0;
virtual Status CopyDeviceDataToDevice(void *src_device, void *dst_device, size_t data_size, int src_device_id,
int dst_device_id) = 0;
};
class MS_API AscendAllocatorPlugin {
public:
static AscendAllocatorPlugin &GetInstance();
bool Register();
int GetCurrentDeviceId();
void *Malloc(size_t size, int device_id = -1);
void Free(void *device_data);
Status CopyDeviceDataToHost(void *device_data, void *host_data, size_t data_size);
Status CopyHostDataToDevice(void *host_data, void *device_data, size_t data_size);
Status CopyDeviceDataToDevice(void *src_device, void *dst_device, size_t data_size, int src_device_id,
int dst_device_id);
private:
AscendAllocatorPlugin();
~AscendAllocatorPlugin();
std::string plugin_path_;
void *handle_ = nullptr;
bool is_registered_ = false;
std::shared_ptr<AscendAllocatorPluginImpl> ascend_allocator_plugin_impl_ = nullptr;
};
} // namespace mindspore::kernel
#endif

View File

@ -345,6 +345,7 @@ Status SingleOpInferSession::InitInputOutputData(const std::vector<tensor::Tenso
kernel_args_.inputs[i]->SetHostData(std::make_shared<kernel::Address>(input.data_c(), input.Size()));
kernel_args_.inputs[i]->SetData(nullptr);
}
kernel_args_.inputs[i]->SetDeviceId(input.device_info().device_id_);
}
if (outputs->empty()) {
std::transform(kernel_args_.outputs.begin(), kernel_args_.outputs.end(), std::back_inserter(*outputs),
@ -373,6 +374,7 @@ Status SingleOpInferSession::InitInputOutputData(const std::vector<tensor::Tenso
kernel_args_.outputs[i]->SetHostData(std::make_shared<kernel::Address>(output.data_c(), output.Size()));
kernel_args_.outputs[i]->SetData(nullptr);
}
kernel_args_.outputs[i]->SetDeviceId(output.device_info().device_id_);
}
return kSuccess;
}

View File

@ -29,6 +29,7 @@
#include "include/backend/device_address.h"
#include "common/utils.h"
#include "common/mutable_tensor_impl.h"
#include "src/extendrt/kernel/ascend/plugin/ascend_allocator_plugin.h"
namespace mindspore {
class TensorDefaultImpl : public MutableTensorImpl {
@ -58,6 +59,11 @@ class TensorDefaultImpl : public MutableTensorImpl {
if (own_data_ && data_ != nullptr && data_ != buffer_.Data()) {
free(const_cast<void *>(data_));
}
if (device_data_ != nullptr && own_data_) {
MS_LOG(INFO) << "free device data in tensor default impl.";
kernel::AscendAllocatorPlugin::GetInstance().Free(device_data_);
device_data_ = nullptr;
}
}
void SetShape(const std::vector<int64_t> &shape) override { shape_ = shape; }
void SetDataType(mindspore::DataType data_type) override { type_ = data_type; }
@ -78,7 +84,23 @@ class TensorDefaultImpl : public MutableTensorImpl {
size_t DataSize() const override { return ElementNum() * lite::DataTypeSize(static_cast<enum TypeId>(type_)); }
void SetDeviceData(void *data) override { device_data_ = data; }
std::string GetDevice() const override { return device_; }
int GetDeviceId() const override { return device_id_; }
void SetDeviceId(int device_id) override { device_id_ = device_id; }
void SetDevice(const std::string &device) override { device_ = device; }
void SetDeviceData(void *data) override {
if (own_data_ && device_data_ != nullptr) {
MS_LOG(INFO) << "tensor has own device data, now release device data and set new device data.";
kernel::AscendAllocatorPlugin::GetInstance().Free(device_data_);
}
device_data_ = data;
own_data_ = false;
}
void *GetDeviceData() override { return device_data_; }
bool IsConst() const override { return is_const_; }
void SetIsConst(bool is_const) { is_const_ = is_const; }
@ -121,6 +143,9 @@ class TensorDefaultImpl : public MutableTensorImpl {
std::vector<QuantParam> quant_param_;
void *device_data_ = nullptr;
std::string device_ = "";
int device_id_ = -1;
mutable Buffer buffer_;
mutable const void *data_ = nullptr;
bool own_data_ = false;

View File

@ -120,6 +120,9 @@ std::vector<mindspore::tensor::Tensor> TensorUtils::MSTensorToTensor(const std::
if (device_address != nullptr) {
auto lite_device_address = std::make_shared<LiteDeviceAddress>(device_address, ms_tensor.DataSize());
tensor.set_device_address(lite_device_address);
// only use device_id now.
auto device_info = tensor::DeviceInfo("DefaultFormat", nullptr, "DefaultFormat", ms_tensor.GetDeviceId());
tensor.set_device_info(device_info);
}
tensors.emplace_back(std::move(tensor));
}

View File

@ -33,7 +33,9 @@
#include "kernel/kernel.h"
#include "src/tensor.h"
#include "infer/tensor.h"
#ifdef ENABLE_CLOUD_INFERENCE
#include "src/extendrt/kernel/ascend/plugin/ascend_allocator_plugin.h"
#endif
namespace mindspore {
class TensorRefData : public tensor::TensorData {
public:
@ -91,6 +93,26 @@ class TensorTensorImpl : public MutableTensorImpl {
return std::shared_ptr<const void>(tensor_->data_c(), [](const void *) {});
}
void SetDeviceId(int device_id) override {
MS_EXCEPTION_IF_NULL(tensor_);
device_id_ = device_id;
}
void SetDevice(const std::string &device) override {
MS_EXCEPTION_IF_NULL(tensor_);
device_ = device;
}
int GetDeviceId() const override {
MS_EXCEPTION_IF_NULL(tensor_);
return device_id_;
}
std::string GetDevice() const override {
MS_EXCEPTION_IF_NULL(tensor_);
return device_;
}
void *MutableData() override {
MS_EXCEPTION_IF_NULL(tensor_);
return tensor_->data_c();
@ -98,9 +120,17 @@ class TensorTensorImpl : public MutableTensorImpl {
void SetDeviceData(void *data) override {
MS_EXCEPTION_IF_NULL(tensor_);
auto old_device_data = GetDeviceData();
MS_LOG(ERROR) << "set device data in tensor utils.";
#ifdef ENABLE_CLOUD_INFERENCE
if (old_device_data != nullptr && device_own_data_) {
kernel::AscendAllocatorPlugin::GetInstance().Free(old_device_data);
}
#endif
auto data_size = DataSize();
auto device_address = std::make_shared<LiteDeviceAddress>(data, data_size);
tensor_->set_device_address(device_address);
device_own_data_ = false;
}
void *GetDeviceData() override {
MS_EXCEPTION_IF_NULL(tensor_);
@ -168,6 +198,9 @@ class TensorTensorImpl : public MutableTensorImpl {
private:
std::shared_ptr<tensor::Tensor> tensor_ = nullptr;
std::string device_ = "";
int device_id_ = -1;
bool device_own_data_ = true;
};
class TensorUtils {

View File

@ -25,7 +25,9 @@
#include "src/litert/cxx_api/tensor_utils.h"
#include "src/tensor.h"
#include "src/common/string_utils.h"
#ifdef ENABLE_CLOUD_INFERENCE
#include "src/extendrt/kernel/ascend/plugin/ascend_allocator_plugin.h"
#endif
namespace mindspore {
using mindspore::lite::RET_OK;
@ -82,7 +84,14 @@ void LiteTensorImpl::SetDeviceData(void *data) {
MS_LOG(ERROR) << "Invalid tensor.";
return;
}
#ifdef ENABLE_CLOUD_INFERENCE
if (GetDeviceData() != nullptr && own_data_) {
MS_LOG(INFO) << "free device data in tensor impl.";
kernel::AscendAllocatorPlugin::GetInstance().Free(GetDeviceData());
}
#endif
lite_tensor_->set_device_data(data);
own_data_ = false;
}
void *LiteTensorImpl::GetDeviceData() {

View File

@ -31,6 +31,9 @@
#include "src/common/log_adapter.h"
#include "ir/api_tensor_impl.h"
#include "common/mutable_tensor_impl.h"
#if defined(ENABLE_CLOUD_FUSION_INFERENCE) || defined(ENABLE_CLOUD_INFERENCE)
#include "src/extendrt/kernel/ascend/plugin/ascend_allocator_plugin.h"
#endif
namespace mindspore {
using mindspore::lite::RET_OK;
@ -43,6 +46,13 @@ class LiteTensorImpl : public MutableTensorImpl {
if (lite_tensor_ == nullptr) {
return;
}
#if defined(ENABLE_CLOUD_FUSION_INFERENCE) || defined(ENABLE_CLOUD_INFERENCE)
if (GetDeviceData() != nullptr && own_data_) {
MS_LOG(INFO) << "free device data in tensor impl.";
kernel::AscendAllocatorPlugin::GetInstance().Free(GetDeviceData());
SetDeviceData(nullptr);
}
#endif
if (!from_session_) {
if (!own_data_) {
lite_tensor_->set_data(nullptr);
@ -126,6 +136,24 @@ class LiteTensorImpl : public MutableTensorImpl {
return lite_shape_;
}
std::string GetDevice() const override { return lite_tensor_->get_device(); }
void SetDevice(const std::string &device) override {
#if defined(ENABLE_CLOUD_FUSION_INFERENCE) || defined(ENABLE_CLOUD_INFERENCE)
void *device_data = GetDeviceData();
if (device_data != nullptr && own_data_) {
MS_LOG(INFO) << "free device data in tensor impl.";
kernel::AscendAllocatorPlugin::GetInstance().Free(device_data);
}
#endif
lite_tensor_->set_device(device);
own_data_ = false;
}
int GetDeviceId() const override { return lite_tensor_->get_device_id(); }
void SetDeviceId(int device_id) override { lite_tensor_->set_device_id(device_id); }
std::shared_ptr<mindspore::MSTensor::Impl> Clone() const override { return nullptr; }
void SetShape(const std::vector<int64_t> &shape) override {

View File

@ -19,6 +19,15 @@
#include "src/tensor.h"
namespace mindspore {
size_t MS_API CalTensorDataSize(const std::vector<int64_t> &shape, enum DataType type) {
size_t element_size = lite::DataTypeSize(static_cast<enum TypeId>(type));
for (size_t i = 0; i < shape.size(); i++) {
auto dim = shape[i];
element_size *= static_cast<size_t>(dim);
}
return element_size;
}
std::vector<int32_t> TruncateShape(const std::vector<int64_t> &shape, enum TypeId type, size_t data_len,
bool verify_size) {
std::vector<int32_t> empty;

View File

@ -29,6 +29,8 @@ namespace mindspore {
std::vector<int32_t> MS_API TruncateShape(const std::vector<int64_t> &shape, enum TypeId type, size_t data_len,
bool verify_size);
size_t MS_API CalTensorDataSize(const std::vector<int64_t> &shape, enum DataType type);
Status MS_API LiteTensorToMSTensor(lite::Tensor *srcTensor, MSTensor *dstTensor, bool fromSession = true);
std::vector<MSTensor> MS_API LiteTensorsToMSTensors(const std::vector<mindspore::lite::Tensor *> &srcTensors,

View File

@ -21,12 +21,14 @@
#include "include/api/status.h"
#include "include/api/dual_abi_helper.h"
#include "src/litert/cxx_api/tensor/tensor_impl.h"
#include "src/litert/cxx_api/tensor_utils.h"
#include "src/common/log_adapter.h"
#ifdef ENABLE_CLOUD_INFERENCE
#include <fstream>
#include "utils/file_utils.h"
#include "ir/dtype.h"
#include "utils/convert_utils_base.h"
#include "extendrt/kernel/ascend/plugin/ascend_allocator_plugin.h"
#endif
namespace mindspore {
@ -109,7 +111,56 @@ bool MSTensor::operator==(const MSTensor &tensor) const {
bool MSTensor::operator!=(const MSTensor &tensor) const { return !operator==(tensor); }
MSTensor *MSTensor::CreateTensor(const std::vector<char> &name, enum DataType type, const std::vector<int64_t> &shape,
const void *data, size_t data_len) noexcept {
const void *data, size_t data_len, const std::vector<char> &device,
int device_id) noexcept {
MS_LOG(INFO) << "device id: " << device_id << ", device type: " << device;
auto device_type = CharToString(device);
if (!device_type.empty() && device_type == "ascend") {
#ifdef ENABLE_CLOUD_INFERENCE
kernel::AscendAllocatorPlugin::GetInstance().Register();
// check device id
device_id = device_id == -1 ? kernel::AscendAllocatorPlugin::GetInstance().GetCurrentDeviceId() : device_id;
// check device data size
size_t element_size = CalTensorDataSize(shape, type);
MS_CHECK_FALSE_MSG(data_len != 0 && element_size != data_len, nullptr, "data len not equal element size.");
// malloc device data
void *device_data = kernel::AscendAllocatorPlugin::GetInstance().Malloc(element_size, device_id);
MS_CHECK_TRUE_MSG(device_data != nullptr, nullptr, "malloc device data failed.");
// create tensor
auto impl = LiteTensorImpl::CreateTensorImpl(CharToString(name), type, shape, nullptr, 0);
if (impl == nullptr) {
kernel::AscendAllocatorPlugin::GetInstance().Free(device_data);
MS_LOG(ERROR) << "Allocate tensor impl failed.";
return nullptr;
}
if (data != nullptr) {
// init device data by host data buf
auto status = kernel::AscendAllocatorPlugin::GetInstance().CopyHostDataToDevice(const_cast<void *>(data),
device_data, element_size);
if (status != kSuccess) {
kernel::AscendAllocatorPlugin::GetInstance().Free(device_data);
MS_LOG(ERROR) << "copy host data to device data failed.";
return nullptr;
}
}
// init impl
impl->SetDeviceData(device_data);
impl->SetDeviceId(device_id);
impl->SetDevice(device_type);
auto ms_tensor = new (std::nothrow) MSTensor(impl);
if (ms_tensor == nullptr) {
kernel::AscendAllocatorPlugin::GetInstance().Free(device_data);
MS_LOG(ERROR) << "Allocate MSTensor failed.";
return nullptr;
}
impl->set_own_data(true);
return ms_tensor;
#endif
MS_LOG(ERROR) << "Unsupported Feature.";
return nullptr;
}
if (data_len > MAX_MALLOC_SIZE) {
MS_LOG(ERROR) << "data_len is error.";
return nullptr;
@ -135,24 +186,87 @@ MSTensor *MSTensor::CreateTensor(const std::vector<char> &name, enum DataType ty
auto impl = LiteTensorImpl::CreateTensorImpl(CharToString(name), type, shape, new_data, data_len);
if (impl == nullptr) {
MS_LOG(ERROR) << "Allocate tensor impl failed.";
if (new_data != nullptr) {
free(new_data);
}
return nullptr;
}
auto ms_tensor = new (std::nothrow) MSTensor(impl);
if (ms_tensor == nullptr) {
MS_LOG(ERROR) << "Allocate MSTensor failed.";
if (new_data != nullptr) {
free(new_data);
}
return nullptr;
}
impl->set_own_data(true);
return ms_tensor;
}
MSTensor *MSTensor::CreateTensor(const std::vector<char> &name, const MSTensor &tensor, const std::vector<char> &device,
int device_id) noexcept {
#ifdef ENABLE_CLOUD_INFERENCE
kernel::AscendAllocatorPlugin::GetInstance().Register();
auto dst_device_type = CharToString(device);
if (!dst_device_type.empty() && dst_device_type != "ascend") {
MS_LOG(ERROR) << "only support create ascend device tensor.";
return nullptr;
}
auto src_device_type = tensor.GetDevice();
if (!src_device_type.empty() && src_device_type != "ascend") {
MS_LOG(ERROR) << "only tensor tensor is ascend device tensor.";
return nullptr;
}
if (src_device_type.empty() && static_cast<MSTensor>(tensor).GetDeviceData() != nullptr) {
MS_LOG(ERROR) << "tensor tensor is device tensor, but device data is nullptr.";
return nullptr;
}
if (src_device_type.empty() && dst_device_type.empty()) {
MS_LOG(INFO) << "copy host tensor to host tensor.";
if (tensor.Data() != nullptr) {
return CreateTensor(tensor.Name(), tensor.DataType(), tensor.Shape(), static_cast<MSTensor>(tensor).MutableData(),
tensor.DataSize());
} else {
return CreateTensor(tensor.Name(), tensor.DataType(), tensor.Shape(), nullptr, 0);
}
} else if (src_device_type == "ascend" && dst_device_type == "ascend") {
MS_LOG(INFO) << "copy device tensor to device tensor.";
auto new_tensor =
CreateTensor(tensor.Name(), tensor.DataType(), tensor.Shape(), nullptr, tensor.DataSize(), "ascend", device_id);
auto status = kernel::AscendAllocatorPlugin::GetInstance().CopyDeviceDataToDevice(
static_cast<MSTensor>(tensor).GetDeviceData(), new_tensor->GetDeviceData(), tensor.DataSize(),
tensor.GetDeviceId(), new_tensor->GetDeviceId());
if (status != kSuccess) {
return nullptr;
}
return new_tensor;
} else if (src_device_type.empty() && dst_device_type == "ascend") {
MS_LOG(INFO) << "copy host tensor to device tensor.";
return CreateTensor(tensor.Name(), tensor.DataType(), tensor.Shape(), static_cast<MSTensor>(tensor).MutableData(),
tensor.DataSize(), dst_device_type, device_id);
} else if (src_device_type == "ascend" && dst_device_type.empty()) {
MS_LOG(INFO) << "copy device tensor to host tensor.";
auto host_form_device = malloc(tensor.DataSize());
MS_CHECK_FALSE_MSG(host_form_device == nullptr, nullptr, "malloc host buf failed.");
auto status = kernel::AscendAllocatorPlugin::GetInstance().CopyDeviceDataToHost(
static_cast<MSTensor>(tensor).GetDeviceData(), host_form_device, tensor.DataSize());
if (status != kSuccess) {
free(host_form_device);
return nullptr;
}
auto new_tensor =
CreateTensor(tensor.Name(), tensor.DataType(), tensor.Shape(), host_form_device, tensor.DataSize());
free(host_form_device);
host_form_device = nullptr;
return new_tensor;
} else {
MS_LOG(ERROR) << "device type is wrong.";
return nullptr;
}
#endif
MS_LOG(ERROR) << "Unsupported Feature.";
return nullptr;
}
MSTensor *MSTensor::CreateRefTensor(const std::vector<char> &name, enum DataType type,
const std::vector<int64_t> &shape, const void *data, size_t data_len,
bool own_data) noexcept {
@ -442,6 +556,22 @@ size_t MSTensor::DataSize() const {
return impl_->DataSize();
}
std::string MSTensor::GetDevice() const {
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Invalid tensor implement.";
return "";
}
return std::static_pointer_cast<MutableTensorImpl>(impl_)->GetDevice();
}
int MSTensor::GetDeviceId() const {
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Invalid tensor implement.";
return -1;
}
return std::static_pointer_cast<MutableTensorImpl>(impl_)->GetDeviceId();
}
bool MSTensor::IsDevice() const {
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Invalid tensor implement.";

View File

@ -277,6 +277,14 @@ class Tensor {
bool get_shape_changed() const { return tensor_c_.shape_changed_; }
int get_device_id() const { return device_id_; }
void set_device_id(int device_id) { device_id_ = device_id; }
std::string get_device() { return device_; }
void set_device(const std::string &device) { device_ = device; }
TensorC *ConvertToTensorC() { return &tensor_c_; }
private:
@ -306,6 +314,8 @@ class Tensor {
void *device_data_ = nullptr;
CompressType compress_type_ = kNoCompression;
size_t compressed_size_ = 0;
std::string device_ = "";
int device_id_ = -1;
};
} // namespace lite
} // namespace mindspore

View File

@ -560,10 +560,8 @@ def test_model_parallel_runner_predict_02():
# ============================ Tensor ============================
def test_tensor_type_error():
with pytest.raises(TypeError) as raise_info:
tensor1 = mslite.Tensor()
tensor2 = mslite.Tensor(tensor=tensor1)
assert "tensor must be MindSpore Lite's Tensor._tensor" in str(raise_info.value)
def test_tensor():

View File

@ -560,10 +560,8 @@ def test_model_parallel_runner_predict_02():
# ============================ Tensor ============================
def test_tensor_type_error():
with pytest.raises(TypeError) as raise_info:
tensor1 = mslite.Tensor()
tensor2 = mslite.Tensor(tensor=tensor1)
assert "tensor must be MindSpore Lite's Tensor._tensor" in str(raise_info.value)
def test_tensor():

View File

@ -311,6 +311,7 @@ def test_model_build_from_file_config_path_not_exist_error():
config_path="test.cfg")
assert "config_path does not exist" in str(raise_info.value)
def test_model_build_from_file_config_dict_type_error():
with pytest.raises(TypeError) as raise_info:
model = mslite.Model()
@ -318,6 +319,7 @@ def test_model_build_from_file_config_dict_type_error():
config_dict="test.cfg")
assert "config_dict must be dict" in str(raise_info.value)
def test_model_build_from_file_config_dict_key_type_error():
with pytest.raises(TypeError) as raise_info:
model = mslite.Model()
@ -326,6 +328,7 @@ def test_model_build_from_file_config_dict_key_type_error():
config_dict=dict_0)
assert "config_dict_key must be str" in str(raise_info.value)
def test_model_build_from_file_config_dict_value_type_error():
with pytest.raises(TypeError) as raise_info:
model = mslite.Model()
@ -334,6 +337,7 @@ def test_model_build_from_file_config_dict_value_type_error():
config_dict=dict_1)
assert "config_dict_value must be dict" in str(raise_info.value)
def test_model_build_from_file_config_dict_value_key_type_error():
with pytest.raises(TypeError) as raise_info:
model = mslite.Model()
@ -342,6 +346,7 @@ def test_model_build_from_file_config_dict_value_key_type_error():
config_dict=dict_2)
assert "config_dict_value_key must be str" in str(raise_info.value)
def test_model_build_from_file_config_dict_value_value_type_error():
with pytest.raises(TypeError) as raise_info:
model = mslite.Model()
@ -350,6 +355,7 @@ def test_model_build_from_file_config_dict_value_value_type_error():
config_dict=dict_3)
assert "config_dict_value_value must be str" in str(raise_info.value)
def get_model():
context = mslite.Context()
context.target = ["cpu"]
@ -460,10 +466,8 @@ def test_model_predict_02():
# ============================ Tensor ============================
def test_tensor_type_error():
with pytest.raises(TypeError) as raise_info:
tensor1 = mslite.Tensor()
tensor2 = mslite.Tensor(tensor=tensor1)
assert "tensor must be MindSpore Lite's Tensor._tensor" in str(raise_info.value)
tensor2 = mslite.Tensor(tensor=tensor1) # now supported
def test_tensor():

View File

@ -168,6 +168,7 @@ set(LITE_SRC ${API_SRC}
${SRC_DIR}/control_flow/control_flow_scheduler.cc
${SRC_DIR}/control_flow/control_subgraph_creator.cc
${SRC_DIR}/litert/kernel/ascend/plugin/ascend_kernel_plugin.cc
${SRC_DIR}/extendrt/kernel/ascend/plugin/ascend_allocator_plugin.cc
)
if(MSLITE_ENABLE_CUSTOM_KERNEL)

View File

@ -14,6 +14,7 @@ set(REG_SRC ${CONVERT_REG_SRC}
${KERNEL_REG_DIR}/../common/utils.cc
${KERNEL_REG_DIR}/../extendrt/delegate/tensorrt/distribution/distribution_base.cc
${KERNEL_REG_DIR}/../extendrt/delegate/plugin/tensorrt_executor_plugin.cc
${KERNEL_REG_DIR}/../extendrt/kernel/ascend/plugin/ascend_allocator_plugin.cc
${CONVERTER_DIR}/converter_context.cc
${TOP_DIR}/mindspore/ccsrc/plugin/device/cpu/kernel/nnacl/tensor_c_utils.c
${TOP_DIR}/mindspore/lite/src/common/file_utils.cc