编写ORM框架

来源:互联网 发布:网络出版物的图书 编辑:程序博客网 时间:2024/06/05 19:57

ORM:Object-Relational Mapping,把关系数据库中的表结构映射到对象上。然后操作数据库就不需要构造SQL语句,而是直接调用相应的方法。ORM框架可以方便的完成这些转换,然后,数据库表中的一行记录就对应着python中的一个对象,就不需要使用SQL语句,可以调用方法直接操作数据库。


Python中有名的ORM框架是SQLAlchemySQLAlchemy的用法:

from sqlalchemy import Column, String, create_enginefrom sqlalchemy.orm import sessionmakerfrom sqlalchemy.ext.declarative import declarative_base# 创建对象的基类:Base = declarative_base()# 定义User对象,继承上面的基类。对应着数据库中的一张表class User(Base):     __tablename__ = 'user' # 表的名字:    # 表的结构:    id = Column(String(20), primary_key=True)    name = Column(String(20))# 初始化数据库连接:engine = create_engine('mysql+mysqlconnector://root:password@localhost:3306/test')# 创建DBSession类型:DBSession = sessionmaker(bind=engine)往数据库中写入数据:session = DBSession()  # 创建session对象(连接数据库)new_user = User(id='5', name='Bob')  # 创建新User对象session.add(new_user) # 添加到sessionsession.commit() # 提交即保存到数据库session.close() # # 关闭session查询数据:session = DBSession() # 创建Sessionuser = session.query(User).filter(User.id=='5').one()  # 创建Query查询,filter是where条件,最后调用one()返回唯一行,如果调用all()则返回所有行print('type:', type(user)) # 打印类型和对象的name属性print('name:', user.name)session.close() # 关闭Session

廖雪峰老师Python教程实战部分使用异步,另外写了一个ORM框架.可以参考SQLAlchemy框架使用来编写新的ORM框架。


orm.py代码总共两百多行,整体结构如下:

import asyncio, loggingimport aiomysqlasync def create_pool(loop, **kw):  # 创建全局的连接池,每个HTTP请求都能从连接池中直接获取数据库连接,不必频繁地打开、关闭数据库连接。    ...async def select(sql, args, size=None):  # 用select()函数来执行SELECT语句,需要传入SQL语句和SQL参数。    ...async def execute(sql, args, autocommit=True): # UPDATE,INSERT,DELETE不需要详细的查询结果,封装在一个execute()函数中。    ...class Field(object):  # 定义字段基类    ...class StringField(Field): #继承Field定义不同类型字段类    ......class ModelMetaclass(type):  # 定义一个元类。每个表(Model对象)需要不同的继承模板,这里通过元类动态创建类。    ...class Model(dict, metaclass=ModelMetaclass): # ORM映射的基类。    ...然后,完成了一个简单ORM框架,使用时参照sqlalchemy框架。

具体细节:

数据库连接池

百度百科中关于数据库连接池的解释:

数据库连接是一种关键的、有限的、昂贵的资源,这一点在多用户的网页应用程序中体现得尤为突出。
数据库连接池在初始化时将创建一定数量的数据库连接放到连接池中,这些数据库连接的数量是由最小数据库连接数制约。无论这些数据库连接是否被使用,连接池都将一直保证至少拥有这么多的连接数量。连接池的最大数据库连接数量限定了这个连接池能占有的最大连接数,当应用程序向连接池请求的连接数超过最大连接数量时,这些请求将被加入到等待队列中。
连接池基本的思想是在系统初始化的时候,将数据库连接作为对象存储在内存中,当用户需要访问数据库时,并非建立一个新的连接,而是从连接池中取出一个已建立的空闲连接对象。使用完毕后,用户也并非将连接关闭,而是将连接放回连接池中,以供下一个请求访问使用。而连接的建立、断开都由连接池自身来管理。同时,还可以通过设置连接池的参数来控制连接池中的初始连接数、连接的上下限数以及每个连接的最大使用次数、最大空闲时间等等。也可以通过其自身的管理机制来监视数据库连接的数量、使用情况等。

python中的aiomysql为MySQL提供了异步IO的驱动。
aiomysql中有create_pool()方法,这里有create_pool()的源码。前辈们准备好了工具,现在先学会使用再说。

@asyncio.coroutine  # 表明create_pool()为协程def create_pool(loop, **kw):    logging.info('create database connection pool...')    global __pool   # 全局变量__pool来存储连接池。    __pool = yield from aiomysql.create_pool(        host=kw.get('host', 'localhost'),        port=kw.get('port', 3306),        user=kw['user'],  # 从参数中获取        password=kw['password'],        db=kw['db'],        charset=kw.get('charset', 'utf8'),        autocommit=kw.get('autocommit', True),  # 自动连接        maxsize=kw.get('maxsize', 10),  # 最多10个连接对象        minsize=kw.get('minsize', 1),  # 最少1个        loop=loop    )

封装SELECT方法:

查找是数据库最重要的一部分。这里写了select()来执行查找语句。

@asyncio.coroutinedef select(sql, args, size=None):  # sql指SQL语句,传递参数指定查找什么,size规定查找几条,默认None,会查找所有数据    log(sql, args)  #记录日志    global __pool    with (yield from __pool) as conn:        cur = yield from conn.cursor(aiomysql.DictCursor)  #创建游标来操作数据库。        yield from cur.execute(sql.replace('?', '%s'), args or ())  #SQL的占位符为?MySQL占位符为%s,然后执行SQL语句。        if size:            rs = yield from cur.fetchmany(size)  # yield            from,协程中调用另一个协程        else:            rs = yield from cur.fetchall()        yield from cur.close()  #关闭游标        logging.info('rows returned: %s' % len(rs))  # 再记录一下        return rs  # 返回查找结果

Insert, Update, Delete

这三个方法,Cursor操作完数据库,不用返回详细结果,封装在了一个execute()函数中。

@asyncio.coroutinedef execute(sql, args):    log(sql)    with (yield from __pool) as conn:        try:            cur = yield from conn.cursor()            yield from cur.execute(sql.replace('?', '%s'), args)  # 这里执行数据库操作            affected = cur.rowcount            yield from cur.close()        except BaseException as e:            raise        return affected  # 只返回影响数据库结果数

Field

有了直接操作数据库的方法,还需要定义数据库表中对应的字段。数据库中一张表有任意行,固定列,每一列的字段类型可能不同。

首先定义Field:

class Field(object):    def __init__(self, name, column_type, primary_key, default):        self.name = name   # 对应着数据库表中的字段名        self.column_type = column_type  #字段数据类型        self.primary_key = primary_key # 是否为主键        self.default = default # 有无默认值    def __str__(self):  #返回对象的字符串形式        return '<%s, %s:%s>' % (self.__class__.__name__, self.column_type, self.name)

Field子类:

class StringField(Field):    def __init__(self, name=None, primary_key=False, default=None, ddl='varchar(100)'):        super().__init__(name, ddl, primary_key, default) #  初始化self。

Model

开始定义所有ORM映射的基类Model

首先要有这样熟悉的功能:

>>> user['id']123>>> user.id123

然后要有find(),findAll(),remove(),update(),save(),这些方便的方法。

class Model(dict, metaclass=ModelMetaclass):  # #拥有dict的功能,同时继承自元类`ModelMetaclass`动态生成Model对象。     def __init__(self, **kw):        super(Model, self).__init__(**kw)    def __getattr__(self, key): #从对象中读取某个属性        try:            return self[key]        except KeyError:            raise AttributeError(r"'Model' object has no attribute '%s'" % key)    def __setattr__(self, key, value): #设置对象的属性        self[key] = value    def getValue(self, key):        return getattr(self, key, None)    def getValueOrDefault(self, key):  # 取默认值,定义字段类设置了默认值属性,默认值也可以是函数        value = getattr(self, key, None)        if value is None:            field = self.__mappings__[key]            if field.default is not None:                value = field.default() if callable(field.default) else field.default                logging.debug('using default value for %s: %s' % (key, str(value)))                setattr(self, key, value)        return value## 然后,find(),findAll(),remove(),update(),save()等好记又好用的方法。    @classmethod  #将方法变成属性    @asyncio.coroutine  # 这些方法都要是协程    def findAll(cls, where=None, args=None, **kw):        ' find objects by where clause. '        sql = [cls.__select__]  # cls for clause,每个表名都不相同,这里的__select__方法是动态生成的。        if where:            sql.append('where')  # 以下都是为了得到完整的SQL查询语句。            sql.append(where)        if args is None:            args = []        orderBy = kw.get('orderBy', None)        if orderBy:            sql.append('order by')            sql.append(orderBy)        limit = kw.get('limit', None)        if limit is not None:            sql.append('limit')            if isinstance(limit, int):                sql.append('?')                args.append(limit)            elif isinstance(limit, tuple) and len(limit) == 2:                sql.append('?, ?')                args.extend(limit)            else:                raise ValueError('Invalid limit value: %s' % str(limit))        rs = yield from select(' '.join(sql), args)  # 调用一开始定义的select()查询记录。        return [cls(**r) for r in rs]  # 将所有结果以列表形式返回。    @classmethod    @asyncio.coroutine    def findNumber(cls, selectField, where=None, args=None):        ' find number by select and where. '        sql = ['select %s _num_ from `%s`' % (selectField, cls.__table__)]  # 这个__table__也各不相同。        if where:            sql.append('where')            sql.append(where)        rs = yield from select(' '.join(sql), args, 1)        if len(rs) == 0:            return None        return rs[0]['_num_']    @classmethod    @asyncio.coroutine    def find(cls, pk):        ' find object by primary key. '        rs = yield from select('%s where `%s`=?' % (cls.__select__, cls.__primary_key__), [pk], 1)        if len(rs) == 0:            return None        return cls(**rs[0])  #返回一个实例对象引用    @asyncio.coroutine    def save(self):        args = list(map(self.getValueOrDefault, self.__fields__))  # 需要传递到SQL语句中的参数        args.append(self.getValueOrDefault(self.__primary_key__))        rows = yield from execute(self.__insert__, args)  # 调用上面定义的execute()方法,返回影响数        if rows != 1:            logging.warn('failed to insert record: affected rows: %s' % rows)    @asyncio.coroutine    def update(self):        args = list(map(self.getValue, self.__fields__))        args.append(self.getValue(self.__primary_key__))        rows = yield from execute(self.__update__, args)        if rows != 1:            logging.warn('failed to update by primary key: affected rows: %s' % rows)    @asyncio.coroutine    def remove(self):        args = [self.getValue(self.__primary_key__)]        rows = yield from execute(self.__delete__, args)        if rows != 1:            logging.warn('failed to remove by primary key: affected rows: %s' % rows)

元类ModelMetaclass

创建一个元类让Model继承,这样,对象需要不同的继承模板。使用元类,通过继承Model就能继承ModelMetaclass,就能动态生成一个对象。

class ModelMetaclass(type):  # 类是对象的模板,元类是类的模板。type看成类工厂,制造各种类。    def __new__(cls, name, bases, attrs):  # 当一个类指定通过莫元类来创建,会调用该元类的__new__方法。    # cls 参数为当前准备创建类的对象 name 为类的名字, bases为继承的父类集合, attrs为类的属性/方法集合。    # 创建User=Model(),name就是User, bases就是Model, attrs就是一个包含User类属性的dict        if name=='Model': # Model是基类,要排除掉            return type.__new__(cls, name, bases, attrs) # 直接返回就行        # 获取table名称:        tableName = attrs.get('__table__', None) or name        logging.info('found model: %s (table: %s)' % (name, tableName))        mappings = dict()  # 用于存储所有的字段名和字段的映射        fields = []  # 用于存储非主键以外的其他字段,而且只存key        primaryKey = None        # 这里k for key, 是字段名, v for vale, 是字段实例,例如StringField        for k, v in attrs.items():            if isinstance(v, Field):                logging.info('  found mapping: %s ==> %s' % (k, v))                mappings[k] = v  # 储存到mappings字典中。                if v.primary_key: # 创建字段会设置primary_key=True                    # 找到主键:                    if primaryKey:                        raise StandardError('Duplicate primary key for field: %s' % k)                    primaryKey = k # 上述条件成立,把这个字段名赋值给primaryKey变量。                else:                    fields.append(k)  # 非主键保存再fields中。        if not primaryKey:  # 一个主键都没有,报错            raise StandardError('Primary key not found.')        for k in mappings.keys():            attrs.pop(k)  # 去除掉不需要的字段名,返回下面的属性。        escaped_fields = list(map(lambda f: '`%s`' % f, fields))        #通过attrs返回的东西,子类中都能通过实例获取,例如self.__table__        # 这样,任何继承自Model的类(比如User),会自动通过ModelMetaclass扫描映射关系,并存储到自身的类属性如__table__、__mappings__中。        attrs['__mappings__'] = mappings # 保存属性和列的映射关系        attrs['__table__'] = tableName        attrs['__primary_key__'] = primaryKey # 主键属性名        attrs['__fields__'] = fields # 除主键外的属性名        # 在这里定义这些属性,Model看起来更简单些        attrs['__select__'] = 'select `%s`, %s from `%s`' % (primaryKey, ', '.join(escaped_fields), tableName)        attrs['__insert__'] = 'insert into `%s` (%s, `%s`) values (%s)' % (tableName, ', '.join(escaped_fields), primaryKey, create_args_string(len(escaped_fields) + 1))        attrs['__update__'] = 'update `%s` set %s where `%s`=?' % (tableName, ', '.join(map(lambda f: '`%s`=?' % (mappings.get(f).name or f), fields)), primaryKey)        attrs['__delete__'] = 'delete from `%s` where `%s`=?' % (tableName, primaryKey)        return type.__new__(cls, name, bases, attrs)

以上,完成了一个简单的ORM 框架。
廖雪峰老师教程中的源码:orm.py


这篇文章很详细:跟着廖大学python之orm框架实现

深刻理解Python中的元类(metaclass)

原创粉丝点击