forked from mindspore-Ecosystem/mindspore
!10044 fix get num classes of concat
From: @luoyang42 Reviewed-by: @liucunwei,@heleiwang,@liucunwei Signed-off-by: @liucunwei,@liucunwei
This commit is contained in:
commit
6035823c4b
|
@ -196,5 +196,20 @@ Status ConcatOp::PreAccept(NodePass *p, bool *modified) {
|
|||
return p->PreRunOnNode(shared_from_base<ConcatOp>(), modified);
|
||||
}
|
||||
|
||||
// Gets the number of classes
|
||||
Status ConcatOp::GetNumClasses(int64_t *num_classes) {
|
||||
int64_t max_num_classes = -1;
|
||||
for (const auto &child : child_) {
|
||||
// Choose a dataset which can get valid num_classes
|
||||
int64_t tmp_num_classes = -1;
|
||||
child->GetNumClasses(&tmp_num_classes);
|
||||
if (tmp_num_classes > max_num_classes) {
|
||||
max_num_classes = tmp_num_classes;
|
||||
}
|
||||
}
|
||||
*num_classes = max_num_classes;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace dataset
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -111,6 +111,11 @@ class ConcatOp : public PipelineOp {
|
|||
/// \return Status of the node visit
|
||||
Status PreAccept(NodePass *p, bool *modified) override;
|
||||
|
||||
/// \brief Gets the number of classes
|
||||
/// \param[out] num_classes the number of classes
|
||||
/// \return Status - The status code return
|
||||
Status GetNumClasses(int64_t *num_classes) override;
|
||||
|
||||
private:
|
||||
Status Verify(int32_t id, const std::unique_ptr<DataBuffer> &buf);
|
||||
|
||||
|
|
|
@ -113,6 +113,20 @@ def test_manifest_dataset_multi_label_onehot():
|
|||
count = count + 1
|
||||
|
||||
|
||||
def test_manifest_dataset_get_num_class():
|
||||
data = ds.ManifestDataset(DATA_FILE, decode=True, shuffle=False)
|
||||
assert data.num_classes() == 3
|
||||
|
||||
padded_samples = [{'image': np.zeros(1, np.uint8), 'label': np.array(1, np.int32)}]
|
||||
padded_ds = ds.PaddedDataset(padded_samples)
|
||||
|
||||
data = data.repeat(2)
|
||||
padded_ds = padded_ds.repeat(2)
|
||||
|
||||
data1 = data + padded_ds
|
||||
assert data1.num_classes() == 3
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_manifest_dataset_train()
|
||||
test_manifest_dataset_eval()
|
||||
|
@ -120,3 +134,4 @@ if __name__ == '__main__':
|
|||
test_manifest_dataset_get_class_index()
|
||||
test_manifest_dataset_multi_label()
|
||||
test_manifest_dataset_multi_label_onehot()
|
||||
test_manifest_dataset_get_num_class()
|
||||
|
|
Loading…
Reference in New Issue