!25660 update dynamic stitch

Merge pull request !25660 from VectorSL/dynamic_stitch-2
This commit is contained in:
i-robot 2021-10-30 10:04:22 +00:00 committed by Gitee
commit f0b38f53e8
2 changed files with 2 additions and 2 deletions

View File

@ -44,7 +44,7 @@ bool DynamicStitchKernel::Init(const CNodePtr &kernel_node) {
auto data_type = AnfAlgo::GetInputDeviceDataType(kernel_node, n_);
// Index type is restricted to int32 by kernel prim.
size_t index_type_size = sizeof(int);
data_type_size_ = GetDtypeNbyte(TypeIdToString(data_type, false));
data_type_size_ = GetDtypeNbyte(TypeIdToString(data_type, true));
auto first_data_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, n_);
one_data_ele_num_ = first_data_shape[first_data_shape.size() - 1];
for (size_t i = 0; i < n_; i++) {

View File

@ -32,7 +32,7 @@ class Net(nn.Cell):
def construct(self, indices, data):
return self.stitch(indices, data)
@pytest.mark.level1
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_net_int32():