twisted实现的Redis服务器
来源:互联网 发布:抱抱软件安全吗 编辑:程序博客网 时间:2024/05/04 11:56
- # -*- coding: utf-8 -*-
- from twisted.internet import reactor
- from twisted.internet.protocol import ServerFactory
- from twisted.protocols.basic import LineReceiver
- import fnmatch
- import shlex
- import time
- class Error(Exception): pass
- class WrongNumberArgumentsError(Error): pass
- class WrongValueTypeError(Error): pass
- class RObj:
- __slots__ = 'type', 'expireAt', 'value'
- type = 'string'
- expireAt = None
- def __init__(self, value):
- self.value = value
- class RedisServerProtocol(LineReceiver):
- def __init__(self):
- self.argc = 0 # 参数个数
- self.argl = 0 # 参数长度
- self.argv = [] # 参数数组
- def connectionMade(self):
- self.SELECT(['', '0'])
- def dataReceived(self, data):
- return LineReceiver.dataReceived(self, data)
- def rawDataReceived(self, data):
- self.argv.append(data[:self.argl])
- self.argc -= 1
- if not self.argc:
- self.process(self.argv)
- self.argv = []
- self.setLineMode(data[self.argl+2:])
- def lineReceived(self, line):
- if self.argc:
- self.argl = int(line[1:])
- self.setRawMode()
- return
- if line[0] == '*':
- self.argc = int(line[1:])
- if self.argv:
- self.sendLine('-ERR: %r' % self.argv)
- return
- self.process(shlex.split(line))
- def process(self, argv):
- command = argv[0].upper()
- handler = getattr(self, command, self.todo)
- try:
- handler(argv)
- except WrongNumberArgumentsError:
- self.sendError("wrong number of arguments for %r command" % command)
- except WrongValueTypeError:
- self.sendError('Operation against a key holding the wrong kind of value')
- except Exception, e:
- import pdb;pdb.set_trace()
- self.sendLine('-ERR %r' % e)
- def todo(self, argv):
- print 'TODO', self.argv
- self.sendOK()
- def sendOK(self):
- self.sendLine('+OK')
- def SELECT(self, argv):
- self.db = self.factory.db[int(argv[1])]
- def PING(self, argv):
- self.sendLine('+PONG')
- def SET(self, argv):
- if len(argv) != 3: raise WrongNumberArgumentsError
- self.db[argv[1]] = RObj(argv[2])
- self.sendOK()
- def SETNX(self, argv):
- if len(argv) != 3: raise WrongNumberArgumentsError
- key = argv[1]
- if key in self.db:
- self.sendInteger(0)
- return
- self.db[argv[1]] = RObj(argv[2])
- self.sendOK()
- def MSET(self, argv):
- if len(argv) % 2 == 0: raise WrongNumberArgumentsError
- for i in xrange(1, len(argv), 2):
- self.db[argv[i]] = RObj(argv[i+1])
- self.sendOK()
- def EXISTS(self, argv):
- if len(argv) != 2: raise WrongNumberArgumentsError
- key = argv[1]
- self.expireCheck(key)
- if key in self.db:
- self.sendInteger(1)
- else:
- self.sendInteger(0)
- def GET(self, argv):
- if len(argv) != 2: raise WrongNumberArgumentsError
- key = argv[1]
- self.expireCheck(key)
- robj = self.db.get(key)
- if robj and robj.type != 'string': raise WrongValueTypeError
- self.sendBulk(robj and robj.value)
- def MGET(self, argv):
- if len(argv) < 2: raise WrongNumberArgumentsError
- keys = argv[1:]
- map(self.expireCheck, keys)
- robjs = [self.db.get(key) for key in keys]
- values = [o.value if o and o.type == 'string' else None for o in robjs]
- self.sendBulks(values)
- def DEL(self, argv):
- if len(argv) < 2: raise WrongNumberArgumentsError
- keys = argv[1:]
- map(self.expireCheck, keys)
- count = 0
- for key in keys:
- if key in self.db:
- del self.db[key]
- count += 1
- self.sendInteger(count)
- def DELEQ(self, argv):
- '''DELEQ key value
-
- Delete a key only if value is matched.
- Used only for strings.
- '''
- if len(argv) != 3: raise WrongNumberArgumentsError
- key = argv[1]
- self.expireCheck(key)
- robj = self.db.get(key)
- if robj and robj.value == argv[2]:
- del self.db[key]
- self.sendInteger(1)
- else:
- self.sendInteger(0)
- def KEYS(self, argv):
- if len(argv) != 2: raise WrongNumberArgumentsError
- keys = fnmatch.filter(self.db.keys(), argv[1])
- map(self.expireCheck, keys)
- keys = fnmatch.filter(self.db.keys(), argv[1])
- self.sendBulks(keys)
- def _INCRBY(self, key, increment):
- self.expireCheck(key)
- robj = self.db.setdefault(key, RObj('0'))
- if robj.type != 'string': raise WrongValueTypeError
- try:
- value = int(robj.value) + increment
- except ValueError:
- self.sendError('value is not an integer or out of range')
- return
- robj.value = str(value)
- self.sendInteger(value)
- def INCR(self, argv):
- if len(argv) != 2: raise WrongNumberArgumentsError
- self._INCRBY(argv[1], 1)
- def INCRBY(self, argv):
- if len(argv) != 3: raise WrongNumberArgumentsError
- self._INCRBY(argv[1], int(argv[2]))
- def DECR(self, argv):
- if len(argv) != 2: raise WrongNumberArgumentsError
- self._INCRBY(argv[1], -1)
- def DECRBY(self, argv):
- if len(argv) != 3: raise WrongNumberArgumentsError
- self._INCRBY(argv[1], -int(argv[2]))
- def _EXPIREAT(self, key, expireAt):
- robj = self.db.get(key)
- if robj:
- robj.expireAt = expireAt
- self.sendInteger(1)
- else:
- self.sendInteger(0)
- def EXPIRE(self, argv):
- if len(argv) != 3: raise WrongNumberArgumentsError
- self._EXPIREAT(argv[1], time.time() + float(argv[2]))
- def EXPIREAT(self, argv):
- if len(argv) != 3: raise WrongNumberArgumentsError
- self._EXPIREAT(argv[1], float(argv[2]))
- def _EXPIREATEQ(self, key, value, expireAt):
- robj = self.db.get(key)
- if robj and robj.value == value:
- robj.expireAt = expireAt
- self.sendInteger(1)
- else:
- self.sendInteger(0)
- def EXPIREEQ(self, argv):
- '''EXPIREEQ key value seconds
-
- Set a key's time to live in seconds only if value is matched.
- Used only for strings.
- '''
- if len(argv) != 4: raise WrongNumberArgumentsError
- self._EXPIREATEQ(argv[1], argv[2], time.time() + float(argv[3]))
- def EXPIREATEQ(self, argv):
- '''EXPIREEQ key value seconds
-
- Set the expiration for a key as a UNIX timestamp only if value is matched.
- Used only for strings.
- '''
- if len(argv) != 4: raise WrongNumberArgumentsError
- self._EXPIREATEQ(argv[1], argv[2], float(argv[3]))
- def TTL(self, argv):
- if len(argv) != 2: raise WrongNumberArgumentsError
- robj = self.db.get(argv[1])
- if robj and robj.expireAt:
- self.sendInteger(robj.expireAt - time.time())
- else:
- self.sendInteger(-1)
- def expireCheck(self, key):
- robj = self.db.get(key)
- if robj and robj.expireAt:
- if time.time() > robj.expireAt:
- del self.db[key]
- def sendInteger(self, i):
- self.sendLine(':%d' % i)
- def sendBulk(self, bulk):
- if bulk is None:
- self.sendLine('$-1')
- return
- self.sendLine('$%d' % len(bulk))
- self.sendLine(bulk)
- def sendBulks(self, bulks):
- self.sendLine('*%d' % len(bulks))
- map(self.sendBulk, bulks)
- def sendError(self, message):
- self.sendLine('-ERR %s