为tornado自定义session

来源:互联网 发布:胡为乎遑遑欲何之的乎 编辑:程序博客网 时间:2024/05/29 19:32

cookie和session

在自定义session前,我们需要先了解cookie和session是什么,可以参考我之前的博客:http://blog.csdn.net/ayhan_huang/article/details/78032097

简单来说:

  • cookie是保存在浏览器的键值对
  • session是保存在服务端的键值对
  • session依赖于cookie

在Django框架中,我们可以直接操作cookie和session,但是tornado只支持cookie,那如果要使用session怎么办呢?自己定义session

思路和实现

我们知道,在tornado中,所有的请求都是由RequestHandler对象来处理(以下简称handler对象)。在RequestHandler源码中,预留了一个钩子方法initialize,该方法会在实例化Handler对象时执行。因此,如果我们继承RequestHandler类并重写initialize,就可以完成一些自定义操作。

代码实现:

import tornado.ioloopimport tornado.webfrom hashlib import sha1import osimport time# 随机生成session_idcreate_session_id = lambda: sha1(bytes('%s%s' % (os.urandom(16), time.time()), encoding='utf-8')).hexdigest()class Session:    """自定义session"""    info_container = {        # session_id: {'user': info} --> 通过session保存用户信息,权限等    }    def __init__(self, handler):        """        初始化时传入RequestHandler对象,通过它进行cookie操作        self.handler.set_cookie()        self.handler.get_cookie()        :param handler:         """        self.handler = handler        # 从 cookie 中获取作为 session_id 的随机字符串,如果没有或不匹配则生成 session_id        random_str = self.handler.get_cookie('session_id')        if (not random_str) or (random_str not in self.info_container):            random_str = create_session_id()            self.info_container[random_str] = {}        self.random_str = random_str        # 每次请求进来都会执行set_cookie,保证每次重置过期时间为当前时间以后xx秒以后        self.handler.set_cookie('session_id', random_str, max_age=60)    def __getitem__(self, item):        return self.info_container[self.random_str].get(item)    def __setitem__(self, key, value):        self.info_container[self.random_str][key] = value    def __delitem__(self, key):        if self.info_container[self.random_str].get(key):            del self.info_container[self.random_str][key]    def delete(self):        """从大字典删除session_id"""        del self.info_container[self.random_str]class SessionHandler:    def initialize(self):        self.session = Session(self)  # handler增加session属性class LoginHandler(SessionHandler, tornado.web.RequestHandler):    def get(self):        self.render('login.html')    def post(self):        user = self.get_argument('user')        pwd = self.get_argument('pwd')        if user == 'lena' and pwd == '123':            self.session['user'] = user            self.redirect('/index')class IndexHandler(SessionHandler, tornado.web.RequestHandler):    def get(self):        user = self.session['user']  # 注意这里不能用session.get('user')        if not user:            self.redirect('/login')            return        self.write('你好啊,%s'% user)  # 返回响应settings = {    'template_path': 'templates',    'cookie_secret': 'asdfasdfasd', # 签名cookie加盐}application = tornado.web.Application([    (r'/login', LoginHandler),    (r'/index', IndexHandler),], **settings)if __name__ == '__main__':    application.listen(8080)    tornado.ioloop.IOLoop.instance().start()

说明:

  • 定义一个Session类,其实例化时接收handler对象
    • Session类中定义一个静态字段(大字典),用来存储session_id和对应的用户信息;所有的session对象都可以访问这个大字典。
    • Session的构造方法中,获取和设置cookie:
      • 调用handler对象get_cookie()方法获取session_id,如果没有,则生成一段随机字符串random_str作为session_id
      • 将session_id写入大字典
      • 调用handler对象的set_cookie()方法,通知浏览器设置cookie:set-cookie: {session_id: random_str}
    • Session类中,定义__getitem__, __setitem__, __delitem__方法来实现通过字典的方式操作session对象(面向对象内置方法参考这里)
  • initialize方法中为handler对象增加session属性,其值是Session对象:self.session=Session(self);在每个路由对应的视图中都重写initialize方法太麻烦了,利用面向对象的多继承,将这一步单独写在一个类SessionHandler,所以视图类先继承这个类即可。
  • 每次请求进来,都会执行SessionHandler中的initialize方法,并实例化Session对象,从而获取session_id
  • 操作session:
    • 通过self.session[key] = value 即可调用session对象的__setitem__方法来写session;
    • 通过self.session[key] 即可调用session对象的__getitem__方法来获取session
    • 通过del self.session[key] 即可调用session对象的__delitem__方法来删除session
    • 通过self.session.delete(),即可调用session对象的delete方法,删除整个session_id

一致性哈希和分布式session

将session保存在redis缓存服务器中,可以获得更高的性能。如果有多台缓存服务器,就需要对服务器作负载均衡,将session分发到每台服务器上。实现负载均衡的算法有很多,最常用的是哈希算法,它的基本逻辑是,对session_id(随机字符串)进行哈希,哈希结果再按服务器数量进行取模运算,得到的余数i就是第i个服务器。

一致性哈希(Consistent Hashing)是分布式负载均衡的首选算法。python中有实现模块hash_ring,不需要安装,直接将其中的单文件hash_ring.py拿来用即可,也就100多行代码。一致性哈希除了可以用在这里,也可以用于作分布式爬虫。

hash_ring.py

# -*- coding: utf-8 -*-"""    hash_ring    ~~~~~~~~~~~~~~    Implements consistent hashing that can be used when    the number of server nodes can increase or decrease (like in memcached).    Consistent hashing is a scheme that provides a hash table functionality    in a way that the adding or removing of one slot    does not significantly change the mapping of keys to slots.    More information about consistent hashing can be read in these articles:        "Web Caching with Consistent Hashing":            http://www8.org/w8-papers/2a-webserver/caching/paper2.html        "Consistent hashing and random trees:        Distributed caching protocols for relieving hot spots on the World Wide Web (1997)":            http://citeseerx.ist.psu.edu/legacymapper?did=38148    Example of usage::        memcache_servers = ['192.168.0.246:11212',                            '192.168.0.247:11212',                            '192.168.0.249:11212']        ring = HashRing(memcache_servers)        server = ring.get_node('my_key')    :copyright: 2008 by Amir Salihefendic.    :license: BSD"""import mathimport sysfrom bisect import bisectif sys.version_info >= (2, 5):    import hashlib    md5_constructor = hashlib.md5else:    import md5    md5_constructor = md5.newclass HashRing(object):    def __init__(self, nodes=None, weights=None):        """`nodes` is a list of objects that have a proper __str__ representation.        `weights` is dictionary that sets weights to the nodes.  The default        weight is that all nodes are equal.        """        self.ring = dict()        self._sorted_keys = []        self.nodes = nodes        if not weights:            weights = {}        self.weights = weights        self._generate_circle()    def _generate_circle(self):        """Generates the circle.        """        total_weight = 0        for node in self.nodes:            total_weight += self.weights.get(node, 1)        for node in self.nodes:            weight = 1            if node in self.weights:                weight = self.weights.get(node)            factor = math.floor((40*len(self.nodes)*weight) / total_weight)            for j in range(0, int(factor)):                b_key = self._hash_digest( '%s-%s' % (node, j) )                for i in range(0, 3):                    key = self._hash_val(b_key, lambda x: x+i*4)                    self.ring[key] = node                    self._sorted_keys.append(key)        self._sorted_keys.sort()    def get_node(self, string_key):        """Given a string key a corresponding node in the hash ring is returned.        If the hash ring is empty, `None` is returned.        """        pos = self.get_node_pos(string_key)        if pos is None:            return None        return self.ring[ self._sorted_keys[pos] ]    def get_node_pos(self, string_key):        """Given a string key a corresponding node in the hash ring is returned        along with it's position in the ring.        If the hash ring is empty, (`None`, `None`) is returned.        """        if not self.ring:            return None        key = self.gen_key(string_key)        nodes = self._sorted_keys        pos = bisect(nodes, key)        if pos == len(nodes):            return 0        else:            return pos    def iterate_nodes(self, string_key, distinct=True):        """Given a string key it returns the nodes as a generator that can hold the key.        The generator iterates one time through the ring        starting at the correct position.        if `distinct` is set, then the nodes returned will be unique,        i.e. no virtual copies will be returned.        """        if not self.ring:            yield None, None        returned_values = set()        def distinct_filter(value):            if str(value) not in returned_values:                returned_values.add(str(value))                return value        pos = self.get_node_pos(string_key)        for key in self._sorted_keys[pos:]:            val = distinct_filter(self.ring[key])            if val:                yield val        for i, key in enumerate(self._sorted_keys):            if i < pos:                val = distinct_filter(self.ring[key])                if val:                    yield val    def gen_key(self, key):        """Given a string key it returns a long value,        this long value represents a place on the hash ring.        md5 is currently used because it mixes well.        """        b_key = self._hash_digest(key)        return self._hash_val(b_key, lambda x: x)    def _hash_val(self, b_key, entry_fn):        return (( b_key[entry_fn(3)] << 24)                |(b_key[entry_fn(2)] << 16)                |(b_key[entry_fn(1)] << 8)                | b_key[entry_fn(0)] )    def _hash_digest(self, key):        m = md5_constructor()        m.update(key.encode('utf-8'))        # return map(ord, m.digest())  # python 2         return list(m.digest())  # pyhton 3

注意,在python3中,list方法可以直接将字符转为ASC码,但是在python 2中需要利用内置函数ord来实现。

>>> import hashlib>>> m = hashlib.md5(b'hello world')>>> list(m.digest())[94, 182, 59, 187, 224, 30, 238, 208, 147, 203, 34, 187, 143, 90, 205, 195]

将自定义session改为分布式

实例化hash_ring对象,改写session中的几个内置方法,其它不变:

import tornado.ioloopimport tornado.webfrom hashlib import sha1import osimport timeimport redisfrom hash_ring import HashRing# 缓存服务器列表cache_servers = [    '192.168.0.246:11212',    '192.168.0.247:11212',    '192.168.0.249:11212']# 配置权重weights = {    '192.168.0.246:11212': 2,    '192.168.0.247:11212': 2,    '192.168.0.249:11212': 1}ring = HashRing(cache_servers, weights) # 实例化HashRing对象# 随机生成session_idcreate_session_id = lambda: sha1(bytes('%s%s' % (os.urandom(16), time.time()), encoding='utf-8')).hexdigest()class Session:    """自定义session"""    info_container = {        # session_id: {'user': info} --> 通过session保存用户信息,权限等    }    def __init__(self, handler):        """        初始化时传入RequestHandler对象,通过它进行cookie操作        self.handler.set_cookie()        self.handler.get_cookie()        :param handler:         """        self.handler = handler        # 从 cookie 中获取作为 session_id 的随机字符串,如果没有或不匹配则生成 session_id        random_str = self.handler.get_cookie('session_id')        if (not random_str) or (random_str not in self.info_container):            random_str = create_session_id()            self.info_container[random_str] = {}        self.random_str = random_str        # 每次请求进来都会执行set_cookie,保证每次重置过期时间为当前时间以后xx秒以后        self.handler.set_cookie('session_id', random_str, max_age=60)    def __getitem__(self, item):        # get_node()根据随机字符串哈希取模的结果,来选取服务器;再通过split方式提取服务器hotst和port        host, port = ring.get_node(self.random_str).split(':')        conn = redis.Redis(host=host, port=port)        return conn.hget(self.random_str, item)    def __setitem__(self, key, value):        host, port = ring.get_node(self.random_str).split(':')        conn = redis.Redis(host=host, port=port)        conn.hset(self.random_str, key, value)    def __delitem__(self, key):        host, port = ring.get_node(self.random_str).split(':')        conn = redis.Redis(host=host, port=port)        conn.hdel(self.random_str, key)    def delete(self):        """从大字典删除session_id"""        del self.info_container[self.random_str]
原创粉丝点击