python光流算法学习

来源:互联网 发布:电脑usb端口上的电涌 编辑:程序博客网 时间:2024/04/27 22:17

某次随便写代码时遇到了光流算法,就简单的看了看代码,对它有了一点点了解,记录一下。

其实代码在python opencv里的例子里就有,路径为:
D:\Program Files\opencv\sources\samples\python2\lk_track.py

网上很多这方面的资料,不过没有给出详细的注释,于是我自己看了看代码,把一些看懂的地方再加上注释,不对的地方望指教。

import numpy as npimport cv2from common import anorm2, draw_strfrom time import clock#光流检测的参数lk_params = dict( winSize  = (15, 15),#搜索窗口的大小                  maxLevel = 2,#最大的金字塔层数                  # 指定停止条件,具体没懂                  criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 10, 0.03))#角点检测的参数feature_params = dict( maxCorners = 500,#最大角点数                       qualityLevel = 0.3,#角点最低质量                       minDistance = 7,#角点间最小欧式距离                       blockSize = 7 )#这个没懂,我做角点检测时只设置了上面几个参数,望指教class App:    def __init__(self, video_src):        self.track_len = 10        self.detect_interval = 5        self.tracks = []        self.cam = video.create_capture(video_src)        self.frame_idx = 0    def run(self):        while True:            ret, frame = self.cam.read()#通过摄像头获取一张图片            frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)#转化为灰度图            vis = frame.copy()#赋值frame的值,不覆盖frame本身            if len(self.tracks) > 0:#检测到角点后进行光流跟踪                 img0, img1 = self.prev_gray, frame_gray                p0 = np.float32([tr[-1] for tr in self.tracks]).reshape(-1, 1, 2)#对np数组进行重塑                #前一帧的角点和当前帧的图像作为输入来得到角点在当前帧的位置,有点绕,具体实现有兴趣就去看源码吧                p1, st, err = cv2.calcOpticalFlowPyrLK(img0, img1, p0, None, **lk_params)                #当前帧跟踪到的角点及图像和前一帧的图像作为输入来找到前一帧的角点位置                  p0r, st, err = cv2.calcOpticalFlowPyrLK(img1, img0, p1, None, **lk_params)                 d = abs(p0-p0r).reshape(-1, 2).max(-1)#得到角点回溯与前一帧实际角点的位置变化关系                good = d < 1#判断d内的值是否小于1,大于1跟踪被认为是错误的跟踪点,为什么是1不知道                new_tracks = []                #将跟踪正确的点列入成功跟踪点                for tr, (x, y), good_flag in zip(self.tracks, p1.reshape(-1, 2), good):                    if not good_flag:                        continue                    tr.append((x, y))                    if len(tr) > self.track_len:                        del tr[0]                    new_tracks.append(tr)                    cv2.circle(vis, (x, y), 2, (0, 255, 0), -1)#画圆                self.tracks = new_tracks                #以上一振角点为初始点,当前帧跟踪到的点为终点划线                cv2.polylines(vis, [np.int32(tr) for tr in self.tracks], False, (0, 255, 0))                draw_str(vis, (20, 20), 'track count: %d' % len(self.tracks))            if self.frame_idx % self.detect_interval == 0:#每5帧检测一次特征点                mask = np.zeros_like(frame_gray)#初始化和视频大小相同的图像                mask[:] = 255#将mask赋值255也就是算全部图像的角点                for x, y in [np.int32(tr[-1]) for tr in self.tracks]:#跟踪的角点画圆                    cv2.circle(mask, (x, y), 5, 0, -1)                p = cv2.goodFeaturesToTrack(frame_gray, mask = mask, **feature_params)#角点检测                if p is not None:                    for x, y in np.float32(p).reshape(-1, 2):                        self.tracks.append([(x, y)])#将检测到的角点放在待跟踪序列中            self.frame_idx += 1            self.prev_gray = frame_gray            cv2.imshow('lk_track', vis)            ch = 0xFF & cv2.waitKey(1)#按esc退出            if ch == 27:                breakdef main():    import sys    try: video_src = sys.argv[1]    except: video_src = 0    print __doc__    App(video_src).run()    cv2.destroyAllWindows()if __name__ == '__main__':    main()

其实有很多地方都还没弄明白,等以后需要用到在去仔细了解了,以上内容参考http://blog.csdn.net/gjy095/article/details/9226883(应该不算抄袭吧,如果有,提醒我删帖)。

之后我在具体应用中,作了一些修改,发现源代码中有一些地方可有可无,实在没明白这么写的意义,总之贴出我修改后的代码做参考,有明白的人麻烦指点一下。

from PIL import ImageGrabimport numpy as npimport cv2import timelk_params = dict( winSize  = (15, 15),                  maxLevel = 2,                  criteria = (cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 10, 0.03))track_len = 10detect_interval = 5tracks = []frame_idx = 0#time.sleep(10)while(True):    #bbox specifies specific region (bbox= x,y,width,height *starts top-left)    #把原来的调用摄像头改成屏幕截图了,其实这样子看不出光流的效果,只是我的具体应用里有需求    img = ImageGrab.grab(bbox=(600,300,800,500))    #this is the array obtained from conversion    img_np = np.array(img)#转化为np数组    frame_gray = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)     vis = img_np.copy()    # 原来的goodfeature改成了fastfeature,为了提高效率    fast = cv2.FastFeatureDetector()    value_x=[]    if len(tracks) > 0:#检测到角点后进行光流跟踪         img0, img1 = prev_gray, frame_gray        p0 = np.float32([tr[-1] for tr in tracks]).reshape(-1, 1, 2)        #自己写的一段角点筛选代码,具体意义忘了        for i in range(0,len(p0)):            temp=10            for j in range(i+1,len(p0)):                              distance_x=abs(p0[i][0][0]-p0[j][0][0])                if distance_x<temp:                    temp=distance_x            if temp==10:                value_x.append(p0[i])        new_value_x=np.float32([tr[-1] for tr in value_x]).reshape(-1, 1, 2)        new_p1, st, err = cv2.calcOpticalFlowPyrLK(img0, img1, new_value_x, None, **lk_params)        new_p0r, st, err = cv2.calcOpticalFlowPyrLK(img1, img0, new_p1, None, **lk_params)        new_d = abs(new_value_x-new_p0r).reshape(-1, 2).max(-1)#得到角点回溯与前一帧实际角点的位置变化关系        mean=np.mean(new_d)        p1, st, err = cv2.calcOpticalFlowPyrLK(img0, img1, p0, None, **lk_params)        p0r, st, err = cv2.calcOpticalFlowPyrLK(img1, img0, p1, None, **lk_params)        d = abs(p0-p0r).reshape(-1, 2).max(-1)#得到角点回溯与前一帧实际角点的位置变化关系        good = d <=0.5* mean#判断d内的值是否小于1,大于1跟踪被认为是错误的跟踪点        new_tracks = []        for tr, (x, y), good_flag in zip(tracks, p1.reshape(-1, 2), good):#将跟踪正确的点列入成功跟踪点            if not good_flag:                continue            tr.append((x, y))            if len(tr) > track_len:                del tr[0]            new_tracks.append(tr)            cv2.circle(vis, (x, y), 2, (0, 255, 0), -1)        tracks = new_tracks        #cv2.polylines(vis, [np.int32(tr) for tr in tracks], False, (0, 255, 0))    if frame_idx % detect_interval == 0:#每5帧检测一次特征点        p = fast.detect(frame_gray,  None)# fast角点检测        lenght_p=len(p)        if p is not None:            for i in range(0,lenght_p):                tracks.append([ p[i].pt])#将检测到的角点放在待跟踪序列中    frame_idx += 1    prev_gray = frame_gray    cv2.imshow('lk_track', vis)    ch = 0xFF & cv2.waitKey(1)      if ch == 27:          breakcv2.destroyAllWindows()

主体没变,改了一些地方,对比一下就可以发现。

0 0
原创粉丝点击