数据集无损处理

来源:互联网 发布:淘宝卖的银壶是真的吗 编辑:程序博客网 时间:2024/05/22 03:25
import numpy as npimport os,globfrom PIL import Imageimport pandas as pdimport scipy.io as sioclass Dataset:    def __init__(self):        self.classes = ['jp2k', 'jpeg', 'wn','gblur','fastfading']        self.cwd = os.getcwd()        self.dataset = []        self.label_index = 0        self.arr = [[]]        self.img_list = [[[[]]]]        self.img_label = []    def crop_concat(self,img,window,stride,labels): #this function return a data of ndarray type        window = 120        stride = 60        len_iter_x = np.floor_divide((img.size[0] - window),stride) + 1        len_iter_y = np.floor_divide((img.size[1] - window),stride) + 1        iterx_array = np.arange(0,stride * len_iter_x,stride)        itery_array = np.arange(0,stride * len_iter_y,stride)        img_arr = [[[]]]        for i in iterx_array:            for j in itery_array:                img1 = img.crop((i,j,i + window,j + window))                img_1 = np.asarray(img1)                #img_1 = img_1[np.newaxis,:,:,:]                if img_arr == [[[]]]:                    img_arr = [img_1]                else:                    img_arr = np.concatenate([img_arr,[img_1]],axis = 0)        for i in range(img_arr.shape[0]):            if self.img_list == [[[[]]]]:                self.img_list = [img_arr[0]]                self.img_list.append(labels)                img_array = np.array(img_list).reshape([1,2])            else:                self.img_list = [img_arr[i]]                self.img_list.append(labels)                img_array1 = np.array(self.img_list).reshape([1,2])                img_array = np.concatenate([img_array,img_array1],axis = 0)  # num of imgs: len_iter_x * len_iter_y        self.img_list = [[[[]]]]        img_arr  = [[[]]]        return img_array    def img2array(self):        f  = open('/home/xm/data/dmos.mat','rb')        labelset= sio.loadmat(f)        labelset = labelset['dmos']        labelset = labelset.reshape([982,1])        for index,name in enumerate(self.classes):            class_path = self.cwd + '/data/' + name + '/'            for infile in glob.glob(class_path + '*.bmp'):                file,ext = os.path.splitext(infile)                img = Image.open(infile)                if self.label_index == 0:                    first_img_info = self.crop_concat(img,window = 120,stride = 60,labels = labelset[self.label_index])                    print(first_img_info.shape)                elif self.label_index == 1:                    img_info = self.crop_concat(img,window = 120,stride = 60,labels = labelset[self.label_index])                    self.img_label = np.concatenate([first_img_info,img_info],axis = 0)                else:                    img_info = self.crop_concat(img,window = 120,stride = 60,labels = labelset[self.label_index])                    print(img_info.shape)                    self.img_label = np.concatenate([self.img_label,img_info],axis = 0)                self.label_index += 1        return self.img_label
haha = Dataset()dataset = haha.img2array()print(dataset.shape)
a = []for i in range(dataset.shape[0]):    if dataset[i][1] == 0:        a.append(i)dataset = np.delete(dataset,a,axis = 0)np.random.shuffle(dataset)print(dataset.shape)show_img = dataset[100][0]show_img = Image.fromarray(show_img,mode = 'RGB')show_img.show()
原创粉丝点击