K-SVD算法
来源:互联网 发布:程序员有前途吗 编辑:程序博客网 时间:2024/05/01 05:35
K-SVD算法的基本思想:
Y为训练样本,D为字典,X为稀疏系数。一般分为Sparse Coding和DictionaryUpdate两个步骤:
1:Sparse Coding:固定字典D通过下面的目标函数采用一种追踪算法找到样本的最佳稀疏矩阵。
2:Dictionary Update:按列更新字典,一句可使MSE减少的准则,通过SVD(奇异值分解)循序的更新每一列和该列对应的稀疏矩阵的值。
EK为字典的第k列的残差,物理意义:没有dk时表示的误差,也就是字典的第k列在表示Y的过程中究竟起到了多大的作用。
根据上面的EK的解释可以知道,我们的目的就是找到一个合适的dk来最大化减小EK。
为了得到dk就需要对EK 进行SVD(奇异值分解),Ek=UΔVT令矩阵U的第一列作为字典第K列更新后的dk,同时令Δ(1,1)乘以V的第一列作为更新后的稀疏系数。
下面是一个简单的利用KSVD和OMP算法的演示代码
代码流程:
Step1:读入的一张lena图片img
Step2: 随机生成一个测量矩阵phi
Step3:y=phi*img得到观测值
Step4:利用[Dictionary,]=KSVD[img,para]得到dictionary
Step5:利用A=OMP[phi*Dictionary,y,L]得到稀疏系数矩阵
Step6:img_rec=Dictionary*A得到重建的图像。Demo_Code_1.m
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% the K-SVD basis is selected as the sparse representation dictionary% the OMP algorithm is used to recover the image% Author: zhang ben, ncuzhangben@qq.com%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%***************************** read in the image **************************img=imread('lena.bmp'); % read in the image "lena.bmp"img=double(img);[N,n]=size(img); img0 = img; % keep an original copy of the input signal%****************form the measurement matrix and Dictionary ***************%form the measurement matrix PhiPhi=randn(N,n); Phi = Phi./repmat(sqrt(sum(Phi.^2,1)),[N,1]); % normalize each column%fix the parametersparam.L =20; % number of elements in each linear combination.param.K =150; %number of dictionary elementsparam.numIteration = 50; % number of iteration to execute the K-SVD algorithm.param.errorFlag = 0; % decompose signals until a certain error is reached. %do not use fix number of coefficients. %param.errorGoal = sigma;param.preserveDCAtom = 0;param.InitializationMethod ='DataElements';%initialization by the signals themselvesparam.displayProgress = 1; % progress information is displyed.[Dictionary,output]= KSVD(img,param);%Dictionary is N*param.K %************************ projection **************************************y=Phi*img; % treat each column as a independent signaly0=y; % keep an original copy of the measurements%********************* recover using OMP *********************************D=Phi*Dictionary;A=OMP(D,y,20);imgr=Dictionary*A; %*********************** show the results ******************************** figure(1)subplot(2,2,1),imagesc(img0),title('original image')subplot(2,2,2),imagesc(y0),title('measurement image')subplot(2,2,3),imagesc(Dictionary),title('Dictionary')psnr=20*log10(255/sqrt(mean((img(:)-imgr(:)).^2)));subplot(2,2,4),imagesc(imgr),title(strcat('recover image (',num2str(psnr),'dB)'))disp('over')
OMP.m(这是网友写好的代码)
function [A]=OMP(D,X,L); %=============================================% Sparse coding of a group of signals based on a given % dictionary and specified number of atoms to use. % input arguments: % D - the dictionary (its columns MUST be normalized).% X - the signals to represent% L - the max. number of coefficients for each signal.% output arguments: % A - sparse coefficient matrix.%=============================================[n,K]=size(D);[n,P]=size(X);for k=1:1:P, a=[]; x=X(:,k);%令向量x等于矩阵X的第K列的元素长度为n*1 residual=x;%n*1 indx=zeros(L,1);%L*1的0矩阵 for j=1:1:L, proj=D'*residual;%K*n n*1 变成K*1 [maxVal,pos]=max(abs(proj));% 最大投影系数对应的位置 pos=pos(1); indx(j)=pos; a=pinv(D(:,indx(1:j)))*x; residual=x-D(:,indx(1:j))*a; if sum(residual.^2) < 1e-6 break; end end; temp=zeros(K,1); temp(indx(1:j))=a; A(:,k)=sparse(temp);%A为返回为K*P的矩阵end;return;
KSVD算法实现代码:
function [Dictionary,output] = KSVD(... Data,... % an nXN matrix that contins N signals (Y), each of dimension n. param)% =========================================================================% K-SVD algorithm% =========================================================================% The K-SVD algorithm finds a dictionary for linear representation of% signals. Given a set of signals, it searches for the best dictionary that% can sparsely represent each signal. Detailed discussion on the algorithm% and possible applications can be found in "The K-SVD: An Algorithm for % Designing of Overcomplete Dictionaries for Sparse Representation", written% by M. Aharon, M. Elad, and A.M. Bruckstein and appeared in the IEEE Trans. % On Signal Processing, Vol. 54, no. 11, pp. 4311-4322, November 2006. % =========================================================================% INPUT ARGUMENTS:% Data an nXN matrix that contins N signals (Y), each of dimension n. % param structure that includes all required% parameters for the K-SVD execution.% Required fields are:% K, ... the number of dictionary elements to train% numIteration,... number of iterations to perform.% errorFlag... if =0, a fix number of coefficients is% used for representation of each signal. If so, param.L must be% specified as the number of representing atom. if =1, arbitrary number% of atoms represent each signal, until a specific representation error% is reached. If so, param.errorGoal must be specified as the allowed% error.% preserveDCAtom... if =1 then the first atom in the dictionary% is set to be constant, and does not ever change. This% might be useful for working with natural% images (in this case, only param.K-1% atoms are trained).% (optional, see errorFlag) L,... % maximum coefficients to use in OMP coefficient calculations.% (optional, see errorFlag) errorGoal, ... % allowed representation error in representing each signal.% InitializationMethod,... mehtod to initialize the dictionary, can% be one of the following arguments: % * 'DataElements' (initialization by the signals themselves), or: % * 'GivenMatrix' (initialization by a given matrix param.initialDictionary).% (optional, see InitializationMethod) initialDictionary,... % if the initialization method % is 'GivenMatrix', this is the matrix that will be used.% (optional) TrueDictionary, ... % if specified, in each% iteration the difference between this dictionary and the trained one% is measured and displayed.% displayProgress, ... if =1 progress information is displyed. If param.errorFlag==0, % the average repersentation error (RMSE) is displayed, while if % param.errorFlag==1, the average number of required coefficients for % representation of each signal is displayed.% =========================================================================% OUTPUT ARGUMENTS:% Dictionary The extracted dictionary of size nX(param.K).% output Struct that contains information about the current run. It may include the following fields:% CoefMatrix The final coefficients matrix (it should hold that Data equals approximately Dictionary*output.CoefMatrix.% ratio If the true dictionary was defined (in% synthetic experiments), this parameter holds a vector of length% param.numIteration that includes the detection ratios in each% iteration).% totalerr The total representation error after each% iteration (defined only if% param.displayProgress=1 and% param.errorFlag = 0)% numCoef A vector of length param.numIteration that% include the average number of coefficients required for representation% of each signal (in each iteration) (defined only if% param.displayProgress=1 and% param.errorFlag = 1)% =========================================================================if (~isfield(param,'displayProgress')) param.displayProgress = 0;endtotalerr(1) = 99999;if (isfield(param,'errorFlag')==0) param.errorFlag = 0;endif (isfield(param,'TrueDictionary')) displayErrorWithTrueDictionary = 1; ErrorBetweenDictionaries = zeros(param.numIteration+1,1); %产生零矩阵 ratio = zeros(param.numIteration+1,1);else displayErrorWithTrueDictionary = 0;ratio = 0;endif (param.preserveDCAtom>0) FixedDictionaryElement(1:size(Data,1),1) = 1/sqrt(size(Data,1));else FixedDictionaryElement = [];end% coefficient calculation method is OMP with fixed number of coefficientsif (size(Data,2) < param.K) disp('Size of data is smaller than the dictionary size. Trivial solution...'); Dictionary = Data(:,1:size(Data,2)); return;elseif (strcmp(param.InitializationMethod,'DataElements')) Dictionary(:,1:param.K-param.preserveDCAtom) = Data(:,1:param.K-param.preserveDCAtom);elseif (strcmp(param.InitializationMethod,'GivenMatrix')) Dictionary(:,1:param.K-param.preserveDCAtom) = param.initialDictionary(:,1:param.K-param.preserveDCAtom);end% reduce the components in Dictionary that are spanned by the fixed% elementsif (param.preserveDCAtom) tmpMat = FixedDictionaryElement \ Dictionary; Dictionary = Dictionary - FixedDictionaryElement*tmpMat;end%normalize the dictionary.Dictionary = Dictionary*diag(1./sqrt(sum(Dictionary.*Dictionary)));Dictionary = Dictionary.*repmat(sign(Dictionary(1,:)),size(Dictionary,1),1); % multiply in the sign of the first element.totalErr = zeros(1,param.numIteration);% the K-SVD algorithm starts here.for iterNum = 1:param.numIteration % find the coefficients if (param.errorFlag==0) %CoefMatrix = mexOMPIterative2(Data, [FixedDictionaryElement,Dictionary],param.L); CoefMatrix = OMP([FixedDictionaryElement,Dictionary],Data, param.L); else %CoefMatrix = mexOMPerrIterative(Data, [FixedDictionaryElement,Dictionary],param.errorGoal); CoefMatrix = OMPerr([FixedDictionaryElement,Dictionary],Data, param.errorGoal); param.L = 1; end replacedVectorCounter = 0;rPerm = randperm(size(Dictionary,2)); for j = rPerm [betterDictionaryElement,CoefMatrix,addedNewVector] = I_findBetterDictionaryElement(Data,... [FixedDictionaryElement,Dictionary],j+size(FixedDictionaryElement,2),... CoefMatrix ,param.L); Dictionary(:,j) = betterDictionaryElement; if (param.preserveDCAtom) tmpCoef = FixedDictionaryElement\betterDictionaryElement; Dictionary(:,j) = betterDictionaryElement - FixedDictionaryElement*tmpCoef; Dictionary(:,j) = Dictionary(:,j)./sqrt(Dictionary(:,j)'*Dictionary(:,j)); end replacedVectorCounter = replacedVectorCounter+addedNewVector; end if (iterNum>1 & param.displayProgress) if (param.errorFlag==0) output.totalerr(iterNum-1) = sqrt(sum(sum((Data-[FixedDictionaryElement,Dictionary]*CoefMatrix).^2))/prod(size(Data))); disp(['Iteration ',num2str(iterNum),' Total error is: ',num2str(output.totalerr(iterNum-1))]); else output.numCoef(iterNum-1) = length(find(CoefMatrix))/size(Data,2); disp(['Iteration ',num2str(iterNum),' Average number of coefficients: ',num2str(output.numCoef(iterNum-1))]); end end if (displayErrorWithTrueDictionary ) [ratio(iterNum+1),ErrorBetweenDictionaries(iterNum+1)] = I_findDistanseBetweenDictionaries(param.TrueDictionary,Dictionary); disp(strcat(['Iteration ', num2str(iterNum),' ratio of restored elements: ',num2str(ratio(iterNum+1))])); output.ratio = ratio; end Dictionary = I_clearDictionary(Dictionary,CoefMatrix(size(FixedDictionaryElement,2)+1:end,:),Data); if (isfield(param,'waitBarHandle')) waitbar(iterNum/param.counterForWaitBar); endendoutput.CoefMatrix = CoefMatrix;Dictionary = [FixedDictionaryElement,Dictionary];%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% findBetterDictionaryElement%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%function [betterDictionaryElement,CoefMatrix,NewVectorAdded] = I_findBetterDictionaryElement(Data,Dictionary,j,CoefMatrix,numCoefUsed)if (length(who('numCoefUsed'))==0) numCoefUsed = 1;endrelevantDataIndices = find(CoefMatrix(j,:)); % the data indices that uses the j'th dictionary element.if (length(relevantDataIndices)<1) %(length(relevantDataIndices)==0) ErrorMat = Data-Dictionary*CoefMatrix; ErrorNormVec = sum(ErrorMat.^2); [d,i] = max(ErrorNormVec); betterDictionaryElement = Data(:,i);%ErrorMat(:,i); % betterDictionaryElement = betterDictionaryElement./sqrt(betterDictionaryElement'*betterDictionaryElement); betterDictionaryElement = betterDictionaryElement.*sign(betterDictionaryElement(1)); CoefMatrix(j,:) = 0; NewVectorAdded = 1; return;endNewVectorAdded = 0;tmpCoefMatrix = CoefMatrix(:,relevantDataIndices); tmpCoefMatrix(j,:) = 0;% the coeffitients of the element we now improve are not relevant.errors =(Data(:,relevantDataIndices) - Dictionary*tmpCoefMatrix); % vector of errors that we want to minimize with the new element% % the better dictionary element and the values of beta are found using svd.% % This is because we would like to minimize || errors - beta*element ||_F^2. % % that is, to approximate the matrix 'errors' with a one-rank matrix. This% % is done using the largest singular value.[betterDictionaryElement,singularValue,betaVector] = svds(errors,1);CoefMatrix(j,relevantDataIndices) = singularValue*betaVector';% *signOfFirstElem%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% findDistanseBetweenDictionaries%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%function [ratio,totalDistances] = I_findDistanseBetweenDictionaries(original,new)% first, all the column in oiginal starts with positive values.catchCounter = 0;totalDistances = 0;for i = 1:size(new,2) new(:,i) = sign(new(1,i))*new(:,i);endfor i = 1:size(original,2) d = sign(original(1,i))*original(:,i); distances =sum ( (new-repmat(d,1,size(new,2))).^2); [minValue,index] = min(distances); errorOfElement = 1-abs(new(:,index)'*d); totalDistances = totalDistances+errorOfElement; catchCounter = catchCounter+(errorOfElement<0.01);endratio = 100*catchCounter/size(original,2);%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% I_clearDictionary%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%function Dictionary = I_clearDictionary(Dictionary,CoefMatrix,Data)T2 = 0.99;T1 = 3;K=size(Dictionary,2);Er=sum((Data-Dictionary*CoefMatrix).^2,1); % remove identical atomsG=Dictionary'*Dictionary; G = G-diag(diag(G));for jj=1:1:K, if max(G(jj,:))>T2 | length(find(abs(CoefMatrix(jj,:))>1e-7))<=T1 , [val,pos]=max(Er); Er(pos(1))=0; Dictionary(:,jj)=Data(:,pos(1))/norm(Data(:,pos(1))); G=Dictionary'*Dictionary; G = G-diag(diag(G)); end;end;
这是运行代码之后的结果:
1 0
- K-SVD算法
- K-SVD算法总结
- K-SVD算法
- K-SVD算法简介
- K-SVD算法总结
- K-SVD算法
- K-SVD算法学习
- K-SVD算法
- K-SVD算法
- K-SVD
- k-svd
- K-SVD
- K-SVD
- K-SVD
- 基于K-SVD稀疏字典的图像去噪算法
- SVD++算法
- SVD算法
- K-SVD, BM3D等
- Linux服务器配置——搭建SVN服务器
- toposort算法模板
- JavaScript_03 事件
- [unity3d]navmesh 自动寻路 鼠标点击的坐标获取 鼠标点击的世界坐标
- 深入浅出JMS(一)——JMS简介
- K-SVD算法
- 用GDB调试程序(一)
- 用GDB调试程序(二)
- 用GDB调试程序(三)
- Android应用开发多语言文件夹
- asp.net i18n 支持
- 用GDB调试程序(四)
- 栈的学习
- 用GDB调试程序(五)