forked from mindspore-Ecosystem/mindspore
!1237 adopt weight initializator modification in resnet
Merge pull request !1237 from gengdongjie/master
This commit is contained in:
commit
122a6e03a9
|
@ -64,11 +64,11 @@ if __name__ == '__main__':
|
|||
if isinstance(cell, nn.Conv2d):
|
||||
cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(),
|
||||
cell.weight.default_input.shape(),
|
||||
cell.weight.default_input.dtype())
|
||||
cell.weight.default_input.dtype()).to_tensor()
|
||||
if isinstance(cell, nn.Dense):
|
||||
cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(),
|
||||
cell.weight.default_input.shape(),
|
||||
cell.weight.default_input.dtype())
|
||||
cell.weight.default_input.dtype()).to_tensor()
|
||||
if not config.label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
|
||||
|
|
|
@ -61,11 +61,11 @@ if __name__ == '__main__':
|
|||
if isinstance(cell, nn.Conv2d):
|
||||
cell.weight.default_input = weight_init.initializer(weight_init.XavierUniform(),
|
||||
cell.weight.default_input.shape(),
|
||||
cell.weight.default_input.dtype())
|
||||
cell.weight.default_input.dtype()).to_tensor()
|
||||
if isinstance(cell, nn.Dense):
|
||||
cell.weight.default_input = weight_init.initializer(weight_init.TruncatedNormal(),
|
||||
cell.weight.default_input.shape(),
|
||||
cell.weight.default_input.dtype())
|
||||
cell.weight.default_input.dtype()).to_tensor()
|
||||
if not config.use_label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
|
||||
|
|
Loading…
Reference in New Issue