!10044 fix get num classes of concat

From: @luoyang42
Reviewed-by: @liucunwei,@heleiwang,@liucunwei
Signed-off-by: @liucunwei,@liucunwei
This commit is contained in:
mindspore-ci-bot 2020-12-17 19:30:50 +08:00 committed by Gitee
commit 6035823c4b
3 changed files with 35 additions and 0 deletions

View File

@ -196,5 +196,20 @@ Status ConcatOp::PreAccept(NodePass *p, bool *modified) {
return p->PreRunOnNode(shared_from_base<ConcatOp>(), modified);
}
// Gets the number of classes
Status ConcatOp::GetNumClasses(int64_t *num_classes) {
int64_t max_num_classes = -1;
for (const auto &child : child_) {
// Choose a dataset which can get valid num_classes
int64_t tmp_num_classes = -1;
child->GetNumClasses(&tmp_num_classes);
if (tmp_num_classes > max_num_classes) {
max_num_classes = tmp_num_classes;
}
}
*num_classes = max_num_classes;
return Status::OK();
}
} // namespace dataset
} // namespace mindspore

View File

@ -111,6 +111,11 @@ class ConcatOp : public PipelineOp {
/// \return Status of the node visit
Status PreAccept(NodePass *p, bool *modified) override;
/// \brief Gets the number of classes
/// \param[out] num_classes the number of classes
/// \return Status - The status code return
Status GetNumClasses(int64_t *num_classes) override;
private:
Status Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf);

View File

@ -113,6 +113,20 @@ def test_manifest_dataset_multi_label_onehot():
count = count + 1
def test_manifest_dataset_get_num_class():
data = ds.ManifestDataset(DATA_FILE, decode=True, shuffle=False)
assert data.num_classes() == 3
padded_samples = [{'image': np.zeros(1, np.uint8), 'label': np.array(1, np.int32)}]
padded_ds = ds.PaddedDataset(padded_samples)
data = data.repeat(2)
padded_ds = padded_ds.repeat(2)
data1 = data + padded_ds
assert data1.num_classes() == 3
if __name__ == '__main__':
test_manifest_dataset_train()
test_manifest_dataset_eval()
@ -120,3 +134,4 @@ if __name__ == '__main__':
test_manifest_dataset_get_class_index()
test_manifest_dataset_multi_label()
test_manifest_dataset_multi_label_onehot()
test_manifest_dataset_get_num_class()