struts2 源码分析 request ---设置setParameters 的值

来源:互联网 发布:便携式设备恢复数据 编辑:程序博客网 时间:2024/06/16 20:20
StrutsRequestWrapper
/*    */ package org.apache.struts2.dispatcher;/*    */ /*    */ import com.opensymphony.xwork2.ActionContext;/*    */ import com.opensymphony.xwork2.util.ValueStack;/*    */ import javax.servlet.http.HttpServletRequest;/*    */ import javax.servlet.http.HttpServletRequestWrapper;/*    */ /*    */ public class StrutsRequestWrapper extends HttpServletRequestWrapper/*    */ {/*    */   public StrutsRequestWrapper(HttpServletRequest req)/*    */   {/* 49 */     super(req);/*    */   }/*    */ /*    */   public Object getAttribute(String s)/*    */   {/* 58 */     if ((s != null) && (s.startsWith("javax.servlet")))/*    */     {/* 61 */       return super.getAttribute(s);/*    */     }/*    */ /* 64 */     ActionContext ctx = ActionContext.getContext();/* 65 */     Object attribute = super.getAttribute(s);/* 66 */     if ((ctx != null) && /* 67 */       (attribute == null)) {/* 68 */       boolean alreadyIn = false;/* 69 */       Boolean b = (Boolean)ctx.get("__requestWrapper.getAttribute");/* 70 */       if (b != null) {/* 71 */         alreadyIn = b.booleanValue();/*    */       }/*    */ /* 76 */       if ((!alreadyIn) && (s.indexOf("#") == -1)) {/*    */         try/*    */         {/* 79 */           ctx.put("__requestWrapper.getAttribute", Boolean.TRUE);/* 80 */           ValueStack stack = ctx.getValueStack();/* 81 */           if (stack != null)/* 82 */             attribute = stack.findValue(s);/*    */         }/*    */         finally {/* 85 */           ctx.put("__requestWrapper.getAttribute", Boolean.FALSE);/*    */         }/*    */       }/*    */     }/*    */ /* 90 */     return attribute;/*    */   }/*    */ }/* Location:           C:\Documents and Settings\wb_zypt\妗岄潰\lib\struts2-core-2.2.1.1.jar * Qualified Name:     org.apache.struts2.dispatcher.StrutsRequestWrapper * JD-Core Version:    0.6.0 */
 MultiPartRequestWrapper

 

package org.apache.struts2.dispatcher.multipart;import com.opensymphony.xwork2.util.logging.Logger;import com.opensymphony.xwork2.util.logging.LoggerFactory;import java.io.File;import java.io.IOException;import java.util.ArrayList;import java.util.Collection;import java.util.Enumeration;import java.util.HashMap;import java.util.Iterator;import java.util.List;import java.util.Map;import java.util.Vector;import javax.servlet.http.HttpServletRequest;import org.apache.struts2.dispatcher.StrutsRequestWrapper;public class MultiPartRequestWrapper extends StrutsRequestWrapper{  protected static final Logger LOG = LoggerFactory.getLogger(MultiPartRequestWrapper.class);  Collection<String> errors;  //struts2是默认是从这个里面去取值的。没有的话才去原生request里面去取值  MultiPartRequest multi;  public MultiPartRequestWrapper(MultiPartRequest multiPartRequest, HttpServletRequest request, String saveDir)  {    super(request);    this.multi = multiPartRequest;    try {      this.multi.parse(request, saveDir);      for (i$ = this.multi.getErrors().iterator(); i$.hasNext(); ) { Object o = i$.next();        String error = (String)o;        addError(error);      }    }    catch (IOException e)    {      Iterator i$;      addError("Cannot parse request: " + e.toString());    }  }  public Enumeration<String> getFileParameterNames()  {    if (this.multi == null) {      return null;    }    return this.multi.getFileParameterNames();  }  public String[] getContentTypes(String name)  {    if (this.multi == null) {      return null;    }    return this.multi.getContentType(name);  }  public File[] getFiles(String fieldName)  {    if (this.multi == null) {      return null;    }    return this.multi.getFile(fieldName);  }  public String[] getFileNames(String fieldName)  {    if (this.multi == null) {      return null;    }    return this.multi.getFileNames(fieldName);  }  public String[] getFileSystemNames(String fieldName)  {    if (this.multi == null) {      return null;    }    return this.multi.getFilesystemName(fieldName);  }  public String getParameter(String name)  {    return (this.multi == null) || (this.multi.getParameter(name) == null) ? super.getParameter(name) : this.multi.getParameter(name);  }  public Map getParameterMap()  {    Map map = new HashMap();    Enumeration enumeration = getParameterNames();    while (enumeration.hasMoreElements()) {      String name = (String)enumeration.nextElement();      map.put(name, getParameterValues(name));    }    return map;  }  public Enumeration getParameterNames()  {    if (this.multi == null) {      return super.getParameterNames();    }    return mergeParams(this.multi.getParameterNames(), super.getParameterNames());  }  public String[] getParameterValues(String name)  {    return (this.multi == null) || (this.multi.getParameterValues(name) == null) ? super.getParameterValues(name) : this.multi.getParameterValues(name);  }  public boolean hasErrors()  {    return (this.errors != null) && (!this.errors.isEmpty());  }  public Collection<String> getErrors()  {    return this.errors;  }  protected void addError(String anErrorMessage)  {    if (this.errors == null) {      this.errors = new ArrayList();    }    this.errors.add(anErrorMessage);  }  protected Enumeration mergeParams(Enumeration params1, Enumeration params2)  {    Vector temp = new Vector();    while (params1.hasMoreElements()) {      temp.add(params1.nextElement());    }    while (params2.hasMoreElements()) {      temp.add(params2.nextElement());    }    return temp.elements();  }}

 

 MultiPartRequest 

 

package org.apache.struts2.dispatcher.multipart;import com.opensymphony.xwork2.inject.Inject;import com.opensymphony.xwork2.util.logging.Logger;import com.opensymphony.xwork2.util.logging.LoggerFactory;import java.io.File;import java.io.IOException;import java.io.InputStream;import java.io.UnsupportedEncodingException;import java.util.ArrayList;import java.util.Collections;import java.util.Enumeration;import java.util.HashMap;import java.util.List;import java.util.Map;import javax.servlet.http.HttpServletRequest;import org.apache.commons.fileupload.FileItem;import org.apache.commons.fileupload.FileUploadException;import org.apache.commons.fileupload.RequestContext;import org.apache.commons.fileupload.disk.DiskFileItem;import org.apache.commons.fileupload.disk.DiskFileItemFactory;import org.apache.commons.fileupload.servlet.ServletFileUpload;public class JakartaMultiPartRequest  implements MultiPartRequest{  static final Logger LOG = LoggerFactory.getLogger(MultiPartRequest.class);  protected Map<String, List<FileItem>> files = new HashMap();  protected Map<String, List<String>> params = new HashMap();  protected List<String> errors = new ArrayList();  protected long maxSize;  @Inject("struts.multipart.maxSize")  public void setMaxSize(String maxSize)  {    this.maxSize = Long.parseLong(maxSize);  }  public void parse(HttpServletRequest request, String saveDir)    throws IOException  {    try    {      processUpload(request, saveDir);    } catch (FileUploadException e) {      LOG.warn("Unable to parse request", e, new String[0]);      this.errors.add(e.getMessage());    }  }  private void processUpload(HttpServletRequest request, String saveDir) throws FileUploadException, UnsupportedEncodingException {    for (FileItem item : parseRequest(request, saveDir)) {      if (LOG.isDebugEnabled()) {        LOG.debug("Found item " + item.getFieldName(), new String[0]);      }      if (item.isFormField())        processNormalFormField(item, request.getCharacterEncoding());      else        processFileField(item);    }  }  private void processFileField(FileItem item)  {    LOG.debug("Item is a file upload", new String[0]);    if ((item.getName() == null) || (item.getName().trim().length() < 1)) {      LOG.debug("No file has been uploaded for the field: " + item.getFieldName(), new String[0]);      return;    }    List values;    List values;    if (this.files.get(item.getFieldName()) != null)      values = (List)this.files.get(item.getFieldName());    else {      values = new ArrayList();    }    values.add(item);    this.files.put(item.getFieldName(), values);  }  private void processNormalFormField(FileItem item, String charset) throws UnsupportedEncodingException {    LOG.debug("Item is a normal form field", new String[0]);    List values;    List values;    if (this.params.get(item.getFieldName()) != null)      values = (List)this.params.get(item.getFieldName());    else {      values = new ArrayList();    }    if (charset != null)      values.add(item.getString(charset));    else {      values.add(item.getString());    }    this.params.put(item.getFieldName(), values);  }  //解析request的值  private List<FileItem> parseRequest(HttpServletRequest servletRequest, String saveDir) throws FileUploadException {    DiskFileItemFactory fac = createDiskFileItemFactory(saveDir);    ServletFileUpload upload = new ServletFileUpload(fac);    upload.setSizeMax(this.maxSize);    return upload.parseRequest(createRequestContext(servletRequest));  }  private DiskFileItemFactory createDiskFileItemFactory(String saveDir) {    DiskFileItemFactory fac = new DiskFileItemFactory();    fac.setSizeThreshold(0);    if (saveDir != null) {      fac.setRepository(new File(saveDir));    }    return fac;  }  public Enumeration<String> getFileParameterNames()  {    return Collections.enumeration(this.files.keySet());  }  public String[] getContentType(String fieldName)  {    List items = (List)this.files.get(fieldName);    if (items == null) {      return null;    }    List contentTypes = new ArrayList(items.size());    for (FileItem fileItem : items) {      contentTypes.add(fileItem.getContentType());    }    return (String[])contentTypes.toArray(new String[contentTypes.size()]);  }  public File[] getFile(String fieldName)  {    List items = (List)this.files.get(fieldName);    if (items == null) {      return null;    }    List fileList = new ArrayList(items.size());    for (FileItem fileItem : items) {      fileList.add(((DiskFileItem)fileItem).getStoreLocation());    }    return (File[])fileList.toArray(new File[fileList.size()]);  }  public String[] getFileNames(String fieldName)  {    List items = (List)this.files.get(fieldName);    if (items == null) {      return null;    }    List fileNames = new ArrayList(items.size());    for (FileItem fileItem : items) {      fileNames.add(getCanonicalName(fileItem.getName()));    }    return (String[])fileNames.toArray(new String[fileNames.size()]);  }  public String[] getFilesystemName(String fieldName)  {    List items = (List)this.files.get(fieldName);    if (items == null) {      return null;    }    List fileNames = new ArrayList(items.size());    for (FileItem fileItem : items) {      fileNames.add(((DiskFileItem)fileItem).getStoreLocation().getName());    }    return (String[])fileNames.toArray(new String[fileNames.size()]);  }  public String getParameter(String name)  {    List v = (List)this.params.get(name);    if ((v != null) && (v.size() > 0)) {      return (String)v.get(0);    }    return null;  }  public Enumeration<String> getParameterNames()  {    return Collections.enumeration(this.params.keySet());  }  public String[] getParameterValues(String name)  {    List v = (List)this.params.get(name);    if ((v != null) && (v.size() > 0)) {      return (String[])v.toArray(new String[v.size()]);    }    return null;  }  public List getErrors()  {    return this.errors;  }  private String getCanonicalName(String filename)  {    int forwardSlash = filename.lastIndexOf("/");    int backwardSlash = filename.lastIndexOf("\\");    if ((forwardSlash != -1) && (forwardSlash > backwardSlash))      filename = filename.substring(forwardSlash + 1, filename.length());    else if ((backwardSlash != -1) && (backwardSlash >= forwardSlash)) {      filename = filename.substring(backwardSlash + 1, filename.length());    }    return filename;  }  private RequestContext createRequestContext(HttpServletRequest req)  {    return new RequestContext(req) {      public String getCharacterEncoding() {        return this.val$req.getCharacterEncoding();      }      public String getContentType() {        return this.val$req.getContentType();      }      public int getContentLength() {        return this.val$req.getContentLength();      }      public InputStream getInputStream() throws IOException {        InputStream in = this.val$req.getInputStream();        if (in == null) {          throw new IOException("Missing content in the request");        }        return this.val$req.getInputStream();      }    };  }}

    strtus2通过 ServletActionContext.getRequest() 获取Request。

    获取的Request对象有可能是MultiPartRequestWrapper也有可能是StrutsRequestWrapper

    为了动态像Request设置值,通过源码了解。通过以下方法可以动态获取值。

    如果没有用strtus2中获取的Request是原生的Request的话,就直接可以通过

Map m = getRequest().getParameterMap();m.put(key, val);

     但是用了struts2封装的Request。就麻烦一点。

     MultiPartRequestWrapper--StrutsRequestWrapper--- HttpServletRequestWrapper

 

 

 /**    * 设置Parameters 的值    * @param key    * @param val    */    public void setParameters(String key,String val){        System.out.println(" getRequest() code == "+getRequest());try {//如果org.apache.struts2.dispatcher.multipart.MultiPartRequestWrapper 中存在  multi属性 就往 multi 设置值if(getRequest() instanceof MultiPartRequestWrapper && getMultiForRequest()){//获取struts2中的取值属性 org.apache.struts2.dispatcher.multipart.JakartaMultiPartRequestField requestField = getRequest().getClass().getDeclaredField("multi");if(requestField!=null){requestField.setAccessible(true);//org.apache.struts2.dispatcher.multipart.JakartaMultiPartRequestMultiPartRequest multiPartRequest =(MultiPartRequest)requestField.get(getRequest());//params 设置值Field paramsField = multiPartRequest.getClass().getDeclaredField("params");requestField.setAccessible(true);if(paramsField!=null){paramsField.setAccessible(true);Map<String, List<String>>  paramsMap= (Map<String, List<String>>)paramsField.get(multiPartRequest);    if(paramsMap!=null){    List<String> paramsList = new ArrayList<String>();    paramsList.add(val);    paramsMap.put(key, paramsList);    }}}}else{Map m = getRequest().getParameterMap();m.put(key, val);}//HttpServletRequestWrapper reqst = (HttpServletRequestWrapper) requestField.get(getRequest());    //Map m = reqst.getParameterMap();//lockedField = m.getClass().getDeclaredField("locked");//lockedField.setAccessible(true);//System.out.println(lockedField.get(m));//lockedField.set(m, false);//System.out.println(lockedField.get(m));//Object flag = m.put(key, val);//System.out.println("m hashCode "+m.hashCode());} catch (Exception e) {log.error(e.getMessage(), e);}    }        /**     * 判断request中是否存在Multi属性     * @return     */    public boolean getMultiForRequest()throws Exception{    boolean flag =false;    String multi = "multi";    for(Field f:getRequest().getClass().getDeclaredFields()){    if(multi.equals(f.getName())){    flag = true;    break;    }    }     return flag;    }

   

XssHttpServletRequestWrapper 
package com.dep.aop;import java.util.HashMap;import java.util.Iterator;import java.util.Map;import java.util.Set;import javax.servlet.http.HttpServletRequest;import javax.servlet.http.HttpServletRequestWrapper;import org.slf4j.Logger;import org.slf4j.LoggerFactory;import com.dep.util.StringUtil;/** * 拦截防止sql注入  * @author wb_zypt * */public class XssHttpServletRequestWrapper extends HttpServletRequestWrapper {HttpServletRequest orgRequest = null;Map newParams = null;private static Logger log = LoggerFactory.getLogger(XssHttpServletRequestWrapper.class);public XssHttpServletRequestWrapper(HttpServletRequest request) {super(request);orgRequest = request;}/*** 覆盖getParameter方法,将参数名和参数值都做xss过滤。<br/>* 如果需要获得原始的值,则通过super.getParameterValues(name)来获取<br/>* getParameterNames,getParameterValues和getParameterMap也可能需要覆盖*/@Overridepublic String getParameter(String name) {String value = super.getParameter(StringUtil.filterDangerString(name));if (value != null) {value = StringUtil.filterDangerString(value);}if(value == null){value = (String)getParameterMap().get(name);}return value;}@Override@SuppressWarnings("unchecked")public Map getParameterMap() {if(newParams !=null){return newParams;}else{newParams = new HashMap();}//Map newParams  = new HashMap();Map params = super.getParameterMap();Set<String> keySet = params.keySet();        for (Iterator iterator = keySet.iterator(); iterator.hasNext();) {            String key = (String) iterator.next();             Object obj =  params.get(key);            if(obj instanceof String){             String str = (String) params.get(key);             newParams.put(key, StringUtil.filterDangerString((String)str));            }else if(obj.getClass() == String[].class){             String[] str = (String[]) params.get(key);             newParams.put(key, xssEncode((String[])str));            }else{             newParams.put(key, obj);            }                                   }/*java.lang.reflect.Field lockedField = null;try {lockedField = params.getClass().getDeclaredField("locked");lockedField.setAccessible(true);lockedField.set(params, false);} catch (Exception e) {log.error(e.getMessage(), e);}Set<String> keySet = params.keySet();        for (Iterator iterator = keySet.iterator(); iterator.hasNext();) {            String key = (String) iterator.next();             Object obj =  params.get(key);            if(obj instanceof String){             String str = (String) params.get(key);             params.put(key, xssEncode((String)str));            }else{             String[] str = (String[]) params.get(key);             params.put(key, xssEncode((String[])str));            }                                   }        if(lockedField!=null){        try {lockedField.set(params, true);} catch (Exception e) {log.error(e.getMessage(), e);}        }*/return newParams;}public String[] getParameterValues(String parameter) {      String[] values = super.getParameterValues(parameter);      if (values==null)  {                  return null;          }      int count = values.length;      String[] encodedValues = new String[count];      for (int i = 0; i < count; i++) {                 encodedValues[i] = StringUtil.filterDangerString(values[i]);       }      return encodedValues;    }/*** 覆盖getHeader方法,将参数名和参数值都做xss过滤。<br/>* 如果需要获得原始的值,则通过super.getHeaders(name)来获取<br/>* getHeaderNames 也可能需要覆盖*/@Overridepublic String getHeader(String name) {String value = super.getHeader(StringUtil.filterDangerString(name));if (value != null) {value = StringUtil.filterDangerString(value);}return value;}private static String[] xssEncode(String[] s) {String[] newStr = new String[s.length];for(int i=0;i<s.length;i++){newStr[i]= StringUtil.filterDangerString(s[i]);}return newStr;}/*** 获取最原始的request** @return*/public HttpServletRequest getOrgRequest() {return orgRequest;}/*** 获取最原始的request的静态方法** @return*/public static HttpServletRequest getOrgRequest(HttpServletRequest req) {if (req instanceof XssHttpServletRequestWrapper) {return ((XssHttpServletRequestWrapper) req).getOrgRequest();}return req;}}

 

0 0
原创粉丝点击