vc++ socket实现的支持断点续传的下载器

来源:互联网 发布:中国国家话剧院 知乎 编辑:程序博客网 时间:2024/06/05 07:54

 原文:http://blog.csdn.net/windless0530/article/details/6799882

网上找了一堆代码,有用wininet的,还有用socket的,整理了半天,还是觉得socket靠谱。


只支持内存中断点续传。如果要加上在磁盘上断点续传,原理也差不多,不是本文重点。


注释:

1. CByteBufferVector是一个缓存池,动态分配BYTE形数组空间用的。代码略,可以简单看成BYTE数组。

2. GetStringA是一个CString转CStringA的函数,无需多说。

3. 除了win socket基本没有其它依赖,噢对,ATL::CString除外……


.h头文件:

 

 

  1. class CSocketDownloader;  
  2.   
  3.   
  4. /** 
  5.  *<SPAN style="WHITE-SPACE: pre">  </SPAN>下载任务 
  6.  */  
  7. class CDownloadTask  
  8. {  
  9.     friend class CSocketDownloader;  
  10.   
  11.   
  12. public:  
  13.   
  14.   
  15.     CDownloadTask();  
  16.   
  17.   
  18.     CStringA GetUrlA() const;  
  19.     CStringA GetAgnetA() const;  
  20.     void ParseUrl();  
  21.       
  22.     int Percentage() const;  
  23.     DWORD RemainTimeSec(DWORD dwTickElapsed, unsigned int uBytesTransferred) const;  
  24.   
  25.   
  26.     CString         m_strUrl;           // 下载地址  
  27.     CString         m_strAgent;         // 用户agent  
  28.     int             m_nMaxTryCount;     // 最多重试次数(重定向不算重试,默认20次)  
  29.     int             m_nTimeoutSec;      // socket超时(秒,默认10秒)  
  30.     int             m_nPort;            // 端口(默认80)  
  31.     HWND            m_hWnd;             // 接收下载进度消息的窗口句柄  
  32.     LONG            *m_pTerminate;      // 指向是否中止的标志位,一般由用户界面操作(如点击“取消”按钮)更改此值  
  33.   
  34.   
  35. protected:  
  36.   
  37.   
  38.     CStringA        m_strAbsoluteUrlA;  
  39.     CStringA        m_strQueryA;  
  40.     CStringA        m_strHostA;  
  41.     unsigned int    m_uReadBytes;  
  42.     unsigned int    m_uTotalBytes;  
  43. };  
  44.   
  45.   
  46. /** 
  47.  *<SPAN style="WHITE-SPACE: pre">  </SPAN>socket实现的断点续传下载器 
  48.  */  
  49. class CSocketDownloader  
  50. {  
  51. public:  
  52.     CSocketDownloader();  
  53.     virtual ~CSocketDownloader();  
  54.   
  55.   
  56.     // 下载到一个buffer   
  57.     DWORD DownloadToBuffer(CDownloadTask &task, CByteBufferVector &bufVec);  
  58.   
  59.   
  60.     // 下载到一个文件   
  61.     DWORD DownloadToFile(CDownloadTask &task, CString strOutputFile);  
  62.   
  63.   
  64. protected:  
  65.   
  66.   
  67.     DWORD DoDownloadToBuffer(CDownloadTask &task, CByteBufferVector &bufVec);  
  68.     DWORD ConnectServer(const CDownloadTask &task, SOCKET hSocket);  
  69.     DWORD DoDownloadToBufferInner(CDownloadTask &task, CByteBufferVector &bufVec, SOCKET hSocket);  
  70.       
  71.     int GetSleepSecCount(int nTryCount) const;  
  72.     int GetBufferSize(const CDownloadTask &task) const;  
  73.     CStringA GenerateRequest(CDownloadTask &task) const;  
  74.   
  75.   
  76. };  

     

    CPP文件:

    [cpp] view plaincopyprint?
    1. <SPAN style="FONT-FAMILY: Arial, Verdana, sans-serif"><SPAN style="WHITE-SPACE: normal"><SPAN style="FONT-FAMILY: monospace"><SPAN style="WHITE-SPACE: pre">#include <math.h>  
    2. #include <time.h>   
    3.   
    4.   
    5. const int BLOCK_SIZE = 1024 * 64;  
    6. const int DEFAULT_MAX_TRY = 20;  
    7. const int DEFAULT_TIMEOUT = 10;  
    8.   
    9.   
    10. //////////////////////////////////////////////////////////////////////////  
    11. // 下载任务   
    12. //////////////////////////////////////////////////////////////////////////  
    13.   
    14.   
    15. CDownloadTask::CDownloadTask()  
    16.     : m_nPort(INTERNET_DEFAULT_HTTP_PORT),  
    17.     m_nMaxTryCount(DEFAULT_MAX_TRY),  
    18.     m_uReadBytes(0),  
    19.     m_uTotalBytes(0),  
    20.     m_nTimeoutSec(DEFAULT_TIMEOUT),  
    21.     m_hWnd(NULL),  
    22.     m_pTerminate(NULL)  
    23. {  
    24.   
    25.   
    26. }  
    27.   
    28.   
    29. CStringA CDownloadTask::GetUrlA() const  
    30. {  
    31.     return GetStringA(m_strUrl);  
    32. }  
    33.   
    34.   
    35. CStringA CDownloadTask::GetAgnetA() const  
    36. {  
    37.     return GetStringA(m_strAgent);  
    38. }  
    39.   
    40.   
    41. void CDownloadTask::ParseUrl()  
    42. {  
    43.     m_strAbsoluteUrlA = m_strHostA = m_strQueryA = "";  
    44.   
    45.   
    46.     CStringA strUrlA = this->GetUrlA();  
    47.     const char *pUrl = strUrlA;  
    48.     const char *p = pUrl;  
    49.     const char *szHttpHead = "http://";  
    50.   
    51.   
    52.     if (_strnicmp(pUrl, szHttpHead, strlen(szHttpHead)) == 0)  
    53.     {  
    54.         p = pUrl + strlen(szHttpHead);  
    55.     }  
    56.   
    57.   
    58.     int nHostLen = 0;  
    59.     const char *q = strchr(p, '/');  
    60.     if (q != NULL)  
    61.     {  
    62.         nHostLen = q - p;  
    63.         int nPathLen = 0;  
    64.         const char *r = strchr(q, '?');  
    65.         if (r != NULL)  
    66.         {  
    67.             // 解析query   
    68.             r++;  
    69.             m_strQueryA = r;  
    70.             nPathLen = r - q - 1;  
    71.         }  
    72.         else  
    73.         {  
    74.             nPathLen = strlen(q);  
    75.         }  
    76.   
    77.   
    78.         // 解析abs_path   
    79.         m_strAbsoluteUrlA.Append(q, nPathLen);  
    80.     }  
    81.     else  
    82.     {  
    83.         nHostLen = strlen(p);  
    84.     }  
    85.   
    86.   
    87.     // 解析host   
    88.     m_strHostA.Append(p, nHostLen);  
    89.   
    90.   
    91.     // 解析port   
    92.     const char *r = strchr(m_strHostA, ':');  
    93.     if (r == 0)  
    94.     {  
    95.         m_nPort = INTERNET_DEFAULT_HTTP_PORT;  
    96.     }  
    97.     else  
    98.     {  
    99.         m_nPort = atoi(r + 1);  
    100.     }  
    101. }  
    102.   
    103.   
    104. int CDownloadTask::Percentage() const  
    105. {  
    106.     return (m_uTotalBytes == 0)  
    107.         ? 0  
    108.         : (int)((unsigned long long)m_uReadBytes * 100 / (unsigned long long) m_uTotalBytes);  
    109. }  
    110.   
    111.   
    112. DWORD CDownloadTask::RemainTimeSec( DWORD dwTickElapsed, unsigned int uBytesTransferred ) const  
    113. {  
    114.     unsigned long long uTickElapsed = (unsigned long long)dwTickElapsed;  
    115.     unsigned long long uBytes = (unsigned long long)uBytesTransferred;  
    116.     unsigned long long uRemain = (unsigned long long)(m_uTotalBytes - m_uReadBytes);  
    117.     Log(_T("elapsed=%d, get=%d, remain=%d\n"), dwTickElapsed, uBytesTransferred, m_uTotalBytes - m_uReadBytes);  
    118.     return (DWORD)(uTickElapsed * uRemain / (uBytes * CLOCKS_PER_SEC));  
    119. }  
    120.   
    121.   
    122. //////////////////////////////////////////////////////////////////////////  
    123. // socket下载器   
    124. //////////////////////////////////////////////////////////////////////////  
    125.   
    126.   
    127. CSocketDownloader::CSocketDownloader()  
    128. {  
    129.   
    130.   
    131. }  
    132.   
    133.   
    134. CSocketDownloader::~CSocketDownloader()  
    135. {  
    136.   
    137.   
    138. }  
    139.   
    140.   
    141. DWORD CSocketDownloader::DownloadToBuffer(CDownloadTask &task, CByteBufferVector &bufVec)  
    142. {  
    143.     int nTryCount = 0;  
    144.     DWORD dwRet = this->DoDownloadToBuffer(task, bufVec);  
    145.     if (web::THE_REDIRECT != dwRet)  
    146.     {  
    147.         nTryCount++;  
    148.     }  
    149.   
    150.   
    151.     while (  
    152.         dwRet != web::THE_SUCCEED  
    153.         && dwRet != web::THE_USER_CANCELED  
    154.         && nTryCount < task.m_nMaxTryCount  
    155.         )  
    156.     {  
    157.         int nTime = this->GetSleepSecCount(nTryCount);  
    158.         ::Sleep(nTime);  
    159.         dwRet = this->DoDownloadToBuffer(task, bufVec);  
    160.         if (web::THE_REDIRECT != dwRet)  
    161.         {  
    162.             nTryCount++;  
    163.         }  
    164.     }  
    165.   
    166.   
    167.     return dwRet;  
    168. }  
    169.   
    170.   
    171. DWORD CSocketDownloader::DownloadToFile( CDownloadTask &task, CString strOutputFile )  
    172. {  
    173.     CByteBufferVector vec;  
    174.   
    175.   
    176.     DWORD dwRet = this->DownloadToBuffer(task, vec);  
    177.     if (web::THE_SUCCEED != dwRet)  
    178.     {  
    179.         return dwRet;  
    180.     }  
    181.   
    182.   
    183.     HANDLE hFile = ::CreateFile(strOutputFile, GENERIC_WRITE, 0, NULL, CREATE_ALWAYS, FILE_ATTRIBUTE_NORMAL, NULL);  
    184.     if (hFile == INVALID_HANDLE_VALUE)  
    185.     {  
    186.         return web::THE_CREATE_FILE;  
    187.     }  
    188.   
    189.   
    190.     BYTE *pBuffer = vec.Ptr(0, task.m_uTotalBytes);  
    191.     DWORD dwBytesWritten = 0;  
    192.     ::WriteFile(hFile, pBuffer, task.m_uTotalBytes, &dwBytesWritten, NULL);  
    193.     ::CloseHandle(hFile);  
    194.     return (dwBytesWritten == task.m_uTotalBytes) ? web::THE_SUCCEED : web::THE_WRITE_FILE;  
    195. }  
    196.   
    197.   
    198. DWORD CSocketDownloader::DoDownloadToBuffer(CDownloadTask &task, CByteBufferVector &bufVec)  
    199. {  
    200.     task.ParseUrl();  
    201.   
    202.   
    203.     SOCKET hSocket = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP);  
    204.     if (hSocket == INVALID_SOCKET)  
    205.     {  
    206.         return web::THE_CREATE_SOCKET;  
    207.     }  
    208.   
    209.   
    210.     DWORD dwRet = this->ConnectServer(task, hSocket);  
    211.     if (web::THE_SUCCEED != dwRet)  
    212.     {  
    213.         closesocket(hSocket);  
    214.         return dwRet;  
    215.     }  
    216.   
    217.   
    218.     dwRet =  this->DoDownloadToBufferInner(task, bufVec, hSocket);  
    219.     closesocket(hSocket);  
    220.     return dwRet;  
    221. }  
    222.   
    223.   
    224. DWORD CSocketDownloader::DoDownloadToBufferInner(CDownloadTask &task, CByteBufferVector &bufVec, SOCKET hSocket)  
    225. {  
    226.     // 发送请求   
    227.     CStringA strRequest = this->GenerateRequest(task);  
    228.     int nLen = send(hSocket, strRequest, strRequest.GetLength(), 0);  
    229.     if (nLen <= 0)  
    230.     {  
    231.         return web::THE_SEND_HTTP_HEADER;  
    232.     }  
    233.   
    234.   
    235.     // 接收一部分数据(header部分,以"\r\n\r\n"为止)  
    236.     CStringA strRecvBuf;  
    237.     char szRecvBuf[MAX_PATH] = { 0 };  
    238.     nLen = recv(hSocket, szRecvBuf, MAX_PATH - 1, 0);  
    239.     while (nLen > 0)  
    240.     {  
    241.         szRecvBuf[nLen] = 0;  
    242.         strRecvBuf.Append(szRecvBuf);  
    243.         if (strstr(szRecvBuf, "\r\n\r\n") != NULL)  
    244.         {  
    245.             break;  
    246.         }  
    247.         nLen = recv(hSocket, szRecvBuf, MAX_PATH - 1, 0);  
    248.     }  
    249.       
    250.     // 找到两个回车换行,即content起始位置。   
    251.     const char *pData = strstr(szRecvBuf, "\r\n\r\n");  
    252.     if (pData == NULL)  
    253.     {  
    254.         return web::THE_INVALID_RECV_END;  
    255.     }  
    256.   
    257.   
    258.     pData += 4;  
    259.   
    260.   
    261.     const char *p = strchr(strRecvBuf, ' ');  
    262.     if (p != NULL)  
    263.     {  
    264.         p++;  
    265.         DWORD dwRet = atoi(p);  
    266.   
    267.   
    268.         if (dwRet == HTTP_STATUS_PARTIAL_CONTENT)        // 206: 断点续传  
    269.         {  
    270.             const char *q = strstr(strRecvBuf, "\r\nContent-Length:");  
    271.             if (q == NULL)  
    272.             {  
    273.                 return web::THE_NO_CONTENT_LENGTH;  
    274.             }  
    275.             task.m_uTotalBytes = task.m_uReadBytes + atoi(q + 17);  
    276.         }  
    277.         else if (dwRet == HTTP_STATUS_OK)               // 200: 重新下载(服务器不支持断点续传)  
    278.         {  
    279.             const char *q = strstr(strRecvBuf, "\r\nContent-Length:");  
    280.             if (q == NULL)  
    281.             {  
    282.                 return web::THE_NO_CONTENT_LENGTH;  
    283.             }  
    284.             task.m_uTotalBytes = task.m_uReadBytes + atoi(q + 17);  
    285.             // 清除已经下载的内容   
    286.             task.m_uReadBytes = 0;  
    287.             bufVec.Reset();  
    288.         }  
    289.         else if (dwRet == HTTP_STATUS_REDIRECT)         // 302: 重定向  
    290.         {  
    291.             const char *q = strstr(strRecvBuf, "\r\nLocation:");  
    292.             if (q == NULL)  
    293.             {  
    294.                 return web::THE_NO_REDIRECT_LOCATION;  
    295.             }  
    296.             q += 12;  
    297.             const char *r = strstr(q, "\r\n");  
    298.             if (r == NULL)  
    299.             {  
    300.                 return web::THE_REDIRECT_INVALID_FORMAT;  
    301.             }  
    302.   
    303.   
    304.             int nUrlLen = r - q;  
    305.             CStringA strUrlA;  
    306.             strUrlA.Append(q, nUrlLen);  
    307.             task.m_strUrl = GetString(strUrlA);  
    308.             return web::THE_REDIRECT;  
    309.         }  
    310.         else  
    311.         {  
    312.             return web::THE_INVALID_STAUS_CODE;  
    313.         }  
    314.     }  
    315.       
    316.     // 复制已传回来的第一部分content   
    317.     int nSize = nLen - (pData - szRecvBuf);  
    318.     BYTE *pBuffer = bufVec.Ptr(task.m_uReadBytes, nSize);  
    319.     memcpy(pBuffer, pData, nSize);  
    320.     task.m_uReadBytes += nSize;  
    321.   
    322.   
    323.     // 继续接收http content,即下载内容。  
    324.     int nBufferSize = this->GetBufferSize(task);  
    325.     pBuffer = bufVec.Ptr(task.m_uReadBytes, nBufferSize);  
    326.   
    327.   
    328.     DWORD dwLastTick = 0;  
    329.   
    330.   
    331.     // 下载测速   
    332.     DWORD dwTickStart = ::GetTickCount();  
    333.     unsigned int uReadBytesStart = task.m_uReadBytes;  
    334.   
    335.   
    336.     while (true)  
    337.     {  
    338.         if (::InterlockedCompareExchange(task.m_pTerminate, 1, 1))  
    339.         {  
    340.             // 用户取消。   
    341.             return web::THE_USER_CANCELED;  
    342.         }  
    343.   
    344.   
    345.         nLen = recv(hSocket, (char *)(pBuffer), nBufferSize, 0);  
    346.         if (nLen < 0)  
    347.         {  
    348.             return web::THE_RECV_FAIL;  
    349.         }  
    350.         else if (nLen == 0)  
    351.         {  
    352.             break;  // 接收完成  
    353.         }  
    354.   
    355.   
    356.         task.m_uReadBytes += nLen;  
    357.         if (task.m_uReadBytes == task.m_uTotalBytes)  
    358.         {  
    359.             break;  // 接收完成  
    360.         }  
    361.   
    362.   
    363.         nBufferSize = this->GetBufferSize(task);  
    364.         pBuffer = bufVec.Ptr(task.m_uReadBytes, nBufferSize);  
    365.   
    366.   
    367.         if (NULL != task.m_hWnd)  
    368.         {  
    369.             DWORD dwTick = ::GetTickCount();  
    370.             if (dwLastTick == 0 || (dwTick - dwLastTick >= 100))        // 每秒最多发10次消息  
    371.             {  
    372.                 // 发送当前下载进度和剩余时间消息   
    373.                 dwLastTick = dwTick;  
    374.                 ::PostMessage(task.m_hWnd, WM_FASTINSTALL_PROGRESS_VALUE,  
    375.                     static_cast<WPARAM>(task.Percentage()), static_cast<LPARAM>(task.RemainTimeSec(dwTick - dwTickStart, task.m_uReadBytes - uReadBytesStart))  
    376.                     );  
    377.             }  
    378.         }  
    379.     }  
    380.   
    381.   
    382.     DWORD dwTick = ::GetTickCount();  
    383.     if (NULL != task.m_hWnd)  
    384.     {  
    385.         ::PostMessage(  
    386.             task.m_hWnd, WM_FASTINSTALL_PROGRESS_VALUE,  
    387.             static_cast<WPARAM>(task.Percentage()), -1  
    388.             );  
    389.     }  
    390.   
    391.   
    392.     return web::THE_SUCCEED;  
    393. }  
    394.   
    395.   
    396. DWORD CSocketDownloader::ConnectServer(const CDownloadTask &task, SOCKET hSocket)  
    397. {  
    398.     PHOSTENT pHostent = gethostbyname(task.m_strHostA);  
    399.     if (pHostent == NULL)  
    400.     {  
    401.         return web::THE_GET_HOST_BY_NAME;  
    402.     }  
    403.   
    404.   
    405.     sockaddr_in addrSvr;  
    406.     addrSvr.sin_port = htons((u_short)task.m_nPort);  
    407.     addrSvr.sin_family = AF_INET;  
    408.     addrSvr.sin_addr.s_addr = *(ULONG*)pHostent->h_addr_list[0];  
    409.     if (SOCKET_ERROR == connect(hSocket, (sockaddr*)&addrSvr, sizeof(addrSvr)))  
    410.     {  
    411.         return web::THE_CONNECT_SOCKET;  
    412.     }  
    413.   
    414.   
    415.     int opt = task.m_nTimeoutSec * 1000;  
    416.     if (0 != setsockopt(hSocket, SOL_SOCKET, SO_RCVTIMEO, (char*)&opt, sizeof(opt)))  
    417.     {  
    418.         return web::THE_SET_SOCK_OPT1;  
    419.     }  
    420.   
    421.   
    422.     BOOL bKeepAlive = TRUE;    
    423.     int len = sizeof(bKeepAlive);  
    424.     getsockopt(hSocket, SOL_SOCKET, SO_KEEPALIVE, (char*)&bKeepAlive, &len);  
    425.   
    426.   
    427.     bKeepAlive = TRUE;  
    428.     if (0 != setsockopt(hSocket, SOL_SOCKET, SO_KEEPALIVE, (char *)&bKeepAlive, sizeof(BOOL)))  
    429.     {  
    430.         return web::THE_SET_SOCK_OPT2;  
    431.     }  
    432.   
    433.   
    434.     return web::THE_SUCCEED;  
    435. }  
    436.   
    437.   
    438. int CSocketDownloader::GetSleepSecCount( int nTryCount ) const  
    439. {  
    440.     return (nTryCount + 1) * 1000;  
    441. }  
    442.   
    443.   
    444. int CSocketDownloader::GetBufferSize( const CDownloadTask &task ) const  
    445. {  
    446.     return std::min<int>(BLOCK_SIZE, task.m_uTotalBytes - task.m_uReadBytes);  
    447. }  
    448.   
    449.   
    450. CStringA CSocketDownloader::GenerateRequest( CDownloadTask &task ) const  
    451. {  
    452.     CStringA strRequest;  
    453.   
    454.   
    455.     CStringA strTemp;  
    456.     if (task.m_strQueryA.IsEmpty())  
    457.     {  
    458.         strTemp.Format(  
    459.             "GET %s HTTP/1.1\r\nHOST: %s\r\n",  
    460.             task.m_strAbsoluteUrlA.GetString(), task.m_strHostA.GetString()  
    461.             );  
    462.     }  
    463.     else  
    464.     {  
    465.         strTemp.Format(  
    466.             "GET %s?%s HTTP/1.1\r\nHOST: %s\r\n",  
    467.             task.m_strAbsoluteUrlA.GetString(), task.m_strQueryA.GetString(), task.m_strHostA.GetString()  
    468.             );  
    469.     }  
    470.     strRequest.Append(strTemp);  
    471.   
    472.   
    473.     strTemp.Format("Range: bytes=%d-\r\n", task.m_uReadBytes);  
    474.     strRequest.Append(strTemp);  
    475.   
    476.   
    477.     strTemp.Format("User-Agent: %s\r\n", task.GetAgnetA().GetString());  
    478.     strRequest.Append(strTemp);  
    479.   
    480.   
    481.     strRequest.Append("Accept: */*\r\n");  
    482.     strRequest.Append("Accept-Encoding: gzip, deflate\r\n");  
    483.     strRequest.Append("Connection: Keep-Alive\r\n\r\n");  
    484.       
    485.     return strRequest;  
    486. }  
    487. </SPAN></SPAN></SPAN></SPAN>  

     

原创粉丝点击