[DL]基于Pytorch的Linear classified model
来源:互联网 发布:身份证人脸比对知乎 编辑:程序博客网 时间:2024/06/06 08:45
留个坑,后面再添加详细解释
data = [("我 的 家 乡 在 哪 里".split(), "CHINESE"), ("Give it to me".split(), "ENGLISH"), ("今 天 天 气 怎 么 样".split(), "CHINESE"), ("No it is not a good idea to get lost at sea".split(), "ENGLISH")]test_data = [("天 气 在 哪 里".split(), "CHINESE"), ("it is lost on me".split(), "ENGLISH")]word2index={}for sen,_ in data+test_data: for word in sen: if word not in word2index: word2index[word]=len(word2index)print(word2index)VOCAB_SIZE=len(word2index)NUM_CLASSES=2label2index={'CHINESE':0,'ENGLISH':1}import torch.nn as nnimport torch.autograd as autogradimport torch.functional as Fclass BoWClassifier(torch.nn.Module): def __init__(self,num_labels,vocab_size): super(BoWClassifier,self).__init__() self.linear=nn.Linear(vocab_size,num_labels) self.softmax=nn.Softmax() def forward(self,bow_vec): return self.softmax(self.linear(bow_vec))def make_bow_vector(sentence,word2index): vec=torch.zeros(len(word2index)) for word in sentence: vec[word2index[word]]+=1 return vec.view(1,-1)def make_target(label,label2index): return torch.LongTensor([label2index[label]])model=BoWClassifier(NUM_CLASSES,VOCAB_SIZE)# for param in model.parameters():# print(param)log_prob=model(autograd.Variable(make_bow_vector(data[0][0],word2index)))print(log_prob)loss_function=nn.NLLLoss()optimizer=torch.optim.SGD(model.parameters(),lr=0.1)for epoch in range(100): for instance,label in data: model.zero_grad() bow_vec=autograd.Variable(make_bow_vector(instance,word2index)) target=autograd.Variable(make_target(label,label2index))# print(target) log_prob=model(bow_vec)# print(log_prob) loss=loss_function(log_prob,target) loss.backward() optimizer.step()for instance,label in test_data: bow_vec=autograd.Variable(make_bow_vector(instance,word2index)) log_prob=model(bow_vec)# print(log_prob)print(model(autograd.Variable(make_bow_vector(['Give','good','good','good','good'],word2index))))print(model(autograd.Variable(make_bow_vector(['我','我','我','我','我'],word2index))))
阅读全文
0 0
- [DL]基于Pytorch的Linear classified model
- [DL]基于Pytorch的N-gram Language Model
- [DL]基于Pytorch的seq2seq模型
- [DL]基于pytorch的Elman RNN语言模型
- 基于DL的图像分割
- pytorch-generative-model-collections
- pytorch model 2 coreml
- 基于pytorch的图像分类框架
- generalized linear model, GLM
- Global Linear Model
- 1- Simple Linear Model
- 2. Linear Model
- 线性模型(Linear Model)
- 基于DL的目标检测概述
- R: anova或linear model 的 MAPE计算
- 基于RHadoop的linear-least-squares算法
- 基于DL的计算机视觉(11)-- 基于DL的快速图像检索系统
- 基于PyTorch的深度学习入门教程(一)——PyTorch安装和配置
- apply提高数组找出最大最小值的方式(性能)
- java防止xss脚本注入攻击,采用spring工具类方式
- 功能测试框架
- HDU 5976 Detachment
- 蓝桥杯 算法训练 友好数
- [DL]基于Pytorch的Linear classified model
- Java 中的调试
- JAVA必背面试题和项目面试通关要点
- 智能家居标准ZHA
- 拷贝工具BeanUtilsBean扩展
- 三、读第八、九章
- Spring Boot整合Morphia访问MongoDB
- MQ入门总结(三)ActiveMQ的用法和实现
- LiteORM框架导入Android Studio步骤简介