Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New test.py #53

Open
Lamborghini1709 opened this issue Jul 29, 2020 · 6 comments
Open

New test.py #53

Lamborghini1709 opened this issue Jul 29, 2020 · 6 comments

Comments

@Lamborghini1709
Copy link

#coding=utf-8
import os
import json
import csv
import argparse
import pandas as pd
import numpy as np
from math import ceil
from tqdm import tqdm
import pickle
import shutil

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import CrossEntropyLoss
from torchvision import datasets, models
import torch.backends.cudnn as cudnn
import torch.nn.functional as F

from transforms import transforms
from models.LoadModel import MainModel
from utils.dataset_DCL import collate_fn4train, collate_fn4test, collate_fn4val, dataset
from config import LoadConfig, load_data_transformers
from utils.test_tool import set_text, save_multi_img, cls_base_acc

import pdb

os.environ['CUDA_DEVICE_ORDRE'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'

def parse_args():
parser = argparse.ArgumentParser(description='dcl parameters')
parser.add_argument('--data', dest='dataset',
default='CUB', type=str)
parser.add_argument('--backbone', dest='backbone',
default='resnet50', type=str)
parser.add_argument('--b', dest='batch_size',
default=16, type=int)
parser.add_argument('--nw', dest='num_workers',
default=16, type=int)
parser.add_argument('--ver', dest='version',
default='test', type=str)
parser.add_argument('--save', dest='resume',
default=None, type=str)
parser.add_argument('--size', dest='resize_resolution',
default=512, type=int)
parser.add_argument('--crop', dest='crop_resolution',
default=448, type=int)
parser.add_argument('--ss', dest='save_suffix',
default=None, type=str)
parser.add_argument('--acc_report', dest='acc_report',
action='store_true')
parser.add_argument('--swap_num', default=[7, 7],
nargs=2, metavar=('swap1', 'swap2'),
type=int, help='specify a range')
args = parser.parse_args()
return args

if name == 'main':
args = parse_args()
print(args)
# if args.submit:
# args.version = 'test'
# if args.save_suffix == '':
# raise Exception('**** miss --ss save suffix is needed. ')
args.version = 'test'
Config = LoadConfig(args, args.version)
transformers = load_data_transformers(args.resize_resolution, args.crop_resolution, args.swap_num)
data_set = dataset(Config,
anno=Config.val_anno if args.version == 'val' else Config.test_anno ,
# unswap=transformers["None"],
swap=transformers["None"],
totensor=transformers['test_totensor'],
test=True)

dataloader = torch.utils.data.DataLoader(data_set,\
                                         batch_size=args.batch_size,\
                                         shuffle=False,\
                                         num_workers=args.num_workers,\
                                         collate_fn=collate_fn4test)

setattr(dataloader, 'total_item_len', len(data_set))

cudnn.benchmark = True
print('****************')
Config.cls_2xmul = True
print(Config.cls_2xmul)
model = MainModel(Config)
model_dict=model.state_dict()
pretrained_dict=torch.load(args.resume)
pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items() if k[7:] in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
model.cuda()
model = nn.DataParallel(model)

model.train(False)
with torch.no_grad():
    val_corrects1 = 0
    val_corrects2 = 0
    val_corrects3 = 0
    val_size = ceil(len(data_set) / dataloader.batch_size)
    result_gather = {}
    count_bar = tqdm(total=dataloader.__len__())
    for batch_cnt_val, data_val in enumerate(dataloader):
        count_bar.update(1)
        inputs, labels, img_name = data_val
        inputs = Variable(inputs.cuda())
        labels = Variable(torch.from_numpy(np.array(labels)).long().cuda())

        outputs = model(inputs)
        # print('outputs:', outputs)
        outputs_pred = outputs[0] + outputs[1][:,0:Config.numcls] + outputs[1][:,Config.numcls:2*Config.numcls]
        print('outputs_pred:', outputs_pred)

        top3_val, top3_pos = torch.topk(outputs_pred, 3)

        if args.version == 'val':
            batch_corrects1 = torch.sum((top3_pos[:, 0] == labels)).data.item()
            val_corrects1 += batch_corrects1
            batch_corrects2 = torch.sum((top3_pos[:, 1] == labels)).data.item()
            val_corrects2 += (batch_corrects2 + batch_corrects1)
            batch_corrects3 = torch.sum((top3_pos[:, 2] == labels)).data.item()
            val_corrects3 += (batch_corrects3 + batch_corrects2 + batch_corrects1)

        if args.acc_report:
            for sub_name, sub_cat, sub_val, sub_label in zip(img_name, top3_pos.tolist(), top3_val.tolist(), labels.tolist()):
                result_gather[sub_name] = {'top1_cat': sub_cat[0], 'top2_cat': sub_cat[1], 'top3_cat': sub_cat[2],
                                           'top1_val': sub_val[0], 'top2_val': sub_val[1], 'top3_val': sub_val[2],
                                           'label': sub_label}
if args.acc_report:
    result_gather_json = open('result_gather_%s'%args.resume.split('/')[-1][:-4]+ '.json', 'w')
    json.dump(result_gather, result_gather_json)
    result_gather_json.close()
    torch.save(result_gather, 'result_gather_%s'%args.resume.split('/')[-1][:-4]+ '.pt')

count_bar.close()
print(args.acc_report)
if args.acc_report:

    val_acc1 = val_corrects1 / len(data_set)
    val_acc2 = val_corrects2 / len(data_set)
    val_acc3 = val_corrects3 / len(data_set)
    print('%sacc1 %f%s\n%sacc2 %f%s\n%sacc3 %f%s\n'%(8*'-', val_acc1, 8*'-', 8*'-', val_acc2, 8*'-', 8*'-',  val_acc3, 8*'-'))

    cls_top1, cls_top3, cls_count = cls_base_acc(result_gather)

    acc_report_io = open('acc_report_%s_%s.json'%(args.save_suffix, args.resume.split('/')[-1]), 'w')
    json.dump({'val_acc1':val_acc1,
               'val_acc2':val_acc2,
               'val_acc3':val_acc3,
               'cls_top1':cls_top1,
               'cls_top3':cls_top3,
               'cls_count':cls_count}, acc_report_io)
    acc_report_io.close()

run test : python test.py --save ./net_model/training_descibe_72721_CUB/weights_36_4999_0.8608_0.9998.pth --acc_report

@yunchangxiaoguan
Copy link

hello, i use the code ,but error:

`python new_test.py --save net_model/_8514_CUB/weights_20_0_1.0000_1.0000.pth --acc_report
Namespace(acc_report=True, backbone='resnet50', batch_size=16, crop_resolution=448, dataset='CUB', num_workers=16, resize_resolution=512, resume='net_model/_8514_CUB/weights_20_0_1.0000_1.0000.pth', save_suffix=None, swap_num=[7, 7], version='test')


True
resnet50
Traceback (most recent call last):
File "new_test.py", line 95, in
model.load_state_dict(model_dict)
File "/home/guanxiao/anaconda3/envs/py36/lib/python3.6/site-packages/torch/nn/modules/module.py", line 847, in load_state_dict
self.class.name, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for MainModel:
size mismatch for classifier_swap.weight: copying a param with shape torch.Size([2, 2048]) from checkpoint, the shape in current model is torch.Size([6, 2048]).
`
how to make it?thanks

@Lamborghini1709
Copy link
Author

check out your num_classes

@yunchangxiaoguan
Copy link

check out your num_classes

thanks,i have solve it

@BaofengZan
Copy link

@Lamborghini1709 @yunchangxiaoguan 你好,能不能分享下训练好的cub模型,这个模型是真的需要硬件。小batch训练太慢。

@JiCheng12138
Copy link

@yunchangxiaoguan 您好请问这个size mismatch for classifier_swap.weight:问题您如何解决的呢,谢谢

@Lamborghini1709
Copy link
Author

@yunchangxiaoguan 您好请问这个size mismatch for classifier_swap.weight:问题您如何解决的呢,谢谢

检查你的输出类别数 num_classes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants