基于 POCO 框架的 TCP 连接分流程序

来源:互联网 发布:java 引用类型 编辑:程序博客网 时间:2024/04/28 16:00

介绍


下面的程序实现了对 TCP 连接的分流,即将一个 TCP 连接的流量分布到多个 TCP 连接上进行传输。

本程序的主要作用是在特定网络环境下提升通过 TCP 连接的 OpenVPN 服务的速率,使之充分利用带宽。

程序的主要复杂之处在于单生产者—多消费者、多生产者—单消费者两种同步方式的实现。


代码


// multisocks, copyright (c) 2013 coolypf#include <stdio.h>#include <stdlib.h>#include <string.h>#include <string>#include <map>#include <vector>#include <Poco/Thread.h>#include <Poco/Event.h>#include <Poco/ErrorHandler.h>#include <Poco/FIFOBuffer.h>#include <Poco/Net/StreamSocket.h>#include <Poco/Net/ServerSocket.h>using namespace std;using namespace Poco;using namespace Poco::Net;class error_handler : public ErrorHandler{public:    virtual void exception(const Exception& exc)    {        printf("[!] Thread %u: Caught exception: %s\n", (unsigned int)Thread::currentTid(), exc.displayText().c_str());        fflush(stdout);        exit(1);    }    virtual void exception(const std::exception& exc)    {        printf("[!] Thread %u: Caught exception: %s\n", (unsigned int)Thread::currentTid(), exc.what());        fflush(stdout);        exit(2);    }    virtual void exception()    {        printf("[!] Thread %u: Caught unknown exception\n", (unsigned int)Thread::currentTid());        fflush(stdout);        exit(3);    }};class{    map<string, string> config;    string empty;    int read_file(const char *filename, vector<char> &out)    {        FILE *fp = fopen(filename, "rb");        if (!fp)        {            printf("[!] Fail to load: %s\n", filename);            return 1;        }        char buf[4096];        size_t sz;        while ((sz = fread(buf, 1, 4096, fp)))            out.insert(out.end(), buf, buf + sz);        fclose(fp);        for (size_t i = 0; i < out.size(); ++i)            if (!out[i])                out[i] = ' ';        return 0;    }    void split_lines(const vector<char> &in, vector<string> &out)    {        size_t pos = 0;        for (size_t i = 0; i < in.size(); ++i)        {            if (in[i] != '\r' && in[i] != '\n')                continue;            out.push_back(string(in.begin() + pos, in.begin() + i));            if (in[i] == '\r' && i + 1 < in.size() && in[i + 1] == '\n')                i++;            pos = i + 1;        }        if (pos < in.size())            out.push_back(string(in.begin() + pos, in.end()));    }    void trim_string(string &str)    {        const char *spaces = " \r\n\t\xb\xc";        int first = 0, last = (int)str.size() - 1;        while (first <= last && strchr(spaces, str[first]))            ++first;        while (last >= first && strchr(spaces, str[last]))            --last;        str.resize(last + 1);        str.erase(0, first);    }    void split_string(const string &in, vector<string> &out, const char *delims)    {        size_t pos = 0;        for (size_t i = 0; i < in.size(); ++i)        {            if (!strchr(delims, in[i]))                continue;            string part(in.begin() + pos, in.begin() + i);            trim_string(part);            out.push_back(part);            pos = i + 1;        }        {            string part(in.begin() + pos, in.end());            trim_string(part);            out.push_back(part);        }    }public:    int load(const char *filename)    {        vector<char> content;        if (read_file(filename, content))            return 1;        vector<string> lines;        split_lines(content, lines);        for (size_t i = 0; i < lines.size(); ++i)        {            if (lines[i].empty() || lines[i][0] == '#')                continue;            vector<string> parts;            split_string(lines[i], parts, "=");            if (parts.size() != 2 || parts[0].empty())            {                printf("[!] Invalid config: %s (%d)\n", filename, (int)i + 1);                continue;            }            if (config.find(parts[0]) != config.end())                printf("[!] Override config: %s (%d)\n", filename, (int)i + 1);            config[parts[0]] = parts[1];        }        return 0;    }    int i(const char *key) const    {        map<string, string>::const_iterator iter = config.find(string(key));        if (iter == config.end())        {            printf("[!] No config: %s\n", key);            return 0;        }        int ret = 0;        if (sscanf(iter->second.c_str(), "%d", &ret) != 1)            printf("[!] Invalid int: %s\n", key);        return ret;    }    const string & s(const char *key) const    {        map<string, string>::const_iterator iter = config.find(string(key));        if (iter == config.end())        {            printf("[!] No config: %s\n", key);            return empty;        }        return iter->second;    }} config;// Read from out buffer and send to remoteclass reader : public Runnable{    int index;    StreamSocket socket;    Event &readable, &writable;    FIFOBuffer &buffer;    bool &self_stopped, &peer_stopped;public:    reader(int i, StreamSocket s, Event &r, Event &w, FIFOBuffer &b, bool &self, bool &peer)        : index(i), socket(s), readable(r), writable(w), buffer(b), self_stopped(self), peer_stopped(peer)    {    }    virtual void run()    {        printf("[.] Reader %d, tid = %u\n", index, (unsigned int)Thread::currentTid());        fflush(stdout);        char buf[8192];        bool closed = false;        while (true)        {            if (buffer.isReadable())            {                int sz = (int)buffer.read(buf, 8192);                writable.set();                int sent = 0;                while (sent < sz)                {                    int sz2;                    try                    {                        sz2 = socket.sendBytes(buf + sent, sz - sent);                    }                    catch (Exception &exc)                    {                        printf("[!] Reader %d: %s\n", index, exc.displayText().c_str());                        fflush(stdout);                        sz2 = 0;                    }                    if (sz2 <= 0)                    {                        closed = true;                        break;                    }                    sent += sz2;                }                if (closed)                    break;            }            else            {                if (peer_stopped)                    break;                readable.wait();            }        }        try { socket.shutdownSend(); } catch (...) {}        self_stopped = true;        writable.set();    }};// Receive from remote and write to in bufferclass writer : public Runnable{    int index;    StreamSocket socket;    Event &readable, &writable;    FIFOBuffer &buffer;    bool &self_stopped, &peer_stopped;public:    writer(int i, StreamSocket s, Event &r, Event &w, FIFOBuffer &b, bool &self, bool &peer)        : index(i), socket(s), readable(r), writable(w), buffer(b), self_stopped(self), peer_stopped(peer)    {    }    virtual void run()    {        printf("[.] Writer %d, tid = %u\n", index, (unsigned int)Thread::currentTid());        fflush(stdout);        char buf[8192];        while (true)        {            int sz;            try            {                sz = socket.receiveBytes(buf, 8192);            }            catch (Exception &exc)            {                printf("[!] Writer %d: %s\n", index, exc.displayText().c_str());                fflush(stdout);                sz = 0;            }            if (sz <= 0)                break;            int written = 0;            while (written < sz)            {                if (buffer.isWritable())                {                    written += buffer.write(buf + written, sz - written);                    readable.set();                }                else                {                    if (peer_stopped)                        break;                    writable.wait();                }            }            if (peer_stopped)                break;        }        try { socket.shutdownReceive(); } catch (...) {}        self_stopped = true;        readable.set();    }};// Receive from local and write to out buffersclass divider : public Runnable{    int nr_conn;    StreamSocket socket;    vector<Event *> &readables;    Event &writable;    vector<FIFOBuffer *> buffers;    bool &self_stopped, *peers_stopped;public:    divider(int n, StreamSocket s, vector<Event *> &vr, Event &w, vector<FIFOBuffer *> &b, bool &self, bool *peers)        : nr_conn(n), socket(s), readables(vr), writable(w), buffers(b), self_stopped(self), peers_stopped(peers)    {    }    virtual void run()    {        printf("[.] Divider, tid = %u\n", (unsigned int)Thread::currentTid());        fflush(stdout);        int sz0 = 8192 * nr_conn;        char *buf0 = new char[sz0];        char **bufv = new char *[nr_conn];        for (int i = 0; i < nr_conn; ++i)            bufv[i] = new char[8192];        int *szv = new int[nr_conn];        int *writtenv = new int[nr_conn];        long long total = 0;        bool peer_stopped = false;        while (true)        {            int sz;            try            {                sz = socket.receiveBytes(buf0, sz0);            }            catch (Exception &exc)            {                printf("[!] Divider: %s\n", exc.displayText().c_str());                fflush(stdout);                sz = 0;            }            if (sz <= 0)                break;            for (int i = 0; i < nr_conn; ++i)            {                szv[i] = 0;                writtenv[i] = 0;            }            for (int i = 0; i < sz; ++i)                bufv[(total + i) % nr_conn][szv[(total + i) % nr_conn]++] = buf0[i];            total += sz;            while (true)            {                bool done = true;                for (int i = 0; i < nr_conn; ++i)                    if (writtenv[i] < szv[i])                        done = false;                if (done)                    break;                bool written = false;                for (int i = 0; i < nr_conn; ++i)                {                    if (buffers[i]->isWritable())                    {                        writtenv[i] += buffers[i]->write(bufv[i] + writtenv[i], szv[i] - writtenv[i]);                        readables[i]->set();                        written = true;                    }                }                if (!written)                {                    for (int i = 0; i < nr_conn; ++i)                        if (peers_stopped[i])                            peer_stopped = true;                    if (peer_stopped)                        break;                    writable.wait();                    writable.reset();                }                if (peer_stopped)                    break;            }        }        try { socket.shutdownReceive(); } catch (...) {}        self_stopped = true;        for (int i = 0; i < nr_conn; ++i)            readables[i]->set();    }};// Read from in buffers and send to localclass combiner : public Runnable{    int nr_conn;    StreamSocket socket;    Event &readable;    vector<Event *> &writables;    vector<FIFOBuffer *> buffers;    bool &self_stopped, *peers_stopped;public:    combiner(int n, StreamSocket s, Event &r, vector<Event *> &vw, vector<FIFOBuffer *> &b, bool &self, bool *peers)        : nr_conn(n), socket(s), readable(r), writables(vw), buffers(b), self_stopped(self), peers_stopped(peers)    {    }    virtual void run()    {        printf("[.] Combiner, tid = %u\n", (unsigned int)Thread::currentTid());        fflush(stdout);        char *buf0 = new char[8192 * nr_conn];        char **bufv = new char *[nr_conn];        for (int i = 0; i < nr_conn; ++i)            bufv[i] = new char[8192];        int *szv = new int[nr_conn];        int *readv = new int[nr_conn];        long long total = 0;        bool peer_stopped = false, closed = false;        while (true)        {            for (int i = 0; i < nr_conn; ++i)            {                szv[i] = buffers[i]->peek(bufv[i], 8192);                readv[i] = 0;            }            int sz = 0;            while (true)            {                int i = (total + sz) % nr_conn;                if (readv[i] >= szv[i])                    break;                buf0[sz++] = bufv[i][readv[i]++];            }            total += sz;            if (sz > 0)            {                for (int i = 0; i < nr_conn; ++i)                {                    if (readv[i] > 0)                    {                        buffers[i]->read(bufv[i], readv[i]);                        writables[i]->set();                    }                }                int sent = 0;                while (sent < sz)                {                    int sz1;                    try                    {                        sz1 = socket.sendBytes(buf0 + sent, sz - sent);                    }                    catch (Exception &exc)                    {                        printf("[!] Combiner: %s\n", exc.displayText().c_str());                        fflush(stdout);                        sz1 = 0;                    }                    if (sz1 <= 0)                    {                        closed = true;                        break;                    }                    sent += sz1;                }                if (closed)                    break;            }            else            {                for (int i = 0; i < nr_conn; ++i)                    if (peers_stopped[i])                        peer_stopped = true;                if (peer_stopped)                    break;                readable.wait();                readable.reset();            }        }        try { socket.shutdownSend(); } catch (...) {}        self_stopped = true;        for (int i = 0; i < nr_conn; ++i)            writables[i]->set();    }};int main(int argc, char **argv){    config.load("multisocks.txt");    for (int i = 1; i < argc; ++i)        if (config.load(argv[i]))            return 1;    if (!config.s("log").empty())        freopen(config.s("log").c_str(), "w", stdout);    ErrorHandler::set(new error_handler);    fflush(stdout);    try    {        int nr_conn = config.i("nr_conn");        bool divider_stopped = false, combiner_stopped = false;        bool *readers_stopped = new bool[nr_conn](), *writers_stopped = new bool[nr_conn]();        Event in_readable(false), out_writable(false);        in_readable.reset();        out_writable.set();        vector<Event *> in_writables, out_readables;        for (int i = 0; i < nr_conn; ++i)        {            in_writables.push_back(new Event);            in_writables.back()->set();            out_readables.push_back(new Event);            out_readables.back()->reset();        }        vector<StreamSocket> remotes;        if (config.s("remote") == "listen")        {            ServerSocket server;            if (config.s("remote.protocol") == "ipv6")                server.bind6(config.i("remote.port"), true, true);            else                server.bind(config.i("remote.port"), true);            server.listen();            for (int i = 0; i < nr_conn; ++i)            {                SocketAddress client_addr;                remotes.push_back(server.acceptConnection(client_addr));                remotes.back().setNoDelay(true);                printf("[.] Remote incoming: %s\n", client_addr.toString().c_str());                fflush(stdout);                if (client_addr.host() != remotes.front().peerAddress().host())                    throw Exception(string("client addresses mismatch"));            }            server.close();        }        else        {            SocketAddress server_addr(config.s("remote.host"), config.i("remote.port"));            for (int i = 0; i < nr_conn; ++i)            {                StreamSocket remote;                if (!config.s("remote.bind").empty())                    remote.impl()->bind(SocketAddress(config.s("remote.bind"), 0), true);                remote.connect(server_addr);                remote.setNoDelay(true);                remotes.push_back(remote);                printf("[.] Connected to remote from: %s\n", remote.address().toString().c_str());                fflush(stdout);            }        }        vector<FIFOBuffer *> in_buffers, out_buffers;        vector<reader *> readers;        vector<writer *> writers;        Thread *reader_threads = new Thread[nr_conn], *writer_threads = new Thread[nr_conn];        for (int i = 0; i < nr_conn; ++i)        {            in_buffers.push_back(new FIFOBuffer(config.i("buffer_size")));            out_buffers.push_back(new FIFOBuffer(config.i("buffer_size")));            readers.push_back(new reader(i, remotes[i], *out_readables[i], out_writable, *out_buffers.back(), readers_stopped[i], divider_stopped));            writers.push_back(new writer(i, remotes[i], in_readable, *in_writables[i], *in_buffers.back(), writers_stopped[i], combiner_stopped));            reader_threads[i].setStackSize(65536);            writer_threads[i].setStackSize(65536);            reader_threads[i].start(*readers.back());            writer_threads[i].start(*writers.back());        }        StreamSocket local;        if (config.s("local") == "listen")        {            ServerSocket server;            if (config.s("local.protocol") == "ipv6")                server.bind6(config.i("local.port"), true, true);            else                server.bind(config.i("local.port"), true);            server.listen();            local = server.acceptConnection();            printf("[.] Local incoming: %s\n", local.peerAddress().toString().c_str());            fflush(stdout);            local.setNoDelay(true);            server.close();        }        else        {            SocketAddress server_addr(config.s("local.host"), config.i("local.port"));            local.connect(server_addr);            local.setNoDelay(true);            printf("[.] Connected to local from: %s\n", local.address().toString().c_str());            fflush(stdout);        }        divider *pdivider = new divider(nr_conn, local, out_readables, out_writable, out_buffers, divider_stopped, readers_stopped);        combiner *pcombiner = new combiner(nr_conn, local, in_readable, in_writables, in_buffers, combiner_stopped, writers_stopped);        Thread divider_thread, combiner_thread;        divider_thread.setStackSize(65536);        combiner_thread.setStackSize(65536);        divider_thread.start(*pdivider);        combiner_thread.start(*pcombiner);        printf("[.] Connection established\n");        fflush(stdout);        divider_thread.join();        combiner_thread.join();        local.close();        for (int i = 0; i < nr_conn; ++i)        {            reader_threads[i].join();            writer_threads[i].join();            remotes[i].close();        }        printf("[.] Connection closed\n");    }    catch (Exception &exc)    {        ErrorHandler::handle(exc);    }    catch (exception &exc)    {        ErrorHandler::handle(exc);    }    catch (...)    {        ErrorHandler::handle();    }    return 0;}

测试


客户端网络环境为 CERNET ,服务器是位于达拉斯的 BurstNET VPS ,连接方式是 TCP/IPv6 。

直接连接, OpenVPN 速率为 341 KB/s 。使用 5 个连接进行分流, OpenVPN 速率可达 1.4 MB/s 。