数据权限实现(Mybatis拦截器+JSqlParser)

来源:互联网 发布:数据库进销存管理系统 编辑:程序博客网 时间:2024/05/21 19:25

由于本人才疏学浅,刚刚入门。本文章是我在实现数据权限的过程中的学习体会。

总体思想


一、Mybatis拦截器

参考:

Mybatis中文官网

慕课网Mybatis方面视频

SQL解析

引用官网说明:

MyBatis 允许你在已映射语句执行过程中的某一点进行拦截调用。默认情况下,MyBatis 允许使用插件来拦截的方法调用包括:

Executor(update, query, flushStatements, commit, rollback, getTransaction, close, isClosed)

ParameterHandler(getParameterObject, setParameters)

ResultSetHandler(handleResultSets, handleOutputParameters)

StatementHandler(prepare, parameterize, batch, update, query)

通过 MyBatis 提供的强大机制,使用插件是非常简单的,只需实现 Interceptor 接口,并指定了想要拦截的方法签名即可。

Mybatis所提供的功能是Plugin,虽然应译为插件,但是实质就是指的我们所需要使用的拦截器。

方法及参数解析:

1. Interceptor 接口

public interface Interceptor {  Object intercept(Invocation invocation) throws Throwable;  Object plugin(Object target);  void setProperties(Properties properties);}

实现 Interceptor 接口也就是实现intercept,plugin,setProperties这三个方法,其中

intercept方法是我们拦截到对象后所进行操作的位置,也就是我们之后编写逻辑代码的位置。

plugin方法,根据参数可以看出,该方法的作用是拦截我们需要拦截到的对象。

setProperties方法,我们可以通过配置文件中进行properties配置,然后在该方法中读取到配置。

这三个方法的执行顺序: setProperties--->plugin--->intercept

2.intercept方法中的Invocation类的属性

 private Object target;//所拦截到的目标的代理 private Method method;//所拦截目标的具体方法 private Object[] args;//方法的参数


实现interceptor接口

@Intercepts({ @Signature(method = "prepare", type = StatementHandler.class, args = { Connection.class }) })public class MyInterceptor implements Interceptor {@Overridepublic Object intercept(Invocation invocation) throws Throwable {//逻辑代码区return invocation.proceed();}@Overridepublic Object plugin(Object target) {//生成代理对象return Plugin.wrap(target, this);}@Overridepublic void setProperties(Properties properties) {}}

解释:

@intercepts声明该类为拦截器,@signature声明拦截对象。

Mybatis获取Statement是在statementHandler中,因为我们需要拦截的对象应该是Statement,StatementHandler类中有返回值为StatementPrepare方法,所以,这个类就是我们需要拦截的对象。

method为我们需要拦截的prepare方法,type为所要拦截的接口类,argsprepare方法的参数。




源码解析:

StatementHandler源码:

public interface StatementHandler {  Statement prepare(Connection connection)      throws SQLException;  void parameterize(Statement statement)      throws SQLException;  void batch(Statement statement)      throws SQLException;  int update(Statement statement)      throws SQLException;  <E> List<E> query(Statement statement, ResultHandler resultHandler)      throws SQLException;  BoundSql getBoundSql();  ParameterHandler getParameterHandler();}

该源码中的prepare方法为我们需要的拦截的,它的实现为:


实际的实现方法在BaseStatementHandler中:

@Override  public Statement prepare(Connection connection) throws SQLException {    ErrorContext.instance().sql(boundSql.getSql());    Statement statement = null;    try {      statement = instantiateStatement(connection);//<-----也就是这个方法      setStatementTimeout(statement);      setFetchSize(statement);      return statement;    } catch (SQLException e) {      closeStatement(statement);      throw e;    } catch (Exception e) {      closeStatement(statement);      throw new ExecutorException("Error preparing statement.  Cause: " + e, e);    }  }
protected abstract Statement instantiateStatement(Connection connection) throws SQLException;
该方法为抽象方法,它的实现为


由于我们的是预编译的sql,所以就是PreparedStatementHandler类中的实现方法

 @Override  protected Statement instantiateStatement(Connection connection) throws SQLException {    String sql = boundSql.getSql();//<----这就是我们的sql语句    if (mappedStatement.getKeyGenerator() instanceof Jdbc3KeyGenerator) {      String[] keyColumnNames = mappedStatement.getKeyColumns();      if (keyColumnNames == null) {        return connection.prepareStatement(sql, PreparedStatement.RETURN_GENERATED_KEYS);      } else {        return connection.prepareStatement(sql, keyColumnNames);      }    } else if (mappedStatement.getResultSetType() != null) {      return connection.prepareStatement(sql, mappedStatement.getResultSetType().getValue(), ResultSet.CONCUR_READ_ONLY);    } else {      return connection.prepareStatement(sql);    }  }

已经理清了sql的执行逻辑,就可以对拦截到的statementHandler进行操作了。得意

@Overridepublic Object intercept(Invocation invocation) throws Throwable {StatementHandler handler = (StatementHandler)invocation.getTarget();//由于mappedStatement中有我们需要的方法id,但却是protected的,所以要通过反射获取MetaObject statementHandler = SystemMetaObject.forObject(handler);MappedStatement mappedStatement = (MappedStatement) statementHandler.getValue("delegate.mappedStatement");//获取sqlBoundSql boundSql = handler.getBoundSql();String sql = boundSql.getSql();//获取方法idString id = mappedStatement.getId();if ("需要增强的方法的id".equals(id)) {//增强sql代码块}return invocation.proceed();}


在以上操作完成之后不要忘了注册该拦截器

<configuration>    <plugins>       <plugin interceptor ="com.test.interceptor.MyInterceptor"/>    </plugins></configuration>
好了,到此Mybatis拦截器的编写以及配置就到此结束,接下来需要做的就是sql解析方面(JSqlParser)的学习了.

二、JSqlParser

GitHub

1.在项目添加jsqlparser依赖

<dependency><groupId>com.github.jsqlparser</groupId><artifactId>jsqlparser</artifactId><version>1.0</version></dependency>

2.解析sql

先判断sql语句的类型(SELECT,UPDATE,INSERT,DELETE.....)
根据语句类型将sql转化成相应对象

CCJSqlParserManager parserManager = new CCJSqlParserManager();
if ("SELECT".equals(sqlCommandType)) {Select select = (Select)parserManager.parse(new StringReader(sql));}

3.访问各个接口实现类(SelectVisitorImpl为自己实现SelectVisitor的实现类)

总体思想就是将sql语句分割成很多个小部分然后去访问各个visitor实现类.

select.getSelectBody().accept(new SelectVisitorImpl());

SelectVisitorImpl.class:

package com.test.sqlparser.visitor;import net.sf.jsqlparser.expression.Expression;import net.sf.jsqlparser.expression.Parenthesis;import net.sf.jsqlparser.expression.operators.conditional.AndExpression;import net.sf.jsqlparser.statement.select.FromItem;import net.sf.jsqlparser.statement.select.Join;import net.sf.jsqlparser.statement.select.OrderByElement;import net.sf.jsqlparser.statement.select.PlainSelect;import net.sf.jsqlparser.statement.select.SelectBody;import net.sf.jsqlparser.statement.select.SelectItem;import net.sf.jsqlparser.statement.select.SelectVisitor;import net.sf.jsqlparser.statement.select.SetOperationList;import net.sf.jsqlparser.statement.select.WithItem;public class SelectVisitorImpl implements SelectVisitor {// 主要工作就是实现各种底层visitor,然后在解析的时候添加条件// 正常的select,也就是包含全部属性的select@Overridepublic void visit(PlainSelect plainSelect) {// 访问 selectif (plainSelect.getSelectItems() != null) {for (SelectItem item : plainSelect.getSelectItems()) {item.accept(new SelectItemVisitorImpl());}}// 访问fromFromItem fromItem = plainSelect.getFromItem();FromItemVisitorImpl fromItemVisitorImpl = new FromItemVisitorImpl();fromItem.accept(fromItemVisitorImpl);// 访问whereif (plainSelect.getWhere() != null) {plainSelect.getWhere().accept(new ExpressionVisitorImpl());}//过滤增强的条件          if (fromItemVisitorImpl.getEnhancedCondition() != null) {              if (plainSelect.getWhere() != null) {                  Expression expr = new Parenthesis(plainSelect.getWhere());                 Expression enhancedCondition =  new Parenthesis(fromItemVisitorImpl.getEnhancedCondition());                 AndExpression and = new AndExpression(enhancedCondition, expr);                  plainSelect.setWhere(and);              } else {              plainSelect.setWhere(fromItemVisitorImpl.getEnhancedCondition());              }          }  // 访问joinif (plainSelect.getJoins() != null) {for (Join join : plainSelect.getJoins()) {join.getRightItem().accept(new FromItemVisitorImpl());}}// 访问 order byif (plainSelect.getOrderByElements() != null) {for (OrderByElement orderByElement : plainSelect.getOrderByElements()) {orderByElement.getExpression().accept(new ExpressionVisitorImpl());}}// 访问group by havingif (plainSelect.getHaving() != null) {plainSelect.getHaving().accept(new ExpressionVisitorImpl());}}// set操作列表@Overridepublic void visit(SetOperationList setOpList) {for (SelectBody plainSelect : setOpList.getSelects()) {plainSelect.accept(new SelectVisitorImpl());}}// with项@Overridepublic void visit(WithItem withItem) {withItem.getSelectBody().accept(new SelectVisitorImpl());}}
SelectItemVisitorImpl.class

package com.test.sqlparser.visitor;import net.sf.jsqlparser.statement.select.AllColumns;import net.sf.jsqlparser.statement.select.AllTableColumns;import net.sf.jsqlparser.statement.select.SelectExpressionItem;import net.sf.jsqlparser.statement.select.SelectItemVisitor;public class SelectItemVisitorImpl implements SelectItemVisitor {@Overridepublic void visit(AllColumns allColumns) {}@Overridepublic void visit(AllTableColumns allTableColumns) {}@Overridepublic void visit(SelectExpressionItem selectExpressionItem) {selectExpressionItem.getExpression().accept(new ExpressionVisitorImpl());}}
ExpressionVisitorImpl.class

package com.test.sqlparser.visitor;import org.slf4j.Logger;import org.slf4j.LoggerFactory;import net.sf.jsqlparser.expression.AllComparisonExpression;import net.sf.jsqlparser.expression.AnalyticExpression;import net.sf.jsqlparser.expression.AnyComparisonExpression;import net.sf.jsqlparser.expression.BinaryExpression;import net.sf.jsqlparser.expression.CaseExpression;import net.sf.jsqlparser.expression.CastExpression;import net.sf.jsqlparser.expression.DateTimeLiteralExpression;import net.sf.jsqlparser.expression.DateValue;import net.sf.jsqlparser.expression.DoubleValue;import net.sf.jsqlparser.expression.Expression;import net.sf.jsqlparser.expression.ExpressionVisitor;import net.sf.jsqlparser.expression.ExtractExpression;import net.sf.jsqlparser.expression.Function;import net.sf.jsqlparser.expression.HexValue;import net.sf.jsqlparser.expression.IntervalExpression;import net.sf.jsqlparser.expression.JdbcNamedParameter;import net.sf.jsqlparser.expression.JdbcParameter;import net.sf.jsqlparser.expression.JsonExpression;import net.sf.jsqlparser.expression.KeepExpression;import net.sf.jsqlparser.expression.LongValue;import net.sf.jsqlparser.expression.MySQLGroupConcat;import net.sf.jsqlparser.expression.NullValue;import net.sf.jsqlparser.expression.NumericBind;import net.sf.jsqlparser.expression.OracleHierarchicalExpression;import net.sf.jsqlparser.expression.OracleHint;import net.sf.jsqlparser.expression.Parenthesis;import net.sf.jsqlparser.expression.RowConstructor;import net.sf.jsqlparser.expression.SignedExpression;import net.sf.jsqlparser.expression.StringValue;import net.sf.jsqlparser.expression.TimeKeyExpression;import net.sf.jsqlparser.expression.TimeValue;import net.sf.jsqlparser.expression.TimestampValue;import net.sf.jsqlparser.expression.UserVariable;import net.sf.jsqlparser.expression.WhenClause;import net.sf.jsqlparser.expression.WithinGroupExpression;import net.sf.jsqlparser.expression.operators.arithmetic.Addition;import net.sf.jsqlparser.expression.operators.arithmetic.BitwiseAnd;import net.sf.jsqlparser.expression.operators.arithmetic.BitwiseOr;import net.sf.jsqlparser.expression.operators.arithmetic.BitwiseXor;import net.sf.jsqlparser.expression.operators.arithmetic.Concat;import net.sf.jsqlparser.expression.operators.arithmetic.Division;import net.sf.jsqlparser.expression.operators.arithmetic.Modulo;import net.sf.jsqlparser.expression.operators.arithmetic.Multiplication;import net.sf.jsqlparser.expression.operators.arithmetic.Subtraction;import net.sf.jsqlparser.expression.operators.conditional.AndExpression;import net.sf.jsqlparser.expression.operators.conditional.OrExpression;import net.sf.jsqlparser.expression.operators.relational.Between;import net.sf.jsqlparser.expression.operators.relational.EqualsTo;import net.sf.jsqlparser.expression.operators.relational.ExistsExpression;import net.sf.jsqlparser.expression.operators.relational.GreaterThan;import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;import net.sf.jsqlparser.expression.operators.relational.InExpression;import net.sf.jsqlparser.expression.operators.relational.IsNullExpression;import net.sf.jsqlparser.expression.operators.relational.LikeExpression;import net.sf.jsqlparser.expression.operators.relational.Matches;import net.sf.jsqlparser.expression.operators.relational.MinorThan;import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;import net.sf.jsqlparser.expression.operators.relational.RegExpMatchOperator;import net.sf.jsqlparser.expression.operators.relational.RegExpMySQLOperator;import net.sf.jsqlparser.schema.Column;import net.sf.jsqlparser.statement.select.SubSelect;import net.sf.jsqlparser.statement.select.WithItem;public class ExpressionVisitorImpl implements ExpressionVisitor {Logger logger =LoggerFactory.getLogger(ExpressionVisitorImpl.class);// 单表达式@Overridepublic void visit(SignedExpression signedExpression) {signedExpression.accept(new ExpressionVisitorImpl());}// jdbc参数@Overridepublic void visit(JdbcParameter jdbcParameter) {}// jdbc参数@Overridepublic void visit(JdbcNamedParameter jdbcNamedParameter) {}// @Overridepublic void visit(Parenthesis parenthesis) {parenthesis.getExpression().accept(new ExpressionVisitorImpl());}// between@Overridepublic void visit(Between between) {between.getLeftExpression().accept(new ExpressionVisitorImpl());between.getBetweenExpressionStart().accept(new ExpressionVisitorImpl());between.getBetweenExpressionEnd().accept(new ExpressionVisitorImpl());}// in表达式@Overridepublic void visit(InExpression inExpression) {if (inExpression.getLeftExpression() != null) {inExpression.getLeftExpression().accept(new ExpressionVisitorImpl());} else if (inExpression.getLeftItemsList() != null) {inExpression.getLeftItemsList().accept(new ItemsListVisitorImpl());}inExpression.getRightItemsList().accept(new ItemsListVisitorImpl());}// 子查询@Overridepublic void visit(SubSelect subSelect) {if (subSelect.getWithItemsList() != null) {for (WithItem withItem : subSelect.getWithItemsList()) {withItem.accept(new SelectVisitorImpl());}}subSelect.getSelectBody().accept(new SelectVisitorImpl());}// exist@Overridepublic void visit(ExistsExpression existsExpression) {existsExpression.getRightExpression().accept(new ExpressionVisitorImpl());}// allComparisonExpression??@Overridepublic void visit(AllComparisonExpression allComparisonExpression) {allComparisonExpression.getSubSelect().getSelectBody().accept(new SelectVisitorImpl());}// anyComparisonExpression??@Overridepublic void visit(AnyComparisonExpression anyComparisonExpression) {anyComparisonExpression.getSubSelect().getSelectBody().accept(new SelectVisitorImpl());}// oexpr??@Overridepublic void visit(OracleHierarchicalExpression oexpr) {if (oexpr.getStartExpression() != null) {oexpr.getStartExpression().accept(this);}if (oexpr.getConnectExpression() != null) {oexpr.getConnectExpression().accept(this);}}// rowConstructor?@Overridepublic void visit(RowConstructor rowConstructor) {for (Expression expr : rowConstructor.getExprList().getExpressions()) {expr.accept(this);}}// cast@Overridepublic void visit(CastExpression cast) {cast.getLeftExpression().accept(new ExpressionVisitorImpl());}// 加法@Overridepublic void visit(Addition addition) {visitBinaryExpression(addition);}// 除法@Overridepublic void visit(Division division) {visitBinaryExpression(division);}// 乘法@Overridepublic void visit(Multiplication multiplication) {visitBinaryExpression(multiplication);}// 减法@Overridepublic void visit(Subtraction subtraction) {visitBinaryExpression(subtraction);}// and表达式@Overridepublic void visit(AndExpression andExpression) {visitBinaryExpression(andExpression);}// or表达式@Overridepublic void visit(OrExpression orExpression) {visitBinaryExpression(orExpression);}// 等式@Overridepublic void visit(EqualsTo equalsTo) {visitBinaryExpression(equalsTo);}// 大于@Overridepublic void visit(GreaterThan greaterThan) {visitBinaryExpression(greaterThan);}// 大于等于@Overridepublic void visit(GreaterThanEquals greaterThanEquals) {visitBinaryExpression(greaterThanEquals);}// like表达式@Overridepublic void visit(LikeExpression likeExpression) {visitBinaryExpression(likeExpression);}// 小于@Overridepublic void visit(MinorThan minorThan) {visitBinaryExpression(minorThan);}// 小于等于@Overridepublic void visit(MinorThanEquals minorThanEquals) {visitBinaryExpression(minorThanEquals);}// 不等于@Overridepublic void visit(NotEqualsTo notEqualsTo) {visitBinaryExpression(notEqualsTo);}// concat@Overridepublic void visit(Concat concat) {visitBinaryExpression(concat);}// matches?@Overridepublic void visit(Matches matches) {visitBinaryExpression(matches);}// bitwiseAnd位运算?@Overridepublic void visit(BitwiseAnd bitwiseAnd) {visitBinaryExpression(bitwiseAnd);}// bitwiseOr?@Overridepublic void visit(BitwiseOr bitwiseOr) {visitBinaryExpression(bitwiseOr);}// bitwiseXor?@Overridepublic void visit(BitwiseXor bitwiseXor) {visitBinaryExpression(bitwiseXor);}// 取模运算modulo?@Overridepublic void visit(Modulo modulo) {visitBinaryExpression(modulo);}// rexp??@Overridepublic void visit(RegExpMatchOperator rexpr) {visitBinaryExpression(rexpr);}// regExpMySQLOperator??@Overridepublic void visit(RegExpMySQLOperator regExpMySQLOperator) {visitBinaryExpression(regExpMySQLOperator);}// 二元表达式public void visitBinaryExpression(BinaryExpression binaryExpression) {binaryExpression.getLeftExpression().accept(new ExpressionVisitorImpl());binaryExpression.getRightExpression().accept(new ExpressionVisitorImpl());}// -------------------------下面都是没用到的-----------------------------------// aexpr??@Overridepublic void visit(AnalyticExpression aexpr) {}// wgexpr??@Overridepublic void visit(WithinGroupExpression wgexpr) {}// eexpr??@Overridepublic void visit(ExtractExpression eexpr) {}// iexpr??@Overridepublic void visit(IntervalExpression iexpr) {}// jsonExpr??@Overridepublic void visit(JsonExpression jsonExpr) {}// hint?@Overridepublic void visit(OracleHint hint) {}// timeKeyExpression?@Overridepublic void visit(TimeKeyExpression timeKeyExpression) {}// caseExpression?@Overridepublic void visit(CaseExpression caseExpression) {}// when?@Overridepublic void visit(WhenClause whenClause) {}// var??@Overridepublic void visit(UserVariable var) {}// bind?@Overridepublic void visit(NumericBind bind) {}// aexpr?@Overridepublic void visit(KeepExpression aexpr) {}// groupConcat?@Overridepublic void visit(MySQLGroupConcat groupConcat) {}// table列@Overridepublic void visit(Column tableColumn) {}// double类型值@Overridepublic void visit(DoubleValue doubleValue) {}// long类型值@Overridepublic void visit(LongValue longValue) {}// 16进制类型值@Overridepublic void visit(HexValue hexValue) {}// date类型值@Overridepublic void visit(DateValue dateValue) {}// time类型值@Overridepublic void visit(TimeValue timeValue) {}// 时间戳类型值@Overridepublic void visit(TimestampValue timestampValue) {}// 空值@Overridepublic void visit(NullValue nullValue) {}// 方法@Overridepublic void visit(Function function) {}// 字符串类型值@Overridepublic void visit(StringValue stringValue) {}// is null表达式@Overridepublic void visit(IsNullExpression isNullExpression) {}// literal?@Overridepublic void visit(DateTimeLiteralExpression literal) {}}
FromItemVisitorImpl.class

package com.test.sqlparser.visitor;import java.util.ArrayList;import java.util.List;import net.sf.jsqlparser.expression.Expression;import net.sf.jsqlparser.expression.LongValue;import net.sf.jsqlparser.expression.StringValue;import net.sf.jsqlparser.expression.operators.conditional.AndExpression;import net.sf.jsqlparser.expression.operators.conditional.OrExpression;import net.sf.jsqlparser.expression.operators.relational.Between;import net.sf.jsqlparser.expression.operators.relational.EqualsTo;import net.sf.jsqlparser.expression.operators.relational.GreaterThan;import net.sf.jsqlparser.expression.operators.relational.GreaterThanEquals;import net.sf.jsqlparser.expression.operators.relational.IsNullExpression;import net.sf.jsqlparser.expression.operators.relational.LikeExpression;import net.sf.jsqlparser.expression.operators.relational.MinorThan;import net.sf.jsqlparser.expression.operators.relational.MinorThanEquals;import net.sf.jsqlparser.expression.operators.relational.NotEqualsTo;import net.sf.jsqlparser.schema.Column;import net.sf.jsqlparser.schema.Table;import net.sf.jsqlparser.statement.select.FromItemVisitor;import net.sf.jsqlparser.statement.select.LateralSubSelect;import net.sf.jsqlparser.statement.select.SubJoin;import net.sf.jsqlparser.statement.select.SubSelect;import net.sf.jsqlparser.statement.select.TableFunction;import net.sf.jsqlparser.statement.select.ValuesList;import org.slf4j.Logger;import org.slf4j.LoggerFactory;import com.test.entity.TableCondition;import com.test.security.UserUtils;public class FromItemVisitorImpl implements FromItemVisitor {private static Logger logger = LoggerFactory.getLogger(FromItemVisitorImpl.class);// 声明增强条件private Expression enhancedCondition;// FROM 表名 <----主要的就是这个,判断用户对这个表有没有权限@Overridepublic void visit(Table tableName) {//判断该表是否是需要操作的表if (isActionTable(tableName.getFullyQualifiedName())) {//根据表名获取该用户对于该表的限制条件List<TableCondition> test = UserUtils.getTableCondition(tableName.getFullyQualifiedName().toUpperCase());//If the TableConditionList is existif (test!=null) {//增强sqlfor (TableCondition tableCondition : test) {// 声明表达式数组Expression[] expressions;// 如果操作符是betweenif ("between".equalsIgnoreCase(tableCondition.getOperator())|| "not between".equalsIgnoreCase(tableCondition.getOperator())) {//expressions = new Expression[] { new LongValue(tableCondition.getFieldName()),new LongValue(tableCondition.getOperator()),new LongValue(tableCondition.getFieldValue()) };} else if ("is null".equalsIgnoreCase(tableCondition.getOperator())|| "is not null".equalsIgnoreCase(tableCondition.getOperator())) {// 如果操作符是 is null或者是is not null的时候//expressions = new Expression[] { new LongValue(tableCondition.getFieldName()) };} else {// 其他情况,也就是最常用的情况,比如where   1 = 1Column column = new Column(new Table(tableName.getAlias()!=null?tableName.getAlias().getName():tableName.getFullyQualifiedName()), tableCondition.getFieldName());if ("1".equals(tableCondition.getFieldName())) {expressions = new Expression[] {new LongValue(tableCondition.getFieldName()),new LongValue(tableCondition.getFieldValue())};}else{expressions = new Expression[] {column,new StringValue(tableCondition.getFieldValue())};}}// 根据运算符对原始数据进行拼接Expression operator = this.getOperator(tableCondition.getOperator(), expressions);if (this.enhancedCondition != null) {enhancedCondition = new AndExpression(enhancedCondition , operator);} else {enhancedCondition = operator;}}}}}// FROM 子查询@Overridepublic void visit(SubSelect subSelect) {// 如果是子查询的话返回到select接口实现类subSelect.getSelectBody().accept(new SelectVisitorImpl());}// FROM subjoin@Overridepublic void visit(SubJoin subjoin) {subjoin.getLeft().accept(new FromItemVisitorImpl());subjoin.getJoin().getRightItem().accept(new FromItemVisitorImpl());}// FROM 横向子查询 @Overridepublic void visit(LateralSubSelect lateralSubSelect) {lateralSubSelect.getSubSelect().getSelectBody().accept(new SelectVisitorImpl());}// FROM value列表@Overridepublic void visit(ValuesList valuesList) {}// FROM tableFunction@Overridepublic void visit(TableFunction tableFunction) {}// 将字符串类型的运算符转换成数据库运算语句private Expression getOperator(String op, Expression[] exp) {if ("=".equals(op)) {EqualsTo eq = new EqualsTo();eq.setLeftExpression(exp[0]);eq.setRightExpression(exp[1]);return eq;} else if (">".equals(op)) {GreaterThan gt = new GreaterThan();gt.setLeftExpression(exp[0]);gt.setRightExpression(exp[1]);return gt;} else if (">=".equals(op)) {GreaterThanEquals geq = new GreaterThanEquals();geq.setLeftExpression(exp[0]);geq.setRightExpression(exp[1]);return geq;} else if ("<".equals(op)) {MinorThan mt = new MinorThan();mt.setLeftExpression(exp[0]);mt.setRightExpression(exp[1]);return mt;} else if ("<=".equals(op)) {MinorThanEquals leq = new MinorThanEquals();leq.setLeftExpression(exp[0]);leq.setRightExpression(exp[1]);return leq;} else if ("<>".equals(op)) {NotEqualsTo neq = new NotEqualsTo();neq.setLeftExpression(exp[0]);neq.setRightExpression(exp[1]);return neq;} else if ("is null".equalsIgnoreCase(op)) {IsNullExpression isNull = new IsNullExpression();isNull.setNot(false);isNull.setLeftExpression(exp[0]);return isNull;} else if ("is not null".equalsIgnoreCase(op)) {IsNullExpression isNull = new IsNullExpression();isNull.setNot(true);isNull.setLeftExpression(exp[0]);return isNull;} else if ("like".equalsIgnoreCase(op)) {LikeExpression like = new LikeExpression();like.setNot(false);like.setLeftExpression(exp[0]);like.setRightExpression(exp[1]);return like;} else if ("not like".equalsIgnoreCase(op)) {LikeExpression nlike = new LikeExpression();nlike.setNot(true);nlike.setLeftExpression(exp[0]);nlike.setRightExpression(exp[1]);return nlike;} else if ("between".equalsIgnoreCase(op)) {Between bt = new Between();bt.setNot(false);bt.setLeftExpression(exp[0]);bt.setBetweenExpressionStart(exp[1]);bt.setBetweenExpressionEnd(exp[2]);return bt;} else if ("not between".equalsIgnoreCase(op)) {Between bt = new Between();bt.setNot(true);bt.setLeftExpression(exp[0]);bt.setBetweenExpressionStart(exp[1]);bt.setBetweenExpressionEnd(exp[2]);return bt;} else {// 如果没有该运算符对应的语句return null;}}public Expression getEnhancedCondition() {return enhancedCondition;}// 判断传入的table是否是要进行操作的tablepublic boolean isActionTable(String tableName) {// 默认为操作boolean flag = true;// 无需操作的表的表名List<String> tableNames = new ArrayList<String>();// 由于sql可能格式不规范可能表名会存在小写,故全部转换成大写,最上面的方法一样if (tableNames.contains(tableName.toUpperCase())) {// 如果表名在过滤条件中则将flag改为flaseflag = false;}return flag;}}


完整的拦截器代码

package com.test.interceptor;import java.io.StringReader;import java.sql.Connection;import java.util.Properties;import net.sf.jsqlparser.parser.CCJSqlParserManager;import net.sf.jsqlparser.statement.Statement;import net.sf.jsqlparser.statement.select.Select;import org.apache.ibatis.executor.statement.StatementHandler;import org.apache.ibatis.mapping.BoundSql;import org.apache.ibatis.mapping.MappedStatement;import org.apache.ibatis.mapping.SqlCommandType;import org.apache.ibatis.plugin.Interceptor;import org.apache.ibatis.plugin.Intercepts;import org.apache.ibatis.plugin.Invocation;import org.apache.ibatis.plugin.Plugin;import org.apache.ibatis.plugin.Signature;import org.apache.ibatis.reflection.MetaObject;import org.apache.ibatis.reflection.SystemMetaObject;import com.test.sqlparser.visitor.SelectVisitorImpl;@Intercepts({ @Signature(method = "prepare", type = StatementHandler.class, args = { Connection.class }) })public class MyInterceptor implements Interceptor {CCJSqlParserManager parserManager = new CCJSqlParserManager();@Overridepublic Object intercept(Invocation invocation) throws Throwable {StatementHandler handler = (StatementHandler)invocation.getTarget();//由于mappedStatement为protected的,所以要通过反射获取MetaObject statementHandler = SystemMetaObject.forObject(handler);//mappedStatement中有我们需要的方法idMappedStatement mappedStatement = (MappedStatement) statementHandler.getValue("delegate.mappedStatement");//获取sqlBoundSql boundSql = handler.getBoundSql();String sql = boundSql.getSql();//获取方法idString id = mappedStatement.getId();//获得方法类型SqlCommandType sqlCommandType = mappedStatement.getSqlCommandType();if ("需要增强的方法的id".equals(id)) {//增强sql代码块if ("SELECT".equals(sqlCommandType)) {//如果是select就将sql转成SELECT对象Select select = (Select)parserManager.parse(new StringReader(sql));//访问各个visitorselect.getSelectBody().accept(new SelectVisitorImpl());//将增强后的sql放回statementHandler.setValue("delegate.boundSql.sql",select.toString());}}return invocation.proceed();}@Overridepublic Object plugin(Object target) {return Plugin.wrap(target, this);}@Overridepublic void setProperties(Properties properties) {}}






















6 0
原创粉丝点击