fix faceattribute network bug

This commit is contained in:
zhanghuiyao 2021-03-26 12:42:08 +08:00
parent 669a37739e
commit fce3c2d3b2
1 changed files with 11 additions and 13 deletions

View File

@ -18,6 +18,7 @@ import time
import datetime
import argparse
import mindspore
import mindspore.nn as nn
from mindspore import context
from mindspore import Tensor
@ -43,16 +44,16 @@ context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs
class BuildTrainNetwork(nn.Cell):
'''Build train network.'''
def __init__(self, network, criterion):
def __init__(self, my_network, my_criterion):
super(BuildTrainNetwork, self).__init__()
self.network = network
self.criterion = criterion
self.network = my_network
self.criterion = my_criterion
self.print = P.Print()
def construct(self, input_data, label):
logit0, logit1, logit2 = self.network(input_data)
loss = self.criterion(logit0, logit1, logit2, label)
return loss
loss0 = self.criterion(logit0, logit1, logit2, label)
return loss0
def parse_args():
@ -64,13 +65,14 @@ def parse_args():
parser.add_argument('--local_rank', type=int, default=0, help='current rank to support distributed')
parser.add_argument('--world_size', type=int, default=8, help='current process number to support distributed')
args, _ = parser.parse_known_args()
arg, _ = parser.parse_known_args()
return args
return arg
def train():
'''train function.'''
if __name__ == "__main__":
mindspore.set_seed(1)
# logger
args = parse_args()
@ -226,7 +228,3 @@ def train():
i += 1
args.logger.info('--------- trains out ---------')
if __name__ == "__main__":
train()