chatterbot中的trains.py详细解释
来源:互联网 发布:淘宝商家怎么找到淘客 编辑:程序博客网 时间:2024/06/07 00:10
该文件中,Trainer是以下四个类的基类
ListTrainer、ChatterBotCorpusTrainer、TwitterTrainer、UbuntuCorpusTrainer
四种用法如下:
ListTrainer用法:
http://blog.csdn.net/elmo66/article/details/53469937
ChatterBotCorpusTrainer用法:
http://blog.csdn.net/appleyuchi/article/details/77388397
TwitterTrainer用法:
没啥意义,略过
UbuntuCorpusTrainer用法:
http://blog.csdn.net/appleyuchi/article/details/77371533
doxygen+graphviz图示trainers.py文件结构如下:
ListTrainer是把对话语料直接放到代码里
TwitterTrainer这个是接API的。
这两种没有太大的实际意义。
重点是ChatterBotCorpusTrainer和UbuntuCorpusTrainer
这两个类最大的不同是:
ChatterBotCorpusTrainer是在处理yml格式的语料数据文件。
UbuntuCorpusTrainer是在处理tsv格式的语料数据文件。
另外,虽然取名字中都带有train,其实只是把语料数据存入storage而已,不存在机器学习意义上的训练。
import logging#/usr/lib/python3.5/loggingimport osimport sysfrom .conversation import Statement, Responsefrom .utils import print_progress_barclass Trainer(object): """ Base class for all other trainer classes. """ def __init__(self, storage, **kwargs): self.storage = storage#来自~/.virtualenvs/python3.5/lib/python3.5/site-packages/chatterbot/storage路径中的 #sql_storage.py文件中的class SQLStorageAdapter object self.logger = logging.getLogger(__name__)#/usr/lib/python3.5/logging print("self.storage=",self.storage) print("self.logger=",self.logger) print("__name__=", __name__) print("以上是初始化函数的输出") def train(self, *args, **kwargs):#这个函数没有用到,暂时不明用途 print("train函数是否用到") """ This class must be overridden by a class the inherits from 'Trainer'. """ print("监控这里的输出效果") raise self.TrainerInitializationException() print("监控这里的输出效果") def get_or_create(self, statement_text): """ Return a statement if it exists. Create and return the statement if it does not exist. """ print("以下输出是对get_or_create函数的监控") statement = self.storage.find(statement_text) print("statement=",statement) print("statement_text=",statement_text) print("以上输出是对get_or_create函数的监控") if not statement: statement = Statement(statement_text)#如果找不到就创建一个statement return statement class TrainerInitializationException(Exception):#异常处理 """ Exception raised when a base class has not overridden the required methods on the Trainer base class. """ def __init__(self, value=None): default = ( 'A training class must be specified before calling train(). ' + 'See http://chatterbot.readthedocs.io/en/stable/training.html' ) self.value = value or default def __str__(self): return repr(self.value) def _generate_export_data(self):#这个函数没有用到,暂时不明用途 print("当前处于generate_export_data函数中") result = [] for statement in self.storage.filter(): print("self.storage.filter()=",self.storage.filter()) for response in statement.in_response_to: result.append([response.text, statement.text]) return result def export_for_training(self, file_path='./export.json'):#这个函数没有用到,暂时不明用途 """ Create a file from the database that can be used to train other chat bots. """ import json export = {'conversations': self._generate_export_data()} with open(file_path, 'w+') as jsonfile: print("以下处于export_for_training函数中") print("file_path=",file_path) json.dumps(export, jsonfile, ensure_ascii=False)class ListTrainer(Trainer): """ Allows a chat bot to be trained using a list of strings where the list represents a conversation. """ def train(self, conversation): """ Train the chat bot based on the provided list of statements that represents a single conversation. """ print("是否进入ListTrainer中的train函数") previous_statement_text = None for conversation_count, text in enumerate(conversation): print_progress_bar("List Trainer", conversation_count + 1, len(conversation)) statement = self.get_or_create(text) if previous_statement_text: statement.add_response( Response(previous_statement_text) ) previous_statement_text = statement.text self.storage.update(statement)class ChatterBotCorpusTrainer(Trainer): """ Allows the chat bot to be trained using data from the ChatterBot dialog corpus. """ def __init__(self, storage, **kwargs): super(ChatterBotCorpusTrainer, self).__init__(storage, **kwargs) from .corpus import Corpus#从一个python文件中导入一个Corpus类 #文件来自/home/appleyuchi/.virtualenvs/python3.5/lib/python3.5/site-packages/chatterbot_corpus self.corpus = Corpus() print("self.corpus=",self.corpus) #self.corpus=<chatterbot_corpus.corpus.Corpus object at 0x7f6eb0a06ba8>#从一个python文件中导入一个Corpus类 def train(self, *corpus_paths): print("corpus_paths=",corpus_paths) # Allow a list of corpora to be passed instead of arguments本例中下面的代码不予以执行 if len(corpus_paths) == 1: if isinstance(corpus_paths[0], list): corpus_paths = corpus_paths[0] # Train the chat bot with each statement and response pair for corpus_path in corpus_paths: print("corpus_path=",corpus_path)#字符串类型:chatterbot.corpus.chinese print("corpus_paths=", corpus_paths)#元组:('chatterbot.corpus.chinese',) corpora = self.corpus.load_corpus(corpus_path)#corpora是一个路径下所有谈话yml中对话内容的合集。 print("corpora=",corpora) corpus_files = self.corpus.list_corpus_files(corpus_path)#corpos是yml文件构成的列表 print("corpus_files=",corpus_files) print("-------------------------------------------下面进入for循环-------------------------------------------") #corpus_files= ['/home/appleyuchi/.virtualenvs/python3.5/lib/python3.5/site-packages/chatterbot_corpus/data/chinese/conversations.yml', '/home/appleyuchi/.virtualenvs/python3.5/lib/python3.5/site-packages/chatterbot_corpus/data/chinese/greetings.yml', '/home/appleyuchi/.virtualenvs/python3.5/lib/python3.5/site-packages/chatterbot_corpus/data/chinese/trivia.yml'] for corpus_count, corpus in enumerate(corpora):#这个代码的意思是从corpora中取一个corpus print("corpus_count=",corpus_count)#corpus_count指代的是第几个yml文件 print("corpus=",corpus)#corpus代表一个yml文件中的所有对话内容 for conversation_count, conversation in enumerate(corpus): print("conversation_count=", conversation_count)#这个是对话段落数量。 print("conversation=",conversation) print_progress_bar(#这个是进度条,没有进行进度条内的所谓的训练 str(os.path.basename(corpus_files[corpus_count])) + " Training",#corpus_files应该是个元祖 conversation_count + 1, len(corpus), ) print("len(corpus)=", len(corpus)) previous_statement_text = None for text in conversation: print("当前的text=",text)#text指的是一个conversation中的一段文字,也就是说是某个人一次说的话。 print("进入新的一轮for text in conversation") statement = self.get_or_create(text)#statement的类型<class 'chatterbot.conversation.statement.Statement'> print("statement类型=",type(statement)) print("进入底层for循环") print("text=",text) print("statement=",statement) #下面这个if语句,处理的前后对statement没有影响,如梭输入的语句在语料库中可以搜索的到,那么相应的回答会被统计次数。 #如果搜索不到,则add_reponse函数会调用append函数进行处理,append函数目前是个空函数。 if previous_statement_text: statement.add_response( Response(previous_statement_text)#previous_statement_text中的内容与add_response函数中的in_response_to的内容一致。 ) print("这里进入if语句") print("previous_statement_text=",previous_statement_text) print("statement=",statement) previous_statement_text = statement.text self.storage.update(statement)class TwitterTrainer(Trainer):#这个class没有用 """ Allows the chat bot to be trained using data gathered from Twitter. :param random_seed_word: The seed word to be used to get random tweets from the Twitter API. This parameter is optional. By default it is the word 'random'. """ def __init__(self, storage, **kwargs): super(TwitterTrainer, self).__init__(storage, **kwargs) from twitter import Api as TwitterApi # The word to be used as the first search term when searching for tweets self.random_seed_word = kwargs.get('random_seed_word', 'random') self.api = TwitterApi( consumer_key=kwargs.get('twitter_consumer_key'), consumer_secret=kwargs.get('twitter_consumer_secret'), access_token_key=kwargs.get('twitter_access_token_key'), access_token_secret=kwargs.get('twitter_access_token_secret') ) def random_word(self, base_word): """ Generate a random word using the Twitter API. Search twitter for recent tweets containing the term 'random'. Then randomly select one word from those tweets and do another search with that word. Return a randomly selected word from the new set of results. """ import random random_tweets = self.api.GetSearch(term=base_word, count=5) random_words = self.get_words_from_tweets(random_tweets) random_word = random.choice(list(random_words)) tweets = self.api.GetSearch(term=random_word, count=5) words = self.get_words_from_tweets(tweets) word = random.choice(list(words)) return word def get_words_from_tweets(self, tweets): """ Given a list of tweets, return the set of words from the tweets. """ words = set() for tweet in tweets: tweet_words = tweet.text.split() for word in tweet_words: # If the word contains only letters with a length from 4 to 9 if word.isalpha() and len(word) > 3 and len(word) <= 9: words.add(word) return words def get_statements(self): """ Returns list of random statements from the API. """ from twitter import TwitterError statements = [] # Generate a random word random_word = self.random_word(self.random_seed_word) self.logger.info(u'Requesting 50 random tweets containing the word {}'.format(random_word)) tweets = self.api.GetSearch(term=random_word, count=50) for tweet in tweets: statement = Statement(tweet.text) if tweet.in_reply_to_status_id: try: status = self.api.GetStatus(tweet.in_reply_to_status_id) statement.add_response(Response(status.text)) statements.append(statement) except TwitterError as error: self.logger.warning(str(error)) self.logger.info('Adding {} tweets with responses'.format(len(statements))) return statements def train(self): for _ in range(0, 10): statements = self.get_statements() for statement in statements: self.storage.update(statement)class UbuntuCorpusTrainer(Trainer):#UbuntuCOrpusTrain继承自Trainer,这个类的核心部分是train函数,其他部分都在负责下载和解压语料库以及处理路径等问题。 """ Allow chatbots to be trained with the data from the Ubuntu Dialog Corpus. """ def __init__(self, storage, **kwargs):#构造函数 print("storage=",storage) print("kwargs=",kwargs) super(UbuntuCorpusTrainer, self).__init__(storage, **kwargs)#这个super函数的作用是为了改变父类名字时,为了保持着这里的继承关系,不需要修改入口参数Trainer还能保持继承关系。 self.data_download_url = kwargs.get( 'ubuntu_corpus_data_download_url', 'http://cs.mcgill.ca/~jpineau/datasets/ubuntu-corpus-1.0/ubuntu_dialogs.tgz' ) self.data_directory = kwargs.get( 'ubuntu_corpus_data_directory', './data/' ) self.extracted_data_directory = os.path.join(#拼接路径的两个部分,构成完成的解压路径 self.data_directory, 'ubuntu_dialogs' ) # Create the data directory if it does not already exist if not os.path.exists(self.data_directory): os.makedirs(self.data_directory) def is_downloaded(self, file_path):#检查数据文件是否已经下载 """ Check if the data file is already downloaded. """ if os.path.exists(file_path): self.logger.info('File is already downloaded') return True return False def is_extracted(self, file_path):#检查数据文件是否已经解压 """ Check if the data file is already extracted. """ print("这里判断是否解压") if os.path.isdir(file_path): self.logger.info('File is already extracted') return True return False def download(self, url, show_status=True):#这个下载函数应该可以当做黑盒使用,直接复用。 """ Download a file from the given url. Show a progress indicator for the download status. Based on: http://stackoverflow.com/a/15645088/1547223 """ import requests file_name = url.split('/')[-1] file_path = os.path.join(self.data_directory, file_name) # Do not download the data if it already exists if self.is_downloaded(file_path): return file_path with open(file_path, 'wb') as open_file: print('Downloading %s' % url) response = requests.get(url, stream=True) total_length = response.headers.get('content-length') if total_length is None: # No content length header open_file.write(response.content) else: download = 0 total_length = int(total_length) for data in response.iter_content(chunk_size=4096): download += len(data) open_file.write(data) if show_status:#这里是下载进度显示。用=来表示下载进度 done = int(50 * download / total_length) sys.stdout.write('\r[%s%s]' % ('=' * done, ' ' * (50 - done))) sys.stdout.flush() # Add a new line after the download bar sys.stdout.write('\n') print('Download location: %s' % file_path) return file_path def extract(self, file_path): """ Extract a tar file at the specified file path. """ import tarfile print('Extracting {}'.format(file_path)) if not os.path.exists(self.extracted_data_directory): os.makedirs(self.extracted_data_directory) def track_progress(members):#解压语料库的进度条显示。 sys.stdout.write('.') for member in members: # This will be the current file being extracted yield member with tarfile.open(file_path) as tar:#用python对压缩包进行解压操作。 tar.extractall(path=self.extracted_data_directory, members=track_progress(tar)) self.logger.info('File extracted to {}'.format(self.extracted_data_directory)) return True def train(self): import glob#用来进行文件搜索的库,支持通配符操作。 import csv#用来处理后面的tsv文件 # Download and extract the Ubuntu dialog corpus if needed corpus_download_path = self.download(self.data_download_url) # Extract if the directory doesn not already exists if not self.is_extracted(self.extracted_data_directory):#如果没有在指定路径解压,就进行解压。 self.extract(corpus_download_path) extracted_corpus_path = os.path.join(#各个语料库文件的路径。 self.extracted_data_directory, '**', '**', '*.tsv' ) file_kwargs = {}#初始化为字典类型 if sys.version_info[0] > 2:#也就是说尽量在python3.x下面进行。 # Specify the encoding in Python versions 3 and up file_kwargs['encoding'] = 'utf-8' # WARNING: This might fail to read a unicode corpus file in Python 2.x for file in glob.iglob(extracted_corpus_path):#对于路径下的所有文件,而iglob函数用来获取匹配括号中路径的文件的集合。 print("glob.iglob(参数)=",glob.iglob(extracted_corpus_path)) self.logger.info('Training from: {}'.format(file)) with open(file, 'r', **file_kwargs) as tsv: reader = csv.reader(tsv, delimiter='\t')#分隔符是'\t',等号左边的reader是个对象名,等号右侧的reader是class名 print("reader=",reader) previous_statement_text = None for row in reader:#row表示语料tsv文件中的每一行数据。 print("row = ",row) if len(row) > 0: text = row[3]#这个是根据数据集的特性来决定写代码的,因为语料库的第四列是对话内容。所以这里是row[3] print("进入新一轮循环,来看一下statement") statement = self.get_or_create(text) print("statement=",statement) print("------------------------------------------------------------") print(text, len(row)) statement.add_extra_data('datetime', row[0])#这个函数用来添加字典类型数据,例如此处添加的数据就是 # {datetime,row[0]} statement.add_extra_data('speaker', row[1]) print("row[2].strip()=",row[2].strip()) if row[2].strip():#这个代码的意思是,对话发起时,是不知道哪个是听众的,因为是在论坛发帖子。 #因为语料数据来自论坛交谈,论坛发帖子是不知道谁会回复的,所以出事状态下,发了一个新帖子以后,在帖子还没有回复的情况下,row[2]默认是空 #只有有回复者以后,row[2]才会有值。 #所以刚发完帖子后,由于row[2]是空,所以if下面的语句将不会被执行。 #上面的if语句中,Python strip() 方法用于移除字符串头尾指定的字符(默认为空格) #所以上面的if语句的意思是,row[2]这个数据集合中的第三个属性,进行去除空格处理,得到的数据是否为空(即判断这是否是个新发的帖子) print("进入if语句") statement.add_extra_data('addressing_speaker', row[2])#因为人们的交谈是一次只能有一个人说话,A说话时,B就只能听着, #可用来增加字典类型数据,如果用法如上,则函数调用前后无变化。 #所以row[1]和row[2]表示row[1]对row[2]讲话 print("*************************************") print("statement=",statement) if previous_statement_text:#这个previous_statement在这里代表上一次某人说的话,因为语料库来自论坛对话,所以论坛帖子没发以前,这个变量肯定是空的。 statement.add_response( Response(previous_statement_text) ) print("previous_statement_text=",previous_statement_text) print("statement=", statement) print("离开最后一个if") print("*************************************") previous_statement_text = statement.text#为下一轮for循环做准备,方便取得此次回答的下一轮回答。 #这样进入下一轮循环的时候,就可以正常进入if语句。 self.storage.update(statement)#两重for循环,第一重针对指定路径下的每个文件,第二重针对每个文件中的每一行,最终目的是更新storage。
在
http://blog.csdn.net/appleyuchi/article/details/77388397
中,相关语句是:
- chatbot.train("chatterbot.corpus.chinese")
train函数来自chatterbot.py文件中的train函数定义,该train函数定义如下:
@propertydef train(self): """ Proxy method to the chat bot's trainer class. """ return self.trainer.train
也就是说,函数的具体定义外包给了trainer class,这个class就处于上文分析的trains.py中。
- chatterbot中的trains.py详细解释
- chatterbot中的remove_stopwords函数用法
- Trains
- SWT中的FormLayout 详细解释
- Java中的ClassLoader详细解释
- javascript中的prototype详细解释
- web.py解释
- UCOS-III中的OS_CFG.H 详细解释
- cat.py解释sys.argv
- 详细解释JAVA中的静态绑定和动态绑定
- Linux服务器中的TCP连接状态详细解释
- Ext中的dom节点查找DomQuery详细解释
- USB3.0中的8/10b编码技术详细解释
- 编程语言中的冒号【 超详细解释】无私奉献
- Linux服务器中的TCP连接状态详细解释
- windows系统中的dll的作用详细解释
- C/C++中的日期和时间time_t详细解释
- 详细解释数据挖掘中的10大算法
- 一个通用的事件监听函数
- 获得maven地址,如何从中央仓库中下载jar
- 关于pdf转html的个人方法
- Depth Only
- MYSQL在15分钟插入千万条数据
- chatterbot中的trains.py详细解释
- 将darknet生成的.weight转化为.pb文件
- HDU2049 不容易系列之(4)——考新郎
- CSS Transform / Transition / Animation 属性的区别
- 网络传输数据格式的选择
- 2017 Multi-University Training Contest
- (一)VUE学习地址
- Multiple dex files define Landroid/support/v4/accessibilityservice解决方法
- C++ MFC listcontrol简单例子参考