【Mybatis】使用Interceptor进行分页

来源:互联网 发布:微软词典知乎 编辑:程序博客网 时间:2024/04/29 23:23

本文主要介绍如何使用Mybaits的拦截器 对Sql进行改造,以便支持分页。

关于mybatis的拦截器使用以及原理可以参考我的另外一篇文章(http://www.cnblogs.com/daxin/p/3544188.html)。

首先说一下实现原理,就是当Mybatis查询数据的时候,利用Mybatis提供的Interceptor对查询方法进行拦截,动态修改SQL,把不带分页的SQL改造成带分页的SQL

比如在MySql数据库下面 正常查询的SQL 为 selelct id, name from user where name = ?  那么改造后的SQL 为 selelct id, name from user where name = ? limit 5 , 10

这样就支持分页了。

 

1、首先扩展mybatis提供的RowBounds

什么都没变,就是加了一个属性total 用来存放查询到的总记录数。另外加了setMeToDefault这个方法,目的是将这个类设置为初始状态,这样Mybatis底层就不会在去分页了。

 

复制代码
 1 package com.framework.core.mybatis; 2  3 import org.apache.ibatis.session.RowBounds; 4  5 public class PagingBounds extends RowBounds { 6      7     //总记录数 8     private int total; 9     //查询的起始位置10     private int offset;11     //查询多少行记录12     private int limit;13 14     public PagingBounds() {15         this.offset = NO_ROW_OFFSET;16         this.limit = NO_ROW_LIMIT;17     }18 19     public PagingBounds(int offset, int limit) {20         this.offset = offset;21         this.limit = limit;22     }23 24     public int getTotal() {25         return total;26     }27 28     public void setTotal(int total) {29         this.total = total;30     }31 32     public int getOffset() {33         return offset;34     }35 36     public int getLimit() {37         return limit;38     }39 40     public void setMeToDefault() {41         this.offset = NO_ROW_OFFSET;42         this.limit = NO_ROW_LIMIT;43     }44     45     public int getSelectCount() {46         return limit + offset;47     }48 }
复制代码

 

2、定义一个抽象类实现Mybatis的拦截器接口,将分页的核心逻辑抽离出来,这样可以针对不同的数据库进行扩展,这个类也是分页逻辑的核心,细节相对复杂。

可以看到有俩个抽象方法需要子类扩展

1:protected abstract String getSelectTotalSql(String targetSql);

这个方法是把目标SQL(这里的SQL就是你XML当中定义的,比如 select id, name from user where name = ?)传递给子类,子类可以针对不同数据库进行包装,并返回查询总数的SQL。

2:protected abstract String getSelectPagingSql(String targetSql, PagingBounds pagingBounds);

这个方法是把目标SQL改造成带分页的SQL,PagingBounds这个参数有分页的信息。

复制代码
  1 package com.framework.core.mybatis;  2   3 import java.lang.reflect.Field;  4 import java.lang.reflect.Modifier;  5 import java.sql.Connection;  6 import java.sql.PreparedStatement;  7 import java.sql.ResultSet;  8 import java.sql.SQLException;  9 import java.util.List; 10 import java.util.Properties; 11 import java.util.regex.Matcher; 12 import java.util.regex.Pattern; 13  14 import org.apache.ibatis.executor.parameter.ParameterHandler; 15 import org.apache.ibatis.executor.statement.RoutingStatementHandler; 16 import org.apache.ibatis.executor.statement.StatementHandler; 17 import org.apache.ibatis.mapping.BoundSql; 18 import org.apache.ibatis.mapping.MappedStatement; 19 import org.apache.ibatis.mapping.ParameterMapping; 20 import org.apache.ibatis.plugin.Interceptor; 21 import org.apache.ibatis.plugin.Invocation; 22 import org.apache.ibatis.plugin.Plugin; 23 import org.apache.ibatis.scripting.defaults.DefaultParameterHandler; 24 import org.apache.ibatis.session.RowBounds; 25  26 public abstract class AbstractPagingInterceptor implements Interceptor { 27  28     private static final Pattern PATTERN_SQL_BLANK = Pattern.compile("\\s+"); 29  30     private static final String FIELD_DELEGATE = "delegate"; 31      32     private static final String FIELD_ROWBOUNDS = "rowBounds"; 33      34     private static final String FIELD_MAPPEDSTATEMENT = "mappedStatement"; 35      36     private static final String FIELD_SQL = "sql"; 37      38     private static final String BLANK = " "; 39  40     public static final String SELECT = "select"; 41      42     public static final String FROM = "from"; 43      44     public static final String ORDER_BY = "order by"; 45      46     public static final String UNION = "union"; 47      48     @Override 49     public Object intercept(Invocation invocation) throws Throwable { 50         Connection connection = (Connection) invocation.getArgs()[0]; 51         RoutingStatementHandler statementHandler = (RoutingStatementHandler) invocation.getTarget(); 52  53         StatementHandler handler = (StatementHandler) readField(statementHandler, FIELD_DELEGATE); 54         PagingBounds pagingBounds = (PagingBounds) readField(handler, FIELD_ROWBOUNDS); 55         MappedStatement mappedStatement = (MappedStatement) readField(handler, FIELD_MAPPEDSTATEMENT); 56         BoundSql boundSql = handler.getBoundSql(); 57  58         //replace all blank 59         String targetSql = replaceSqlBlank(boundSql.getSql()); 60          61         //paging 62         getTotalAndSetInPagingBounds(targetSql, boundSql, pagingBounds, mappedStatement, connection); 63         String pagingSql = getSelectPagingSql(targetSql, pagingBounds); 64         writeDeclaredField(boundSql, FIELD_SQL, pagingSql); 65          66         //ensure set to default 67         pagingBounds.setMeToDefault();  68         return invocation.proceed(); 69     } 70  71     private void getTotalAndSetInPagingBounds(String targetSql, BoundSql boundSql, PagingBounds bounds,  72                             MappedStatement mappedStatement, Connection connection) throws SQLException { 73         String totalSql = getSelectTotalSql(targetSql); 74         List<ParameterMapping> parameterMappings = boundSql.getParameterMappings(); 75         Object parameterObject = boundSql.getParameterObject(); 76         BoundSql totalBoundSql = new BoundSql(mappedStatement.getConfiguration(), totalSql, parameterMappings, parameterObject); 77         ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, parameterObject, totalBoundSql); 78          79         PreparedStatement pstmt = null; 80         ResultSet rs = null; 81         try { 82             pstmt = connection.prepareStatement(totalSql); 83             parameterHandler.setParameters(pstmt); 84             rs = pstmt.executeQuery(); 85             if(rs.next()) { 86                 int totalRecord = rs.getInt(1); 87                 bounds.setTotal(totalRecord); 88             } 89         } finally { 90             if(rs != null) { 91                 rs.close(); 92             } 93             if(pstmt != null) { 94                 pstmt.close(); 95             } 96         } 97     } 98  99     protected abstract String getSelectTotalSql(String targetSql);100     101     protected abstract String getSelectPagingSql(String targetSql, PagingBounds pagingBounds);102 103     private String replaceSqlBlank(String originalSql) {104         Matcher matcher = PATTERN_SQL_BLANK.matcher(originalSql);105         return matcher.replaceAll(BLANK);106     }107 108     @Override109     public Object plugin(Object target) {110         if (target instanceof RoutingStatementHandler) {111             try {112                 Field delegate = getField(RoutingStatementHandler.class, FIELD_DELEGATE);113                 StatementHandler handler = (StatementHandler) delegate.get(target);114                 RowBounds rowBounds = (RowBounds) readField(handler, FIELD_ROWBOUNDS);115                 if (rowBounds != RowBounds.DEFAULT && rowBounds instanceof PagingBounds) {116                     return Plugin.wrap(target, this);117                 }118             } catch (IllegalAccessException e) {119                 // ignore120             }121         }122         return target;123     }124 125     private void writeDeclaredField(Object target, String fieldName, Object value) 126             throws IllegalAccessException {127         if (target == null) {128             throw new IllegalArgumentException("target object must not be null");129         }130         Class<?> cls = target.getClass();131         Field field = getField(cls, fieldName);132         if (field == null) {133             throw new IllegalArgumentException("Cannot locate declared field " + cls.getName() + "." + fieldName);134         }135         field.set(target, value);136     }137 138     private Object readField(Object target, String fieldName)139             throws IllegalAccessException {140         if (target == null) {141             throw new IllegalArgumentException("target object must not be null");142         }143         Class<?> cls = target.getClass();144         Field field = getField(cls, fieldName);145         if (field == null) {146             throw new IllegalArgumentException("Cannot locate field " + fieldName + " on " + cls);147         }148         if (!field.isAccessible()) {149             field.setAccessible(true);150         }151         return field.get(target);152     }153 154     private static Field getField(final Class<?> cls, String fieldName) {155         for (Class<?> acls = cls; acls != null; acls = acls.getSuperclass()) {156             try {157                 Field field = acls.getDeclaredField(fieldName);158                 if (!Modifier.isPublic(field.getModifiers())) {159                     field.setAccessible(true);160                     return field;161                 }162             } catch (NoSuchFieldException ex) {163                 // ignore164             }165         }166         return null;167     }168     169     @Override170     public void setProperties(Properties properties) {171 172     }173 }
复制代码

 

3、针对不同数据库进行扩展,这里给出俩个例子。

MSSQL2008

复制代码
package com.framework.core.mybatis;import java.sql.Connection;import org.apache.ibatis.executor.statement.StatementHandler;import org.apache.ibatis.plugin.Intercepts;import org.apache.ibatis.plugin.Signature;@Intercepts(@Signature(type=StatementHandler.class,method="prepare",args={Connection.class})) public class SqlServerPagingInterceptor extends AbstractPagingInterceptor {    /**     * 改造sql变成查询总数的sql     * @param targetSql 正常查询数据的sql: select id, name from user where name = ?     * @return select count(1) from user where name = ?     */    @Override    protected String getSelectTotalSql(String targetSql) {        String sql = targetSql.toLowerCase();        StringBuilder sqlBuilder = new StringBuilder(sql);                int orderByPos = 0;        if((orderByPos = sqlBuilder.lastIndexOf(ORDER_BY)) != -1) {            sqlBuilder.delete(orderByPos, sqlBuilder.length());        }                if(sqlBuilder.indexOf(UNION) != -1) {            sqlBuilder.insert(0, "select count(1) as _count from ( ").append(" ) as _alias");            return sqlBuilder.toString();        }                int fromPos = sqlBuilder.indexOf(FROM);        if(fromPos != -1) {            sqlBuilder.delete(0, fromPos);            sqlBuilder.insert(0, "select count(1) as _count ");        }                return sqlBuilder.toString();    }    /**     * 改造sql变成支持分页的sql     * @param targetSql 正常查询数据的sql: select id, name from user where name = ?     * @return WITH query AS (SELECT inner_query.*, ROW_NUMBER() OVER (ORDER BY CURRENT_TIMESTAMP)      * as __mybatis_row_nr__ FROM ( select id, name from user where name = ? ) inner_query )      * SELECT * FROM query WHERE __mybatis_row_nr__ >= 3 AND __mybatis_row_nr__ <= 6      */    @Override    protected String getSelectPagingSql(String targetSql, PagingBounds bounds) {        String sql = targetSql.toLowerCase();        StringBuilder sqlBuilder = new StringBuilder(sql);                if(sqlBuilder.indexOf(ORDER_BY) != -1) {            int selectPos = sqlBuilder.indexOf(SELECT);            sqlBuilder.insert(selectPos + SELECT.length(), " TOP(" + bounds.getSelectCount() + ")");        }                sqlBuilder.insert(0, "SELECT inner_query.*, ROW_NUMBER() OVER (ORDER BY CURRENT_TIMESTAMP) as __mybatis_row_nr__ FROM ( ");        sqlBuilder.append(" ) inner_query ");                sqlBuilder.insert(0, "WITH query AS (").append(") SELECT ").append("*").append(" FROM query ");        sqlBuilder.append("WHERE __mybatis_row_nr__ >= " + (bounds.getOffset() + 1) + " AND __mybatis_row_nr__ <= " + bounds.getSelectCount());                return sqlBuilder.toString();    }}
复制代码

 

Oracle10G的

复制代码
package com.framework.core.mybatis;import java.sql.Connection;import org.apache.ibatis.executor.statement.StatementHandler;import org.apache.ibatis.plugin.Intercepts;import org.apache.ibatis.plugin.Signature;@Intercepts(@Signature(type=StatementHandler.class,method="prepare",args={Connection.class})) public class OraclePagingInterceptor extends AbstractPagingInterceptor {    @Override    protected String getSelectTotalSql(String targetSql) {        String sql = targetSql.toLowerCase();        StringBuilder sqlBuilder = new StringBuilder(sql);                int orderByPos = 0;        if((orderByPos = sqlBuilder.lastIndexOf(ORDER_BY)) != -1) {            sqlBuilder.delete(orderByPos, sqlBuilder.length());        }        sqlBuilder.insert(0, "select count(1) as _count from ( ").append(" )");                return sqlBuilder.toString();    }    @Override    protected String getSelectPagingSql(String targetSql, PagingBounds bounds) {        String sql = targetSql.toLowerCase();        StringBuilder sqlBuilder = new StringBuilder(sql);                sqlBuilder.insert(0, "select * from ( select table_alias.*, rownum mybatis_rowNo from (");        sqlBuilder.append(") ");        sqlBuilder.append("table_alias where rownum <= " + bounds.getSelectCount()).append(")");        sqlBuilder.append("where mybatis_rowNo >= " + (bounds.getOffset() + 1));                return sqlBuilder.toString();    }}
复制代码

 

4、就是将拦截器注册到mybatis的配置文件当中去。你使用什么数据库就注册那个类就行了。

复制代码
<configuration>    <settings>        <setting name="logImpl" value="LOG4J" />        <!-- <setting name="cacheEnabled" value="false"/> -->    </settings>    <plugins>        <plugin interceptor="com.framework.core.mybatis.SqlServerPagingInterceptor" />    </plugins>    <environments default="development">        <environment id="development">            <transactionManager type="JDBC" />            <dataSource type="POOLED">                <property name="driver" value="net.sourceforge.jtds.jdbc.Driver" />                <property name="url"                    value="jdbc:jtds:sqlserver://127.0.0.1/FDK2" />                <property name="username" value="sa" />                <property name="password" value="sql2008" />            </dataSource>        </environment>    </environments>    <mappers>        <mapper resource="com/framework/code/mapper/ConnectMapper.xml" />    </mappers></configuration>
复制代码

 

 

5、接下来写个DEMO测试一下吧。

 View Code

 

复制代码
 1 public class Demo { 2  3      4     public static void main(String[] args) throws NoSuchMethodException, SecurityException, IOException { 5         String resource = "mybatis-config.xml"; 6         InputStream inputStream = Resources.getResourceAsStream(resource); 7         DefaultSqlSessionFactory sqlSessionFactory = (DefaultSqlSessionFactory) new SqlSessionFactoryBuilder().build(inputStream); 8         DefaultSqlSession session = (DefaultSqlSession) sqlSessionFactory.openSession(); 9         ConnectMapper mapper = (ConnectMapper)session.getMapper(ConnectMapper.class);10         11         //创建一个PagingBounds 设置起始位置2 查询4调记录12         PagingBounds bounds = new PagingBounds(2, 4);13         HashMap<String, Object> params = new HashMap<String, Object>();14         15         //返回结果16         List<Connect> list = mapper.selectPaging(bounds, params);17         System.out.println(list);18         //总记录数19         System.out.println(bounds.getTotal());20         21         session.commit();22         session.close();23     }24     25 }
复制代码

 

 

看一下日志输出 已经分页了。

复制代码
DEBUG 2014-07-23 11:25:47,216 org.apache.ibatis.logging.jdbc.BaseJdbcLogger: ==>  Preparing: select count(1) as _count from connect DEBUG 2014-07-23 11:25:47,259 org.apache.ibatis.logging.jdbc.BaseJdbcLogger: ==> Parameters: DEBUG 2014-07-23 11:25:47,290 org.apache.ibatis.logging.jdbc.BaseJdbcLogger: ==>  Preparing: WITH query AS (SELECT inner_query.*, ROW_NUMBER() OVER (ORDER BY CURRENT_TIMESTAMP) as __mybatis_row_nr__ FROM ( select id, pid, number, numbertype, numbertypename, previousreceiveuser, previousreceiveusername, receiveuser, receiveusername, happentime, whethersuccess, successtime, whetherhhold from connect ) inner_query ) SELECT * FROM query WHERE __mybatis_row_nr__ >= 3 AND __mybatis_row_nr__ <= 6 DEBUG 2014-07-23 11:25:47,290 org.apache.ibatis.logging.jdbc.BaseJdbcLogger: ==> Parameters: [com.framework.code.domain.Connect@c26dd5, com.framework.code.domain.Connect@b8cab9]4
复制代码

http://www.cnblogs.com/daxin/p/3236861.html

0 0
原创粉丝点击