forked from mindspore-Ecosystem/mindspore
!22342 [Auto parallel] Adjust the device matrix of OneHot operator
Merge pull request !22342 from Xiaoda/85-adapt-onehot
This commit is contained in:
commit
9b28bd6308
|
@ -84,7 +84,11 @@ Status OneHotInfo::InferDevMatrixShape() {
|
|||
}
|
||||
}
|
||||
old_dev_matrix_back_ = dev_matrix_shape_.back();
|
||||
repeated_num_in_dev_matrix_right_ = false;
|
||||
if (old_dev_matrix_back_ == 1) {
|
||||
repeated_num_in_dev_matrix_right_ = true;
|
||||
} else {
|
||||
repeated_num_in_dev_matrix_right_ = false;
|
||||
}
|
||||
return SUCCESS;
|
||||
}
|
||||
|
||||
|
|
|
@ -83,7 +83,7 @@ TEST_F(TestOneHotInfo, InferDevMatrixShape2) {
|
|||
ASSERT_EQ(status, SUCCESS);
|
||||
Shape dev_matrix_shape = onehot_info->dev_matrix_shape();
|
||||
|
||||
Shape expect = {2, 4, 1};
|
||||
Shape expect = {4, 1, 2};
|
||||
ASSERT_EQ(dev_matrix_shape, expect);
|
||||
}
|
||||
|
||||
|
|
|
@ -83,7 +83,7 @@ TEST_F(TestOneHotInfo2, InferDevMatrixShape2) {
|
|||
ASSERT_EQ(status, SUCCESS);
|
||||
Shape dev_matrix_shape = onehot_info2->dev_matrix_shape();
|
||||
|
||||
Shape expect = {2, 4, 1};
|
||||
Shape expect = {4, 1, 2};
|
||||
ASSERT_EQ(dev_matrix_shape, expect);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue