基于CCA的图像文本交叉检索

来源:互联网 发布:js给input添加样式 编辑:程序博客网 时间:2024/06/05 00:30

1、研究内容摘要

  互联网的出现使得网络媒体内容在量和种类上出现井喷式的增长,跨媒体检索的需求也日益增大。本文主要关注基于CCA(典型关联分析)的“图像—文本”交叉检索算法的实现。要用计算机实现图像、文本这两种最常见的媒体内容之间的交叉检索,首先要分别把图像和文本各自用某种特征向量表示,即把图像数据映射到图像特征空间I1,文本数据映射到文本特征空间T1。然而特征空间I1T1之间并没有直接的联系,CCA算法则可以通过许多“图像—样本”样本对的训练把I1T1分别映射到I2T2,其中特征空间I2T2是线性相关的,可以直接度量I2T2中特征向量之间的相似性,从而为图像—文本交叉检索提供了理论基础。下面将结合具体的数据进行详细论述。

2、数据集

  本文所用的图像文本数据为维基百科公开数据库Wikipedia articles,可下载1.4GB的原始数据压缩文件,也可仅下载由原始数据提取的1.2MB的特征数据压缩文件。特征数据文件解压后如下图所示:
  特征数据文件
  其中raw_features.mat可以用matlab导入查看,其他三个.list文件也可以用windows自带的写字板打开查看。raw_features.mat包含四个变量如下图所示:
  这里写图片描述
  上图中I_tr 、T_tr分别为2173个“图片—文本”训练样本中的图像和文本提取出的特征。图像特征是128维的SIFT特征,文本特征是由10个主题的LDA文本模型生成的10维特征。而I_tr 、T_tr则相应的是693个测试样本对应的特征数据。
  testset_txt_img_cat.list文件用写字板打开如下图所示:
  这里写图片描述
  上图中每一行代表一个“图片—文本”测试样本,共有693行。可以看到每行被隔开成三部分,第一部分为文本部分对应的原始文件的文件名,第二部分为图片原始文件的文件名,第三部分为该行“图片—文本”对应的种类编号(1-10)。
  用以上raw_features.mat和testset_txt_img_cat.list两个文件就可以初步完成基于CCA的图像文本交叉检索算法。

3、代码解析

  算法的第一步自然是从原始的图像、文本数据提取各自的特征,如上文所述,本文直接使用已经处理好的特征数据,其中图像为128维SIFT特征、文本为10主题LDA模型提取的10维特征,具体提取算法并非本文重点,在此不进行深入探讨。
  有了2173个训练样本的图像特征和文本特征数据,就可以用CCA算法学习图像特征和文本特征的组合权重系数,使得变换后的图像特征和文本特征具有最大的相关性,CCA算法原理可参考“典型关联分析(CCA)算法原理”。
  根据参考文献《A New Approach to Cross-Modal Multimedia Retrieval》可以知道,可通过学习得到d组组合系数向量对,每一组组合系数向量对包括对应于128维图像特征的128维组合系数向量和对应于10维文本特征的10维组合系数向量,对于图像特征,每个组合系数向量把原来的128维数据通过线性组合生成一个数据,所以d个组合系数向量可以把128维特征映射为d维特征向量。对于文本特征也是如此,即把10维特征向量映射为d维特征向量。CCA的处理使得图像的d维特征和文本的d维特征是线性相关的,通过对训练数据的中心化预处理后,甚至可以认为这两个d维子空间是同一个子空间,这意味着图像特征生成的d维特征和文本特征生成的d维特征之间可以直接用来计算相似性,而这正是跨媒体检索的核心。
  运行CCA算法前的训练数据的中心化预处理非常重要,具体步骤是先减去随机的均值,再除以随机变量的方差的平方根。设训练数据矩阵为X、Y,对应的matlab代码为:

    vX = sqrt(var(X,1));    vY = sqrt(var(Y,1));    mX = mean(X,1);    mY = mean(Y,1);    X = (X - repmat(mX,size(X,1),1))./repmat(vX,size(X,1),1);    Y = (Y - repmat(mY,size(Y,1),1))./repmat(vY,size(Y,1),1);    X(find(isnan(X))) = 0;    Y(find(isnan(Y))) = 0;

  而CCA算法可以通过matlab自带的canoncorr函数实现。最后可以得到d组组合系数向量对,通过这些组合系数可以把用于测试的693个样本的图像特征矩阵I_te (693 x 128)转换为d维特征空间的矩阵Scaled_I_te(693 x d),文本特征矩阵T_te(693 x 10)也转换为693 x d大小的矩阵Scaled_T_te。
  之后就是实现图像文本交叉检索了,为了之后检索结果的统计方便,首先求出查询集(query set)和检索集(retrieve set)的完整距离矩阵distAll,distAll[i,j]为查询集中第i条数据和检索集中第j条数据之间的距离,其中距离的度量又有多种标准,本文中使用NC距离(normalised correlation),即归一化相关系数,来计算两个向量之间的相似性。如果是以图像来检索文本,则查询集为Scaled_I_te(693 x d)矩阵,检索集为同样大小的Scaled_T_te矩阵,生成的distAll为693 x 693矩阵。而查询集图像和训练集文本的ground_truth(即真实种类标签)是相同的,都是testset_txt_img_cat.list文件的最后一列(693 x 1)数据。
  最后的部分涉及信息检索方面的理论,常用的表征信息检索算法性能的主要有map(Mean Average Precision),pr曲线(准确率—召回率)等。
  求map的matlab代码如下所示:

function  [query, class] = ir_perquery2(gt_quer, distance_mtx, gt_retr)% distance_mtx : Distance of query from the rest of the datapoints%                Each row refers to one query% ground_truth : A ground truth vector of categories (starting from 1)%% per_query is a stucture with following elements:%   map   : Mean Average Precision, for all the queries%   apr   : struct with the following two fields:%      .P : Average P/n for all the queries%      .R : Average R/n for all the queries%           use the last two to compute P/R and P/n curves%   pr    : precision recal, to plot against 0:0.001:1%   cm    : confusion matrix[r,c]=size(distance_mtx);testPoints=c;queryPoints=r;%%cat_num = max(ground_truth);[r,c]=size(gt_quer);if r>c,    [r,c]=size(gt_retr);    if c>r,        gt_retr=gt_retr';    endelse    gt_quer=gt_quer';    [r,c]=size(gt_retr);    if c>r,        gt_retr=gt_retr';    endend;gt_cat=sort(unique([gt_quer;gt_retr]));cat_num = length(gt_cat); %cardinality of the categoriescat_card = zeros(1,cat_num);for i = 1 : cat_num    %the true sum of each category in all retrieve items    cat_card(i) = length(find(gt_retr == gt_cat(i)));  endMAP = zeros(queryPoints, cat_num); %rank accuracyfor itext = 1 : queryPoints    dist = distance_mtx(itext,:);    [foo ind] = sort(dist,'ascend'); %(give index in V)    %most similar image (take the original index)    %pick the class to which the query belongs to    for cls = 1:cat_num      classe = gt_cat(cls); %ground_truth(itext);      %classes of all quesries, from best to worst match      classeT = gt_retr(ind);      %make 0-1 GT      classeGT  = (classeT == classe);      %compute the indexes in the rank      ranks = find(classeGT)';      %compute AP for the query      map = sum((1:length(ranks))./ranks)/length(ranks);      %store      if isfinite(map),        MAP(itext,cls) = map;      end    end    %% change 1    %truemap(itext) = MAP(itext, gt_quer(itext));    idx_cls=find(gt_cat == gt_quer(itext));    truemap(itext) = MAP( itext, idx_cls );    %% end of change 1    classe = gt_quer(itext);    classeT = gt_retr(ind);    classeGT  = (classeT == classe)';    pn(itext,:) = cumsum(classeGT)./[1:testPoints];    %% change 2    %rprecision(itext) = pn( itext, cat_card(gt_quer(itext)) );    if cat_card(idx_cls)>0,        rprecision(itext) = pn( itext, cat_card(idx_cls) );    else        % no element on the retrieval set belonging to the query object's        % category. I.e. R-Precision = div by 0 (NaN)        rprecision(itext) = NaN;    endendquery.map = mean(truemap); % final map 

  pr曲线、ROC曲线等其他指标的求法在此不进行详细展开,可通过查阅信息检索领域的相关文献资料深入了解。

0 0
原创粉丝点击