TensorFlow脑洞人脸识别(二)

来源:互联网 发布:linux如何切换用户 编辑:程序博客网 时间:2024/06/12 21:48

本篇博客给大家介绍数据的处理、模型的建立、识别算法

 

1.数据的处理

比较简单,先上代码:

 

import osimport threadingimport randomimport timeimport numpy as npimport cv2import tensorflow as tfimport classfierdataList=[]dataLabelList=[]isStarted=FalseisInited=FalsemaxCachedData=200cachedImageList=[]cachedIndexList=[]lock=threading.Lock()imageDir=["face","body","background"]def initSampler(path):    global isInited,isStarted    if isStarted!=False:return    isStarted=True    thread=threading.Thread(target=threadFun,args=[path])    thread.setDaemon(True)    thread.start()def threadFun(path):    global isInited,isStarted    if loadImageData(path)!=True:        isStarted=False        return    isInited=True    fillCache()def loadImageData(path):    global dataList,dataLabelList    dataLabelList=[]    dataList=[]    for dir in imageDir:        dirPath=os.path.join(path,dir)        if os.path.isdir(dirPath)!=True:continue        dataLabelList.append(dir)        tempList=[]        for image in os.listdir(dirPath):            imagePath=os.path.join(dirPath,image)            tempList.append(cv2.resize(cv2.imread(imagePath),classfier.IMAGE_SIZE))        dataList.append(tempList)    return Trueimport tensorflow as tfinput_image=tf.placeholder(tf.float32,[classfier.IMAGE_SIZE[0],classfier.IMAGE_SIZE[1],3])output_image=tf.image.per_image_standardization(input_image)session=tf.Session()def preProcessImage(image):    return session.run([output_image],feed_dict={input_image:image})[0]def fillCache():    global dataList    global lock,cachedImageList,cachedIndexList    while True:        if len(cachedImageList)>=maxCachedData:            time.sleep(0.01)            continue        indexList=[]        imageList=[]        for i in range(maxCachedData):            index=random.randint(0,len(dataList)-1)            image=dataList[index][random.randint(0,len(dataList[index])-1)]            image=transfromImage(image,0.2,0.2,0.2).astype(np.float32)            image=preProcessImage(image)            indexList.append(index)            imageList.append(image)        lock.acquire()        cachedIndexList+=indexList        cachedImageList+=imageList        lock.release()def transfromImage(image,fRandomFactor,hRandomFactor, vRandomFactor):    shape=image.shape[0:2]    fromPoint=[[0,0],[shape[0]-1,0],[0,shape[1]-1]]    moveScale = shape[0]    hMoveLength = moveScale * random.uniform(-hRandomFactor, hRandomFactor)    vMoveLength = moveScale * random.uniform(-vRandomFactor, vRandomFactor)    transformScale=shape[0]/4    toPoint = np.array(fromPoint, np.float32)    for i in range(3):        toPoint[i][0]+=transformScale*random.uniform(-fRandomFactor,fRandomFactor)+hMoveLength        toPoint[i][1] += transformScale * random.uniform(-fRandomFactor, fRandomFactor)+vMoveLength    transfromMat=cv2.getAffineTransform(np.array(fromPoint,np.float32),toPoint)    return cv2.warpAffine(image,transfromMat,shape,flags=cv2.INTER_LINEAR)def rotateImage(image,randomFactor):    shape=image.shape[0:2]    angle=random.uniform(-randomFactor,randomFactor)    rotateMat=cv2.getRotationMatrix2D((shape[0]/2,shape[1]/2),angle,1)    return cv2.warpAffine(image,rotateMat,shape)def check():    if isStarted==True and isInited==False:        while True:            if isInited==False:                time.sleep(5)            else:                break    if isInited!=True:        raise Exception("sampler is  not inited")    if len(dataLabelList)!=len(dataList):        raise Exception("sampler data is wrong")def getBatch(num):    global maxCachedData,lock,cachedImageList,cachedIndexList    check()    if num*2>maxCachedData:        maxCachedData=2*num    while True:        if len(cachedImageList)<=num:            print("sleep to get batch")            time.sleep(0.01)        lock.acquire()        imageList=cachedImageList[0:num]        cachedImageList=cachedImageList[num:len(cachedImageList)]        indexList=cachedIndexList[0:num]        cachedIndexList=cachedIndexList[num:len(cachedIndexList)]        lock.release()        return np.array(imageList),np.array(indexList)def getLabel(index):    if index<0 or index>=len(imageDir):        raise Exception("invalid index")    return imageDir[index]

 

外部模块调用initSampler(path)方法后,会启动一个线程。这个线程会把所有数据加载到内存里,并预先处理图片放到队列里,以供网络模型使用。处理方法是先用opencv进行一个随机的仿射变换,然后用tensorflow提供的方法去均值并归一化。

外部模块调用getBatch(num)方法可以获取指定数量的minibatch数据。

 

 

2.模型的建立

主要参考Alexnet,不过做了很大的简化,上代码

 

IMAGE_SIZE=(128,128)import threadingimport tensorflow as tfimport numpy as npimport cv2import samplerdef addConvLayer(input,shape,name):    with tf.name_scope(name) as scope:        weight=tf.Variable(tf.truncated_normal(shape,stddev=0.1),name="weight")        bias=tf.Variable(tf.zeros([shape[3]]),name="bias")        output=tf.nn.relu(tf.nn.conv2d(input,weight,[1,1,1,1],padding="SAME")+bias,name="output")    weightLoss=tf.multiply(tf.nn.l2_loss(weight),0.01)    tf.add_to_collection("loss",weightLoss)    return outputdef addFullLayer(input,shape,name):    with tf.name_scope(name) as scope:        weight = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name="weight")        bias = tf.Variable(tf.zeros([shape[1]]), name="bias")        output=tf.nn.relu(tf.matmul(input,weight)+bias,name="output")    weightLoss = tf.multiply(tf.nn.l2_loss(weight), 0.01)    tf.add_to_collection("loss",weightLoss)    return outputinput_image=tf.placeholder(tf.float32,[IMAGE_SIZE[0],IMAGE_SIZE[1],3])output_image=tf.image.per_image_standardization(input_image)inputImage=tf.placeholder(tf.float32,[None,IMAGE_SIZE[0],IMAGE_SIZE[1],3])inputLabel=tf.placeholder(tf.int32,[None])conv1=addConvLayer(inputImage,[7,7,3,8],"conv1")pool1=tf.nn.max_pool(conv1,[1,4,4,1],[1,4,4,1],padding="SAME")conv2=addConvLayer(pool1,[5,5,8,16],"conv2")pool2=tf.nn.max_pool(conv2,[1,4,4,1],[1,4,4,1],padding="SAME")conv3=addConvLayer(pool2,[5,5,16,32],"conv3")pool3=tf.nn.max_pool(conv3,[1,4,4,1],[1,4,4,1],padding="SAME")featureLength=int(IMAGE_SIZE[0]*IMAGE_SIZE[1]/128)print("feature length="+str(featureLength))reShape=tf.reshape(pool3,[-1,featureLength])fc1=addFullLayer(reShape,[featureLength,32],"fc1")fc2=addFullLayer(fc1,[32,3],"fc2")logit=tf.nn.softmax(fc2)outIndex=tf.argmax(fc2,axis=1)cross_entropy=tf.nn.sparse_softmax_cross_entropy_with_logits(    logits=fc2,labels=inputLabel,name="cross_entropy")tf.add_to_collection("loss",tf.reduce_mean(cross_entropy))train_op=tf.train.AdamOptimizer().minimize(tf.add_n(tf.get_collection("loss")))accurate=tf.reduce_sum(tf.cast(tf.equal(tf.cast(inputLabel,tf.int64),tf.argmax(fc2,1)),tf.float32))initer=tf.global_variables_initializer()saver=tf.train.Saver()#######################################################################BATCH=50message=""isStart=Falsepath=Nonedef startTrain():    global isStart,path    if isStart==True:        return    isStart=True    thread=threading.Thread(target=threadFun,args=[tf.get_default_graph()])    thread.setDaemon(True)    thread.start()def setSavePath(savePath):    global path    path=savePathdef threadFun(graph):    global message,isStart,path    sess=tf.Session(graph=graph)    sess.run([initer])    iterCount = np.zeros([200], np.float32)    accurateCount = np.zeros([200], np.float32)    curIter=-1    while True:        curIter+=1        trainData=sampler.getBatch(BATCH)        retAcc=sess.run([accurate,train_op],feed_dict={inputImage:trainData[0],inputLabel:trainData[1]})[0]        iterCount[curIter % 200] = 50        accurateCount[curIter % 200] = retAcc        message="iter:"+str(curIter)+"  cur:"+str(int(retAcc))\                +"  accurate:"+str(np.sum(accurateCount)/np.sum(iterCount))        if path!=None:            saver.save(sess,path)            sess.close()            isStart=False            path = None            message=""            return#######################################################################classfySess=Nonedef loadModule(path):    global classfySess    if classfySess!=None:        classfySess.close()    classfySess=tf.Session()    saver.restore(sess=classfySess,save_path=path)def recognizeImage(image):    if classfySess==None:        raise Exception("classfy session is not inited")    image=cv2.resize(image,IMAGE_SIZE).astype(np.float32)    image=classfySess.run([output_image],feed_dict={input_image:image})[0]    probability=classfySess.run([logit],feed_dict={inputImage:np.array([image])})[0][0]    returnDic={}    returnDic[sampler.getLabel(0)]=probability[0]    returnDic[sampler.getLabel(1)]=probability[1]    returnDic[sampler.getLabel(2)]=probability[2]    return returnDic

 

构建模型的代码在最前面,比较简单,不介绍。模型用了3个卷积层2个全连接层,激活函数采用relu。所有权值均采用了L2正则化。

外部模块调用startTrain()可以开始模型的训练,不过在此之前需要先初始化数据处理模块。调用setSavePath(savePath)可以保存模型到指定路径并结束训练。

外部模块调用loadModule(path)可以载入指定的模型,之后就可以recognizeImage(image)对图像进行识别,这个方法返回一个字典,里面储存图片对应类型的概率。

 

 

3.识别算法

方法第一篇已经介绍过了,直接上代码吧!

 

import classfierMIN_SREACH_SIZE=40subRegionList=[(0,0.6,0,1),(0.4,1,0,1),(0,1,0,0.6),(0,1,0.4,1),(0.2,0.8,0.2,0.8)]def getFaceRegion(image):    retDic,retRect=recognizeFace(image,(0,image.shape[0],0,image.shape[1]))    if retDic["background"] >= 1.0 / 3:        return None    return retRect[0]/image.shape[0],retRect[1]/image.shape[0],retRect[2]/image.shape[1],retRect[3]/image.shape[1]def recognizeFace(image,rect):    retDic=classfier.recognizeImage(image[rect[0]:rect[1],rect[2]:rect[3]])    if retDic["background"]>=1.0/3:        return retDic,rect    width=rect[1]-rect[0]    height=rect[3]-rect[2]    if  width<=MIN_SREACH_SIZE or height<=MIN_SREACH_SIZE:        return retDic,rect    goodDic=retDic    goodRect=rect    for subRegion in subRegionList:        subRect=[]        subRect.append(rect[0]+int(subRegion[0]*width))        subRect.append(rect[0] + int(subRegion[1] * width))        subRect.append(rect[2] + int(subRegion[2] * height))        subRect.append(rect[2] + int(subRegion[3] * height))        subRetDic,subRetRect=recognizeFace(image,subRect)        if subRetDic["face"]>goodDic["face"]:            goodDic=subRetDic            goodRect=subRetRect    return goodDic,goodRect

 

外部模块调用getFaceRegion(image)可以获取图片对应的人脸区域,再次之前需要初始化人脸识别模型模块。

recognizeFace(image,rect)方法先判断此图片的类型,如果是“face”或“body”,且图片尺寸大于MIN_SREACH_SIZE,则会按照subRegionList对图片进行分割,递归调用自身,尝试获取“face”概率最高的分割区域

 

 

4.总结

没什么好总结的,下一篇介绍运行结果

原创粉丝点击