java中使用asm实现动态创一个类动态代理

来源:互联网 发布:ipad软件连不上网络 编辑:程序博客网 时间:2024/05/22 00:13

辅助代码:

package com.test.bean;

import java.lang.reflect.Method;

public interface InvocationHandler {
 public Object invoke(Object proxy,String methodName,Method method,Object[] args);

}

 

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;

public class MyInvocationHandler implements InvocationHandler {
    private Object instance;
    public MyInvocationHandler(Object instance)
    {
     this.instance=instance;
    }
 public Object invoke(Object proxy, String methodName, Method method,
   Object[] args) {
  // TODO Auto-generated method stub
  System.out.println("invoke");
  try {
   return method.invoke(instance, args);
  } catch (IllegalArgumentException e) {
   // TODO Auto-generated catch block
   e.printStackTrace();
  } catch (IllegalAccessException e) {
   // TODO Auto-generated catch block
   e.printStackTrace();
  } catch (InvocationTargetException e) {
   // TODO Auto-generated catch block
   e.printStackTrace();
  }
  return null;
 }

}

 

package com.test.bean;

public class Student {
    public Object add(Object instance,String id,short s1,int s2)
    {
     System.out.println("add");
     return "add";
    }
    public String update(Object instance,String id,int s1)
    {
     System.out.println("update");
     return "update";
    }
    public void del(Object instance,String id)
    {
     System.out.println("del");
  
    }
}

关键代码:

package com.test.bean;

import java.io.File;
import java.io.FileOutputStream;
import java.io.PrintWriter;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.ArrayList;
import java.util.List;

import org.objectweb.asm.ClassReader;
import org.objectweb.asm.ClassWriter;
import org.objectweb.asm.Label;
import org.objectweb.asm.MethodVisitor;
import org.objectweb.asm.Type;
import org.objectweb.asm.Opcodes;
import org.objectweb.asm.util.CheckClassAdapter;
import org.objectweb.asm.util.TraceClassVisitor;

public class Proxy {
    private Proxy()
    {
    
    }
    public static void main(String args[])
    {
  com.test.bean.InvocationHandler h=new MyInvocationHandler(new Student());
  try {
   Proxy.getProxyInstance(Student.class, h);
  } catch (SecurityException e) {
   // TODO Auto-generated catch block
   e.printStackTrace();
  } catch (IllegalArgumentException e) {
   // TODO Auto-generated catch block
   e.printStackTrace();
  } catch (NoSuchMethodException e) {
   // TODO Auto-generated catch block
   e.printStackTrace();
  } catch (InstantiationException e) {
   // TODO Auto-generated catch block
   e.printStackTrace();
  } catch (IllegalAccessException e) {
   // TODO Auto-generated catch block
   e.printStackTrace();
  } catch (InvocationTargetException e) {
   // TODO Auto-generated catch block
   e.printStackTrace();
  }
    }
    public static Object getProxyInstance(Class superClass,InvocationHandler h) throws SecurityException, NoSuchMethodException, IllegalArgumentException, InstantiationException, IllegalAccessException, InvocationTargetException
    {
     Generator g=new Generator();
     g.setSuperClass(superClass);
     Class clz=g.create();
     Constructor c=clz.getConstructor(new Class[]{InvocationHandler.class});
     return c.newInstance(new Object[]{h});
    }
    static class Generator
    {
     private Class superClass;
   
     public void setSuperClass(Class superClass)
     {
      this.superClass=superClass;
     }

     public Class create()
     {
      try {
      ClassWriter cw=new ClassWriter(ClassWriter.COMPUTE_MAXS);//使用自动计算
      String superName=this.superClass.getName().replace('.', '/');
      cw.visit(Opcodes.V1_6, Opcodes.ACC_PUBLIC, superName+"Prxoy0",null ,superName, null);
      //构建字段
      cw.visitField(Opcodes.ACC_PRIVATE, "h", "Lcom/test/bean/InvocationHandler;", null, null).visitEnd();
      //构建无参数的构造函数
      MethodVisitor mv1=cw.visitMethod(Opcodes.ACC_PUBLIC, "<init>", "()V", null, null);
      mv1.visitCode();
     
      mv1.visitVarInsn(Opcodes.ALOAD, 0);
      mv1.visitMethodInsn(Opcodes.INVOKESPECIAL, "java/lang/Object", "<init>", "()V");
      mv1.visitInsn(Opcodes.RETURN);
      mv1.visitMaxs(1, 1);
      mv1.visitEnd();
      //构建有参数的构造函数
      MethodVisitor mv2=cw.visitMethod(Opcodes.ACC_PUBLIC, "<init>", "(Lcom/test/bean/InvocationHandler;)V", null, null);
      mv2.visitCode();
      mv2.visitVarInsn(Opcodes.ALOAD,0);
      mv2.visitMethodInsn(Opcodes.INVOKESPECIAL, "java/lang/Object", "<init>", "()V");
      mv2.visitVarInsn(Opcodes.ALOAD, 0);
      mv2.visitVarInsn(Opcodes.ALOAD, 1);
      mv2.visitFieldInsn(Opcodes.PUTFIELD, superName+"Prxoy0", "h", "Lcom/test/bean/InvocationHandler;");
      mv2.visitInsn(Opcodes.RETURN);
      mv2.visitMaxs(2, 2);
      mv2.visitEnd();
         Method[] methods=getMethods(this.superClass);
         Type type=Type.getType(this.superClass);
         for(int i=0;i<methods.length;i++)
         {//构建方法
          Method method=methods[i];
          //添加静态字段
          cw.visitField(Opcodes.ACC_PUBLIC+Opcodes.ACC_STATIC, "method"+i, "Ljava/lang/reflect/Method;", null, null).visitEnd();
          MethodVisitor mv=cw.visitMethod(Opcodes.ACC_PUBLIC, method.getName(),type.getMethodDescriptor(method) , null, null);
          mv.visitCode();
          Type[] types=type.getArgumentTypes(method);
          mv.visitVarInsn(Opcodes.ALOAD, 0);//给proxy赋值
          //访问对象的h字段
          mv.visitFieldInsn(Opcodes.GETFIELD, superName+"Prxoy0", "h", "Lcom/test/bean/InvocationHandler;");
          mv.visitVarInsn(Opcodes.ALOAD, 0);//给proxy赋值
          mv.visitLdcInsn(method.getName());//给methodName赋值
          //给method赋值
          mv.visitVarInsn(Opcodes.ALOAD,0);
          mv.visitFieldInsn(Opcodes.GETSTATIC, superName+"Prxoy0", "method"+i, "Ljava/lang/reflect/Method;");
         
         
          //给args参数赋值
          //创建object数组
          mv.visitLdcInsn(new Integer(types.length));
          mv.visitMultiANewArrayInsn("[Ljava/lang/Object;",1);
          //mv.visitVarInsn(Opcodes.ASTORE, types.length+1);
          for(int j=0;j<types.length;j++)
          {//传递参数
           if(types[j].equals(Type.INT_TYPE))
           {
            mv.visitInsn(Opcodes.DUP);//数组引用
            //mv.visitVarInsn(Opcodes.ALOAD, types.length+1);//数组引用
            mv.visitLdcInsn(new Integer(j));//数组下标
            //int参数包装成integer
            mv.visitVarInsn(Opcodes.ILOAD, j+1);//参数
            mv.visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Integer","valueOf", "(I)Ljava/lang/Integer;");
            mv.visitInsn(Opcodes.AASTORE);//给数组元素赋值
           }
           else if(types[j].equals(Type.FLOAT_TYPE))
           {
            mv.visitInsn(Opcodes.DUP);//数组引用
            //mv.visitVarInsn(Opcodes.ALOAD, types.length+1);//数组引用
            mv.visitLdcInsn(new Integer(j));//数组下标
            //float参数包装成Float
            mv.visitVarInsn(Opcodes.FLOAD, j+1);//参数
            mv.visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Float", "valueOf","(F)Ljava/lang/Float;");
            mv.visitInsn(Opcodes.AASTORE);
           }
           else if(types[j].equals(Type.DOUBLE_TYPE))
           {
            mv.visitInsn(Opcodes.DUP);//数组引用
            //mv.visitVarInsn(Opcodes.ALOAD, types.length+1);//数组引用
            mv.visitLdcInsn(new Integer(j));//数组下标
            //double参数包装成Double
            mv.visitVarInsn(Opcodes.DLOAD, j+1);//参数
            mv.visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Double", "valueOf", "(D)Ljava/lang/Double;");
            mv.visitInsn(Opcodes.AASTORE);//给数组元素赋值
           }
           else if(types[j].equals(Type.LONG_TYPE))
           {
            mv.visitInsn(Opcodes.DUP);//数组引用
            //mv.visitVarInsn(Opcodes.ALOAD, types.length+1);//数组引用
            mv.visitLdcInsn(new Integer(j));//数组下标
            //long参数包装成Long
            mv.visitVarInsn(Opcodes.LLOAD, j+1);//元素
            mv.visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Long", "valueOf", "(J)Ljava/lang/Long;");
            mv.visitInsn(Opcodes.AASTORE);//赋值
           
           }
           else if(types[j].equals(Type.BYTE_TYPE))
           {
            mv.visitInsn(Opcodes.DUP);//数组引用
            //mv.visitVarInsn(Opcodes.ALOAD,types.length+1);//数组引用
            mv.visitLdcInsn(new Integer(j));//数组下标
            //byte参数包装成Byte
            mv.visitVarInsn(Opcodes.ILOAD, j+1);//元素
            mv.visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Byte", "valueOf", "(B)Ljava/lang/Byte;");
            mv.visitInsn(Opcodes.AASTORE);//赋值
           }
           else if(types[j].equals(Type.BOOLEAN_TYPE))
           {
            mv.visitInsn(Opcodes.DUP);//数组引用
            //mv.visitVarInsn(Opcodes.ALOAD, types.length+1);//数组引用
            mv.visitLdcInsn(new Integer(j));//数组下标
            //boolean参数包装成Boolean
            mv.visitVarInsn(Opcodes.ILOAD, j+1);//元素
            mv.visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Boolean", "valueOf", "(Z)Ljava/lang/Boolean;");
            mv.visitInsn(Opcodes.AASTORE);//赋值
           }
           else if(types[j].equals(Type.CHAR_TYPE))
           {
            mv.visitInsn(Opcodes.DUP);//数组引用
            //mv.visitVarInsn(Opcodes.ALOAD, types.length+1);//数组引用
            mv.visitLdcInsn(new Integer(j));//数组下标
            //char参数包装成Character
            mv.visitVarInsn(Opcodes.ILOAD, j+1);//元素
            mv.visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Character", "valueOf", "(C)Ljava/lang/Character;");
            mv.visitInsn(Opcodes.AASTORE);//赋值
           }
           else if(types[j].equals(Type.SHORT_TYPE))
           {
            mv.visitInsn(Opcodes.DUP);//数组引用
            //mv.visitVarInsn(Opcodes.ALOAD, types.length+1);//数组引用
            mv.visitLdcInsn(new Integer(j));//数组下标
            //short参数包装成Short
            mv.visitVarInsn(Opcodes.ILOAD, j+1);//元素
            mv.visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Short", "valueOf", "(S)Ljava/lang/Short;");
            mv.visitInsn(Opcodes.AASTORE);//赋值
           }
           else
           {
            mv.visitInsn(Opcodes.DUP);//数组引用
            //mv.visitVarInsn(Opcodes.ALOAD, types.length+1);//数组引用
            mv.visitLdcInsn(new Integer(j));//数组下标
            mv.visitVarInsn(Opcodes.ALOAD, j+1);//元素
            mv.visitInsn(Opcodes.AASTORE);//赋值
           }
          }
          //mv.visitVarInsn(Opcodes.ALOAD, types.length+1);//给数组赋值
          //调用拦截器
             mv.visitMethodInsn(Opcodes.INVOKEINTERFACE,"com/test/bean/InvocationHandler", "invoke", "(Ljava/lang/Object;Ljava/lang/String;Ljava/lang/reflect/Method;[Ljava/lang/Object;)Ljava/lang/Object;");
          //转换并返回
       if(method.getReturnType().toString().equals("int"))
       {
        mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Object", "toString", "()Ljava/lang/String;");
        mv.visitMethodInsn(Opcodes.INVOKESTATIC,"java/lang/Integer","parseInt","(Ljava/lang/String;)I");
        mv.visitInsn(Opcodes.IRETURN);
       }
       else if(method.getReturnType().toString().equals("float"))
       {
        mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Object", "toString", "()Ljava/lang/String;");
        mv.visitMethodInsn(Opcodes.INVOKESTATIC,"java/lang/Float","parseFloat","(Ljava/lang/String;)F");
        mv.visitInsn(Opcodes.FRETURN);
       
       }
       else if(method.getReturnType().toString().equals("double"))
       {
        mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Object", "toString", "()Ljava/lang/String;");
        mv.visitMethodInsn(Opcodes.INVOKESTATIC,"java/lang/Double","parseDouble","(Ljava/lang/String;)D");
        mv.visitInsn(Opcodes.DRETURN);
       }
       else if(method.getReturnType().toString().equals("long"))
       {
        mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Object", "toString", "()Ljava/lang/String;");
        mv.visitMethodInsn(Opcodes.INVOKESTATIC,"java/lang/Long","parseLong","(Ljava/lang/String;)L");
        mv.visitInsn(Opcodes.LRETURN);
       }
       else if(method.getReturnType().toString().equals("byte"))
       {
        mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Object", "toString", "()Ljava/lang/String;");
        mv.visitMethodInsn(Opcodes.INVOKESTATIC,"java/lang/Byte","parseByte","(Ljava/lang/String;)J");
        mv.visitInsn(Opcodes.IRETURN);
       }
       else if(method.getReturnType().toString().equals("boolean"))
       {
        mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Object", "toString", "()Ljava/lang/String;");
        mv.visitMethodInsn(Opcodes.INVOKESTATIC,"java/lang/Boolean","parseBoolean","(Ljava/lang/String;)Z");
        mv.visitInsn(Opcodes.IRETURN);
       }
       else if(method.getReturnType().toString().equals("char"))
       {
        mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Object", "toString", "()Ljava/lang/String;");
        mv.visitMethodInsn(Opcodes.INVOKESTATIC,"java/lang/String","toCharArray","()[C");
        mv.visitInsn(Opcodes.ICONST_0);
        mv.visitInsn(Opcodes.CALOAD);
        mv.visitInsn(Opcodes.IRETURN);
       
       }
       else if(method.getReturnType().toString().equals("short"))
       {
        mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Object", "toString", "()Ljava/lang/String;");
        mv.visitMethodInsn(Opcodes.INVOKESTATIC,"java/lang/Short","parseShort","(Ljava/lang/String;)S");
        mv.visitInsn(Opcodes.IRETURN);
       }
       else if(method.getReturnType().toString().equals("void"))
       {
        mv.visitInsn(Opcodes.POP);
        mv.visitInsn(Opcodes.RETURN);
       }
       else
       {
        String str=type.getReturnType(method).getDescriptor();
        if(str.startsWith("["))
        {//数组类型
         mv.visitTypeInsn(Opcodes.CHECKCAST, str);
        }
        else
        {
         int length=str.length();
         mv.visitTypeInsn(Opcodes.CHECKCAST, str.substring(1, length-1));
        }
        mv.visitInsn(Opcodes.ARETURN);
       }
       mv.visitMaxs(types.length+4, types.length+3);
       mv.visitEnd();
         }
         //静态块
         MethodVisitor mv3=cw.visitMethod(Opcodes.ACC_STATIC, "<clinit>", "()V", null, null);
         mv3.visitCode();
         Label tryLabel=new Label();
         Label catchLabel=new Label();
         mv3.visitTryCatchBlock(tryLabel, catchLabel, catchLabel, "java/lang/Exception");
         mv3.visitLabel(tryLabel);
         mv3.visitFrame(Opcodes.F_NEW, 0, null, 0, null);
         //给静态字段赋值
         for(int i=0;i<methods.length;i++)
         {
          Method method=methods[i];
          mv3.visitVarInsn(Opcodes.ALOAD, 0);
          mv3.visitFieldInsn(Opcodes.GETSTATIC, superName+"Prxoy0", "method"+i, "Ljava/lang/reflect/Method;");
          mv3.visitLdcInsn(method.getDeclaringClass().getName());
          mv3.visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Class", "forName", "(Ljava/lang/String;)Ljava/lang/Class;");
          mv3.visitLdcInsn(method.getName());
          //构建方法参数
          Class[] params=method.getParameterTypes();
          mv3.visitLdcInsn(new Integer(params.length));
          //new数组
          mv3.visitMultiANewArrayInsn("[Ljava/lang/Class;", 1);
          for(int j=0;j<params.length;j++)
          {
           mv3.visitInsn(Opcodes.DUP); //引用数组
           mv3.visitLdcInsn(new Integer(j));//数组下标
           mv3.visitLdcInsn(getType(params[j].getName()));
           mv3.visitMethodInsn(Opcodes.INVOKESTATIC, "java/lang/Class", "forName", "(Ljava/lang/String;)Ljava/lang/Class;");
           mv3.visitInsn(Opcodes.AASTORE);//赋值
          }
          mv3.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Class", "getMethod", "(Ljava/lang/String;[Ljava/lang/Class;)Ljava/lang/reflect/Method;");
          mv3.visitFieldInsn(Opcodes.PUTSTATIC, superName+"Prxoy0", "method"+i, "Ljava/lang/reflect/Method;");
         }
      mv3.visitInsn(Opcodes.RETURN);
      mv3.visitLabel(catchLabel);
      mv3.visitFrame(Opcodes.F_NEW, 0, null, 0, null);
      mv3.visitVarInsn(Opcodes.ASTORE, 0);
      mv3.visitVarInsn(Opcodes.ALOAD, 0);
      mv3.visitMethodInsn(Opcodes.INVOKEVIRTUAL, "java/lang/Exception", "printStackTrace", "()V");
      mv3.visitInsn(Opcodes.RETURN);
      mv3.visitMaxs(methods.length, methods.length);
         mv3.visitEnd();
        
         cw.visitEnd();
      byte[] b=cw.toByteArray();
      CommonClassLoader commonClassLoader=new CommonClassLoader();
      return commonClassLoader.defineClass(b);
      /*ClassReader cr=new ClassReader(b);
      ClassWriter cw3=new ClassWriter(0);
       TraceClassVisitor cw4=new TraceClassVisitor(new CheckClassAdapter(cw3),new PrintWriter(System.out));
       cr.accept(cw4, 0);
       Class clz=null;
     
       FileOutputStream out=new FileOutputStream(new File("D:\com\test\bean\StudentPrxoy0.class"));
       out.write(b);
       out.close();
       ClassLoader parent=Thread.currentThread().getContextClassLoader();
       URL[] urls=new URL[]{new File("D:\com.zip").toURL()};
       System.out.println(urls[0]);
       ClassLoader loader=new URLClassLoader(urls,parent);
       Thread.currentThread().setContextClassLoader(loader);
          clz=loader.loadClass("com.test.bean.StudentPrxoy0");
        
       Object instance=clz.newInstance();
       //clz.getMethods()[0].invoke(instance, null);
          return clz;*/
          //return defineClass(loader,b,0,b.length);
      } catch (Exception e) {
       // TODO Auto-generated catch block
       e.printStackTrace();
      }
      return null;
     }
     //方法过滤
     private Method[] getMethods(Class clz) throws SecurityException, NoSuchMethodException
     {
      List list=new ArrayList();
      Method[] methods=clz.getDeclaredMethods();
      for(int i=0;i<methods.length;i++)
      {
       if(methods[i].getModifiers()==0||methods[i].getModifiers()==1||methods[i].getModifiers()==4)
       {
        list.add(methods[i]);
       }
      }
      Method m1=clz.getMethod("hashCode", null);
      Method m2=clz.getMethod("equals", new Class[]{Object.class});
      Method m3=clz.getMethod("toString", null);
      list.add(m1);
      list.add(m2);
      list.add(m3);
      return (Method[])list.toArray(new Method[]{});
     }
     //参数转换
     private String getType(String type)
     {
      if(type.equals("byte"))
      {
       return"java/lang/Byte";
      }
      else if(type.equals("char"))
      {
       //Character
       return"java/lang/Character";
      }
      else if(type.equals("boolean"))
      {
       return"java/lang/Boolean";
      }
      else if(type.equals("short"))
      {
       return"java/lang/Short";
      }
      else if(type.equals("int"))
      {
       return"java/lang/Integer";
      }
      else if(type.equals("float"))
      {
       return"java/lang/Float";
      }
      else if(type.equals("double"))
      {
       return"java/lang/Double";
      }
      else if(type.equals("long"))
      {
       return"java/lang/Long";
      }
      else
      {
       return type;
      }
     }
    }
    static class CommonClassLoader extends ClassLoader
    {
     public Class defineClass(byte[] b)
     {
      return super.defineClass(b, 0, b.length);
     }
    }
}

原创粉丝点击