Spring手动扫描包路径及容器外获取Bean实例

来源:互联网 发布:中金数据系统 编辑:程序博客网 时间:2024/05/23 02:04

最近做的项目有一个需求,希望开放指定包下的Controller给其他应用调用,但需要验证其许可。
解决方案:定义一个Filter,在init初始化方法内扫描指定包下的所有Controller,生成开放URL集合;在doFilter方法内对请求参数校验(加盐MD5生成)

方案用到了两个工具类,第一个是HttpServletRequest的包装类,主要是为了解决RequestBody两次读取的问题。正常情况下HttpServletRequest的流只能被读取一次

/** * 自定义的Http请求封装类,解决RequestBody只能读取一次问题 * * @create 2017-12-11 17:18 */public class HttpRequestTwiceReadingWrapper extends HttpServletRequestWrapper {    private byte[] requestBody = null;    public HttpRequestTwiceReadingWrapper(HttpServletRequest request) {        super(request);        //缓存请求body        try {            requestBody = StreamUtils.copyToByteArray(request.getInputStream());        } catch (IOException e) {            e.printStackTrace();        }    }    /**     * 重写 getInputStream()     */    @Override    public ServletInputStream getInputStream() throws IOException {        if (requestBody == null) {            requestBody = new byte[0];        }        InputStream bis = new ByteArrayInputStream(requestBody);        return new ServletInputStream() {            @Override            public boolean isFinished() {                return false;            }            @Override            public boolean isReady() {                return true;            }            @Override            public void setReadListener(ReadListener readListener) {            }            @Override            public int read() throws IOException {                return bis.read();            }        };    }    /**     * 重写 getReader()     */    @Override    public BufferedReader getReader() throws IOException {        return new BufferedReader(new InputStreamReader(getInputStream()));    }}

第二个工具类是Spring外获取容器内的Bean的类。

/** * Spring容器工具类,用于容器外对象获取容器内的Bean * * @create 2017-12-12 14:19 */ @Componentpublic class SpringBeanInstanceAccessor implements BeanFactoryAware {    //@Autowired不支持static属性注入,只能用实现指定接口的形式    private static BeanFactory factory;    @Override    public void setBeanFactory(BeanFactory beanFactory) throws BeansException {        factory = beanFactory;    }    /**     * 获取指定名称的Bean     *     * @param beanName     * @param clazz     * @param <T>     * @return     */    public static <T> Object getBean(String beanName, Class<T> clazz) {        return factory.getBean(beanName, clazz);    }    /**     * 获取指定类型的Bean     *     * @param clazz     * @param <T>     * @return     */    public static <T> Object getBean(Class<T> clazz) {        return factory.getBean(clazz);    }}

最后一个就是Filter了,扫描指定包下的Controller生成URL集合,获取RequestBody内的参数,通过Spring的Mapper查询参数指定的key,组合key+”,”+Referer+”,”+RequestBody,然后通过加盐MD5生成token,然后和HttpHeader内的token验证。

/** * 请求许可验证过滤器 * [token]和[timestamp]由HttpHeader传入 * * @create 2017-12-08 11:26 */public class CrosRequestPermitCheckingFilter implements Filter {    private static final Logger LOGGER = LoggerFactory.getLogger(CrosRequestPermitCheckingFilter.class);    private static final String RESOURCE_PATTERN = "**/*.class";    private final List<TypeFilter> includeFilters = new LinkedList<TypeFilter>();    private final List<TypeFilter> excludeFilters = new LinkedList<TypeFilter>();    private static final String APPLICATION_NAMESPACE = "";    private static final Integer PERMIT_VALIDITY_IN_MINUTE = 5;    private static final List<String> EXPOSED_CONTROLLER_PACKAGES = Arrays.asList("com.bob.mvc.controller");    private Set<String> exposedRequestUriSet = new LinkedHashSet<String>();    /**     * 初始化,扫描开放的Controller包,生成开放的URL集合     *     * @param filterConfig     * @throws ServletException     */    @Override    public void init(FilterConfig filterConfig) throws ServletException {        includeFilters.add(new AnnotationTypeFilter(RequestMapping.class, false));        List<String> controllers = new ArrayList<String>();        try {            for (String pkg : EXPOSED_CONTROLLER_PACKAGES) {                controllers.addAll(this.findCandidateControllers(pkg));            }            if (controllers.isEmpty()) {                if (LOGGER.isWarnEnabled()) {                    LOGGER.warn("扫描指定包{}时未发现符合的开放Controller类", EXPOSED_CONTROLLER_PACKAGES.toString());                }                return;            }            generateExposedURL(this.transformToClass(controllers), APPLICATION_NAMESPACE);            if (LOGGER.isDebugEnabled()) {                LOGGER.debug("扫描指定Controller包,发现开放URL:{}", exposedRequestUriSet.toString());            }        } catch (Exception e) {            LOGGER.error("扫描开放Controller出现异常", e);            return;        }    }    @Override    public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {        HttpServletRequest request = new HttpRequestTwiceReadingWrapper((HttpServletRequest)servletRequest);        String path = request.getRequestURI();        if (path.endsWith("/")) {            path = path.substring(0, path.length() - 1);        }        if (!exposedRequestUriSet.contains(path)) {            filterChain.doFilter(servletRequest, servletResponse);            return;        }        try {            processCrosRequestPermitCkecking(request);        } catch (IllegalArgumentException | IllegalStateException e) {            writeResult(servletResponse, e.getMessage());        }        filterChain.doFilter(request, servletResponse);    }    /**     * 处理跨域请求许可验证     *     * @param request     * @throws Exception     */    private void processCrosRequestPermitCkecking(HttpServletRequest request) throws IOException {        String timestamp = request.getHeader("timestamp");        Assert.hasText(timestamp, "跨域请求未指定[timestamp]");        Assert.state(isNumber(timestamp), "[timestamp]不是一个有效的时间戳");        Assert.state(Long.valueOf(timestamp) >= System.currentTimeMillis() - 1000 * 60 * PERMIT_VALIDITY_IN_MINUTE, "跨域请求许可已过期");        String referer = request.getHeader("Referer");        String requestBody = getRequestBodyInString(request);        String appcode = getTargetProperty(requestBody, "appcode");        Assert.hasText(appcode, "跨域请求[appcode]不存在");        String campusId = getTargetProperty(requestBody, "campusId");        Assert.isTrue(isNumber(campusId), "跨域请求[campusId]不正确");        String key = selectKey(appcode, campusId);        Assert.notNull(key, "跨域请求[appcode]和[campusId]相应的key不存在");        Assert.state(verify(key + "," + referer + "," + requestBody, request.getHeader("token")), "跨域请求[token]和参数不匹配");    }    /**     * 写入错误结果     *     * @param servletResponse     * @param result     * @throws IOException     */    private void writeResult(ServletResponse servletResponse, String result) throws IOException {        servletResponse.getOutputStream().write(result.getBytes("UTF-8"));    }    /**     * 如果RequestBody内没有数据,则返回""     *     * @param request     * @return     * @throws IOException     */    private String getRequestBodyInString(HttpServletRequest request) throws IOException {        InputStream is = request.getInputStream();        byte[] bytes = new byte[2048];        int length = is.read(bytes);        return length < 0 ? "" : new String(Arrays.copyOf(bytes, length), "UTF-8");    }    /**     * 从json字符串中获取指定属性的值     *     * @param json     * @param property     * @return     */    private String getTargetProperty(String json, String property) {        if (StringUtils.isEmpty(json)) {            return null;        }        String[] fragments = json.split(",");        String value = null;        for (String fragment : fragments) {            if (fragment.contains(property)) {                value = fragment.substring(fragment.indexOf(":") + 1).trim();                if (value.contains("}")) {                    value = value.substring(0, value.indexOf("}")).trim();                }                if (value.contains("\"")) {                    value = value.substring(1, value.length() - 1).trim();                }                break;            }        }        return value;    }    /**     * 校验加盐后是否和原文一致     *     * @param password     * @param md5     * @return     */    private boolean verify(String password, String md5) {        if (StringUtils.isEmpty(password) || StringUtils.isEmpty(md5)) {            return false;        }        char[] cs1 = new char[32];        char[] cs2 = new char[16];        for (int i = 0; i < 48; i += 3) {            cs1[i / 3 * 2] = md5.charAt(i);            cs1[i / 3 * 2 + 1] = md5.charAt(i + 2);            cs2[i / 3] = md5.charAt(i + 1);        }        return new String(cs1).equals(md5Hex(password + new String(cs2)));    }    /**     * 获取十六进制字符串形式的MD5摘要     */    private String md5Hex(String src) {        try {            MessageDigest md5 = MessageDigest.getInstance("MD5");            byte[] bs = md5.digest(src.getBytes());            return new String(new Hex().encode(bs));        } catch (Exception e) {            return null;        }    }    /**     * 字符串是否是数字     *     * @param value     * @return     */    private boolean isNumber(String value) {        char[] chars = ((String)value).toCharArray();        for (char c : chars) {            if (!Character.isDigit(c)) {                return false;            }        }        return true;    }    /**     * 获取符合要求的Controller名称     * @ComponentScan就是使用这些代码扫描包,然后通过TypeFilter过滤想要的     * @ComponentScan扫描时添加了一个AnnotationTypeFilter(Component.class, false)的类型过滤     *     * @param basePackage     * @return     * @throws IOException     */    private List<String> findCandidateControllers(String basePackage) throws IOException {        if (LOGGER.isDebugEnabled()) {            LOGGER.debug("开始扫描包[" + basePackage + "]下的所有类");        }        List<String> controllers = new ArrayList<String>();        String packageSearchPath = CLASSPATH_ALL_URL_PREFIX + replaceDotByDelimiter(basePackage) + '/' + RESOURCE_PATTERN;        ResourceLoader resourceLoader = new DefaultResourceLoader();        MetadataReaderFactory readerFactory = new SimpleMetadataReaderFactory(resourceLoader);        Resource[] resources = ResourcePatternUtils.getResourcePatternResolver(resourceLoader).getResources(packageSearchPath);        for (Resource resource : resources) {            MetadataReader reader = readerFactory.getMetadataReader(resource);            if (isCandidateController(reader, readerFactory)) {                controllers.add(reader.getClassMetadata().getClassName());                if (LOGGER.isDebugEnabled()) {                    LOGGER.debug("扫描到符合要求开放Controller类:[" + controllers.get(controllers.size() - 1) + "]");                }            }        }        return controllers;    }    /**     * 通过TypeFilter得到标识了@RequestMapping的类     *     * @param reader     * @param readerFactory     * @return     * @throws IOException     */    protected boolean isCandidateController(MetadataReader reader, MetadataReaderFactory readerFactory) throws IOException {        for (TypeFilter tf : this.excludeFilters) {            if (tf.match(reader, readerFactory)) {                return false;            }        }        for (TypeFilter tf : this.includeFilters) {            if (tf.match(reader, readerFactory)) {                return true;            }        }        return false;    }    /**     * 将类名转换为类对象     *     * @param classNames     * @return     * @throws ClassNotFoundException     */    private List<Class<?>> transformToClass(List<String> classNames) throws ClassNotFoundException {        List<Class<?>> classes = new ArrayList<Class<?>>(classNames.size());        for (String className : classNames) {            classes.add(ClassUtils.forName(className, this.getClass().getClassLoader()));        }        return classes;    }    /**     * 用"/"替换包路径中"."     *     * @param path     * @return     */    private String replaceDotByDelimiter(String path) {        return StringUtils.replace(path, ".", "/");    }    /**     * 内省Controllers,生成开放的URL集合     *     * @param controllers     * @param prefix     */    private void generateExposedURL(List<Class<?>> controllers, String prefix) {        for (Class<?> controller : controllers) {            String[] classMappings = controller.getAnnotation(RequestMapping.class).value();            ReflectionUtils.doWithMethods(controller,                (method) -> {                    String[] methodMappings = method.getAnnotation(RequestMapping.class).value();                    exposedRequestUriSet.add(prefix + transformMappings(classMappings) + transformMappings(methodMappings));                },                (method) -> method.isAnnotationPresent(RequestMapping.class)            );        }    }    /**     * 通过工具类获取Spring容器内的Mapper,查询数据库的刀key     *     * @param appcode     * @param campusId     * @return     */    public String selectKey(String appcode, String campusId) {        CampusMapper campusMapper = (CampusMapper )SpringBeanInstanceAccessor.getBean(CampusMapper .class);        return bankUserMapper.selectKey(Integer.valueOf(campusId));    }    /**     * 如果方法或者类上的{@linkplain RequestMapping#value()}未指定,则使用""代替     * value()仅支持单个值     *     * @param mappings     * @return     */    private String transformMappings(String[] mappings) {        return ObjectUtils.isEmpty(mappings) ? "" : mappings[0];    }    @Override    public void destroy() {    }}
原创粉丝点击