在mybatis执行SQL语句之前进行拦截处理

来源:互联网 发布:新疆有网络吗 编辑:程序博客网 时间:2024/05/22 02:24

比较适用于在分页时候进行拦截。对分页的SQL语句通过封装处理,处理成不同的分页sql。

实用性比较强。

[java] view plain copy print?在CODE上查看代码片派生到我的代码片
  1. import java.sql.Connection;  
  2. import java.sql.PreparedStatement;  
  3. import java.sql.ResultSet;  
  4. import java.sql.SQLException;  
  5. import java.util.List;  
  6. import java.util.Properties;  
  7.   
  8. import org.apache.ibatis.executor.parameter.ParameterHandler;  
  9. import org.apache.ibatis.executor.statement.RoutingStatementHandler;  
  10. import org.apache.ibatis.executor.statement.StatementHandler;  
  11. import org.apache.ibatis.mapping.BoundSql;  
  12. import org.apache.ibatis.mapping.MappedStatement;  
  13. import org.apache.ibatis.mapping.ParameterMapping;  
  14. import org.apache.ibatis.plugin.Interceptor;  
  15. import org.apache.ibatis.plugin.Intercepts;  
  16. import org.apache.ibatis.plugin.Invocation;  
  17. import org.apache.ibatis.plugin.Plugin;  
  18. import org.apache.ibatis.plugin.Signature;  
  19. import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;  
  20.   
  21. import com.yidao.utils.Page;  
  22. import com.yidao.utils.ReflectHelper;  
  23.   
  24. /**  
  25.  *  
  26.  * 分页拦截器,用于拦截需要进行分页查询的操作,然后对其进行分页处理。  
  27.  * 利用拦截器实现Mybatis分页的原理:  
  28.  * 要利用JDBC对数据库进行操作就必须要有一个对应的Statement对象,Mybatis在执行Sql语句前就会产生一个包含Sql语句的Statement对象,而且对应的Sql语句  
  29.  * 是在Statement之前产生的,所以我们就可以在它生成Statement之前对用来生成Statement的Sql语句下手。在Mybatis中Statement语句是通过RoutingStatementHandler对象的  
  30.  * prepare方法生成的。所以利用拦截器实现Mybatis分页的一个思路就是拦截StatementHandler接口的prepare方法,然后在拦截器方法中把Sql语句改成对应的分页查询Sql语句,之后再调用  
  31.  * StatementHandler对象的prepare方法,即调用invocation.proceed()。  
  32.  * 对于分页而言,在拦截器里面我们还需要做的一个操作就是统计满足当前条件的记录一共有多少,这是通过获取到了原始的Sql语句后,把它改为对应的统计语句再利用Mybatis封装好的参数和设  
  33.  * 置参数的功能把Sql语句中的参数进行替换,之后再执行查询记录数的Sql语句进行总记录数的统计。  
  34.  *  
  35.  */    
  36. @Intercepts({@Signature(type=StatementHandler.class,method="prepare",args={Connection.class})})  
  37. public class PageInterceptor implements Interceptor {  
  38.     private String dialect = ""//数据库方言    
  39.     private String pageSqlId = ""//mapper.xml中需要拦截的ID(正则匹配)    
  40.         
  41.     public Object intercept(Invocation invocation) throws Throwable {  
  42.         //对于StatementHandler其实只有两个实现类,一个是RoutingStatementHandler,另一个是抽象类BaseStatementHandler,    
  43.         //BaseStatementHandler有三个子类,分别是SimpleStatementHandler,PreparedStatementHandler和CallableStatementHandler,    
  44.         //SimpleStatementHandler是用于处理Statement的,PreparedStatementHandler是处理PreparedStatement的,而CallableStatementHandler是    
  45.         //处理CallableStatement的。Mybatis在进行Sql语句处理的时候都是建立的RoutingStatementHandler,而在RoutingStatementHandler里面拥有一个    
  46.         //StatementHandler类型的delegate属性,RoutingStatementHandler会依据Statement的不同建立对应的BaseStatementHandler,即SimpleStatementHandler、    
  47.         //PreparedStatementHandler或CallableStatementHandler,在RoutingStatementHandler里面所有StatementHandler接口方法的实现都是调用的delegate对应的方法。    
  48.         //我们在PageInterceptor类上已经用@Signature标记了该Interceptor只拦截StatementHandler接口的prepare方法,又因为Mybatis只有在建立RoutingStatementHandler的时候    
  49.         //是通过Interceptor的plugin方法进行包裹的,所以我们这里拦截到的目标对象肯定是RoutingStatementHandler对象。  
  50.         if(invocation.getTarget() instanceof RoutingStatementHandler){    
  51.             RoutingStatementHandler statementHandler = (RoutingStatementHandler)invocation.getTarget();    
  52.             StatementHandler delegate = (StatementHandler) ReflectHelper.getFieldValue(statementHandler, "delegate");    
  53.             BoundSql boundSql = delegate.getBoundSql();  
  54.             Object obj = boundSql.getParameterObject();  
  55.             if (obj instanceof Page<?>) {    
  56.                 Page<?> page = (Page<?>) obj;    
  57.                 //通过反射获取delegate父类BaseStatementHandler的mappedStatement属性    
  58.                 MappedStatement mappedStatement = (MappedStatement)ReflectHelper.getFieldValue(delegate, "mappedStatement");    
  59.                 //拦截到的prepare方法参数是一个Connection对象    
  60.                 Connection connection = (Connection)invocation.getArgs()[0];    
  61.                 //获取当前要执行的Sql语句,也就是我们直接在Mapper映射语句中写的Sql语句    
  62.                 String sql = boundSql.getSql();    
  63.                 //给当前的page参数对象设置总记录数    
  64.                 this.setTotalRecord(page,    
  65.                        mappedStatement, connection);    
  66.                 //获取分页Sql语句    
  67.                 String pageSql = this.getPageSql(page, sql);    
  68.                 //利用反射设置当前BoundSql对应的sql属性为我们建立好的分页Sql语句    
  69.                 ReflectHelper.setFieldValue(boundSql, "sql", pageSql);    
  70.             }   
  71.         }    
  72.         return invocation.proceed();    
  73.     }  
  74.       
  75.     /**  
  76.      * 给当前的参数对象page设置总记录数  
  77.      *  
  78.      * @param page Mapper映射语句对应的参数对象  
  79.      * @param mappedStatement Mapper映射语句  
  80.      * @param connection 当前的数据库连接  
  81.      */    
  82.     private void setTotalRecord(Page<?> page,    
  83.            MappedStatement mappedStatement, Connection connection) {    
  84.        //获取对应的BoundSql,这个BoundSql其实跟我们利用StatementHandler获取到的BoundSql是同一个对象。    
  85.        //delegate里面的boundSql也是通过mappedStatement.getBoundSql(paramObj)方法获取到的。    
  86.        BoundSql boundSql = mappedStatement.getBoundSql(page);    
  87.        //获取到我们自己写在Mapper映射语句中对应的Sql语句    
  88.        String sql = boundSql.getSql();    
  89.        //通过查询Sql语句获取到对应的计算总记录数的sql语句    
  90.        String countSql = this.getCountSql(sql);    
  91.        //通过BoundSql获取对应的参数映射    
  92.        List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();    
  93.        //利用Configuration、查询记录数的Sql语句countSql、参数映射关系parameterMappings和参数对象page建立查询记录数对应的BoundSql对象。    
  94.        BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, page);    
  95.        //通过mappedStatement、参数对象page和BoundSql对象countBoundSql建立一个用于设定参数的ParameterHandler对象    
  96.        ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, page, countBoundSql);    
  97.        //通过connection建立一个countSql对应的PreparedStatement对象。    
  98.        PreparedStatement pstmt = null;    
  99.        ResultSet rs = null;    
  100.        try {    
  101.            pstmt = connection.prepareStatement(countSql);    
  102.            //通过parameterHandler给PreparedStatement对象设置参数    
  103.            parameterHandler.setParameters(pstmt);    
  104.            //之后就是执行获取总记录数的Sql语句和获取结果了。    
  105.            rs = pstmt.executeQuery();    
  106.            if (rs.next()) {    
  107.               int totalRecord = rs.getInt(1);    
  108.               //给当前的参数page对象设置总记录数    
  109.               page.setTotalRecord(totalRecord);    
  110.            }    
  111.        } catch (SQLException e) {    
  112.            e.printStackTrace();    
  113.        } finally {    
  114.            try {    
  115.               if (rs != null)    
  116.                   rs.close();    
  117.                if (pstmt != null)    
  118.                   pstmt.close();    
  119.            } catch (SQLException e) {    
  120.               e.printStackTrace();    
  121.            }    
  122.        }    
  123.     }    
  124.       
  125.     /**  
  126.      * 根据原Sql语句获取对应的查询总记录数的Sql语句  
  127.      * @param sql  
  128.      * @return  
  129.      */    
  130.     private String getCountSql(String sql) {    
  131.        int index = sql.indexOf("from");    
  132.        return "select count(*) " + sql.substring(index);    
  133.     }    
  134.       
  135.     /**  
  136.      * 根据page对象获取对应的分页查询Sql语句,这里只做了两种数据库类型,Mysql和Oracle  
  137.      * 其它的数据库都 没有进行分页  
  138.      *  
  139.      * @param page 分页对象  
  140.      * @param sql 原sql语句  
  141.      * @return  
  142.      */    
  143.     private String getPageSql(Page<?> page, String sql) {    
  144.        StringBuffer sqlBuffer = new StringBuffer(sql);    
  145.        if ("mysql".equalsIgnoreCase(dialect)) {    
  146.            return getMysqlPageSql(page, sqlBuffer);    
  147.        } else if ("oracle".equalsIgnoreCase(dialect)) {    
  148.            return getOraclePageSql(page, sqlBuffer);    
  149.        }    
  150.        return sqlBuffer.toString();    
  151.     }    
  152.       
  153.     /**  
  154.     * 获取Mysql数据库的分页查询语句  
  155.     * @param page 分页对象  
  156.     * @param sqlBuffer 包含原sql语句的StringBuffer对象  
  157.     * @return Mysql数据库分页语句  
  158.     */    
  159.    private String getMysqlPageSql(Page<?> page, StringBuffer sqlBuffer) {    
  160.       //计算第一条记录的位置,Mysql中记录的位置是从0开始的。    
  161. //     System.out.println("page:"+page.getPage()+"-------"+page.getRows());  
  162.       int offset = (page.getPage() - 1) * page.getRows();    
  163.       sqlBuffer.append(" limit ").append(offset).append(",").append(page.getRows());    
  164.       return sqlBuffer.toString();    
  165.    }    
  166.       
  167.    /**  
  168.     * 获取Oracle数据库的分页查询语句  
  169.     * @param page 分页对象  
  170.     * @param sqlBuffer 包含原sql语句的StringBuffer对象  
  171.     * @return Oracle数据库的分页查询语句  
  172.     */    
  173.    private String getOraclePageSql(Page<?> page, StringBuffer sqlBuffer) {    
  174.       //计算第一条记录的位置,Oracle分页是通过rownum进行的,而rownum是从1开始的    
  175.       int offset = (page.getPage() - 1) * page.getRows() + 1;    
  176.       sqlBuffer.insert(0"select u.*, rownum r from (").append(") u where rownum < ").append(offset + page.getRows());    
  177.       sqlBuffer.insert(0"select * from (").append(") where r >= ").append(offset);    
  178.       //上面的Sql语句拼接之后大概是这个样子:    
  179.       //select * from (select u.*, rownum r from (select * from t_user) u where rownum < 31) where r >= 16    
  180.       return sqlBuffer.toString();    
  181.    }    
  182.      
  183.         
  184.     /**  
  185.      * 拦截器对应的封装原始对象的方法  
  186.      */          
  187.     public Object plugin(Object arg0) {    
  188.         // TODO Auto-generated method stub    
  189.         if (arg0 instanceof StatementHandler) {    
  190.             return Plugin.wrap(arg0, this);    
  191.         } else {    
  192.             return arg0;    
  193.         }   
  194.     }    
  195.     
  196.     /**  
  197.      * 设置注册拦截器时设定的属性  
  198.      */   
  199.     public void setProperties(Properties p) {  
  200.           
  201.     }  
  202.   
  203.     public String getDialect() {  
  204.         return dialect;  
  205.     }  
  206.   
  207.     public void setDialect(String dialect) {  
  208.         this.dialect = dialect;  
  209.     }  
  210.   
  211.     public String getPageSqlId() {  
  212.         return pageSqlId;  
  213.     }  
  214.   
  215.     public void setPageSqlId(String pageSqlId) {  
  216.         this.pageSqlId = pageSqlId;  
  217.     }  
  218.       
  219. }  

xml配置:

[html] view plain copy print?在CODE上查看代码片派生到我的代码片
  1. <!-- MyBatis 接口编程配置  -->  
  2.     <bean class="org.mybatis.spring.mapper.MapperScannerConfigurer">  
  3.         <!-- basePackage指定要扫描的包,在此包之下的映射器都会被搜索到,可指定多个包,包与包之间用逗号或分号分隔-->  
  4.         <property name="basePackage" value="com.yidao.mybatis.dao" />  
  5.         <property name="sqlSessionFactoryBeanName" value="sqlSessionFactory" />  
  6.     </bean>  
  7.       
  8.     <!-- MyBatis 分页拦截器-->  
  9.     <bean id="paginationInterceptor" class="com.mybatis.interceptor.PageInterceptor">  
  10.         <property name="dialect" value="mysql"/>   
  11.         <!-- 拦截Mapper.xml文件中,id包含query字符的语句 -->   
  12.         <property name="pageSqlId" value=".*query$"/>  
  13.     </bean>   


Page类


[java] view plain copy print?在CODE上查看代码片派生到我的代码片
  1. package com.yidao.utils;  
  2.   
  3.   
  4. /**自己看看,需要什么字段加什么字段吧*/  
  5. public class Page {  
  6.       
  7.     private Integer rows;  
  8.       
  9.     private Integer page = 1;  
  10.       
  11.     private Integer totalRecord;  
  12.   
  13.     public Integer getRows() {  
  14.         return rows;  
  15.     }  
  16.   
  17.     public void setRows(Integer rows) {  
  18.         this.rows = rows;  
  19.     }  
  20.   
  21.     public Integer getPage() {  
  22.         return page;  
  23.     }  
  24.   
  25.     public void setPage(Integer page) {  
  26.         this.page = page;  
  27.     }  
  28.   
  29.     public Integer getTotalRecord() {  
  30.         return totalRecord;  
  31.     }  
  32.   
  33.     public void setTotalRecord(Integer totalRecord) {  
  34.         this.totalRecord = totalRecord;  
  35.     }  
  36.       
  37. }  



ReflectHelper类

[java] view plain copy print?在CODE上查看代码片派生到我的代码片
  1. package com.yidao.utils;  
  2.   
  3. import java.lang.reflect.Field;  
  4.   
  5. import org.apache.commons.lang3.reflect.FieldUtils;  
  6.   
  7. public class ReflectHelper {  
  8.       
  9.     public static Object getFieldValue(Object obj , String fieldName ){  
  10.           
  11.         if(obj == null){  
  12.             return null ;  
  13.         }  
  14.           
  15.         Field targetField = getTargetField(obj.getClass(), fieldName);  
  16.           
  17.         try {  
  18.             return FieldUtils.readField(targetField, obj, true ) ;  
  19.         } catch (IllegalAccessException e) {  
  20.             e.printStackTrace();  
  21.         }   
  22.         return null ;  
  23.     }  
  24.       
  25.     public static Field getTargetField(Class<?> targetClass, String fieldName) {  
  26.         Field field = null;  
  27.   
  28.         try {  
  29.             if (targetClass == null) {  
  30.                 return field;  
  31.             }  
  32.   
  33.             if (Object.class.equals(targetClass)) {  
  34.                 return field;  
  35.             }  
  36.   
  37.             field = FieldUtils.getDeclaredField(targetClass, fieldName, true);  
  38.             if (field == null) {  
  39.                 field = getTargetField(targetClass.getSuperclass(), fieldName);  
  40.             }  
  41.         } catch (Exception e) {  
  42.         }  
  43.   
  44.         return field;  
  45.     }  
  46.       
  47.     public static void setFieldValue(Object obj , String fieldName , Object value ){  
  48.         if(null == obj){return;}  
  49.         Field targetField = getTargetField(obj.getClass(), fieldName);    
  50.         try {  
  51.              FieldUtils.writeField(targetField, obj, value) ;  
  52.         } catch (IllegalAccessException e) {  
  53.             e.printStackTrace();  
  54.         }   
  55.     }   

0 0
原创粉丝点击