forked from mindspore-Ecosystem/mindspore
syncronize issues
This commit is contained in:
parent
258afe86f9
commit
cd3daba7cb
|
@ -282,9 +282,9 @@ void Context::SetNPUFrequency(const std::shared_ptr<Context> &context, int freq)
|
|||
}
|
||||
auto iter = context->context_.find(kNPUFrequency);
|
||||
if (iter != context->context_.end()) {
|
||||
iter->second = true;
|
||||
iter->second = freq;
|
||||
} else {
|
||||
context->context_.emplace(kNPUFrequency, true);
|
||||
context->context_.emplace(kNPUFrequency, freq);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -48,17 +48,12 @@ Model::Model(const GraphCell &graph, const std::shared_ptr<Context> &model_conte
|
|||
impl_ = std::shared_ptr<ModelImpl>(new (std::nothrow) ModelImpl());
|
||||
if (impl_ == nullptr || graph.GetGraph() == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid graph.";
|
||||
} else if (model_context == nullptr) {
|
||||
MS_LOG(ERROR) << "Invalid context.";
|
||||
} else {
|
||||
if (model_context == nullptr) {
|
||||
MS_LOG(INFO) << "Invalid context, use default context.";
|
||||
auto context = std::shared_ptr<Context>(new (std::nothrow) Context());
|
||||
Context::SetAsDefault(context);
|
||||
impl_->SetContext(context);
|
||||
} else {
|
||||
impl_->SetContext(model_context);
|
||||
}
|
||||
auto new_graph_cell = std::shared_ptr<GraphCell>(new (std::nothrow) GraphCell(graph));
|
||||
if (new_graph_cell != nullptr) {
|
||||
impl_->SetContext(model_context);
|
||||
impl_->SetGraphCell(new_graph_cell);
|
||||
} else {
|
||||
MS_LOG(ERROR) << "New graphcell failed.";
|
||||
|
|
|
@ -71,7 +71,8 @@ Status ModelImpl::Build() {
|
|||
model_context.thread_num_ = Context::GetThreadNum(context_);
|
||||
model_context.device_list_.clear();
|
||||
if (Context::IfCPUEnabled(context_) && Context::IfGPUEnabled(context_) && Context::IfNPUEnabled(context_)) {
|
||||
MS_LOG(INFO) << "CPU/GPU/NPU cannot be enabled at the same time.";
|
||||
MS_LOG(ERROR) << "CPU/GPU/NPU cannot be enabled at the same time.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
if (!Context::IfCPUEnabled(context_)) {
|
||||
MS_LOG(INFO) << "CPU is forced to be enabled.";
|
||||
|
@ -155,6 +156,7 @@ Status ModelImpl::Predict(const std::vector<MSTensor> &inputs, std::vector<MSTen
|
|||
MS_LOG(DEBUG) << "Empty outputs.";
|
||||
return kLiteError;
|
||||
}
|
||||
outputs->clear();
|
||||
outputs->insert(outputs->end(), res.begin(), res.end());
|
||||
return kSuccess;
|
||||
}
|
||||
|
@ -167,8 +169,13 @@ std::vector<MSTensor> ModelImpl::GetInputs() {
|
|||
}
|
||||
std::vector<MSTensor> res;
|
||||
auto inputs = session_->GetInputs();
|
||||
for (auto input : inputs) {
|
||||
auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(input));
|
||||
if (inputs.empty()) {
|
||||
MS_LOG(ERROR) << "The inputs of model is null.";
|
||||
return empty;
|
||||
}
|
||||
res.resize(inputs.size());
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(inputs[i]));
|
||||
if (impl == nullptr) {
|
||||
MS_LOG(ERROR) << "Create tensor failed.";
|
||||
return empty;
|
||||
|
@ -178,7 +185,7 @@ std::vector<MSTensor> ModelImpl::GetInputs() {
|
|||
MS_LOG(ERROR) << "Create tensor failed.";
|
||||
return empty;
|
||||
}
|
||||
res.push_back(tensor);
|
||||
res[i] = tensor;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
@ -191,9 +198,22 @@ std::vector<MSTensor> ModelImpl::GetOutputs() {
|
|||
}
|
||||
std::vector<MSTensor> res;
|
||||
auto names = session_->GetOutputTensorNames();
|
||||
if (names.empty()) {
|
||||
MS_LOG(ERROR) << "The names of model is null.";
|
||||
return empty;
|
||||
}
|
||||
auto outputs = session_->GetOutputs();
|
||||
for (auto name : names) {
|
||||
auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(outputs[name]));
|
||||
if (outputs.empty()) {
|
||||
MS_LOG(ERROR) << "The outputs of model is null.";
|
||||
return empty;
|
||||
}
|
||||
if (names.size() != outputs.size()) {
|
||||
MS_LOG(ERROR) << "The size of outputs dose not match the size of names.";
|
||||
return empty;
|
||||
}
|
||||
res.resize(names.size());
|
||||
for (size_t i = 0; i < names.size(); i++) {
|
||||
auto impl = std::shared_ptr<MSTensor::Impl>(new (std::nothrow) MSTensor::Impl(outputs[names[i]]));
|
||||
if (impl == nullptr) {
|
||||
MS_LOG(ERROR) << "Create tensor failed.";
|
||||
return empty;
|
||||
|
@ -203,7 +223,7 @@ std::vector<MSTensor> ModelImpl::GetOutputs() {
|
|||
MS_LOG(ERROR) << "Create tensor failed.";
|
||||
return empty;
|
||||
}
|
||||
res.push_back(tensor);
|
||||
res[i] = tensor;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
@ -213,26 +233,44 @@ Status ModelImpl::Resize(const std::vector<MSTensor> &inputs, const std::vector<
|
|||
MS_LOG(ERROR) << "Session is null.";
|
||||
return kLiteNullptr;
|
||||
}
|
||||
if (inputs.empty()) {
|
||||
MS_LOG(ERROR) << "Inputs is null.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
if (dims.empty()) {
|
||||
MS_LOG(ERROR) << "Dims is null.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
if (inputs.size() != dims.size()) {
|
||||
MS_LOG(ERROR) << "The size of inputs is not equal to the size of dims.";
|
||||
MS_LOG(ERROR) << "The size of inputs does not match the size of dims.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
auto model_inputs = session_->GetInputs();
|
||||
if (model_inputs.empty()) {
|
||||
MS_LOG(ERROR) << "The inputs of model is null.";
|
||||
return kLiteParamInvalid;
|
||||
}
|
||||
if (inputs.size() != model_inputs.size()) {
|
||||
MS_LOG(ERROR) << "The size of inputs is incorrect.";
|
||||
return kLiteInputParamInvalid;
|
||||
}
|
||||
std::vector<tensor::MSTensor *> inner_input;
|
||||
for (auto input : inputs) {
|
||||
inner_input.resize(inputs.size());
|
||||
std::vector<std::vector<int32_t>> truncated_shape;
|
||||
truncated_shape.resize(inputs.size());
|
||||
for (size_t i = 0; i < inputs.size(); i++) {
|
||||
auto input = inputs[i];
|
||||
if (input.impl_ == nullptr || input.impl_->lite_tensor() == nullptr) {
|
||||
MS_LOG(ERROR) << "Input tensor " << input.Name() << " is null.";
|
||||
return kLiteInputTensorError;
|
||||
}
|
||||
inner_input.push_back(input.impl_->lite_tensor());
|
||||
}
|
||||
std::vector<std::vector<int32_t>> truncated_shape;
|
||||
for (size_t i = 0; i < inner_input.size(); i++) {
|
||||
std::vector<int32_t> tmp = TruncateShape(dims.at(i), inner_input.at(i)->data_type(), inner_input.at(i)->Size());
|
||||
if (tmp.empty()) {
|
||||
MS_LOG(ERROR) << "Input dims[" << i << "]is invalid.";
|
||||
inner_input[i] = input.impl_->lite_tensor();
|
||||
std::vector<int32_t> shape = TruncateShape(dims[i], inner_input[i]->data_type(), inner_input[i]->Size(), false);
|
||||
if (shape.empty() && !(dims[i].empty())) {
|
||||
MS_LOG(ERROR) << "Input dims[" << i << "] is invalid.";
|
||||
return kLiteParamInvalid;
|
||||
}
|
||||
truncated_shape.push_back(tmp);
|
||||
truncated_shape[i] = shape;
|
||||
}
|
||||
auto ret = session_->Resize(inner_input, truncated_shape);
|
||||
return static_cast<StatusCode>(ret);
|
||||
|
|
|
@ -28,11 +28,11 @@
|
|||
namespace mindspore {
|
||||
MSTensor::Impl::Impl(const std::string &name, enum DataType type, const std::vector<int64_t> &shape, const void *data,
|
||||
size_t data_len) {
|
||||
std::vector<int32_t> truncated_shape = TruncateShape(shape, static_cast<enum TypeId>(type), data_len);
|
||||
if (!truncated_shape.empty()) {
|
||||
lite_tensor_ = new (std::nothrow) lite::Tensor(name, static_cast<enum TypeId>(type), truncated_shape, data);
|
||||
} else {
|
||||
std::vector<int32_t> truncated_shape = TruncateShape(shape, static_cast<enum TypeId>(type), data_len, true);
|
||||
if (truncated_shape.empty() && !(shape.empty())) {
|
||||
lite_tensor_ = nullptr;
|
||||
} else {
|
||||
lite_tensor_ = new (std::nothrow) lite::Tensor(name, static_cast<enum TypeId>(type), truncated_shape, data);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -18,22 +18,30 @@
|
|||
#include "src/tensor.h"
|
||||
|
||||
namespace mindspore {
|
||||
static std::vector<int32_t> TruncateShape(const std::vector<int64_t> &shape, enum TypeId type, size_t data_len) {
|
||||
static std::vector<int32_t> TruncateShape(const std::vector<int64_t> &shape, enum TypeId type, size_t data_len,
|
||||
bool verify_size) {
|
||||
std::vector<int32_t> empty;
|
||||
if (shape.empty()) {
|
||||
return empty;
|
||||
}
|
||||
std::vector<int32_t> truncated_shape;
|
||||
truncated_shape.resize(shape.size());
|
||||
size_t element_size = lite::DataTypeSize(type);
|
||||
for (auto i : shape) {
|
||||
if (i < 0 || i > INT_MAX || element_size > INT_MAX / static_cast<size_t>(i)) {
|
||||
for (size_t i = 0; i < shape.size(); i++) {
|
||||
auto dim = shape[i];
|
||||
if (dim < 0 || dim > INT_MAX || element_size > INT_MAX / static_cast<size_t>(dim)) {
|
||||
MS_LOG(ERROR) << "Invalid shape.";
|
||||
return empty;
|
||||
} else {
|
||||
element_size *= static_cast<size_t>(i);
|
||||
truncated_shape.push_back(static_cast<int32_t>(i));
|
||||
element_size *= static_cast<size_t>(dim);
|
||||
truncated_shape[i] = static_cast<int32_t>(dim);
|
||||
}
|
||||
}
|
||||
if (element_size != data_len) {
|
||||
MS_LOG(ERROR) << "Invalid data size.";
|
||||
return empty;
|
||||
if (verify_size) {
|
||||
if (element_size != data_len) {
|
||||
MS_LOG(ERROR) << "Invalid data size.";
|
||||
return empty;
|
||||
}
|
||||
}
|
||||
return truncated_shape;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue