代码理解—IRGAN(QA)_dataPrepare

来源:互联网 发布:淘宝昵称怎么改不了 编辑:程序博客网 时间:2024/06/05 19:22

如何读在具体的QA task中的IRGAN代码呢?(Answer Selection in IRGAN)?以前也看了 answer selection 使用cnn实现的代码)
这里写图片描述

一般是直接按照开源的代码先运行看下效果(可能是环境配置或者是版本不对应,总之直接这样运行我是出现了点问题,所以打算直接看代码,然后再慢慢调吧),然后再看其中的代码(具有先后顺序)。

这里写图片描述

那就先看dataPrepare。(除了一开始的c和c++语言是上课时候的课程还从一点点学起外,对于java,c#,python等语言,都只是边看到哪个问题,边去百度或者google找下,没有具体的先去全部学习一遍。)

使用python一开始有几个很不适应的问题,主要的是它以缩进作为代码逻辑,中间有次想运行别人的代码,结果报错,是因为缩进不统一(linux的gedit,eclipse,pychram显示都不一样)。后面自己写代码或者运行在linux上时都是修改为统一的tab缩进。

看别人的代码,如果乱得飞起 那就是个非常痛苦的事情,当然如果是自己乱的飞起,那完全没压力。
所以只能看代码实际运行时需要使用的部分,理通了先后执行调用关系,去找到相应的代码才好理解。
所以突然想起了关于python中main函数的问题:
一般的main是程序执行的起点,所以经常看到有下面的这段代码,平时没有注意,Python中,也有类似的运行机制,Python使用缩进对齐组织代码的执行,所有没有直接缩进的代码,都会在载入时自动执行,这些代码,可以认为是Python的main函数。执行顺序,还是从上到下。

开始代码的学习(奈何水平低,很简单的东西都要查…):

第一步:加载数据

answers=load("original/answers")print ("have %d answers" % len(answers))vocabulary=load("original/vocabulary")print ("have %d words" % len(vocabulary))

将二进制文件通过pickle load 加载还原为python对象:

def load(file_name):  return pickle.load(open(os.path.join(path, file_name), 'rb'))

这里写图片描述

第二步:调用convertALL,将文件列表中的所有文件都调用convert2TSV转换。

if __name__ == "__main__":    # parseTrain()    convertAll(subset_size=0)                #
def convertAll(subset_size=0):    for rawFilename in ["dev","test1","test2"]:        filename=convert2TSV(rawFilename)        temp_file=format_file(filename,subset_size)        os.remove(filename)

第三步:调用convert2TSV,由于原始的数据是编码过的且保存为二进制文件,下图通过load方法还原为python对象可以看到数据的每行的元素是列表,列表里保存着三个字典 bad ,good, question,对应着vocabulary里的值,进行解码才可以。
**注意这是验证集和测试集,一般一个问题只有一个正确的答案,其余的都是错误的答案,所以意思是模型最后从众多的错误答案中挑选出一个正确的答案。
bad (表示一个问题对应多个错误答案),good(表示一个问题对应一个正确答案), question。**

def convert2TSV(rawFilename):    test1=load("original/"+rawFilename)    lines=[]    print (len(test1))    for item in test1:        question=item["question"]        bad=item["bad"]        good=item["good"]        q_words= " ".join([ vocabulary[w]  for w in question])        for sen in good:            correct_words=" ".join( [ vocabulary[w]  for w in answers[sen]])            line= "\t".join( (q_words, correct_words, "1" ))            lines.append(line)        for sen in bad:            uncorrect_words= " ".join( [ vocabulary[w]  for w in answers[sen]])            line= "\t".join( (q_words, uncorrect_words, "0" ))            lines.append(line)        # print lines    filename= "original/insurance_%s.tsv" %(rawFilename)    with open(filename, "w") as f:        f.write("\n".join(lines) )    return filename

这里写图片描述

第四步:调用format_file,格式化数据文件

def convertAll(subset_size=0):    for rawFilename in ["dev","test1","test2"]:        filename=convert2TSV(rawFilename)        temp_file=format_file(filename,subset_size)        os.remove(filename)

下面的数据格式需要注意的是项之间只是以 \t 作为分隔符,另外现在直接填充为固定形式,不好吧,个人是准备换个分割符且在统计词后再padding到同一长度。。

总体上代码最后完成的数据处理结果为(注意的是正负答案与标记相匹配):
原始从字典中解码得到的是:question+ positive(or negetive) question +lable
这里写图片描述

后面形式化数据后得到的是:lable+ qid:xxx + question_(表示填充为特定长度) +answer_(表示填充到特定长度)。
这里写图片描述

~~~~~~~~~附加的学习~~~~~~~~~~~~~~
原代码+注释

import pickle,ospath=""def load(file_name):    return pickle.load(open(os.path.join(path, file_name), 'rb'))answers=load("original/answers")print ("have %d answers" % len(answers))#answer 单个样子为 24244:[14834,21507,,,,]vocabulary=load("original/vocabulary")print ("have %d words" % len(vocabulary))def convert2TSV(rawFilename):    test1=load("original/"+rawFilename)    #调用的是声明的load方法,加载还原二进制文件成python的对象    lines=[]    #最后的lines 包含了的格式为:  question \t positive answer \t 1    #   question  \t negetive answer \t 0    print (len(test1))    for item in test1:        #原始的item一个列表 里面包含三个字典        question=item["question"]        bad=item["bad"]        good=item["good"]#对问题句子进行解码,每个单词以空格连接  比如:#do homeowner insurance cover damage to vehicle        q_words= " ".join([ vocabulary[w]  for w in question])        #good代表的是positive answer        for sen in good:            #good 只有一个            #意思是good这里面的编码还需要对应answers这个空间进行解码?            correct_words=" ".join( [ vocabulary[w]  for w in answers[sen]])            line= "\t".join( (q_words, correct_words, "1" ))            #print(line)            lines.append(line)        for sen in bad:            #bad对应多个            uncorrect_words= " ".join( [ vocabulary[w]  for w in answers[sen]])            line= "\t".join( (q_words, uncorrect_words, "0" ))            #line = "[#]".join((q_words, uncorrect_words, "0"))            lines.append(line)        # print lines#保存到tsv文件,假如文件不存在呢?    #filename= "D:/My Documents/Downloads/irgan-master/Question-Answer/original/insurance_%s.tsv" %(rawFilename)    filename = "original/insurance_%s.tsv" % (rawFilename)    with open(filename, "w") as f:        f.write("\n".join(lines) )        print("write done")    return filenamedef convertAll(subset_size=0):    for rawFilename in ["dev","test1","test2"]:        filename=convert2TSV(rawFilename)        temp_file=format_file(filename,subset_size)        #os.remove(filename)def parseTrain():    train= load ("train")    lines=[]    for item in train:        question=item["question"]        q_words= " ".join([ vocabulary[w]  for w in question])        answerIDs=item["answers"]        for sen in answerIDs:            correct_words=" ".join( [ vocabulary[w]  for w in answers[sen]])            line= "\t".join( (q_words, correct_words, "1" ))            lines.append(line)    filename= "original/insurance_%s.tsv" %("train")    with open(filename, "w") as f:        f.write("\n".join(lines) )def format_file(filename="insurance_dev.tsv",subset_size=1800):    temp_file="insuranceQA"+"/"+filename[filename.index("_")+1:filename.index(".")]    with open(filename) as f, open(temp_file,"w") as out:        for index, line in enumerate(f):            question,answer,label=line.strip().split("\t")            #question, answer, label = line.strip().split("[#]")            newline="%s qid:%s" %(label, index/500)            if subset_size!=0 and index / 500 >=subset_size:                break            for sen in [question,answer]:                tokens=sen.split()                fill=max(0,200-len(tokens))                tokens.extend(['<a>']*fill)                newline+=" "+"_".join( tokens)+"_"            out.write(newline+"\n")    return temp_fileif __name__ == "__main__":    # parseTrain()    convertAll(subset_size=0)                #

遇到的不懂问题:

  1. os.path:os.path模块主要用于文件的属性获取,在编程中经常用到。

    1.os.path.abspath(path)
    返回path规范化的绝对路径。

    os.path.abspath(‘test.csv’)
    ‘C:\Python25\test.csv’

    os.path.abspath(‘c:\test.csv’)
    ‘c:\test.csv’

    2.os.path.join(path1[, path2[, …]])
    将多个路径组合后返回,第一个绝对路径之前的参数将被忽略。

    os.path.join(‘c:\’, ‘csv’, ‘test.csv’)
    ‘c:\csv\test.csv’

    os.path.join(‘windows\temp’, ‘c:\’, ‘csv’, ‘test.csv’)
    ‘c:\csv\test.csv’

    os.path.join(‘/home/aa’,’/home/aa/bb’,’/home/aa/bb/c’)
    ‘/home/aa/bb/c’

  2. pickle: python的pickle模块实现了基本的数据(二进制序列化)序列和反序列化。通过pickle模块的序列化操作我们能够将程序中运行的对象信息保存到文件中去,永久存储;通过pickle模块的反序列化操作,我们能够从文件中创建上一次程序保存的对象。

pickle.dump(obj, file, [,protocol])
  注解:将对象obj保存到文件file中去。
     protocol为序列化使用的协议版本,0:ASCII协议,所序列化的对象使用可打印的ASCII码表示;1:老式的二进制协议;2:2.3版本引入的新二进制协议,较以前的更高效。其中协议0和1兼容老版本的python。protocol默认值为0。
     file:对象保存到的类文件对象。file必须有write()接口, file可以是一个以’w’方式打开的文件或者一个StringIO对象或者其他任何实现write()接口的对象。如果protocol>=1,文件对象需要是二进制模式打开的。
     #使用pickle模块将数据对象保存到文件

import pickledata1 = {'a': [1, 2.0, 3, 4+6j],         'b': ('string', u'Unicode string'),         'c': None}selfref_list = [1, 2, 3]selfref_list.append(selfref_list)output = open('data.pkl', 'wb')    # Pickle dictionary using protocol 0.pickle.dump(data1, output)    # Pickle the list using the highest protocol available.pickle.dump(selfref_list, output, -1)output.close()

pickle.load(file)
  注解:从file中读取一个字符串,并将它重构为原来的python对象。
  file:类文件对象,有read()和readline()接口。

    #使用pickle模块从文件中重构python对象    import pprint, pickle    pkl_file = open('data.pkl', 'rb')    data1 = pickle.load(pkl_file)    pprint.pprint(data1)    data2 = pickle.load(pkl_file)    pprint.pprint(data2)    pkl_file.close()

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

原创粉丝点击