forked from mindspore-Ecosystem/mindspore
!25660 update dynamic stitch
Merge pull request !25660 from VectorSL/dynamic_stitch-2
This commit is contained in:
commit
f0b38f53e8
|
@ -44,7 +44,7 @@ bool DynamicStitchKernel::Init(const CNodePtr &kernel_node) {
|
|||
auto data_type = AnfAlgo::GetInputDeviceDataType(kernel_node, n_);
|
||||
// Index type is restricted to int32 by kernel prim.
|
||||
size_t index_type_size = sizeof(int);
|
||||
data_type_size_ = GetDtypeNbyte(TypeIdToString(data_type, false));
|
||||
data_type_size_ = GetDtypeNbyte(TypeIdToString(data_type, true));
|
||||
auto first_data_shape = AnfAlgo::GetInputRealDeviceShapeIfExist(kernel_node, n_);
|
||||
one_data_ele_num_ = first_data_shape[first_data_shape.size() - 1];
|
||||
for (size_t i = 0; i < n_; i++) {
|
||||
|
|
|
@ -32,7 +32,7 @@ class Net(nn.Cell):
|
|||
def construct(self, indices, data):
|
||||
return self.stitch(indices, data)
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_net_int32():
|
||||
|
|
Loading…
Reference in New Issue