产品概述
产品定价
客户价值
应用场景








## 在valide函数中,上传训练指标 acc1和acc5def validate(val_loader, model, criterion, epoch, args):batch_time = AverageMeter('Time', ':6.3f')losses = AverageMeter('Loss', ':.4e')top1 = AverageMeter('Acc@1', ':6.2f')top5 = AverageMeter('Acc@5', ':6.2f')progress = ProgressMeter(len(val_loader),[batch_time, losses, top1, top5],prefix='Test: ')# switch to evaluate modemodel.eval()with torch.no_grad():end = time.time()for i, (images, target) in enumerate(val_loader):if args.gpu is not None:images = images.cuda(args.gpu, non_blocking=True)target = target.cuda(args.gpu, non_blocking=True)# compute outputoutput = model(images)loss = criterion(output, target)# measure accuracy and record lossacc1, acc5 = accuracy(output, target, topk=(1, 5))losses.update(loss.item(), images.size(0))top1.update(acc1[0], images.size(0))top5.update(acc5[0], images.size(0))# measure elapsed timebatch_time.update(time.time() - end)end = time.time()if i % args.print_freq == 0:progress.display(i)## 调用Tikit指令上传指标client.push_training_metrics(int(time.time()), {"acc1": float(format(top1.avg, '.3f')), "acc5": float(format(top5.avg, '.3f'))}, epoch=epoch)# TODO: this should also be done with the ProgressMeterprint('TIACC - * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f} Epoch={epoch}'.format(top1=top1, top5=top5, epoch=epoch))return top1.avg
文档反馈