forked from mindspore-Ecosystem/mindspore
API to change TensorsWeights for Site AI
This commit is contained in:
parent
2828f317f7
commit
8433b1be54
|
@ -64,6 +64,14 @@ class MS_API Model {
|
||||||
/// \return Status.
|
/// \return Status.
|
||||||
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
|
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
|
||||||
|
|
||||||
|
/// \brief Change the size and or content of weight tensors
|
||||||
|
///
|
||||||
|
/// \param[in] new_weights a vector of tensors with new shapes and data to use in the model
|
||||||
|
/// If data pointer is null, the data of the original tensors will be copied to the new ones
|
||||||
|
///
|
||||||
|
/// \return Status.
|
||||||
|
Status UpdateWeights(const std::vector<MSTensor> &new_weights);
|
||||||
|
|
||||||
/// \brief Inference model.
|
/// \brief Inference model.
|
||||||
///
|
///
|
||||||
/// \param[in] inputs A vector where model inputs are arranged in sequence.
|
/// \param[in] inputs A vector where model inputs are arranged in sequence.
|
||||||
|
|
|
@ -14,8 +14,8 @@
|
||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef MINDSPORE_LITE_INCLUDE_LITE_SESSION_H
|
#ifndef MINDSPORE_LITE_INCLUDE_LITE_SESSION_H_
|
||||||
#define MINDSPORE_LITE_INCLUDE_LITE_SESSION_H
|
#define MINDSPORE_LITE_INCLUDE_LITE_SESSION_H_
|
||||||
|
|
||||||
#ifndef NOT_USE_STL
|
#ifndef NOT_USE_STL
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
@ -190,6 +190,14 @@ class MS_API LiteSession {
|
||||||
return mindspore::lite::RET_ERROR;
|
return mindspore::lite::RET_ERROR;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// \brief Change the size and or content of weight tensors
|
||||||
|
///
|
||||||
|
/// \param[in] new_weights a vector of tensors with new shapes and data to use in the model
|
||||||
|
/// If data pointer is null, the data of the original tensors will be copied to the new ones
|
||||||
|
///
|
||||||
|
/// \return STATUS as an error code of operation, STATUS is defined in errorcode.h.
|
||||||
|
virtual int UpdateWeights(std::vector<tensor::MSTensor *> new_weights) { return mindspore::lite::RET_ERROR; }
|
||||||
|
|
||||||
/// \brief Get model featuremap MindSpore Lite MSTensors of Training model prediction
|
/// \brief Get model featuremap MindSpore Lite MSTensors of Training model prediction
|
||||||
///
|
///
|
||||||
/// \return a vector of output tensors (MindSpore Lite MSTensor).
|
/// \return a vector of output tensors (MindSpore Lite MSTensor).
|
||||||
|
@ -233,4 +241,4 @@ class MS_API LiteSession {
|
||||||
};
|
};
|
||||||
} // namespace session
|
} // namespace session
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
#endif // MINDSPORE_LITE_INCLUDE_LITE_SESSION_H
|
#endif // MINDSPORE_LITE_INCLUDE_LITE_SESSION_H_
|
||||||
|
|
|
@ -102,6 +102,14 @@ Status Model::Resize(const std::vector<MSTensor> &inputs, const std::vector<std:
|
||||||
return impl_->Resize(inputs, dims);
|
return impl_->Resize(inputs, dims);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status Model::UpdateWeights(const std::vector<MSTensor> &new_weights) {
|
||||||
|
if (impl_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Model implement is null.";
|
||||||
|
return kLiteNullptr;
|
||||||
|
}
|
||||||
|
return impl_->UpdateWeights(new_weights);
|
||||||
|
}
|
||||||
|
|
||||||
Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
Status Model::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs,
|
||||||
const MSKernelCallBack &before, const MSKernelCallBack &after) {
|
const MSKernelCallBack &before, const MSKernelCallBack &after) {
|
||||||
if (impl_ == nullptr) {
|
if (impl_ == nullptr) {
|
||||||
|
|
|
@ -559,6 +559,29 @@ Status ModelImpl::Resize(const std::vector<MSTensor> &inputs, const std::vector<
|
||||||
return static_cast<StatusCode>(ret);
|
return static_cast<StatusCode>(ret);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Status ModelImpl::UpdateWeights(const std::vector<MSTensor> &new_weights) {
|
||||||
|
if (session_ == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Session is null.";
|
||||||
|
return kLiteNullptr;
|
||||||
|
}
|
||||||
|
if (new_weights.empty()) {
|
||||||
|
MS_LOG(ERROR) << "New weights are empty.";
|
||||||
|
return kLiteInputParamInvalid;
|
||||||
|
}
|
||||||
|
std::vector<tensor::MSTensor *> inner_weights;
|
||||||
|
inner_weights.resize(new_weights.size());
|
||||||
|
for (size_t i = 0; i < new_weights.size(); i++) {
|
||||||
|
auto weight = new_weights[i];
|
||||||
|
if (weight.impl_ == nullptr || weight.impl_->lite_tensor() == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Input tensor " << weight.Name() << " is null.";
|
||||||
|
return kLiteInputTensorError;
|
||||||
|
}
|
||||||
|
inner_weights[i] = weight.impl_->lite_tensor();
|
||||||
|
}
|
||||||
|
auto ret = session_->UpdateWeights(inner_weights);
|
||||||
|
return static_cast<StatusCode>(ret);
|
||||||
|
}
|
||||||
|
|
||||||
session::LiteSession *ModelImpl::CreateLiteSession(lite::InnerContext *context) {
|
session::LiteSession *ModelImpl::CreateLiteSession(lite::InnerContext *context) {
|
||||||
auto session = new (std::nothrow) lite::LiteSession();
|
auto session = new (std::nothrow) lite::LiteSession();
|
||||||
if (session == nullptr) {
|
if (session == nullptr) {
|
||||||
|
|
|
@ -63,6 +63,7 @@ class ModelImpl {
|
||||||
const std::shared_ptr<Context> &model_context);
|
const std::shared_ptr<Context> &model_context);
|
||||||
Status Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context);
|
Status Build(const std::string &model_path, ModelType model_type, const std::shared_ptr<Context> &model_context);
|
||||||
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
|
Status Resize(const std::vector<MSTensor> &inputs, const std::vector<std::vector<int64_t>> &dims);
|
||||||
|
Status UpdateWeights(const std::vector<MSTensor> &new_weights);
|
||||||
|
|
||||||
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs, const MSKernelCallBack &before,
|
Status Predict(const std::vector<MSTensor> &inputs, std::vector<MSTensor> *outputs, const MSKernelCallBack &before,
|
||||||
const MSKernelCallBack &after);
|
const MSKernelCallBack &after);
|
||||||
|
|
|
@ -91,7 +91,7 @@ MSTensor *MSTensor::CreateTensor(const std::vector<char> &name, enum DataType ty
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
if (data_len > 0 && data == nullptr) {
|
if (data_len > 0 && data == nullptr) {
|
||||||
MS_LOG(ERROR) << "Mull data ptr of tensor.";
|
MS_LOG(ERROR) << "Null data ptr of tensor.";
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
auto impl = Impl::CreateTensorImpl(CharToString(name), type, shape, nullptr, data_len);
|
auto impl = Impl::CreateTensorImpl(CharToString(name), type, shape, nullptr, data_len);
|
||||||
|
|
|
@ -28,7 +28,7 @@ using mindspore::lite::RET_OK;
|
||||||
using mindspore::schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits;
|
using mindspore::schema::PrimitiveType_SparseSoftmaxCrossEntropyWithLogits;
|
||||||
|
|
||||||
namespace mindspore::kernel {
|
namespace mindspore::kernel {
|
||||||
int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ReSize() { return RET_OK; }
|
int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ReSize() { return Prepare(); }
|
||||||
|
|
||||||
int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const int *labels, const float *losses,
|
int SparseSoftmaxCrossEntropyWithLogitsCPUKernel::ForwardPostExecute(const int *labels, const float *losses,
|
||||||
float *output) const {
|
float *output) const {
|
||||||
|
|
|
@ -50,8 +50,6 @@ int StridedSliceGradCPUKernel::Prepare() {
|
||||||
MS_LOG(ERROR) << "Not supported data type: " << input->data_type();
|
MS_LOG(ERROR) << "Not supported data type: " << input->data_type();
|
||||||
return RET_ERROR;
|
return RET_ERROR;
|
||||||
}
|
}
|
||||||
FillEmptyDims();
|
|
||||||
FillOutputDim();
|
|
||||||
return ReSize();
|
return ReSize();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -113,7 +111,11 @@ void StridedSliceGradCPUKernel::FillOutputDim() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int StridedSliceGradCPUKernel::ReSize() { return RET_OK; }
|
int StridedSliceGradCPUKernel::ReSize() {
|
||||||
|
FillEmptyDims();
|
||||||
|
FillOutputDim();
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
int StridedSliceGradImpl(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
int StridedSliceGradImpl(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
|
||||||
CHECK_NULL_RETURN(cdata);
|
CHECK_NULL_RETURN(cdata);
|
||||||
|
|
|
@ -176,6 +176,89 @@ int TrainSession::InitCallBack() {
|
||||||
return RET_OK;
|
return RET_OK;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static int ReshapeWeightTensor(Tensor *orig_tensor, tensor::MSTensor *new_tensor) {
|
||||||
|
if (orig_tensor->data_type() != new_tensor->data_type()) {
|
||||||
|
MS_LOG(ERROR) << "Cannot reshape tensor of different type: " << new_tensor->tensor_name();
|
||||||
|
return RET_PARAM_INVALID;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (orig_tensor->category() != lite::Category::CONST_TENSOR) {
|
||||||
|
MS_LOG(ERROR) << "Cannot reshape non const tensor: " << new_tensor->tensor_name();
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto orig_size = orig_tensor->Size();
|
||||||
|
uint8_t *new_data = reinterpret_cast<uint8_t *>(new_tensor->data());
|
||||||
|
if (new_data == nullptr) {
|
||||||
|
// Copy original data into new_tensor
|
||||||
|
new_data = reinterpret_cast<uint8_t *>(new_tensor->MutableData());
|
||||||
|
if (new_data == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Allocation of Data Failed" << new_tensor->tensor_name();
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
if (orig_size == 0) {
|
||||||
|
MS_LOG(ERROR) << "Operation failed: Both new tensors and original one have no data";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
uint8_t *orig_data = reinterpret_cast<uint8_t *>(orig_tensor->data());
|
||||||
|
for (unsigned int loc = 0; loc < new_tensor->Size(); loc++) {
|
||||||
|
new_data[loc] = orig_data[loc % orig_size];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
orig_tensor->FreeData();
|
||||||
|
orig_tensor->set_data(nullptr);
|
||||||
|
orig_tensor->set_shape(new_tensor->shape());
|
||||||
|
|
||||||
|
uint8_t *dst_data = reinterpret_cast<uint8_t *>(orig_tensor->MutableData());
|
||||||
|
if (dst_data == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Allocation of Data Failed";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
std::copy(new_data, new_data + orig_tensor->Size(), dst_data);
|
||||||
|
return RET_OK;
|
||||||
|
}
|
||||||
|
|
||||||
|
int TrainSession::UpdateWeights(std::vector<tensor::MSTensor *> modify_tensors) {
|
||||||
|
unsigned int num_of_found_tensors = 0;
|
||||||
|
for (auto tensor : tensors_) {
|
||||||
|
for (auto modify : modify_tensors) {
|
||||||
|
if (modify == nullptr) {
|
||||||
|
MS_LOG(ERROR) << "Tensor is nullptr";
|
||||||
|
return RET_PARAM_INVALID;
|
||||||
|
}
|
||||||
|
if (modify->tensor_name() == tensor->tensor_name()) {
|
||||||
|
auto ret = ReshapeWeightTensor(tensor, modify);
|
||||||
|
num_of_found_tensors++;
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (num_of_found_tensors != modify_tensors.size()) {
|
||||||
|
MS_LOG(ERROR) << "Did not find all the given tensors in the model";
|
||||||
|
return RET_ERROR;
|
||||||
|
}
|
||||||
|
auto ret = ReSizeKernels(kernels_);
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "Resize kernels fail!";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_eval = IsEval();
|
||||||
|
ret = Train(); // This will trigger proper Allocation of static data;
|
||||||
|
if (ret != RET_OK) {
|
||||||
|
MS_LOG(ERROR) << "General failure occurred during Update of Weights";
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
if (is_eval) {
|
||||||
|
ret = Eval();
|
||||||
|
}
|
||||||
|
return ret;
|
||||||
|
}
|
||||||
|
|
||||||
int TrainSession::AllocTensors(const std::vector<kernel::LiteKernel *> &kernels) {
|
int TrainSession::AllocTensors(const std::vector<kernel::LiteKernel *> &kernels) {
|
||||||
if (!IS_STATIC_ALLOCATOR(allocator_)) return RET_OK;
|
if (!IS_STATIC_ALLOCATOR(allocator_)) return RET_OK;
|
||||||
OptAllocator allocator;
|
OptAllocator allocator;
|
||||||
|
@ -199,8 +282,12 @@ int TrainSession::AllocTensors(const std::vector<kernel::LiteKernel *> &kernels)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Set Tensor data
|
// Set Tensor data
|
||||||
|
auto size = allocator.total_size();
|
||||||
|
if (size > tensors_data_size_) {
|
||||||
|
free(tensors_data_);
|
||||||
|
tensors_data_ = nullptr;
|
||||||
|
}
|
||||||
if (tensors_data_ == nullptr) {
|
if (tensors_data_ == nullptr) {
|
||||||
auto size = allocator.total_size();
|
|
||||||
auto buf = malloc(size);
|
auto buf = malloc(size);
|
||||||
if (buf == nullptr) {
|
if (buf == nullptr) {
|
||||||
MS_LOG(ERROR) << "cannot allocate buffer size" << size;
|
MS_LOG(ERROR) << "cannot allocate buffer size" << size;
|
||||||
|
@ -209,6 +296,7 @@ int TrainSession::AllocTensors(const std::vector<kernel::LiteKernel *> &kernels)
|
||||||
StaticAllocator *alloc = reinterpret_cast<StaticAllocator *>(allocator_.get());
|
StaticAllocator *alloc = reinterpret_cast<StaticAllocator *>(allocator_.get());
|
||||||
alloc->SetContex(buf, size);
|
alloc->SetContex(buf, size);
|
||||||
tensors_data_ = buf;
|
tensors_data_ = buf;
|
||||||
|
tensors_data_size_ = size;
|
||||||
}
|
}
|
||||||
for (auto kernel : train_kernels_) {
|
for (auto kernel : train_kernels_) {
|
||||||
for (auto tensor : kernel->out_tensors()) {
|
for (auto tensor : kernel->out_tensors()) {
|
||||||
|
|
|
@ -85,6 +85,7 @@ class TrainSession : virtual public lite::LiteSession {
|
||||||
return lite::LiteSession::GetOutputByTensorName(tensor_name);
|
return lite::LiteSession::GetOutputByTensorName(tensor_name);
|
||||||
}
|
}
|
||||||
int Resize(const std::vector<tensor::MSTensor *> &inputs, const std::vector<std::vector<int>> &dims) override;
|
int Resize(const std::vector<tensor::MSTensor *> &inputs, const std::vector<std::vector<int>> &dims) override;
|
||||||
|
int UpdateWeights(std::vector<tensor::MSTensor *> new_weights) override;
|
||||||
|
|
||||||
std::vector<tensor::MSTensor *> GetPredictions() const override {
|
std::vector<tensor::MSTensor *> GetPredictions() const override {
|
||||||
std::vector<tensor::MSTensor *> outputs;
|
std::vector<tensor::MSTensor *> outputs;
|
||||||
|
@ -166,6 +167,7 @@ class TrainSession : virtual public lite::LiteSession {
|
||||||
SchedCallBack sched_mix_precision_callback_;
|
SchedCallBack sched_mix_precision_callback_;
|
||||||
bool train_mode_ = false;
|
bool train_mode_ = false;
|
||||||
void *tensors_data_ = nullptr;
|
void *tensors_data_ = nullptr;
|
||||||
|
unsigned int tensors_data_size_ = 0;
|
||||||
std::shared_ptr<Allocator> allocator_;
|
std::shared_ptr<Allocator> allocator_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -229,4 +229,29 @@ TEST_F(TestCxxApiLiteModel, test_fp16_SUCCESS) {
|
||||||
train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = true;
|
train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = true;
|
||||||
ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess);
|
ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define NUM_OF_CLASSES 10
|
||||||
|
#define FEATURE_SIZE 10
|
||||||
|
TEST_F(TestCxxApiLiteModel, set_weights_FAILURE) {
|
||||||
|
Model model;
|
||||||
|
Graph graph;
|
||||||
|
auto context = std::make_shared<Context>();
|
||||||
|
auto cpu_context = std::make_shared<mindspore::CPUDeviceInfo>();
|
||||||
|
cpu_context->SetEnableFP16(true);
|
||||||
|
context->MutableDeviceInfo().push_back(cpu_context);
|
||||||
|
auto train_cfg = std::make_shared<TrainCfg>();
|
||||||
|
train_cfg->mix_precision_cfg_.is_raw_mix_precision_ = true;
|
||||||
|
|
||||||
|
ASSERT_TRUE(Serialization::Load("./nets/mix_lenet_tod.ms", ModelType::kMindIR, &graph) == kSuccess);
|
||||||
|
ASSERT_TRUE(model.Build(GraphCell(graph), context, train_cfg) == kSuccess);
|
||||||
|
std::vector<mindspore::MSTensor> changes;
|
||||||
|
ASSERT_TRUE(model.UpdateWeights(changes) != kSuccess);
|
||||||
|
changes.push_back(
|
||||||
|
*MSTensor::CreateTensor("fc4.weight", mindspore::DataType::kNumberTypeFloat32, {NUM_OF_CLASSES}, nullptr, 0));
|
||||||
|
ASSERT_TRUE(model.UpdateWeights(changes) != kSuccess);
|
||||||
|
changes.clear();
|
||||||
|
changes.push_back(
|
||||||
|
*MSTensor::CreateTensor("fc3.bias", mindspore::DataType::kNumberTypeFloat32, {NUM_OF_CLASSES}, nullptr, 0));
|
||||||
|
ASSERT_TRUE(model.UpdateWeights(changes) == kSuccess);
|
||||||
|
}
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
Loading…
Reference in New Issue