Skip to content

教程:使数据范围权限DataScope支持Bean Searcher #246

@zengyufei

Description

@zengyufei
  1. 引入依赖
<dependency>
	<groupId>cn.zhxu</groupId>
	<artifactId>bean-searcher-boot-starter</artifactId>
	<version>4.1.2</version>
</dependency>
  1. 实现 SqlInterceptor
@RequiredArgsConstructor
@Intercepts({
		@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
public class DataPermissionInterceptor implements Interceptor, SqlInterceptor {

	private final DataScopeSqlProcessor dataScopeSqlProcessor;

	private final DataPermissionHandler dataPermissionHandler;

	@Override
	public Object intercept(Invocation invocation) throws Throwable {
            ......
	}

	@Override
	public <T> SearchSql<T> intercept(SearchSql<T> searchSql, Map<String, Object> map, FetchType fetchType) {
		String hashCode;
		final boolean shouldQueryCluster = fetchType.shouldQueryCluster();
		final boolean shouldQueryList = fetchType.shouldQueryList();
		String sqlString = shouldQueryCluster ? searchSql.getClusterSqlString() : searchSql.getListSqlString();
		hashCode = Convert.toStr(sqlString.hashCode());

		// 获取当前需要控制的 dataScope 集合
		List<DataScope> filterDataScopes = dataPermissionHandler.filterDataScopes(hashCode);
		if (filterDataScopes == null || filterDataScopes.isEmpty()) {
			return searchSql;
		}

		// 根据用户权限判断是否需要拦截,例如管理员可以查看所有,则直接放行
		if (dataPermissionHandler.ignorePermissionControl(filterDataScopes, hashCode)) {
			return searchSql;
		}

		if (shouldQueryCluster) {
			// 创建 matchNumTreadLocal
			DataScopeMatchNumHolder.initMatchNum();
			try {
				final String countSql = searchSql.getClusterSqlString();
				searchSql.setClusterSqlString(dataScopeSqlProcessor.parserSingle(countSql, filterDataScopes));

				// 根据 DataScopes 进行数据权限的 sql 处理
				// 如果解析后发现当前 hashCode 对应的 sql,没有任何数据权限匹配,则记录下来,后续可以直接跳过不解析
				Integer matchNum = DataScopeMatchNumHolder.pollMatchNum();
				List<DataScope> allDataScopes = dataPermissionHandler.dataScopes();
				if (allDataScopes.size() == filterDataScopes.size() && matchNum != null && matchNum == 0) {
					MappedStatementIdsWithoutDataScope.addToWithoutSet(filterDataScopes, hashCode);
				}
			} finally {
				DataScopeMatchNumHolder.removeIfEmpty();
			}

		}
		if (shouldQueryList) {
			// 创建 matchNumTreadLocal
			DataScopeMatchNumHolder.initMatchNum();
			try {
				final String listSql = searchSql.getListSqlString();
				searchSql.setListSqlString(dataScopeSqlProcessor.parserSingle(listSql, filterDataScopes));

				// 根据 DataScopes 进行数据权限的 sql 处理
				// 如果解析后发现当前 hashCode 对应的 sql,没有任何数据权限匹配,则记录下来,后续可以直接跳过不解析
				Integer matchNum = DataScopeMatchNumHolder.pollMatchNum();
				List<DataScope> allDataScopes = dataPermissionHandler.dataScopes();
				if (allDataScopes.size() == filterDataScopes.size() && matchNum != null && matchNum == 0) {
					MappedStatementIdsWithoutDataScope.addToWithoutSet(filterDataScopes, hashCode);
				}
			} finally {
				DataScopeMatchNumHolder.removeIfEmpty();
			}
		}
		return searchSql;
	}

	@Override
	public Object plugin(Object target) {
		if (target instanceof StatementHandler) {
			return Plugin.wrap(target, this);
		}
		return target;
	}

}
  1. 解决 Bean Searcher 不支持事务问题
@AutoConfiguration
@RequiredArgsConstructor
@ConditionalOnBean(DataScope.class)
public class DataScopeAutoConfiguration {
    /**
    * 使 beanSearcher 支持事务
    * */
    @Bean
    @Primary
    public SqlExecutor regMyDefaultSqlExecutor(@Autowired DataSource dataSource, ObjectProvider<SqlExecutor.SlowListener> slowListener, BeanSearcherProperties config) {
		MyDefaultSqlExecutor executor = new MyDefaultSqlExecutor(dataSource);
		ifAvailable(slowListener, executor::setSlowListener);
		executor.setSlowSqlThreshold(config.getSql().getSlowSqlThreshold());
		return executor;
    }

	private <T> void ifAvailable(ObjectProvider<T> provider, Consumer<T> consumer) {
		// 为了兼容 1.x 的 SpringBoot,最低兼容到 v1.4
		// 不直接使用 ObjectProvider.ifAvailable 方法
		T dependency = provider.getIfAvailable();
		if (dependency != null) {
			consumer.accept(dependency);
		}
	}
}

MyDefaultSqlExecutor 类

package com.hccake.ballcat.common.datascope;

import cn.zhxu.bs.BeanMeta;
import cn.zhxu.bs.SearchException;
import cn.zhxu.bs.SearchSql;
import cn.zhxu.bs.SqlExecutor;
import cn.zhxu.bs.SqlResult;
import cn.zhxu.bs.bean.SearchBean;
import cn.zhxu.bs.implement.DefaultSqlExecutor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.jdbc.datasource.DataSourceUtils;

import javax.sql.DataSource;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ConcurrentHashMap;

/**
 * JDBC Sql 执行器
 *
 * @author Troy.Zhou
 * @since 1.1.1
 */
public class MyDefaultSqlExecutor implements SqlExecutor {

    protected static final Logger log = LoggerFactory.getLogger(DefaultSqlExecutor.class);

    /**
     * 默认数据源
     */
    private DataSource dataSource;

    /**
     * 具名数据源
     *
     * @since v3.0.0
     */
    private final Map<String, DataSource> dataSourceMap = new ConcurrentHashMap<>();

    /**
     * 是否使用只读事务
     */
    private boolean transactional = false;

    /**
     * 慢 SQL 阈值(单位:毫秒),默认:500 毫秒
     *
     * @since v3.7.0
     */
    private long slowSqlThreshold = 500;

    /**
     * 慢 SQL 监听器
     *
     * @since v3.7.0
     */
    private SlowListener slowListener;


    public MyDefaultSqlExecutor() {
    }

    public MyDefaultSqlExecutor(DataSource dataSource) {
        this.dataSource = dataSource;
    }


    @Override
    public <T> SqlResult<T> execute(SearchSql<T> searchSql) {
        if (!searchSql.isShouldQueryList() && !searchSql.isShouldQueryCluster()) {
            return new SqlResult<>(searchSql);
        }
        Connection connection;
        try {
            connection = getConnection(searchSql.getBeanMeta());
        } catch (SQLException e) {
            throw new SearchException("Can not get connection from dataSource!", e);
        }
        try {
            return doExecute(searchSql, connection);
        } catch (SQLException e) {
            // 如果有异常,则立马关闭,否则与 SqlResult 一起关闭
//			closeQuietly(connection);
            DataSourceUtils.releaseConnection(connection, dataSource);
            throw new SearchException("A exception occurred when executing sql.", e);
        }
    }

    protected Connection getConnection(BeanMeta<?> beanMeta) throws SQLException {
        String name = beanMeta.getDataSource();
        if (name == null || "".equals(name)) {
            final DataSource dataSource = this.getDataSource();
            if (dataSource == null) {
                throw new SearchException("There is not a default dataSource for " + beanMeta.getBeanClass());
            }
            return DataSourceUtils.doGetConnection(dataSource);
        }
        DataSource dataSource = this.getDataSourceMap().get(name);
        if (dataSource == null) {
            throw new SearchException("There is not a dataSource named " + name + " for " + beanMeta.getBeanClass());
        }
        return DataSourceUtils.doGetConnection(dataSource);
    }


    protected <T> SqlResult<T> doExecute(SearchSql<T> searchSql, Connection connection) throws SQLException {
        final boolean readOnly = connection.isReadOnly();
//		if (transactional) {
//			connection.setAutoCommit(false);
//			connection.setTransactionIsolation(transactionIsolation);
//			connection.setReadOnly(true);
//		}
        SqlResult.ResultSet listResult = null;
        SqlResult.Result clusterResult = null;
        try {
            Number totalCount = null;
            if (searchSql.isShouldQueryCluster()) {
                clusterResult = executeClusterSql(searchSql, connection);
                String countAlias = searchSql.getCountAlias();
                if (countAlias != null) {
                    totalCount = (Number) clusterResult.get(countAlias);
                }
            }
            if (searchSql.isShouldQueryList()) {
                if (totalCount == null || totalCount.longValue() > 0) {
                    listResult = executeListSql(searchSql, connection);
                } else {
                    listResult = SqlResult.ResultSet.EMPTY;
                }
            }
        } catch (SQLException e) {
            closeQuietly(clusterResult);
            throw e;
        }
        return new SqlResult<T>(searchSql, listResult, clusterResult) {
            @Override
            public void close() {
                try {
                    super.close();
                } finally {
//					closeQuietly(connection);
                    DataSourceUtils.releaseConnection(connection, dataSource);
                }
            }
        };
    }

    protected SqlResult.ResultSet executeListSql(SearchSql<?> searchSql, Connection connection) throws SQLException {
        Result result = executeQuery(connection, searchSql.getListSqlString(),
                searchSql.getListSqlParams(), searchSql.getBeanMeta());
        ResultSet resultSet = result.resultSet;
        return new SqlResult.ResultSet() {
            @Override
            public boolean next() throws SQLException {
                return resultSet.next();
            }

            @Override
            public Object get(String columnLabel) throws SQLException {
                return resultSet.getObject(columnLabel);
            }

            @Override
            public void close() {
                result.close();
            }
        };
    }

    protected SqlResult.Result executeClusterSql(SearchSql<?> searchSql, Connection connection) throws SQLException {
        Result result = executeQuery(connection, searchSql.getClusterSqlString(),
                searchSql.getClusterSqlParams(), searchSql.getBeanMeta());
        ResultSet resultSet = result.resultSet;
        boolean hasValue;
        try {
            hasValue = resultSet.next();
        } catch (SQLException e) {
            result.close();
            throw e;
        }
        return new SqlResult.Result() {
            @Override
            public Object get(String columnLabel) throws SQLException {
                if (hasValue) {
                    return resultSet.getObject(columnLabel);
                }
                return null;
            }

            @Override
            public void close() {
                result.close();
            }
        };
    }

    protected static class Result {

        final PreparedStatement statement;
        final ResultSet resultSet;

        public Result(PreparedStatement statement, ResultSet resultSet) {
            this.statement = statement;
            this.resultSet = resultSet;
        }

        public void close() {
            closeQuietly(resultSet);
            closeQuietly(statement);
        }

    }

    protected Result executeQuery(Connection connection, String sql, List<Object> params,
                                  BeanMeta<?> beanMeta) throws SQLException {
        PreparedStatement statement = connection.prepareStatement(sql);
        int size = params.size();
        for (int i = 0; i < size; i++) {
            statement.setObject(i + 1, params.get(i));
        }
        long t0 = System.currentTimeMillis();
        try {
            int timeout = beanMeta.getTimeout();
            if (timeout > 0) {
                // 这个方法比较耗时,只在 timeout 大于 0 的情况下才调用它
                statement.setQueryTimeout(timeout);
            }
            ResultSet resultSet = statement.executeQuery();
            return new Result(statement, resultSet);
        } catch (SQLException e) {
            closeQuietly(statement);
            throw e;
        } finally {
            long cost = System.currentTimeMillis() - t0;
            afterExecute(beanMeta, sql, params, cost);
        }
    }

    protected void afterExecute(BeanMeta<?> beanMeta, String sql, List<Object> params, long timeCost) {
        if (timeCost >= slowSqlThreshold) {
            Class<?> beanClass = beanMeta.getBeanClass();
            SlowListener listener = slowListener;
            if (listener != null) {
                listener.onSlowSql(beanClass, sql, params, timeCost);
            }
            log.warn("bean-searcher [{}ms] slow-sql: [{}] params: {} on [{}]", timeCost, sql, params, beanClass.getName());
        } else {
            log.debug("bean-searcher [{}ms] sql: [{}] params: {}", timeCost, sql, params);
        }
    }

    protected static void closeQuietly(AutoCloseable resource) {
        try {
            if (resource != null) {
                resource.close();
            }
        } catch (Exception e) {
            log.error("Can not close {}", resource.getClass().getSimpleName(), e);
        }
    }

    /**
     * 设置默认数据源
     *
     * @param dataSource 数据源
     */
    public void setDataSource(DataSource dataSource) {
        this.dataSource = Objects.requireNonNull(dataSource);
    }

    public DataSource getDataSource() {
        return dataSource;
    }

    /**
     * 设置具名数据源
     *
     * @param name       数据源名称
     * @param dataSource 数据源
     * @see SearchBean#dataSource()
     * @since v3.1.0
     */
    public void setDataSource(String name, DataSource dataSource) {
        if (name != null && dataSource != null) {
            dataSourceMap.put(name.trim(), dataSource);
        }
    }

    public Map<String, DataSource> getDataSourceMap() {
        return dataSourceMap;
    }

    /**
     * 设置是否使用只读事务
     *
     * @param transactional 是否使用事务
     */
    public void setTransactional(boolean transactional) {
        this.transactional = transactional;
    }

    public boolean isTransactional() {
        return transactional;
    }

    public long getSlowSqlThreshold() {
        return slowSqlThreshold;
    }

    /**
     * 设置慢 SQL 阈值(最小慢 SQL 执行时间)
     *
     * @param slowSqlThreshold 慢 SQL 阈值,单位:毫秒
     * @since v3.7.0
     */
    public void setSlowSqlThreshold(long slowSqlThreshold) {
        this.slowSqlThreshold = slowSqlThreshold;
    }

    public SlowListener getSlowListener() {
        return slowListener;
    }

    public void setSlowListener(SlowListener slowListener) {
        this.slowListener = slowListener;
    }

}
  1. 可选: 简化前端多值参数传递支持 XX,YY,ZZ
/**
     * 为了简化多值参数传递,不是必须的
     * 参考:https://github.com/troyzhxu/bean-searcher/issues/10
     *
     * @return 参数过滤器
     */
    @Bean
    public ParamFilter myParamFilter(BeanSearcherProperties config) {
        final BeanSearcherProperties.Params configParams = config.getParams();
        final String separator = configParams.getSeparator();
        final String operatorKey = configParams.getOperatorKey();
        return new ParamFilter() {

            final String OP_SUFFIX = separator + operatorKey;

            @Override
            public <T> Map<String, Object> doFilter(BeanMeta<T> beanMeta, Map<String, Object> paraMap) {
                Map<String, Object> newParaMap = new HashMap<>();
                paraMap.forEach((key, value) -> {
                    if (key == null) {
                        return;
                    }
                    boolean isOpKey = key.endsWith(OP_SUFFIX);
                    String opKey = isOpKey ? key : key + OP_SUFFIX;
                    Object opVal = paraMap.get(opKey);
                    if (!Arrays.asList("mv", "il", "bt", "nb", "ol", "ni").contains(StrUtil.trim((CharSequence) opVal))) {
                        newParaMap.put(key, value);
                        return;
                    }
                    if (newParaMap.containsKey(key)) {
                        return;
                    }
                    String valKey = key;
                    Object valVal = value;
                    if (isOpKey) {
                        valKey = key.substring(0, key.length() - OP_SUFFIX.length());
                        valVal = paraMap.get(valKey);
                    }
                    if (strContainDou(valVal)) {
                        try {
                            final List<String> split = StrUtil.split((String) valVal, ",");
                            for (int i = 0; i < split.size(); i++) {
                                final String v = split.get(i);
                                newParaMap.put(valKey + separator + i, StrUtil.trim(v));
                            }
                            newParaMap.put(opKey, opVal);
                            return;
                        } catch (Exception ignore) {
                        }
                    }
                    newParaMap.put(key, value);
                });
                return newParaMap;
            }

            private boolean strContainDou(Object value) {
                if (value instanceof String) {
                    String str = ((String) value).trim();
                    return str.contains(",");
                }
                return false;
            }

        };
    }

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions