GMM EM C CODE c代码

来源:互联网 发布:虚拟机mac可以升级吗 编辑:程序博客网 时间:2024/04/30 08:39
///////////////////////////////
//function: calculate u and sigma with histogram
//summary : get u and sigma
//parameters:
//          pImg : the image
//          num  : total pixels
//          u    : 均值  引用传递 返回
//          sigma2: 方差  注意不是平均差
//引用传递 返回u和sigma^2 

int GetUAndSigma(const IplImage *pImg, int num, float &u, float &sigma2)  
{
    
int i;
    
int histo[256];    //个数统计
    float sum;
    
    
    sum 
= 0.0;
    
    
for(i = 0; i < 256; i++)
    
{
        histo[i] 
= 0;
    }

    
    
//histogram
    for(i = 0; i < num; i++ )
    
{    
        histo[(unsigned 
char)pImg->imageData[i]]++;
    }

    
    
for(i = 0; i < 256; i++)
    
{
        sum 
+= i * histo[i];    
    }

    
    u 
= sum / num;       //平均值
    
    sum 
= 0.0;
    
    
for(i = 0; i < 256; i++)
    
{
        sum 
+= (u - i) * (u - i) * histo[i] ;
    }

    
    sigma2 
= sum / num;  //方差   注意:不是标准差
    
    
return true;
    
}


//////////////////////////////////////////////////
//summary: 利用flag 标志更新各个类的均值和方差
//
//
void UpdateUandSigma(const IplImage *pImg, int num, int k, char *flag, int *center, int *sigma)
{
    
int i,j;
    
int histo[256];    //个数统计
    float sum;
    
int n ;
    
    
for(j = 0; j < k; j++)
    
{
        sum 
= 0.0;
        
        n 
= 0;
        
        
for(i = 0; i < 256; i++)
        
{
            histo[i] 
= 0;
        }

        
        
//histogram
        for(i = 0; i < num; i++ )
        
{    
            
if(flag[i] == j)
            
{
                histo[(unsigned 
char)pImg->imageData[i]]++;
                n
++;
            }

        }

        
        
for(i = 0; i < 256; i++)
        
{
            sum 
+= i * histo[i];    
        }

        
        center[j] 
= sum / n;       //平均值
        
        sum 
= 0.0;
        
        
for(i = 0; i < 256; i++)
        
{
            sum 
+= (center[j] - i) * (center[j] - i) * histo[i] ;
        }

        
        sigma[j] 
= sum / n;  //方差   注意:不是标准差
    }

    
}


//////////////////////////////////////
//
//
//
void GMM(char *fileName, int k)
{
    
//-------opencv 读取图象相关数据----------
    IplImage *pImg;
    IplImage 
*pImgK;
    
    
int width, height, step;
    
int num;          //total pixels  width * height
    
    pImg 
= cvLoadImage(fileName, CV_LOAD_IMAGE_GRAYSCALE);
    
    
if(NULL == pImg)
    
{
        printf(
"image open fail ");
        
return ;
    }

    
    pImgK 
= cvCloneImage(pImg);
    width 
= pImg->width;
    height 
= pImg->height;
    step 
= pImg->widthStep;
    num 
= width * height;
    
    
//--------------------------------------------------
    
    
int i, j, ii;  //loop variant 
    
    
int *center;   //clustering center  均值
    int *sigma;    //方差 每个类一个方差 注意:是方差不是标准差
    float *a;        //ai i = 0..k-1 sum(ai) == 1
    float *z;        // k * num matrix  p(xk|wj)
    float *p;        // k*num matrix  p(wj|xk)
    char *flag;      //indicate which class the points are in
    unsigned char data; //pixel value
    
    
float u, sigma2;  //
    int iteration;
    
    center 
= (int*) malloc(k * sizeof(int));  //for memory reason , substitute int for char
    sigma = (int*) malloc(k * sizeof(int));  //方差数组  注意:是方差不是标准差
    a = (float*) malloc(k*sizeof(float));        //
    z = (float*) malloc(k * num * sizeof(float));  //
    p = (float*) malloc(k * num * sizeof(float));
    flag 
= (char*) malloc(num * sizeof(char));
    
    
//------函数初始化------
    
    GetUAndSigma(pImgK, num, u, sigma2);
    
    printf(
"the whole image's u = %f  sigma2 = %f  ", u, sigma2);   //test
    
    
for(i = 0; i < k; i++)
    
{
        a[i] 
= 1.0 / k;     //初始化每部分概率为1/k 各部分概率相同 for p
    }

    
    
for(i = 0; i < (k*num); i++)
    
{
        z[i] 
= 0.0;   
        p[i] 
= 0.0;
    }

    
    
//srand((unsigned) time(NULL));
    for(i = 0; i < num; i++)
    
{
        flag[i] 
= rand() % k;                //随机初始化各个类 
    }

    
//---------------------------------------
    
    UpdateUandSigma(pImg, num, k, flag, center, sigma);
    
    
    
//主循环
    iteration = 0;
    
float maxLike_old; 
    
float    maxLike;
    
    maxLike 
= FLT_MAX;
    
    
    
do
    
{
        
        printf(
"center[0] = %d, center[1] = %d  ", center[0], center[1]); //test
        printf("sigma[0] = %d, sigma[1] = %d  ",sigma[0], sigma[1]);
        printf(
"a[0] = %f, a[1] = %f ", a[0], a[1]);
        
        maxLike_old 
= maxLike;
        
        
//E-step
        
//---------------p(xn|wk)--------------------
        for(i = 0; i < num; i++)
        
{
            data  
= pImg->imageData[i] ;
            
            
for(j = 0; j < k; j++)
            
{    
                z[j
*num + i] = exp(-pow((data - center[j]), 2.0/ (2* sigma[j] + FLT_MIN)) / (sqrt(2*PI*sigma[j]) + FLT_MIN) ;
            }

        }
    //计算每个点在对应类的正态分布值
        
        
//e-step
        float maxProb = FLT_MIN;
        
char currentClass = -1;
        
float temp;
        
        
for(i = 0; i < num; i++)
        
{
            currentClass 
= -1;  //进入前清空
            maxProb = FLT_MIN;
            temp 
= 0.0;
            
            
for(ii = 0; ii < k; ii++)
            
{
                temp 
+= z[ii*num + i] * a[ii];    //分母啊
            }

            
            
for(j = 0; j < k; j++)
            
{
                p[j
*num + i] = z[j*num + i]* a[j] / temp ;      //p(wj|xk)
                if(p[j*num + i] > maxProb)
                
{
                    maxProb 
= p[j*num + i];
                    currentClass 
= j ;
                }

            }
    //bayes公式 点在哪个类的概率大 划到哪个类去
            
            flag[i] 
= currentClass;
            
        }

        
        
//        for(i = 0; i < num; i++)
        
//        {
        
//            printf("flag = %d ", flag[i]);
        
//        }
        
// UpdateUandSigma(pImg, num, k, flag, center, sigma);
        
//------------------------------------------------------------------
        
        
        
//M-step
        float denominator; //分母
        float numerator ;  //分子
        
//m-step
        
        
for(j = 0; j < k; j++)
        
{
            numerator 
= 0.0;
            denominator 
= 0.0;
            
            
for(i = 0; i < num; i++)
            
{
                data 
= pImg->imageData[i];
                numerator 
+= p[j*num +i] * data;
                denominator 
+= p[j*num +i];
            }

            
            center[j] 
= numerator / denominator ;
            a[j] 
= denominator / num ;
        }

        
        
//UpdateUandSigma(pImg, num, k, flag, center, sigma);
        
        
        
// 计算max likehood for GMM
        maxLike = 0.0;
        
        
        
for(i = 0; i < num; i++)
        
{
            temp 
= 0.0;
            
for(j = 0; j < k; j++)
            
{
                temp 
+= a[j]*z[j*num + i] ;
            }

            maxLike 
+= log(temp);
        }

        
        iteration
++;
        
        printf(
"maxLike = %f ", maxLike);
        printf(
"maxLike_old = %f  ", maxLike_old);
        
    }
while(fabs(maxLike - maxLike_old) > 0.000001);   //fabs(maxLike - maxLike_old) > 0
    
    printf(
"center[0] = %d, center[1] = %d  ", center[0], center[1]); //test
    printf(" iteration = %d  ", iteration);
    
    
    
    printf(
" z[i][j]  ");
    
/*    for(i = 0; i < num; i++)
    {
    
      for(j = 0; j < k; j++)
      {
      printf(" %f ", z[j*num +i]);
      }
      printf("  ");
      }
      
        printf(" p[i][j]  ");
        for(i = 0; i < num; i++)
        {
        
          for(j = 0; j < k; j++)
          {
          printf(" %f ", p[j*num +i]);
          }
          printf("  ");
    } //
*/

    
    
//-------------display clustering results-----------------
    int color = 255 / k;
    
for(i = 0; i < height; i++)
        
for(j = 0; j < width; j++)
        
{
            pImgK
->imageData[i*step +j] = (unsigned char) flag[i*step + j] *  color ;
            
        }

        
        
        
//--------display image-----------------------------------------
        cvNamedWindow("source", CV_WINDOW_AUTOSIZE);
        cvNamedWindow(
"GMM", CV_WINDOW_AUTOSIZE);
        
        cvShowImage(
"source", pImg);
        cvShowImage(
"GMM",pImgK);
        
        cvWaitKey(
0);
        
        cvDestroyWindow(
"source");
        cvDestroyWindow(
"GMM");
        
        cvReleaseImage(
&pImg);
        cvReleaseImage(
&pImgK);
        
        
        free(center);    
//release memory
        center = NULL;
        free(a);
        a 
= NULL;
        free(z);
        z 
= NULL;
        free(flag);
        flag 
= NULL;
        
}

 
原创粉丝点击