Initial trainset(with all classes)

来源:互联网 发布:淘宝我是商家怎么激活 编辑:程序博客网 时间:2024/05/16 19:19
function init_train_xticI = imread('LCB1.tif');I = double(I) / double(max(I(:)));Y = imread('berlin_GT.tif');Y = double(Y);pos = tabulate(Y(:));pos(1,:) = [];train_x = zeros(sum(ceil(0.5*pos(:,2))), 28*28);train_y = zeros(sum(ceil(0.5*pos(:,2))), 17);test_x = zeros(sum(pos(:,2)) - sum(ceil(0.5*pos(:,2))), 28*28);test_y = zeros(sum(pos(:,2)) - sum(ceil(0.5*pos(:,2))), 17);k_train = 1;k_test = 1;for i = 1:size(pos,1)    RAND = randperm(ceil(0.5*pos(i,2)));    REST = setdiff(1:pos(i,2), RAND);    [all_y_x, all_y_y] = find(Y == pos(i,1));    train_y(k_train:(k_train+size(RAND,2)-1), pos(i,1)) = 1;    test_y(k_test:(k_test+size(REST,2)-1), pos(i,1)) = 1;    tmp_train = [all_y_x(RAND), all_y_y(RAND)];    tmp_test = [all_y_x(REST), all_y_y(REST)];    for j = 1:size(RAND,2)        block_train = zeros(28, 28);        for ii = -13:14            for jj = -13:14                if (tmp_train(j,1)+ii)>0 && (tmp_train(j,2)+jj)>0 && (tmp_train(j,1)+ii)<=666 && (tmp_train(j,2)+jj)<=643                    block_train(ii+14, jj+14) = I(tmp_train(j,1)+ii, tmp_train(j,2)+jj);                end            end        end        train_x(k_train, :) = reshape(block_train, 1, 28*28);        k_train = k_train+1;    end    for j = 1:size(REST,2)        block_test = zeros(28, 28);        for ii = -13:14            for jj = -13:14                if (tmp_test(j,1)+ii)>0 && (tmp_test(j,2)+jj)>0 && (tmp_test(j,1)+ii)<=666 && (tmp_test(j,2)+jj)<=643                    block_test(ii+14, jj+14) = I(tmp_test(j,1)+ii, tmp_test(j,2)+jj);                end            end        end        test_x(k_test, :) = reshape(block_test, 1, 28*28);        k_test = k_test + 1;    endendsave berlin_1 train_x train_y test_x test_ytoc
1 0