API to change TensorsWeights for Site AI

This commit is contained in:
Emir Haleva 2021-11-03 16:02:09 +02:00
parent 2828f317f7
commit 8433b1be54
11 changed files with 174 additions and 9 deletions

View File

@ -64,6 +64,14 @@ class MS_API Model {
/// \return Status.
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
/// \brief Change the size and or content of weight tensors
///
/// \param[in] new_weights a vector of tensors with new shapes and data to use in the model
/// If data pointer is null, the data of the original tensors will be copied to the new ones
///
/// \return Status.
Status UpdateWeights(const std::vector<MSTensor> &new_weights);
/// \brief Inference model.
///
/// \param[in] inputs A vector where model inputs are arranged in sequence.

View File

@ -14,8 +14,8 @@
* limitations under the License.
*/
#ifndef MINDSPORE_LITE_INCLUDE_LITE_SESSION_H
#define MINDSPORE_LITE_INCLUDE_LITE_SESSION_H
#ifndef MINDSPORE_LITE_INCLUDE_LITE_SESSION_H_
#define MINDSPORE_LITE_INCLUDE_LITE_SESSION_H_
#ifndef NOT_USE_STL
#include <unordered_map>
@ -190,6 +190,14 @@ class MS_API LiteSession {
return mindspore::lite::RET_ERROR;
}
/// \brief Change the size and or content of weight tensors
///
/// \param[in] new_weights a vector of tensors with new shapes and data to use in the model
/// If data pointer is null, the data of the original tensors will be copied to the new ones
///
/// \return STATUS as an error code of operation, STATUS is defined in errorcode.h.
virtual int UpdateWeights(std::vector<tensor::MSTensor *> new_weights) { return mindspore::lite::RET_ERROR; }
/// \brief Get model featuremap MindSpore Lite MSTensors of Training model prediction
///
/// \return a vector of output tensors (MindSpore Lite MSTensor).
@ -233,4 +241,4 @@ class MS_API LiteSession {
};
} // namespace session
} // namespace mindspore
#endif // MINDSPORE_LITE_INCLUDE_LITE_SESSION_H
#endif // MINDSPORE_LITE_INCLUDE_LITE_SESSION_H_

View File

@ -102,6 +102,14 @@ Status Model::Resize(const std::vector<MSTensor> &inputs, const std::vector<std:
return impl_->Resize(inputs, dims);
}
Status Model::UpdateWeights(const std::vector<MSTensor> &new_weights) {
if (impl_ == nullptr) {
MS_LOG(ERROR) << "Model implement is null.";
return kLiteNullptr;
}
return impl_->UpdateWeights(new_weights);
}
Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
const MSKernelCallBack &before, const MSKernelCallBack &after) {
if (impl_ == nullptr) {

View File

@ -559,6 +559,29 @@ Status ModelImpl::Resize(const std::vector<MSTensor> &inputs, const std::vector<
return static_cast<StatusCode>(ret);
}
Status ModelImpl::UpdateWeights(const std::vector<MSTensor> &new_weights) {
if (session_ == nullptr) {
MS_LOG(ERROR) << "Session is null.";
return kLiteNullptr;
}
if (new_weights.empty()) {
MS_LOG(ERROR) << "New weights are empty.";
return kLiteInputParamInvalid;
}
std::vector<tensor::MSTensor *> inner_weights;
inner_weights.resize(new_weights.size());
for (size_t i = 0; i < new_weights.size(); i++) {
auto weight = new_weights[i];
if (weight.impl_ == nullptr || weight.impl_->lite_tensor() == nullptr) {
MS_LOG(ERROR) << "Input tensor " << weight.Name() << " is null.";
return kLiteInputTensorError;
}
inner_weights[i] = weight.impl_->lite_tensor();
}
auto ret = session_->UpdateWeights(inner_weights);
return static_cast<StatusCode>(ret);
}
session::LiteSession *ModelImpl::CreateLiteSession(lite::InnerContext *context) {
auto session = new (std::nothrow) lite::LiteSession();
if (session == nullptr) {

View File

@ -63,6 +63,7 @@ class ModelImpl {
const std::shared_ptr<Context> &model_context);
Status Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context);
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
Status UpdateWeights(const std::vector<MSTensor> &new_weights);
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs, const MSKernelCallBack &before,
const MSKernelCallBack &after);

View File

@ -91,7 +91,7 @@ MSTensor *MSTensor::CreateTensor(const std::vector<char> &name, enum DataType ty
return nullptr;
}
if (data_len > 0 && data == nullptr) {
MS_LOG(ERROR) << "Mull data ptr of tensor.";
MS_LOG(ERROR) << "Null data ptr of tensor.";
return nullptr;
}
auto impl = Impl::CreateTensorImpl(CharToString(name), type, shape, nullptr, data_len);

View File

@ -28,7 +28,7 @@ using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits;
namespace mindspore::kernel {
int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ReSize() { return RET_OK; }
int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ReSize() { return Prepare(); }
int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const int *labels, const float *losses,
float *output) const {

View File

@ -50,8 +50,6 @@ int StridedSliceGradCPUKernel::Prepare() {
MS_LOG(ERROR) << "Not supported data type: " << input->data_type();
return RET_ERROR;
}
FillEmptyDims();
FillOutputDim();
return ReSize();
}
@ -113,7 +111,11 @@ void StridedSliceGradCPUKernel::FillOutputDim() {
}
}
int StridedSliceGradCPUKernel::ReSize() { return RET_OK; }
int StridedSliceGradCPUKernel::ReSize() {
FillEmptyDims();
FillOutputDim();
return RET_OK;
}
int StridedSliceGradImpl(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
CHECK_NULL_RETURN(cdata);

View File

@ -176,6 +176,89 @@ int TrainSession::InitCallBack() {
return RET_OK;
}
static int ReshapeWeightTensor(Tensor *orig_tensor, tensor::MSTensor *new_tensor) {
if (orig_tensor->data_type() != new_tensor->data_type()) {
MS_LOG(ERROR) << "Cannot reshape tensor of different type: " << new_tensor->tensor_name();
return RET_PARAM_INVALID;
}
if (orig_tensor->category() != lite::Category::CONST_TENSOR) {
MS_LOG(ERROR) << "Cannot reshape non const tensor: " << new_tensor->tensor_name();
return RET_ERROR;
}
auto orig_size = orig_tensor->Size();
uint8_t *new_data = reinterpret_cast<uint8_t *>(new_tensor->data());
if (new_data == nullptr) {
// Copy original data into new_tensor
new_data = reinterpret_cast<uint8_t *>(new_tensor->MutableData());
if (new_data == nullptr) {
MS_LOG(ERROR) << "Allocation of Data Failed" << new_tensor->tensor_name();
return RET_ERROR;
}
if (orig_size == 0) {
MS_LOG(ERROR) << "Operation failed: Both new tensors and original one have no data";
return RET_ERROR;
}
uint8_t *orig_data = reinterpret_cast<uint8_t *>(orig_tensor->data());
for (unsigned int loc = 0; loc < new_tensor->Size(); loc++) {
new_data[loc] = orig_data[loc % orig_size];
}
}
orig_tensor->FreeData();
orig_tensor->set_data(nullptr);
orig_tensor->set_shape(new_tensor->shape());
uint8_t *dst_data = reinterpret_cast<uint8_t *>(orig_tensor->MutableData());
if (dst_data == nullptr) {
MS_LOG(ERROR) << "Allocation of Data Failed";
return RET_ERROR;
}
std::copy(new_data, new_data + orig_tensor->Size(), dst_data);
return RET_OK;
}
int TrainSession::UpdateWeights(std::vector<tensor::MSTensor *> modify_tensors) {
unsigned int num_of_found_tensors = 0;
for (auto tensor : tensors_) {
for (auto modify : modify_tensors) {
if (modify == nullptr) {
MS_LOG(ERROR) << "Tensor is nullptr";
return RET_PARAM_INVALID;
}
if (modify->tensor_name() == tensor->tensor_name()) {
auto ret = ReshapeWeightTensor(tensor, modify);
num_of_found_tensors++;
if (ret != RET_OK) {
return ret;
}
break;
}
}
}
if (num_of_found_tensors != modify_tensors.size()) {
MS_LOG(ERROR) << "Did not find all the given tensors in the model";
return RET_ERROR;
}
auto ret = ReSizeKernels(kernels_);
if (ret != RET_OK) {
MS_LOG(ERROR) << "Resize kernels fail!";
return ret;
}
bool is_eval = IsEval();
ret = Train(); // This will trigger proper Allocation of static data;
if (ret != RET_OK) {
MS_LOG(ERROR) << "General failure occurred during Update of Weights";
return ret;
}
if (is_eval) {
ret = Eval();
}
return ret;
}
int TrainSession::AllocTensors(const std::vector<kernel::LiteKernel *> &kernels) {
if (!IS_STATIC_ALLOCATOR(allocator_)) return RET_OK;
OptAllocator allocator;
@ -199,8 +282,12 @@ int TrainSession::AllocTensors(const std::vector<kernel::LiteKernel *> &kernels)
}
}
// Set Tensor data
if (tensors_data_ == nullptr) {
auto size = allocator.total_size();
if (size > tensors_data_size_) {
free(tensors_data_);
tensors_data_ = nullptr;
}
if (tensors_data_ == nullptr) {
auto buf = malloc(size);
if (buf == nullptr) {
MS_LOG(ERROR) << "cannot allocate buffer size" << size;
@ -209,6 +296,7 @@ int TrainSession::AllocTensors(const std::vector<kernel::LiteKernel *> &kernels)
StaticAllocator *alloc = reinterpret_cast<StaticAllocator *>(allocator_.get());
alloc->SetContex(buf, size);
tensors_data_ = buf;
tensors_data_size_ = size;
}
for (auto kernel : train_kernels_) {
for (auto tensor : kernel->out_tensors()) {

View File

@ -85,6 +85,7 @@ class TrainSession : virtual public lite::LiteSession {
return lite::LiteSession::GetOutputByTensorName(tensor_name);
}
int Resize(const std::vector<tensor::MSTensor *> &inputs, const std::vector<std::vector<int>> &dims) override;
int UpdateWeights(std::vector<tensor::MSTensor *> new_weights) override;
std::vector<tensor::MSTensor *> GetPredictions() const override {
std::vector<tensor::MSTensor *> outputs;
@ -166,6 +167,7 @@ class TrainSession : virtual public lite::LiteSession {
SchedCallBack sched_mix_precision_callback_;
bool train_mode_ = false;
void *tensors_data_ = nullptr;
unsigned int tensors_data_size_ = 0;
std::shared_ptr<Allocator> allocator_;
};

View File

@ -229,4 +229,29 @@ TEST_F(TestCxxApiLiteModel, test_fp16_SUCCESS) {
train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = true;
ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess);
}
#define NUM_OF_CLASSES 10
#define FEATURE_SIZE 10
TEST_F(TestCxxApiLiteModel, set_weights_FAILURE) {
Model model;
Graph graph;
auto context = std::make_shared<Context>();
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
cpu_context->SetEnableFP16(true);
context->MutableDeviceInfo().push_back(cpu_context);
auto train_cfg = std::make_shared<TrainCfg>();
train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = true;
ASSERT_TRUE(Serialization::Load("./nets/mix_lenet_tod.ms", ModelType::kMindIR, &graph) == kSuccess);
ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess);
std::vector<mindspore::MSTensor> changes;
ASSERT_TRUE(model.UpdateWeights(changes) != kSuccess);
changes.push_back(
*MSTensor::CreateTensor("fc4.weight", mindspore::DataType::kNumberTypeFloat32, {NUM_OF_CLASSES}, nullptr, 0));
ASSERT_TRUE(model.UpdateWeights(changes) != kSuccess);
changes.clear();
changes.push_back(
*MSTensor::CreateTensor("fc3.bias", mindspore::DataType::kNumberTypeFloat32, {NUM_OF_CLASSES}, nullptr, 0));
ASSERT_TRUE(model.UpdateWeights(changes) == kSuccess);
}
} // namespace mindspore