forked from OSSInnovation/mindspore
!7375 Change normalize input check
Merge pull request !7375 from EricZ/change_normalize_check
This commit is contained in:
commit
70bd405338
|
@ -518,13 +518,6 @@ bool NormalizeOperation::ValidateParams() {
|
|||
MS_LOG(ERROR) << "Normalize: mean vector has incorrect size: " << mean_.size();
|
||||
return false;
|
||||
}
|
||||
// check mean value
|
||||
for (int i = 0; i < mean_.size(); ++i) {
|
||||
if (mean_[i] < 0.0f || mean_[i] > 255.0f || CmpFloat(mean_[i], 0.0f)) {
|
||||
MS_LOG(ERROR) << "Normalize: mean vector has incorrect value: " << mean_[i];
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (std_.size() != 3) {
|
||||
MS_LOG(ERROR) << "Normalize: std vector has incorrect size: " << std_.size();
|
||||
return false;
|
||||
|
|
|
@ -90,11 +90,9 @@ AlbumOp::AlbumOp(int32_t num_wkrs, int32_t rows_per_buffer, std::string file_dir
|
|||
// Helper function for string comparison
|
||||
// album sorts the files via numerical values, so this is not a simple string comparison
|
||||
bool StrComp(const std::string &a, const std::string &b) {
|
||||
// returns 1 if string a represent a numeric value
|
||||
// less than string b
|
||||
// quite similar to strcmp operation
|
||||
// returns 1 if string "a" represent a numeric value less than string "b"
|
||||
// the following will always return name, provided there is only one "." character in name
|
||||
// "." character is guranteed since the extension is checked befor this function call.
|
||||
// "." character is guaranteed to exist since the extension is checked befor this function call.
|
||||
int64_t value_a = std::atoi(a.substr(1, a.find(".")).c_str());
|
||||
int64_t value_b = std::atoi(b.substr(1, b.find(".")).c_str());
|
||||
return value_a < value_b;
|
||||
|
|
|
@ -130,12 +130,12 @@ class AlbumOp : public ParallelOp, public RandomAccessOp {
|
|||
}
|
||||
|
||||
/// \brief Check validity of input args
|
||||
/// \return - The error code return
|
||||
/// \return - The error code returned
|
||||
Status SanityCheck();
|
||||
|
||||
/// \brief The builder "build" method creates the final object.
|
||||
/// \param[inout] std::shared_ptr<AlbumOp> *op - DatasetOp
|
||||
/// \return - The error code return
|
||||
/// \return - The error code returned
|
||||
Status Build(std::shared_ptr<AlbumOp> *op);
|
||||
|
||||
private:
|
||||
|
@ -167,18 +167,18 @@ class AlbumOp : public ParallelOp, public RandomAccessOp {
|
|||
~AlbumOp() = default;
|
||||
|
||||
/// \brief Initialize AlbumOp related var, calls the function to walk all files
|
||||
/// \return - The error code return
|
||||
/// \return - The error code returned
|
||||
Status PrescanEntry();
|
||||
|
||||
/// \brief Worker thread pulls a number of IOBlock from IOBlock Queue, make a buffer and push it to Connector
|
||||
/// \param[in] int32_t workerId - id of each worker
|
||||
/// \return Status - The error code return
|
||||
/// \return Status - The error code returned
|
||||
Status WorkerEntry(int32_t worker_id) override;
|
||||
|
||||
/// \brief Main Loop of AlbumOp
|
||||
/// Master thread: Fill IOBlockQueue, then goes to sleep
|
||||
/// Worker thread: pulls IOBlock from IOBlockQueue, work on it then put buffer to mOutConnector
|
||||
/// \return Status - The error code return
|
||||
/// \return Status - The error code returned
|
||||
Status operator()() override;
|
||||
|
||||
/// \brief A print method typically used for debugging
|
||||
|
@ -188,7 +188,7 @@ class AlbumOp : public ParallelOp, public RandomAccessOp {
|
|||
|
||||
/// \brief Check if image ia valid.Only support JPEG/PNG/GIF/BMP
|
||||
/// This function could be optimized to return the tensor to reduce open/closing files
|
||||
/// \return Status - The error code return
|
||||
/// \return Status - The error code returned
|
||||
Status CheckImageType(const std::string &file_name, bool *valid);
|
||||
|
||||
// Base-class override for NodePass visitor acceptor.
|
||||
|
@ -203,84 +203,84 @@ class AlbumOp : public ParallelOp, public RandomAccessOp {
|
|||
|
||||
private:
|
||||
/// \brief Initialize Sampler, calls sampler->Init() within
|
||||
/// \return Status The error code return
|
||||
/// \return Status The error code returned
|
||||
Status InitSampler();
|
||||
|
||||
/// \brief Load image to tensor row
|
||||
/// \param[in] image_file Image name of file
|
||||
/// \param[in] col_num Column num in schema
|
||||
/// \param[inout] row Tensor row to push to
|
||||
/// \return Status The error code return
|
||||
/// \return Status The error code returned
|
||||
Status LoadImageTensor(const std::string &image_file, uint32_t col_num, TensorRow *row);
|
||||
|
||||
/// \brief Load vector of ints to tensor, append tensor to tensor row
|
||||
/// \param[in] json_obj Json object containing multi-dimensional label
|
||||
/// \param[in] col_num Column num in schema
|
||||
/// \param[inout] row Tensor row to push to
|
||||
/// \return Status The error code return
|
||||
/// \return Status The error code returned
|
||||
Status LoadIntArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);
|
||||
|
||||
/// \brief Load vector of floatss to tensor, append tensor to tensor row
|
||||
/// \param[in] json_obj Json object containing array data
|
||||
/// \param[in] col_num Column num in schema
|
||||
/// \param[inout] row Tensor row to push to
|
||||
/// \return Status The error code return
|
||||
/// \return Status The error code returned
|
||||
Status LoadFloatArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);
|
||||
|
||||
/// \brief Load string array into a tensor, append tensor to tensor row
|
||||
/// \param[in] json_obj Json object containing string tensor
|
||||
/// \param[in] col_num Column num in schema
|
||||
/// \param[inout] row Tensor row to push to
|
||||
/// \return Status The error code return
|
||||
/// \return Status The error code returned
|
||||
Status LoadStringArrayTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);
|
||||
|
||||
/// \brief Load string into a tensor, append tensor to tensor row
|
||||
/// \param[in] json_obj Json object containing string tensor
|
||||
/// \param[in] col_num Column num in schema
|
||||
/// \param[inout] row Tensor row to push to
|
||||
/// \return Status The error code return
|
||||
/// \return Status The error code returned
|
||||
Status LoadStringTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);
|
||||
|
||||
/// \brief Load float value to tensor row
|
||||
/// \param[in] json_obj Json object containing float
|
||||
/// \param[in] col_num Column num in schema
|
||||
/// \param[inout] row Tensor row to push to
|
||||
/// \return Status The error code return
|
||||
/// \return Status The error code returned
|
||||
Status LoadFloatTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);
|
||||
|
||||
/// \brief Load int value to tensor row
|
||||
/// \param[in] json_obj Json object containing int
|
||||
/// \param[in] col_num Column num in schema
|
||||
/// \param[inout] row Tensor row to push to
|
||||
/// \return Status The error code return
|
||||
/// \return Status The error code returned
|
||||
Status LoadIntTensor(const nlohmann::json &json_obj, uint32_t col_num, TensorRow *row);
|
||||
|
||||
/// \brief Load emtpy tensor to tensor row
|
||||
/// \param[in] col_num Column num in schema
|
||||
/// \param[inout] row Tensor row to push to
|
||||
/// \return Status The error code return
|
||||
/// \return Status The error code returned
|
||||
Status LoadEmptyTensor(uint32_t col_num, TensorRow *row);
|
||||
|
||||
/// \brief Load id from file name to tensor row
|
||||
/// \param[in] file The file name to get ID from
|
||||
/// \param[in] col_num Column num in schema
|
||||
/// \param[inout] row Tensor row to push to
|
||||
/// \return Status The error code return
|
||||
/// \return Status The error code returned
|
||||
Status LoadIDTensor(const std::string &file, uint32_t col_num, TensorRow *row);
|
||||
|
||||
/// \brief Load a tensor row according to a json file
|
||||
/// \param[in] ImageColumns file Json file location
|
||||
/// \param[inout] TensorRow row Json content stored into a tensor row
|
||||
/// \return Status The error code return
|
||||
/// \return Status The error code returned
|
||||
Status LoadTensorRow(const std::string &file, TensorRow *row);
|
||||
|
||||
/// \param[in] const std::vector<int64_t> &keys Keys in ioblock
|
||||
/// \param[inout] std::unique_ptr<DataBuffer> db Databuffer to push to
|
||||
/// \return Status The error code return
|
||||
/// \return Status The error code returned
|
||||
Status LoadBuffer(const std::vector<int64_t> &keys, std::unique_ptr<DataBuffer> *db);
|
||||
|
||||
/// \brief Called first when function is called
|
||||
/// \return The error code return
|
||||
/// \return Status The error code returned
|
||||
Status LaunchThreadsAndInitOp();
|
||||
|
||||
/// \brief reset Op
|
||||
|
@ -288,7 +288,7 @@ class AlbumOp : public ParallelOp, public RandomAccessOp {
|
|||
Status Reset() override;
|
||||
|
||||
// Private function for computing the assignment of the column name map.
|
||||
// @return - Status
|
||||
// @return Status The error code returned
|
||||
Status ComputeColMap() override;
|
||||
|
||||
int32_t rows_per_buffer_;
|
||||
|
|
|
@ -723,15 +723,9 @@ TEST_F(MindDataTestPipeline, TestNormalize) {
|
|||
TEST_F(MindDataTestPipeline, TestNormalizeFail) {
|
||||
MS_LOG(INFO) << "Doing MindDataTestPipeline-TestNormalizeFail with invalid parameters.";
|
||||
|
||||
// mean value 0.0
|
||||
std::shared_ptr<TensorOperation> normalize =
|
||||
mindspore::dataset::api::vision::Normalize({0.0, 115.0, 100.0}, {70.0, 68.0, 71.0});
|
||||
EXPECT_EQ(normalize, nullptr);
|
||||
// std value at 0.0
|
||||
normalize = mindspore::dataset::api::vision::Normalize({121.0, 115.0, 100.0}, {0.0, 68.0, 71.0});
|
||||
EXPECT_EQ(normalize, nullptr);
|
||||
// mean value 300.0 greater than 255.0
|
||||
normalize = mindspore::dataset::api::vision::Normalize({300.0, 115.0, 100.0}, {70.0, 68.0, 71.0});
|
||||
std::shared_ptr<TensorOperation> normalize =
|
||||
mindspore::dataset::api::vision::Normalize({121.0, 115.0, 100.0}, {0.0, 68.0, 71.0});
|
||||
EXPECT_EQ(normalize, nullptr);
|
||||
// normalize with 2 values (not 3 values) for mean
|
||||
normalize = mindspore::dataset::api::vision::Normalize({121.0, 115.0}, {70.0, 68.0, 71.0});
|
||||
|
|
Loading…
Reference in New Issue