!22342 [Auto parallel] Adjust the device matrix of OneHot operator

Merge pull request !22342 from Xiaoda/85-adapt-onehot
This commit is contained in:
i-robot 2021-08-25 09:02:44 +00:00 committed by Gitee
commit 9b28bd6308
3 changed files with 7 additions and 3 deletions

View File

@ -84,7 +84,11 @@ Status OneHotInfo::InferDevMatrixShape() {
}
}
old_dev_matrix_back_ = dev_matrix_shape_.back();
repeated_num_in_dev_matrix_right_ = false;
if (old_dev_matrix_back_ == 1) {
repeated_num_in_dev_matrix_right_ = true;
} else {
repeated_num_in_dev_matrix_right_ = false;
}
return SUCCESS;
}

View File

@ -83,7 +83,7 @@ TEST_F(TestOneHotInfo, InferDevMatrixShape2) {
ASSERT_EQ(status, SUCCESS);
Shape dev_matrix_shape = onehot_info->dev_matrix_shape();
Shape expect = {2, 4, 1};
Shape expect = {4, 1, 2};
ASSERT_EQ(dev_matrix_shape, expect);
}

View File

@ -83,7 +83,7 @@ TEST_F(TestOneHotInfo2, InferDevMatrixShape2) {
ASSERT_EQ(status, SUCCESS);
Shape dev_matrix_shape = onehot_info2->dev_matrix_shape();
Shape expect = {2, 4, 1};
Shape expect = {4, 1, 2};
ASSERT_EQ(dev_matrix_shape, expect);
}