chatterbot中的trains.py详细解释

来源:互联网 发布:淘宝商家怎么找到淘客 编辑:程序博客网 时间:2024/06/07 00:10


该文件中,Trainer是以下四个类的基类

ListTrainer、ChatterBotCorpusTrainer、TwitterTrainer、UbuntuCorpusTrainer

四种用法如下:

ListTrainer用法:

http://blog.csdn.net/elmo66/article/details/53469937

ChatterBotCorpusTrainer用法:

http://blog.csdn.net/appleyuchi/article/details/77388397

TwitterTrainer用法:

没啥意义,略过

UbuntuCorpusTrainer用法:

http://blog.csdn.net/appleyuchi/article/details/77371533






doxygen+graphviz图示trainers.py文件结构如下:


ListTrainer是把对话语料直接放到代码里

TwitterTrainer这个是接API的。

这两种没有太大的实际意义。

重点是ChatterBotCorpusTrainer和UbuntuCorpusTrainer

这两个类最大的不同是:

ChatterBotCorpusTrainer是在处理yml格式的语料数据文件。

UbuntuCorpusTrainer是在处理tsv格式的语料数据文件。

另外,虽然取名字中都带有train,其实只是把语料数据存入storage而已,不存在机器学习意义上的训练。

import logging#/usr/lib/python3.5/loggingimport osimport sysfrom .conversation import Statement, Responsefrom .utils import print_progress_barclass Trainer(object):    """    Base class for all other trainer classes.    """    def __init__(self, storage, **kwargs):        self.storage = storage#来自~/.virtualenvs/python3.5/lib/python3.5/site-packages/chatterbot/storage路径中的        #sql_storage.py文件中的class SQLStorageAdapter object        self.logger = logging.getLogger(__name__)#/usr/lib/python3.5/logging        print("self.storage=",self.storage)        print("self.logger=",self.logger)        print("__name__=", __name__)        print("以上是初始化函数的输出")    def train(self, *args, **kwargs):#这个函数没有用到,暂时不明用途        print("train函数是否用到")        """        This class must be overridden by a class the inherits from 'Trainer'.        """        print("监控这里的输出效果")        raise self.TrainerInitializationException()        print("监控这里的输出效果")    def get_or_create(self, statement_text):        """        Return a statement if it exists.        Create and return the statement if it does not exist.        """        print("以下输出是对get_or_create函数的监控")        statement = self.storage.find(statement_text)        print("statement=",statement)        print("statement_text=",statement_text)        print("以上输出是对get_or_create函数的监控")        if not statement:            statement = Statement(statement_text)#如果找不到就创建一个statement        return statement    class TrainerInitializationException(Exception):#异常处理        """        Exception raised when a base class has not overridden        the required methods on the Trainer base class.        """        def __init__(self, value=None):            default = (                'A training class must be specified before calling train(). ' +                'See http://chatterbot.readthedocs.io/en/stable/training.html'            )            self.value = value or default        def __str__(self):            return repr(self.value)    def _generate_export_data(self):#这个函数没有用到,暂时不明用途        print("当前处于generate_export_data函数中")        result = []        for statement in self.storage.filter():            print("self.storage.filter()=",self.storage.filter())            for response in statement.in_response_to:                result.append([response.text, statement.text])        return result    def export_for_training(self, file_path='./export.json'):#这个函数没有用到,暂时不明用途        """        Create a file from the database that can be used to        train other chat bots.        """        import json        export = {'conversations': self._generate_export_data()}        with open(file_path, 'w+') as jsonfile:            print("以下处于export_for_training函数中")            print("file_path=",file_path)            json.dumps(export, jsonfile, ensure_ascii=False)class ListTrainer(Trainer):    """    Allows a chat bot to be trained using a list of strings    where the list represents a conversation.    """    def train(self, conversation):        """        Train the chat bot based on the provided list of        statements that represents a single conversation.        """        print("是否进入ListTrainer中的train函数")        previous_statement_text = None        for conversation_count, text in enumerate(conversation):            print_progress_bar("List Trainer", conversation_count + 1, len(conversation))            statement = self.get_or_create(text)            if previous_statement_text:                statement.add_response(                    Response(previous_statement_text)                )            previous_statement_text = statement.text            self.storage.update(statement)class ChatterBotCorpusTrainer(Trainer):    """    Allows the chat bot to be trained using data from the    ChatterBot dialog corpus.    """    def __init__(self, storage, **kwargs):        super(ChatterBotCorpusTrainer, self).__init__(storage, **kwargs)        from .corpus import Corpus#从一个python文件中导入一个Corpus类        #文件来自/home/appleyuchi/.virtualenvs/python3.5/lib/python3.5/site-packages/chatterbot_corpus        self.corpus = Corpus()        print("self.corpus=",self.corpus)        #self.corpus=<chatterbot_corpus.corpus.Corpus object at 0x7f6eb0a06ba8>#从一个python文件中导入一个Corpus类    def train(self, *corpus_paths):        print("corpus_paths=",corpus_paths)        # Allow a list of corpora to be passed instead of arguments本例中下面的代码不予以执行        if len(corpus_paths) == 1:            if isinstance(corpus_paths[0], list):                corpus_paths = corpus_paths[0]        # Train the chat bot with each statement and response pair        for corpus_path in corpus_paths:            print("corpus_path=",corpus_path)#字符串类型:chatterbot.corpus.chinese            print("corpus_paths=", corpus_paths)#元组:('chatterbot.corpus.chinese',)            corpora = self.corpus.load_corpus(corpus_path)#corpora是一个路径下所有谈话yml中对话内容的合集。            print("corpora=",corpora)            corpus_files = self.corpus.list_corpus_files(corpus_path)#corpos是yml文件构成的列表            print("corpus_files=",corpus_files)            print("-------------------------------------------下面进入for循环-------------------------------------------")            #corpus_files= ['/home/appleyuchi/.virtualenvs/python3.5/lib/python3.5/site-packages/chatterbot_corpus/data/chinese/conversations.yml', '/home/appleyuchi/.virtualenvs/python3.5/lib/python3.5/site-packages/chatterbot_corpus/data/chinese/greetings.yml', '/home/appleyuchi/.virtualenvs/python3.5/lib/python3.5/site-packages/chatterbot_corpus/data/chinese/trivia.yml']            for corpus_count, corpus in enumerate(corpora):#这个代码的意思是从corpora中取一个corpus                print("corpus_count=",corpus_count)#corpus_count指代的是第几个yml文件                print("corpus=",corpus)#corpus代表一个yml文件中的所有对话内容                for conversation_count, conversation in enumerate(corpus):                    print("conversation_count=", conversation_count)#这个是对话段落数量。                    print("conversation=",conversation)                    print_progress_bar(#这个是进度条,没有进行进度条内的所谓的训练                        str(os.path.basename(corpus_files[corpus_count])) + " Training",#corpus_files应该是个元祖                        conversation_count + 1,                        len(corpus),                    )                    print("len(corpus)=", len(corpus))                    previous_statement_text = None                    for text in conversation:                        print("当前的text=",text)#text指的是一个conversation中的一段文字,也就是说是某个人一次说的话。                        print("进入新的一轮for text in conversation")                        statement = self.get_or_create(text)#statement的类型<class 'chatterbot.conversation.statement.Statement'>                        print("statement类型=",type(statement))                        print("进入底层for循环")                        print("text=",text)                        print("statement=",statement)                        #下面这个if语句,处理的前后对statement没有影响,如梭输入的语句在语料库中可以搜索的到,那么相应的回答会被统计次数。                        #如果搜索不到,则add_reponse函数会调用append函数进行处理,append函数目前是个空函数。                        if previous_statement_text:                            statement.add_response(                                Response(previous_statement_text)#previous_statement_text中的内容与add_response函数中的in_response_to的内容一致。                            )                        print("这里进入if语句")                        print("previous_statement_text=",previous_statement_text)                        print("statement=",statement)                        previous_statement_text = statement.text                        self.storage.update(statement)class TwitterTrainer(Trainer):#这个class没有用    """    Allows the chat bot to be trained using data    gathered from Twitter.    :param random_seed_word: The seed word to be used to get random tweets from the Twitter API.                             This parameter is optional. By default it is the word 'random'.    """    def __init__(self, storage, **kwargs):        super(TwitterTrainer, self).__init__(storage, **kwargs)        from twitter import Api as TwitterApi        # The word to be used as the first search term when searching for tweets        self.random_seed_word = kwargs.get('random_seed_word', 'random')        self.api = TwitterApi(            consumer_key=kwargs.get('twitter_consumer_key'),            consumer_secret=kwargs.get('twitter_consumer_secret'),            access_token_key=kwargs.get('twitter_access_token_key'),            access_token_secret=kwargs.get('twitter_access_token_secret')        )    def random_word(self, base_word):        """        Generate a random word using the Twitter API.        Search twitter for recent tweets containing the term 'random'.        Then randomly select one word from those tweets and do another        search with that word. Return a randomly selected word from the        new set of results.        """        import random        random_tweets = self.api.GetSearch(term=base_word, count=5)        random_words = self.get_words_from_tweets(random_tweets)        random_word = random.choice(list(random_words))        tweets = self.api.GetSearch(term=random_word, count=5)        words = self.get_words_from_tweets(tweets)        word = random.choice(list(words))        return word    def get_words_from_tweets(self, tweets):        """        Given a list of tweets, return the set of        words from the tweets.        """        words = set()        for tweet in tweets:            tweet_words = tweet.text.split()            for word in tweet_words:                # If the word contains only letters with a length from 4 to 9                if word.isalpha() and len(word) > 3 and len(word) <= 9:                    words.add(word)        return words    def get_statements(self):        """        Returns list of random statements from the API.        """        from twitter import TwitterError        statements = []        # Generate a random word        random_word = self.random_word(self.random_seed_word)        self.logger.info(u'Requesting 50 random tweets containing the word {}'.format(random_word))        tweets = self.api.GetSearch(term=random_word, count=50)        for tweet in tweets:            statement = Statement(tweet.text)            if tweet.in_reply_to_status_id:                try:                    status = self.api.GetStatus(tweet.in_reply_to_status_id)                    statement.add_response(Response(status.text))                    statements.append(statement)                except TwitterError as error:                    self.logger.warning(str(error))        self.logger.info('Adding {} tweets with responses'.format(len(statements)))        return statements    def train(self):        for _ in range(0, 10):            statements = self.get_statements()            for statement in statements:                self.storage.update(statement)class UbuntuCorpusTrainer(Trainer):#UbuntuCOrpusTrain继承自Trainer,这个类的核心部分是train函数,其他部分都在负责下载和解压语料库以及处理路径等问题。    """    Allow chatbots to be trained with the data from    the Ubuntu Dialog Corpus.    """    def __init__(self, storage, **kwargs):#构造函数        print("storage=",storage)        print("kwargs=",kwargs)        super(UbuntuCorpusTrainer, self).__init__(storage, **kwargs)#这个super函数的作用是为了改变父类名字时,为了保持着这里的继承关系,不需要修改入口参数Trainer还能保持继承关系。        self.data_download_url = kwargs.get(            'ubuntu_corpus_data_download_url',            'http://cs.mcgill.ca/~jpineau/datasets/ubuntu-corpus-1.0/ubuntu_dialogs.tgz'        )        self.data_directory = kwargs.get(            'ubuntu_corpus_data_directory',            './data/'        )        self.extracted_data_directory = os.path.join(#拼接路径的两个部分,构成完成的解压路径            self.data_directory, 'ubuntu_dialogs'        )        # Create the data directory if it does not already exist        if not os.path.exists(self.data_directory):            os.makedirs(self.data_directory)    def is_downloaded(self, file_path):#检查数据文件是否已经下载        """        Check if the data file is already downloaded.        """        if os.path.exists(file_path):            self.logger.info('File is already downloaded')            return True        return False    def is_extracted(self, file_path):#检查数据文件是否已经解压        """        Check if the data file is already extracted.        """        print("这里判断是否解压")        if os.path.isdir(file_path):            self.logger.info('File is already extracted')            return True        return False    def download(self, url, show_status=True):#这个下载函数应该可以当做黑盒使用,直接复用。        """        Download a file from the given url.        Show a progress indicator for the download status.        Based on: http://stackoverflow.com/a/15645088/1547223        """        import requests        file_name = url.split('/')[-1]        file_path = os.path.join(self.data_directory, file_name)        # Do not download the data if it already exists        if self.is_downloaded(file_path):            return file_path        with open(file_path, 'wb') as open_file:            print('Downloading %s' % url)            response = requests.get(url, stream=True)            total_length = response.headers.get('content-length')            if total_length is None:                # No content length header                open_file.write(response.content)            else:                download = 0                total_length = int(total_length)                for data in response.iter_content(chunk_size=4096):                    download += len(data)                    open_file.write(data)                    if show_status:#这里是下载进度显示。用=来表示下载进度                        done = int(50 * download / total_length)                        sys.stdout.write('\r[%s%s]' % ('=' * done, ' ' * (50 - done)))                        sys.stdout.flush()            # Add a new line after the download bar            sys.stdout.write('\n')        print('Download location: %s' % file_path)        return file_path    def extract(self, file_path):        """        Extract a tar file at the specified file path.        """        import tarfile        print('Extracting {}'.format(file_path))        if not os.path.exists(self.extracted_data_directory):            os.makedirs(self.extracted_data_directory)        def track_progress(members):#解压语料库的进度条显示。            sys.stdout.write('.')            for member in members:                # This will be the current file being extracted                yield member        with tarfile.open(file_path) as tar:#用python对压缩包进行解压操作。            tar.extractall(path=self.extracted_data_directory, members=track_progress(tar))        self.logger.info('File extracted to {}'.format(self.extracted_data_directory))        return True    def train(self):        import glob#用来进行文件搜索的库,支持通配符操作。        import csv#用来处理后面的tsv文件        # Download and extract the Ubuntu dialog corpus if needed        corpus_download_path = self.download(self.data_download_url)        # Extract if the directory doesn not already exists        if not self.is_extracted(self.extracted_data_directory):#如果没有在指定路径解压,就进行解压。            self.extract(corpus_download_path)        extracted_corpus_path = os.path.join(#各个语料库文件的路径。            self.extracted_data_directory,            '**', '**', '*.tsv'        )        file_kwargs = {}#初始化为字典类型        if sys.version_info[0] > 2:#也就是说尽量在python3.x下面进行。            # Specify the encoding in Python versions 3 and up            file_kwargs['encoding'] = 'utf-8'            # WARNING: This might fail to read a unicode corpus file in Python 2.x        for file in glob.iglob(extracted_corpus_path):#对于路径下的所有文件,而iglob函数用来获取匹配括号中路径的文件的集合。            print("glob.iglob(参数)=",glob.iglob(extracted_corpus_path))            self.logger.info('Training from: {}'.format(file))            with open(file, 'r', **file_kwargs) as tsv:                reader = csv.reader(tsv, delimiter='\t')#分隔符是'\t',等号左边的reader是个对象名,等号右侧的reader是class名                print("reader=",reader)                previous_statement_text = None                for row in reader:#row表示语料tsv文件中的每一行数据。                    print("row = ",row)                    if len(row) > 0:                        text = row[3]#这个是根据数据集的特性来决定写代码的,因为语料库的第四列是对话内容。所以这里是row[3]                        print("进入新一轮循环,来看一下statement")                        statement = self.get_or_create(text)                        print("statement=",statement)                        print("------------------------------------------------------------")                        print(text, len(row))                        statement.add_extra_data('datetime', row[0])#这个函数用来添加字典类型数据,例如此处添加的数据就是                        # {datetime,row[0]}                        statement.add_extra_data('speaker', row[1])                        print("row[2].strip()=",row[2].strip())                        if row[2].strip():#这个代码的意思是,对话发起时,是不知道哪个是听众的,因为是在论坛发帖子。                            #因为语料数据来自论坛交谈,论坛发帖子是不知道谁会回复的,所以出事状态下,发了一个新帖子以后,在帖子还没有回复的情况下,row[2]默认是空                            #只有有回复者以后,row[2]才会有值。                            #所以刚发完帖子后,由于row[2]是空,所以if下面的语句将不会被执行。                            #上面的if语句中,Python strip() 方法用于移除字符串头尾指定的字符(默认为空格)                            #所以上面的if语句的意思是,row[2]这个数据集合中的第三个属性,进行去除空格处理,得到的数据是否为空(即判断这是否是个新发的帖子)                            print("进入if语句")                            statement.add_extra_data('addressing_speaker', row[2])#因为人们的交谈是一次只能有一个人说话,A说话时,B就只能听着,                            #可用来增加字典类型数据,如果用法如上,则函数调用前后无变化。                            #所以row[1]和row[2]表示row[1]对row[2]讲话                        print("*************************************")                        print("statement=",statement)                        if previous_statement_text:#这个previous_statement在这里代表上一次某人说的话,因为语料库来自论坛对话,所以论坛帖子没发以前,这个变量肯定是空的。                            statement.add_response(                                Response(previous_statement_text)                            )                        print("previous_statement_text=",previous_statement_text)                        print("statement=", statement)                        print("离开最后一个if")                        print("*************************************")                        previous_statement_text = statement.text#为下一轮for循环做准备,方便取得此次回答的下一轮回答。                        #这样进入下一轮循环的时候,就可以正常进入if语句。                        self.storage.update(statement)#两重for循环,第一重针对指定路径下的每个文件,第二重针对每个文件中的每一行,最终目的是更新storage。


http://blog.csdn.net/appleyuchi/article/details/77388397

中,相关语句是:

  1. chatbot.train("chatterbot.corpus.chinese")  

train函数来自chatterbot.py文件中的train函数定义,该train函数定义如下:

@propertydef train(self):    """    Proxy method to the chat bot's trainer class.    """    return self.trainer.train

也就是说,函数的具体定义外包给了trainer class,这个class就处于上文分析的trains.py中。