!23794 Fix minddata code clean

Merge pull request !23794 from xiefangqi/md_fix_code_clean_sep
This commit is contained in:
i-robot 2021-09-23 01:40:36 +00:00 committed by Gitee
commit 32093585a9
10 changed files with 103 additions and 46 deletions

View File

@ -215,7 +215,10 @@ Execute::~Execute() {
#ifdef ENABLE_ACL
if (device_type_ == MapTargetDevice::kAscend310) {
if (device_resource_) {
device_resource_->FinalizeResource();
auto rc = device_resource_->FinalizeResource();
if (rc.IsError()) {
MS_LOG(ERROR) << "Device resource release failed, error msg is " << rc;
}
} else {
MS_LOG(ERROR) << "Device resource is nullptr which is illegal under case Ascend310";
}
@ -525,7 +528,11 @@ std::string Execute::AippCfgGenerator() {
std::string config_location = "./aipp.cfg";
#ifdef ENABLE_ACL
if (info_->init_with_shared_ptr_) {
ParseTransforms();
auto rc = ParseTransforms();
if (rc.IsError()) {
MS_LOG(ERROR) << "Parse transforms failed, error msg is " << rc;
return "";
}
info_->init_with_shared_ptr_ = false;
}
std::vector<uint32_t> paras; // Record the parameters value of each Ascend operators
@ -538,7 +545,11 @@ std::string Execute::AippCfgGenerator() {
}
// Define map between operator name and parameter name
ops_[i]->to_json(&ir_info);
auto rc = ops_[i]->to_json(&ir_info);
if (rc.IsError()) {
MS_LOG(ERROR) << "IR information serialize to json failed, error msg is " << rc;
return "";
}
// Collect the information of operators
for (auto pos = info_->op2para_map_.equal_range(ops_[i]->Name()); pos.first != pos.second; ++pos.first) {
@ -601,7 +612,11 @@ std::string Execute::AippCfgGenerator() {
std::vector<float> aipp_std = AippStdFilter(normalize_paras);
std::map<std::string, std::string> aipp_options;
AippInfoCollection(&aipp_options, aipp_size, aipp_mean, aipp_std);
auto rc = AippInfoCollection(&aipp_options, aipp_size, aipp_mean, aipp_std);
if (rc.IsError()) {
MS_LOG(ERROR) << "Aipp information initialization failed, error msg is " << rc;
return "";
}
std::string tab_char(4, ' ');
outfile << "aipp_op {" << std::endl;

View File

@ -57,13 +57,13 @@ PYBIND_REGISTER(
return out;
})
.def("get_edges_from_nodes",
[](gnn::GraphData &g, std::vector<std::pair<gnn::NodeIdType, gnn::NodeIdType>> node_list) {
[](gnn::GraphData &g, const std::vector<std::pair<gnn::NodeIdType, gnn::NodeIdType>> &node_list) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetEdgesFromNodes(node_list, &out));
return out;
})
.def("get_all_neighbors",
[](gnn::GraphData &g, std::vector<gnn::NodeIdType> node_list, gnn::NodeType neighbor_type,
[](gnn::GraphData &g, const std::vector<gnn::NodeIdType> &node_list, gnn::NodeType neighbor_type,
OutputFormat format) {
std::shared_ptr<Tensor> out;
THROW_IF_ERROR(g.GetAllNeighbors(node_list, neighbor_type, format, &out));

View File

@ -93,7 +93,7 @@ PYBIND_REGISTER(CityscapesNode, 2, ([](const py::module *m) {
(void)py::class_<CityscapesNode, DatasetNode, std::shared_ptr<CityscapesNode>>(
*m, "CityscapesNode", "to create a CityscapesNode")
.def(py::init([](std::string dataset_dir, std::string usage, std::string quality_mode,
std::string task, bool decode, py::handle sampler) {
std::string task, bool decode, const py::handle &sampler) {
auto cityscapes = std::make_shared<CityscapesNode>(dataset_dir, usage, quality_mode, task, decode,
toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(cityscapes->ValidateParams());
@ -118,7 +118,7 @@ PYBIND_REGISTER(CocoNode, 2, ([](const py::module *m) {
(void)py::class_<CocoNode, DatasetNode, std::shared_ptr<CocoNode>>(*m, "CocoNode",
"to create a CocoNode")
.def(py::init([](std::string dataset_dir, std::string annotation_file, std::string task,
bool decode, py::handle sampler, bool extra_metadata) {
bool decode, const py::handle &sampler, bool extra_metadata) {
std::shared_ptr<CocoNode> coco = std::make_shared<CocoNode>(
dataset_dir, annotation_file, task, decode, toSamplerObj(sampler), nullptr, extra_metadata);
THROW_IF_ERROR(coco->ValidateParams());
@ -154,7 +154,7 @@ PYBIND_REGISTER(DIV2KNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(
FlickrNode, 2, ([](const py::module *m) {
(void)py::class_<FlickrNode, DatasetNode, std::shared_ptr<FlickrNode>>(*m, "FlickrNode", "to create a FlickrNode")
.def(py::init([](std::string dataset_dir, std::string annotation_file, bool decode, py::handle sampler) {
.def(py::init([](std::string dataset_dir, std::string annotation_file, bool decode, const py::handle &sampler) {
auto flickr =
std::make_shared<FlickrNode>(dataset_dir, annotation_file, decode, toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(flickr->ValidateParams());
@ -213,7 +213,7 @@ PYBIND_REGISTER(MindDataNode, 2, ([](const py::module *m) {
(void)py::class_<MindDataNode, DatasetNode, std::shared_ptr<MindDataNode>>(*m, "MindDataNode",
"to create a MindDataNode")
.def(py::init([](std::string dataset_file, py::list columns_list, py::handle sampler,
py::dict padded_sample, int64_t num_padded, ShuffleMode shuffle_mode) {
const py::dict &padded_sample, int64_t num_padded, ShuffleMode shuffle_mode) {
nlohmann::json padded_sample_json;
std::map<std::string, std::string> sample_bytes;
THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes));
@ -225,7 +225,7 @@ PYBIND_REGISTER(MindDataNode, 2, ([](const py::module *m) {
return minddata;
}))
.def(py::init([](py::list dataset_file, py::list columns_list, py::handle sampler,
py::dict padded_sample, int64_t num_padded, ShuffleMode shuffle_mode) {
const py::dict &padded_sample, int64_t num_padded, ShuffleMode shuffle_mode) {
nlohmann::json padded_sample_json;
std::map<std::string, std::string> sample_bytes;
THROW_IF_ERROR(ToJson(padded_sample, &padded_sample_json, &sample_bytes));
@ -268,7 +268,7 @@ PYBIND_REGISTER(RandomNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(SBUNode, 2, ([](const py::module *m) {
(void)py::class_<SBUNode, DatasetNode, std::shared_ptr<SBUNode>>(*m, "SBUNode",
"to create an SBUNode")
.def(py::init([](std::string dataset_dir, bool decode, py::handle sampler) {
.def(py::init([](std::string dataset_dir, bool decode, const py::handle &sampler) {
auto sbu = std::make_shared<SBUNode>(dataset_dir, decode, toSamplerObj(sampler), nullptr);
THROW_IF_ERROR(sbu->ValidateParams());
return sbu;
@ -326,7 +326,8 @@ PYBIND_REGISTER(USPSNode, 2, ([](const py::module *m) {
PYBIND_REGISTER(VOCNode, 2, ([](const py::module *m) {
(void)py::class_<VOCNode, DatasetNode, std::shared_ptr<VOCNode>>(*m, "VOCNode", "to create a VOCNode")
.def(py::init([](std::string dataset_dir, std::string task, std::string usage,
py::dict class_indexing, bool decode, py::handle sampler, bool extra_metadata) {
const py::dict &class_indexing, bool decode, const py::handle &sampler,
bool extra_metadata) {
std::shared_ptr<VOCNode> voc =
std::make_shared<VOCNode>(dataset_dir, task, usage, toStringMap(class_indexing), decode,
toSamplerObj(sampler), nullptr, extra_metadata);

View File

@ -32,7 +32,7 @@ Status DCShiftOp::Compute(const std::shared_ptr<Tensor> &input, std::shared_ptr<
return DCShift<double>(input, output, shift_, limiter_gain_);
} else {
std::shared_ptr<Tensor> tmp;
TypeCast(input, &tmp, DataType(DataType::DE_FLOAT32));
RETURN_IF_NOT_OK(TypeCast(input, &tmp, DataType(DataType::DE_FLOAT32)));
return DCShift<float>(tmp, output, shift_, limiter_gain_);
}
}

View File

@ -530,9 +530,7 @@ Status Tensor::GetItemPtr(uchar **ptr, const std::vector<dsize_t> &index, offset
RETURN_IF_NOT_OK(shape_.ToFlatIndex(index, &flat_idx));
offset_t length_temp = 0;
RETURN_IF_NOT_OK(GetStringAt(flat_idx, ptr, &length_temp));
if (length != nullptr) {
*length = length_temp;
}
*length = length_temp;
return Status::OK();
} else {
std::string err = "data type not compatible";

View File

@ -164,7 +164,7 @@ Status AlbumNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetNode
std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<AlbumNode>(dataset_dir, data_schema, column_names, decode, sampler, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
(void)((*ds)->SetNumWorkers(json_obj["num_parallel_workers"]));
return Status::OK();
}
#endif

View File

@ -175,7 +175,7 @@ Status TextFileNode::from_json(nlohmann::json json_obj, std::shared_ptr<DatasetN
std::shared_ptr<DatasetCache> cache = nullptr;
RETURN_IF_NOT_OK(DatasetCache::from_json(json_obj, &cache));
*ds = std::make_shared<TextFileNode>(dataset_files, num_samples, shuffle, num_shards, shard_id, cache);
(*ds)->SetNumWorkers(json_obj["num_parallel_workers"]);
(void)((*ds)->SetNumWorkers(json_obj["num_parallel_workers"]));
return Status::OK();
}

View File

@ -21,6 +21,7 @@
#include "minddata/dataset/include/dataset/constants.h"
#include "minddata/dataset/core/tensor_shape.h"
#include "minddata/dataset/kernels/image/image_utils.h"
#include "minddata/dataset/util/status.h"
namespace {
const int BUFFER_SIZE = 2048;
@ -39,7 +40,7 @@ mode_t SetFileDefaultUmask() { return umask(DEFAULT_FILE_PERMISSION); }
*/
MDAclProcess::MDAclProcess(uint32_t resizeWidth, uint32_t resizeHeight, uint32_t cropWidth, uint32_t cropHeight,
aclrtContext context, bool is_crop, aclrtStream stream,
std::shared_ptr<DvppCommon> dvppCommon)
const std::shared_ptr<DvppCommon> &dvppCommon)
: resizeWidth_(resizeWidth),
resizeHeight_(resizeHeight),
cropWidth_(cropWidth),
@ -51,7 +52,7 @@ MDAclProcess::MDAclProcess(uint32_t resizeWidth, uint32_t resizeHeight, uint32_t
processedInfo_(nullptr) {}
MDAclProcess::MDAclProcess(uint32_t ParaWidth, uint32_t ParaHeight, aclrtContext context, bool is_crop,
aclrtStream stream, std::shared_ptr<DvppCommon> dvppCommon)
aclrtStream stream, const std::shared_ptr<DvppCommon> &dvppCommon)
: contain_crop_(is_crop), context_(context), stream_(stream), dvppCommon_(dvppCommon), processedInfo_(nullptr) {
if (is_crop) {
resizeWidth_ = 0;
@ -67,7 +68,7 @@ MDAclProcess::MDAclProcess(uint32_t ParaWidth, uint32_t ParaHeight, aclrtContext
}
MDAclProcess::MDAclProcess(aclrtContext context, bool is_crop, aclrtStream stream,
std::shared_ptr<DvppCommon> dvppCommon)
const std::shared_ptr<DvppCommon> &dvppCommon)
: resizeWidth_(0),
resizeHeight_(0),
cropWidth_(0),
@ -192,9 +193,18 @@ APP_ERROR MDAclProcess::H2D_Sink(const std::shared_ptr<mindspore::dataset::Tenso
const mindspore::dataset::DataType dvpp_data_type(mindspore::dataset::DataType::DE_UINT8);
const mindspore::dataset::TensorShape dvpp_shape({1, 1, 1});
mindspore::dataset::DeviceTensor::CreateEmpty(dvpp_shape, dvpp_data_type, &device_input);
device_input->SetAttributes(deviceInputData->data, deviceInputData->dataSize, deviceInputData->width,
deviceInputData->widthStride, deviceInputData->height, deviceInputData->heightStride);
auto rc = mindspore::dataset::DeviceTensor::CreateEmpty(dvpp_shape, dvpp_data_type, &device_input);
if (rc.IsError()) {
MS_LOG(ERROR) << "Failed to allocate memory, error msg is " << rc;
return APP_ERR_ACL_BAD_ALLOC;
}
rc =
device_input->SetAttributes(deviceInputData->data, deviceInputData->dataSize, deviceInputData->width,
deviceInputData->widthStride, deviceInputData->height, deviceInputData->heightStride);
if (rc.IsError()) {
MS_LOG(ERROR) << "Failed to initialize device attribution, error msg is " << rc;
return APP_ERR_ACL_INVALID_PARAM;
}
return APP_ERR_OK;
}
@ -225,8 +235,16 @@ APP_ERROR MDAclProcess::D2H_Pop(const std::shared_ptr<mindspore::dataset::Device
uint32_t _output_height_ = device_output->GetYuvStrideShape()[2];
uint32_t _output_heightStride_ = device_output->GetYuvStrideShape()[3];
const mindspore::dataset::DataType dvpp_data_type(mindspore::dataset::DataType::DE_UINT8);
mindspore::dataset::Tensor::CreateFromMemory(dvpp_shape, dvpp_data_type, ret_ptr, &output);
output->SetYuvShape(_output_width_, _output_widthStride_, _output_height_, _output_heightStride_);
auto rc = mindspore::dataset::Tensor::CreateFromMemory(dvpp_shape, dvpp_data_type, ret_ptr, &output);
if (rc.IsError()) {
MS_LOG(ERROR) << "Failed to allocate memory, error msg is " << rc;
return APP_ERR_ACL_BAD_ALLOC;
}
rc = output->SetYuvShape(_output_width_, _output_widthStride_, _output_height_, _output_heightStride_);
if (rc.IsError()) {
MS_LOG(ERROR) << "Failed to set yuv shape, error msg is " << rc;
return APP_ERR_ACL_INVALID_PARAM;
}
if (!output->HasData()) {
return APP_ERR_COMM_ALLOC_MEM;
}
@ -399,7 +417,11 @@ APP_ERROR MDAclProcess::JPEG_R_(const DvppDataInfo &ImageInfo) {
uint32_t pri_w = decoded_image->widthStride;
// Define the resize shape
DvppDataInfo resizeOut;
ResizeConfigFilter(resizeOut, pri_w, pri_h);
ret = ResizeConfigFilter(resizeOut, pri_w, pri_h);
if (ret != APP_ERR_OK) {
MS_LOG(ERROR) << "Failed to config resize parameter, ret = " << ret << ".";
return ret;
}
ret = dvppCommon_->CombineResizeProcess(*decoded_image, resizeOut, true);
if (ret != APP_ERR_OK) {
MS_LOG(ERROR) << "Failed to process resize, ret = " << ret << ".";
@ -422,8 +444,12 @@ APP_ERROR MDAclProcess::JPEG_R_(std::string &last_step) {
uint32_t pri_h = input_image->heightStride;
uint32_t pri_w = input_image->widthStride;
DvppDataInfo resizeOut;
ResizeConfigFilter(resizeOut, pri_w, pri_h);
APP_ERROR ret = dvppCommon_->CombineResizeProcess(*input_image, resizeOut, true);
auto ret = ResizeConfigFilter(resizeOut, pri_w, pri_h);
if (ret != APP_ERR_OK) {
MS_LOG(ERROR) << "Failed to config resize, ret = " << ret << ".";
return ret;
}
ret = dvppCommon_->CombineResizeProcess(*input_image, resizeOut, true);
if (ret != APP_ERR_OK) {
MS_LOG(ERROR) << "Failed to process resize, ret = " << ret << ".";
return ret;
@ -760,7 +786,11 @@ APP_ERROR MDAclProcess::JPEG_DRC_(const RawData &ImageInfo) {
uint32_t pri_w = decodeOutData->widthStride;
// Define output of resize jpeg image
DvppDataInfo resizeOut;
ResizeConfigFilter(resizeOut, pri_w, pri_h);
ret = ResizeConfigFilter(resizeOut, pri_w, pri_h);
if (ret != APP_ERR_OK) {
MS_LOG(ERROR) << "Failed to config resize, ret = " << ret << ".";
return ret;
}
// Run resize application function
ret = dvppCommon_->CombineResizeProcess(*decodeOutData, resizeOut, true);
if (ret != APP_ERR_OK) {
@ -882,7 +912,11 @@ APP_ERROR MDAclProcess::JPEG_DR_(const RawData &ImageInfo) {
uint32_t pri_h = decodeOutData->heightStride;
uint32_t pri_w = decodeOutData->widthStride;
DvppDataInfo resizeOut;
ResizeConfigFilter(resizeOut, pri_w, pri_h);
ret = ResizeConfigFilter(resizeOut, pri_w, pri_h);
if (ret != APP_ERR_OK) {
MS_LOG(ERROR) << "Failed to config resize, ret = " << ret << ".";
return ret;
}
// Run resize application function
ret = dvppCommon_->CombineResizeProcess(*decodeOutData, resizeOut, true);
if (ret != APP_ERR_OK) {
@ -934,7 +968,8 @@ void MDAclProcess::CropConfigFilter(CropRoiConfig &cfg, DvppCropInputInfo &cropi
cropinfo.roi = cfg;
}
APP_ERROR MDAclProcess::ResizeConfigFilter(DvppDataInfo &resizeinfo, const uint32_t pri_w_, const uint32_t pri_h_) {
APP_ERROR MDAclProcess::ResizeConfigFilter(DvppDataInfo &resizeinfo, const uint32_t pri_w_,
const uint32_t pri_h_) const {
if (resizeWidth_ != 0) { // 如果输入参数个数为2按指定参数缩放
resizeinfo.width = resizeWidth_;
resizeinfo.widthStride = DVPP_ALIGN_UP(resizeWidth_, VPC_STRIDE_WIDTH);

View File

@ -43,13 +43,13 @@ class MDAclProcess {
public:
MDAclProcess(uint32_t resizeWidth, uint32_t resizeHeight, uint32_t cropWidth, uint32_t cropHeight,
aclrtContext context, bool is_crop = true, aclrtStream stream = nullptr,
std::shared_ptr<DvppCommon> dvppCommon = nullptr);
const std::shared_ptr<DvppCommon> &dvppCommon = nullptr);
MDAclProcess(uint32_t ParaWidth, uint32_t ParaHeight, aclrtContext context, bool is_crop = false,
aclrtStream stream = nullptr, std::shared_ptr<DvppCommon> dvppCommon = nullptr);
aclrtStream stream = nullptr, const std::shared_ptr<DvppCommon> &dvppCommon = nullptr);
MDAclProcess(aclrtContext context, bool is_crop = false, aclrtStream stream = nullptr,
std::shared_ptr<DvppCommon> dvppCommon = nullptr);
const std::shared_ptr<DvppCommon> &dvppCommon = nullptr);
~MDAclProcess(){};
@ -109,7 +109,7 @@ class MDAclProcess {
// Crop definition
void CropConfigFilter(CropRoiConfig &cfg, DvppCropInputInfo &cropinfo, DvppDataInfo &resizeinfo);
// Resize definition
APP_ERROR ResizeConfigFilter(DvppDataInfo &resizeinfo, const uint32_t pri_w_, const uint32_t pri_h_);
APP_ERROR ResizeConfigFilter(DvppDataInfo &resizeinfo, const uint32_t pri_w_, const uint32_t pri_h_) const;
// Initialize DVPP modules used by this sample
APP_ERROR InitModule();
// Dvpp process with crop

View File

@ -114,7 +114,7 @@ int calc_coeff(int input_size, int out_size, int input0, int input1, struct inte
for (; x < kernel_size; x++) {
coeff[x] = 0;
}
region[xx * 2 + 0] = x_min;
region[xx * 2] = x_min;
region[xx * 2 + 1] = x_max;
}
@ -150,12 +150,12 @@ Status ImagingHorizontalInterp(LiteMat &output, LiteMat input, int offset, int k
// obtain the ptr of output, and put calculated value into it
uint8_t *bgr_buf = output_ptr;
for (int xx = 0; xx < output.width_; xx++) {
int x_min = regions[xx * 2 + 0];
int x_min = regions[xx * 2];
int x_max = regions[xx * 2 + 1];
k = &kk[xx * kernel_size];
ss0 = ss1 = ss2 = 1 << (PrecisionBits - 1);
for (int x = 0; x < x_max; x++) {
ss0 += (input_ptr[(yy + offset) * input_width + (x + x_min) * 3 + 0]) * k[x];
ss0 += (input_ptr[(yy + offset) * input_width + (x + x_min) * 3]) * k[x];
ss1 += (input_ptr[(yy + offset) * input_width + (x + x_min) * 3 + 1]) * k[x];
ss2 += (input_ptr[(yy + offset) * input_width + (x + x_min) * 3 + 2]) * k[x];
}
@ -169,8 +169,8 @@ Status ImagingHorizontalInterp(LiteMat &output, LiteMat input, int offset, int k
return Status::OK();
}
Status ImagingVerticalInterp(LiteMat &output, LiteMat input, int offset, int kernel_size,
const std::vector<int> &regions, const std::vector<double> &prekk) {
Status ImagingVerticalInterp(LiteMat &output, LiteMat input, int kernel_size, const std::vector<int> &regions,
const std::vector<double> &prekk) {
int ss0, ss1, ss2;
// normalize previous calculated coefficients
@ -185,12 +185,12 @@ Status ImagingVerticalInterp(LiteMat &output, LiteMat input, int offset, int ker
// obtain the ptr of output, and put calculated value into it
uint8_t *bgr_buf = output_ptr;
int32_t *k = &kk[yy * kernel_size];
int y_min = regions[yy * 2 + 0];
int y_min = regions[yy * 2];
int y_max = regions[yy * 2 + 1];
for (int xx = 0; xx < output.width_; xx++) {
ss0 = ss1 = ss2 = 1 << (PrecisionBits - 1);
for (int y = 0; y < y_max; y++) {
ss0 += (input_ptr[(y + y_min) * input_width + xx * 3 + 0]) * k[y];
ss0 += (input_ptr[(y + y_min) * input_width + xx * 3]) * k[y];
ss1 += (input_ptr[(y + y_min) * input_width + xx * 3 + 1]) * k[y];
ss2 += (input_ptr[(y + y_min) * input_width + xx * 3 + 2]) * k[y];
}
@ -236,7 +236,11 @@ bool ImageInterpolation(LiteMat input, LiteMat &output, int x_size, int y_size,
}
temp.Init(x_size, rect_y1 - rect_y0, 3, LDataType::UINT8, false);
ImagingHorizontalInterp(temp, input, rect_y0, horiz_kernel, horiz_region, horiz_coeff);
auto rc = ImagingHorizontalInterp(temp, input, rect_y0, horiz_kernel, horiz_region, horiz_coeff);
if (rc.IsError()) {
MS_LOG(ERROR) << "Image horizontal resize failed, error msg is " << rc;
return false;
}
if (temp.IsEmpty()) {
return false;
}
@ -247,7 +251,11 @@ bool ImageInterpolation(LiteMat input, LiteMat &output, int x_size, int y_size,
if (vertical_interp) {
output.Init(input.width_, y_size, 3, LDataType::UINT8, false);
if (!output.IsEmpty()) {
ImagingVerticalInterp(output, input, 0, vert_kernel, vert_region, vert_coeff);
auto rc = ImagingVerticalInterp(output, input, vert_kernel, vert_region, vert_coeff);
if (rc.IsError()) {
MS_LOG(ERROR) << "Image vertical resize failed, error msg is " << rc;
return false;
}
}
if (output.IsEmpty()) {
return false;