fix nll_loss
This commit is contained in:
parent
78d51aa323
commit
094e306611
|
@ -1761,7 +1761,10 @@ class NLLLossGrad(PrimitiveWithInfer):
|
|||
validator.check_int(len(t_shape), 1, Rel.EQ, "target rank", self.name)
|
||||
validator.check_int(len(w_shape), 1, Rel.EQ, "weight rank", self.name)
|
||||
validator.check(f"input_shape[0]", x_shape[0], "target_shape", t_shape[0], Rel.EQ, self.name)
|
||||
validator.check(f"input_shape[1]", x_shape[1], "weight_shape", w_shape[0], Rel.EQ, self.name)
|
||||
if len(x_shape) == 1:
|
||||
validator.check(f"input_shape[0]", x_shape[0], "weight_shape", w_shape[0], Rel.EQ, self.name)
|
||||
else:
|
||||
validator.check(f"input_shape[1]", x_shape[1], "weight_shape", w_shape[0], Rel.EQ, self.name)
|
||||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, y_grad_dtype, t_dtype, w_dtype, tw_dtype):
|
||||
|
|
|
@ -1956,7 +1956,7 @@ class NLLLoss(PrimitiveWithInfer):
|
|||
def __init__(self, reduction="mean"):
|
||||
"""Initialize NLLLoss"""
|
||||
self.init_prim_io_names(inputs=['x', 'target', "weight"], outputs=['loss'])
|
||||
self.reduction = validator.check_string(reduction.lower(), ['none', 'sum', 'mean'], 'reduction', self.name)
|
||||
self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
|
||||
self.add_prim_attr('reduction', self.reduction)
|
||||
|
||||
def infer_shape(self, x_shape, t_shape, w_shape):
|
||||
|
@ -1964,7 +1964,10 @@ class NLLLoss(PrimitiveWithInfer):
|
|||
validator.check_int(len(t_shape), 1, Rel.EQ, "target rank", self.name)
|
||||
validator.check_int(len(w_shape), 1, Rel.EQ, "weight rank", self.name)
|
||||
validator.check(f"input_shape[0]", x_shape[0], "target_shape", t_shape[0], Rel.EQ, self.name)
|
||||
validator.check(f"input_shape[1]", x_shape[1], "weight_shape", w_shape[0], Rel.EQ, self.name)
|
||||
if len(x_shape) == 1:
|
||||
validator.check(f"input_shape[0]", x_shape[0], "weight_shape", w_shape[0], Rel.EQ, self.name)
|
||||
else:
|
||||
validator.check(f"input_shape[1]", x_shape[1], "weight_shape", w_shape[0], Rel.EQ, self.name)
|
||||
if self.reduction == "none":
|
||||
return t_shape, ()
|
||||
return (), ()
|
||||
|
|
Loading…
Reference in New Issue