keras实现Unet进行字符定位与识别分类

来源:互联网 发布:网络直播云南电视台 编辑:程序博客网 时间:2024/06/03 09:24
#coding=utf-8import cv2import numpy as npfrom keras.utils import to_categoricalfrom model.augmentations import randomHueSaturationValue, randomShiftScaleRotate, randomHorizontalFlipfrom keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint, TensorBoardimport matplotlib.pyplot as pltfrom keras.preprocessing.image import img_to_arrayfrom keras.utils.vis_utils import plot_modelfrom keras import backend as Kfrom keras.callbacks import ModelCheckpoint,Callback, EarlyStoppingclass LossHistory(Callback):    def on_train_begin(self, logs={}):        self.losses = []    def on_batch_end(self, batch, logs={}):        self.losses.append(logs.get('loss'))    # def on_epoch_end(self, epoch, logs=None):#unet model
def get_unet_128_muticlass(input_shape=(None, 128, 128, 3),                 num_classes=1):    inputs = Input(batch_shape=input_shape)#shape=input_shape)    # 128    down1 = Conv2D(64, (3, 3), padding='same')(inputs)    down1 = BatchNormalization()(down1)    down1 = Activation('relu')(down1)    down1 = Conv2D(64, (3, 3), padding='same')(down1)    down1 = BatchNormalization()(down1)    down1 = Activation('relu')(down1)    down1_pool = MaxPooling2D((2, 2), strides=(2, 2))(down1)    # 64    down2 = Conv2D(128, (3, 3), padding='same')(down1_pool)    down2 = BatchNormalization()(down2)    down2 = Activation('relu')(down2)    down2 = Conv2D(128, (3, 3), padding='same')(down2)    down2 = BatchNormalization()(down2)    down2 = Activation('relu')(down2)    down2_pool = MaxPooling2D((2, 2), strides=(2, 2))(down2)    # 32    down3 = Conv2D(256, (3, 3), padding='same')(down2_pool)    down3 = BatchNormalization()(down3)    down3 = Activation('relu')(down3)    down3 = Conv2D(256, (3, 3), padding='same')(down3)    down3 = BatchNormalization()(down3)    down3 = Activation('relu')(down3)    down3_pool = MaxPooling2D((2, 2), strides=(2, 2))(down3)    # 16    down4 = Conv2D(512, (3, 3), padding='same')(down3_pool)    down4 = BatchNormalization()(down4)    down4 = Activation('relu')(down4)    down4 = Conv2D(512, (3, 3), padding='same')(down4)    down4 = BatchNormalization()(down4)    down4 = Activation('relu')(down4)    down4_pool = MaxPooling2D((2, 2), strides=(2, 2))(down4)    # 8    center = Conv2D(1024, (3, 3), padding='same')(down4_pool)    center = BatchNormalization()(center)    center = Activation('relu')(center)    center = Conv2D(1024, (3, 3), padding='same')(center)    center = BatchNormalization()(center)    center = Activation('relu')(center)    # center    up4 = UpSampling2D((2, 2))(center)    up4 = concatenate([down4, up4], axis=3)    up4 = Conv2D(512, (3, 3), padding='same')(up4)    up4 = BatchNormalization()(up4)    up4 = Activation('relu')(up4)    up4 = Conv2D(512, (3, 3), padding='same')(up4)    up4 = BatchNormalization()(up4)    up4 = Activation('relu')(up4)    up4 = Conv2D(512, (3, 3), padding='same')(up4)    up4 = BatchNormalization()(up4)    up4 = Activation('relu')(up4)    # 16    up3 = UpSampling2D((2, 2))(up4)    up3 = concatenate([down3, up3], axis=3)    up3 = Conv2D(256, (3, 3), padding='same')(up3)    up3 = BatchNormalization()(up3)    up3 = Activation('relu')(up3)    up3 = Conv2D(256, (3, 3), padding='same')(up3)    up3 = BatchNormalization()(up3)    up3 = Activation('relu')(up3)    up3 = Conv2D(256, (3, 3), padding='same')(up3)    up3 = BatchNormalization()(up3)    up3 = Activation('relu')(up3)    # 32    up2 = UpSampling2D((2, 2))(up3)    up2 = concatenate([down2, up2], axis=3)    up2 = Conv2D(128, (3, 3), padding='same')(up2)    up2 = BatchNormalization()(up2)    up2 = Activation('relu')(up2)    up2 = Conv2D(128, (3, 3), padding='same')(up2)    up2 = BatchNormalization()(up2)    up2 = Activation('relu')(up2)    up2 = Conv2D(128, (3, 3), padding='same')(up2)    up2 = BatchNormalization()(up2)    up2 = Activation('relu')(up2)    # 64    up1 = UpSampling2D((2, 2))(up2)    up1 = concatenate([down1, up1], axis=3)    up1 = Conv2D(64, (3, 3), padding='same')(up1)    up1 = BatchNormalization()(up1)    up1 = Activation('relu')(up1)    up1 = Conv2D(64, (3, 3), padding='same')(up1)    up1 = BatchNormalization()(up1)    up1 = Activation('relu')(up1)    up1 = Conv2D(64, (3, 3), padding='same')(up1)    up1 = BatchNormalization()(up1)    up1 = Activation('relu')(up1)    # 128    classify = Conv2D(num_classes, (1, 1), activation='softmax')(up1)    model = Model(inputs=inputs, outputs=classify)    model.compile(optimizer=RMSprop(lr=0.001), loss=categorical_crossentropy, metrics=['acc'])    return model
print('model summary...')model = get_unet_128_muticlass(num_classes=3)model.summary()plot_model(model,'model.png', show_shapes=True)SIZE = (128, 128)def fix_mask(mask): mask[mask < 100] = 0 mask[mask == 128] = 128 mask[mask > 128] = 255def fix_mask_onehot(mask): mask[mask < 100] = 0 mask[mask == 128] = 1 mask[mask > 128] = 2# Processing function for the training datadef train_process(data): img, mask = data img = img[:,:,:3] mask = mask[:, :, :3] mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) fix_mask(mask) img = cv2.resize(img, SIZE) mask = cv2.resize(mask, SIZE) img = randomHueSaturationValue(img, hue_shift_limit=(-50, 50), sat_shift_limit=(0, 0), val_shift_limit=(-15, 15)) img, mask = randomShiftScaleRotate(img, mask, shift_limit=(-0.0625, 0.0625), scale_limit=(-0.1, 0.1), rotate_limit=(-20, 20)) img, mask = randomHorizontalFlip(img, mask) fix_mask(mask) img = img/255. # mask = mask/255. # mask = np.expand_dims(mask, axis=2) # mask =np.reshape(mask, (16384,1)) # print(np.shape(mask)) # fix_mask_onehot(mask) # # print(list(mask)) # mask =to_categorical(mask,num_classes=3) # print(np.shape(mask)) # mask = np.expand_dims(mask, axis=2) mask_onehot=[] for i in range(128): for j in range(128): if mask[i,j]>200: mask_onehot.append([1,0,0]) elif mask[i,j]<100: mask_onehot.append([0,0,1]) else: mask_onehot.append([0,1,0]) mask_onehot=np.reshape(mask_onehot, (128,128,3)) # print(np.shape(mask_onehot)) return (img, mask_onehot)# x=cv2.imread(r'data\train3_muticlass\test.tif',cv2.IMREAD_COLOR)# y=cv2.imread(r'data\train3_muticlass\mask\test.tif',cv2.IMREAD_COLOR)# x=255-x# y=255-y# # cv2.imshow('x',x)# # cv2.imshow('y',y)# # print(np.shape(x))# (xx,yy)=train_process((x,y))# exit()# Processing function for the validation data, no data augmentationdef validation_process(data): img, mask = data img = img[:,:,:3] mask = mask[:, :, :3] mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) fix_mask(mask) img = cv2.resize(img, SIZE) mask = cv2.resize(mask, SIZE) fix_mask(mask) img = img/255. # mask = mask/255. mask_onehot=[] for i in range(128): for j in range(128): if mask[i,j]>200: mask_onehot.append([1,0,0]) elif mask[i,j]<100: mask_onehot.append([0,0,1]) else: mask_onehot.append([0,1,0]) mask_onehot=np.reshape(mask_onehot, (128,128,3)) # print(np.shape(mask_onehot)) return (img, mask_onehot)dir=r'data\train3_muticlass'epochs=500# model.load_weights('weights/best_weights.hdf5')import glob,osx_train = []y_train = []for file in glob.glob(dir+r'\*.tif'): print(file) print(dir+r'\mask'+os.path.split(file)[1]) x = cv2.imread(file, cv2.IMREAD_COLOR) y = cv2.imread(dir+r'\mask\\'+os.path.split(file)[1], cv2.IMREAD_COLOR) x = 255 - x # y = 255 - y for i in range(30): (xx, yy) = train_process((x, y)) x_train.append(xx) y_train.append(yy)x_val = []y_val = []for file in glob.glob(dir+r'\*.tif'): print(file) print(dir+r'\mask'+os.path.split(file)[1]) x = cv2.imread(file, cv2.IMREAD_COLOR) y = cv2.imread(dir+r'\mask\\'+os.path.split(file)[1], cv2.IMREAD_COLOR) x = 255 - x # y = 255 - y (xx, yy) = validation_process((x, y)) x_val.append(xx) y_val.append(yy)# print(np.shape(x_train))# print(np.shape(y_train))# 用于保存验证集误差最小的参数,当验证集误差减少时,立马保存下来checkpointer = ModelCheckpoint(filepath="model_muticlass.w", verbose=0, save_best_only=True, save_weights_only=True) #save_weights_only=Truehistory = LossHistory()earlystop = EarlyStopping(patience=5)model.load_weights('model_muticlass.w')model.fit(np.array(x_train), np.array(y_train), epochs=100, batch_size=3,verbose=1, validation_data=(np.array(x_val), np.array(y_val)), callbacks=[checkpointer, history, earlystop] )model.load_weights('model_muticlass.w')for file in glob.glob(dir+r'\*.tif'): x = cv2.imread(file, cv2.IMREAD_COLOR) y = cv2.imread(dir+r'\mask\\'+os.path.split(file)[1], cv2.IMREAD_COLOR) x = 255 - x # y = 255 - y (xx,yy)=validation_process((x,y)) x_val = [] x_val.append(xx) print(np.shape(x_val)) # xx=np.expand_dims(xx, axis=0) # yy=np.expand_dims(yy, axis=0) predicted_mask_batch = model.predict(np.array(x_val)) print(np.shape(predicted_mask_batch)) predicted_mask = predicted_mask_batch[0].reshape((128,128,3)) plt.imshow(xx[0]) plt.imshow(predicted_mask[:,:,2], alpha=0.6) plt.show()K.clear_session()
阅读全文
0 0
原创粉丝点击