借shared_ptr实现copy on write 以减少锁的使用

来源:互联网 发布:yum wireshark 编辑:程序博客网 时间:2024/06/06 01:51

shared_ptr是引用计数智能指针,如果当前只有一个观察者,那么引用计数为1,可以用shared_ptr::unique()来判断
对于write端,如果发现引用计数为1,这时可以安全地修改对象,不必担心有人在读它。
对于read端,在读之前把引用计数加1,读完之后减1,这样可以保证在读的期间其引用计数大于1,可以阻止并发写。
比较难的是,对于write端,如果发现引用计数大于1,该如何处理?既然要更新数据,肯定要加锁,如果这时候其他线程正在读,那么不能在原来的数据上修改,得创建一个副本,在副本上修改,修改完了再替换。如果没有用户在读,那么可以直接修改。

typedef boost::shared_ptr<TcpConnection> TcpConnectionPtr;
typedef std::set<TcpConnectionPtr> ConnectionList;
typedef boost::shared_ptr<ConnectionList> ConnectionListPtr;

读写公用:ConnectionListPtr connections_;
读端:

void onStringMessage(const TcpConnectionPtr&,

                       const string& message,
                       Timestamp)
  {
    // 引用计数加1,mutex保护的临界区大大缩短
    ConnectionListPtr connections = getConnectionList();
    // 可能大家会有疑问,不受mutex保护,写者更改了连接列表怎么办?
    // 实际上,写者是在另一个复本上修改,所以无需担心。

    for (ConnectionList::iterator it = connections->begin();
        it != connections->end();
        ++it)
    {
      codec_.send(get_pointer(*it), message);
    }

// 这个断言不一定成立
//assert(!connections.unique());
    // 当connections这个栈上的变量销毁的时候,引用计数减1,,同时如果写端执行了connections_.reset(new ConnectionList(*connections_)); 引用计数也会减一,则这时这个connections指向的老对象的引用计数变成0,就销毁了,所以不会出现多个副本
  }

  ConnectionListPtr getConnectionList()
  {
    MutexLockGuard lock(mutex_);
    return connections_;

  }

写端:

  void onConnection(const TcpConnectionPtr& conn)
  {
    LOG_INFO << conn->localAddress().toIpPort() << " -> "
        << conn->peerAddress().toIpPort() << " is "
        << (conn->connected() ? "UP" : "DOWN");

    MutexLockGuard lock(mutex_);//多个线程同时写的时候要有锁才行
    if (!connections_.unique()) // 说明引用计数大于1

    {
      // connections_是一个智能指针,现在重置指针指向,new ConnectionList(*connections_)这段代码拷贝了一份ConnectionList
      connections_.reset(new ConnectionList(*connections_));//这时复制了一份ConnectionList,connections_指向的新对象的引用计数变成1,读端的老对象因为connections_的reset所以老对象的引用会减一
    }
    assert(connections_.unique());
    // 在复本上修改,不会影响读者,所以读者在遍历列表的时候,不需要用mutex保护
    if (conn->connected())
    {
      connections_->insert(conn);
    }
    else
    {
      connections_->erase(conn);
    }
  }

读端转发信息(群发)的时候是在一个线程中就行的,当转发数量多时,转发时间会较长

可以 采用threadlocal变量实现多线程高效转发

class ChatServer : boost::noncopyable
{
 public:
  ChatServer(EventLoop* loop,
             const InetAddress& listenAddr)
  : loop_(loop),
    server_(loop, listenAddr, "ChatServer"),
    codec_(boost::bind(&ChatServer::onStringMessage, this, _1, _2, _3))
  {
    server_.setConnectionCallback(
        boost::bind(&ChatServer::onConnection, this, _1));
    server_.setMessageCallback(
        boost::bind(&LengthHeaderCodec::onMessage, &codec_, _1, _2, _3));
  }

  void setThreadNum(int numThreads)
  {
    server_.setThreadNum(numThreads);
  }

  void start()
  {
    server_.setThreadInitCallback(boost::bind(&ChatServer::threadInit, this, _1));
    server_.start();
  }

 private:
  void onConnection(const TcpConnectionPtr& conn)
  {
    LOG_INFO << conn->localAddress().toIpPort() << " -> "
             << conn->peerAddress().toIpPort() << " is "
             << (conn->connected() ? "UP" : "DOWN");
    if (conn->connected())
    {
      connections_.instance().insert(conn);//各个线程特有的实例,里面保存各个线程从各自loop中获得的conn连接
    }
    else
    {
      connections_.instance().erase(conn);
    }
  }

  void onStringMessage(const TcpConnectionPtr&,
                       const string& message,
                       Timestamp)
  {
    EventLoop::Functor f = boost::bind(&ChatServer::distributeMessage, this, message);
    LOG_DEBUG;
    MutexLockGuard lock(mutex_);
    for (std::set<EventLoop*>::iterator it = loops_.begin();
        it != loops_.end();
        ++it)
    {
      (*it)->queueInLoop(f);
    }
    LOG_DEBUG;
  }

  typedef std::set<TcpConnectionPtr> ConnectionList;

  void distributeMessage(const string& message)
  {
    LOG_DEBUG << "begin";
    for (ConnectionList::iterator it = connections_.instance().begin();
        it != connections_.instance().end();
        ++it)
    {
      codec_.send(get_pointer(*it), message);
    }
    LOG_DEBUG << "end";
  }

//每个io线程启动时都会调用,主要是把io线程的loop返回回来插入loops_
  void threadInit(EventLoop* loop)
  {
    assert(connections_.pointer() == NULL);
    connections_.instance();
    assert(connections_.pointer() != NULL);
    MutexLockGuard lock(mutex_);
    loops_.insert(loop);
  }


  EventLoop* loop_;
  TcpServer server_;
  LengthHeaderCodec codec_;
  ThreadLocalSingleton<ConnectionList>connections_;
  MutexLock mutex_;
  std::set<EventLoop*> loops_;
};

int main(int argc, char* argv[])
{
  LOG_INFO << "pid = " << getpid();
  if (argc > 1)
  {
    EventLoop loop;
    uint16_t port = static_cast<uint16_t>(atoi(argv[1]));
    InetAddress serverAddr(port);
    ChatServer server(&loop, serverAddr);
    if (argc > 2)
    {
      server.setThreadNum(atoi(argv[2]));
    }
    server.start();
    loop.loop();
  }
  else
  {
    printf("Usage: %s port [thread_num]\n", argv[0]);
  }
}

0 0
原创粉丝点击