pytorch实现带标签格式数据的模型训练

来源:互联网 发布:淘宝运费险换货赔吗 编辑:程序博客网 时间:2024/06/05 03:27

1.训练数据读入

注:以下模拟数据,主要讲解方法。

标签数据


下面函数即为实现标签数据的读入

def reader(txt):

    fh = open(txt)  
    c=0  
    imgs=[]  
    class_names=[]  
    for line in  fh.readlines():  
        if c==0:  
            class_names=[n.strip() for n in line.rstrip().split('   ')]  
        else:  
            cls = line.split()   
            fn = cls.pop(0)
            imgs.append((fn, tuple([float(v) for v in cls])))  
        c=c+1

    return class_names,imgs

其中,返回imgs是标签元组,即[1,0,0,1],class_names为属性名,即sex。

如人脸特征数据,也可以通过reader()读入。

2.简单模型设计(以全连层为例)

cmodel=nn.Linear(100, 2) ,(或者nn.Sequential(nn.Linear(100, 2))

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.classify=cmodel
    def forward(self, x):
        x=self.classify(x)
        return x,

3.模型训练

训练集读入

train_data_loader = torch.utils.data.DataLoader(  \
         ImageFloder(root = "./fea.txt", label = "./label.txt"), batch_size= 2, shuffle= False, num_workers= 4)

其中,root,label分别是特征与标签文件地址, ImageFloder类定义如下:

class ImageFloder(data.Dataset):  
    def __init__(self, root, label):

self.classes1,self.imgs1 = reader(label)
        self.classes2,self.imgs2 = reader(root)

    def __getitem__(self, index):  
        fn1, label1 = self.imgs1[index]
        fn2, label2 = self.imgs2[index]

return torch.Tensor(label1),torch.Tensor(label2)

    def __len__(self):  
        return len(self.imgs1)

训练代码详见项目:

https://github.com/eeric/pytorch-model-training-label

阅读全文
0 0
原创粉丝点击