forked from OSSInnovation/mindspore
fix weight init and add data aug
This commit is contained in:
parent
6ea2aa4e73
commit
34864fbc56
|
@ -63,7 +63,7 @@ def random_sample_crop(image, boxes):
|
|||
if not drop_mask.any():
|
||||
continue
|
||||
|
||||
if overlap[drop_mask].min() < min_iou:
|
||||
if overlap[drop_mask].min() < min_iou and overlap[drop_mask].max() > (min_iou + 0.2):
|
||||
continue
|
||||
|
||||
image_t = image_t[rect[0]:rect[2], rect[1]:rect[3], :]
|
||||
|
|
|
@ -14,16 +14,17 @@
|
|||
# ============================================================================
|
||||
"""Parameters utils"""
|
||||
|
||||
from mindspore import Tensor
|
||||
from mindspore.common.initializer import initializer, TruncatedNormal
|
||||
import numpy as np
|
||||
|
||||
def init_net_param(network, initialize_mode='TruncatedNormal'):
|
||||
"""Init the parameters in net."""
|
||||
params = network.trainable_params()
|
||||
for p in params:
|
||||
if isinstance(p.data, Tensor) and 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
|
||||
if 'beta' not in p.name and 'gamma' not in p.name and 'bias' not in p.name:
|
||||
np.random.seed(seed=1)
|
||||
if initialize_mode == 'TruncatedNormal':
|
||||
p.set_parameter_data(initializer(TruncatedNormal(0.03), p.data.shape, p.data.dtype))
|
||||
p.set_parameter_data(initializer(TruncatedNormal(), p.data.shape, p.data.dtype))
|
||||
else:
|
||||
p.set_parameter_data(initialize_mode, p.data.shape, p.data.dtype)
|
||||
|
||||
|
|
Loading…
Reference in New Issue