简单的实现一个自定义的Linq to Sql Provider

这两天空闲时间研究了一下Linq 的提供器,简单的实现了一下,代码写的很乱,也没有注释,也没怎么对代码进行设计,因此有很多的临时变量和有些不必要的操作,但注重的是实现原理吧,微软的Linq to SQL实现水很深,这个例子只是简单的实现select和where,其他的没有实现,并且对于where查询,只支持有限的==、>、<,不过这个不重要,如果需要可以添加对应的实现



using System;using System.Collections.Generic;using System.Linq;using System.Text;using System.Linq.Expressions;using System.Collections;namespace SimpleLinq2Sql{    public class CustomTable<T> : IQueryable<T>    {        private Type _ElementType = null;        private Expression _Expression = null;        private IQueryProvider _Provider = null;        public Type ElementType        {            get { return _ElementType; }        }        public Expression Expression        {            get { return _Expression; }        }        public IQueryProvider Provider        {            get { return _Provider; }        }        public CustomTable(Expression expression, IQueryProvider provider)        {            if (provider == null)                throw new Exception("provider can't be null");            _ElementType = typeof(T);            _Expression = expression;            _Provider = provider;        }        public CustomTable()            : this(null, new CustomProvider())        {            _Expression = Expression.Constant(this);        }        public IEnumerator<T> GetEnumerator()        {            return (Provider.Execute<IEnumerable<T>>(Expression)).GetEnumerator();        }        IEnumerator IEnumerable.GetEnumerator()        {            return GetEnumerator();        }        public override string ToString()        {            return _Provider.ToString();        }    }}


using System;using System.Collections.Generic;using System.Linq;using System.Text;using System.Linq.Expressions;using System.Reflection;using System.Data;using System.Data.SqlClient;namespace SimpleLinq2Sql{    public class CustomProvider : IQueryProvider    {        private string sql = "";        private int count = 0;        private string tableName = "";        private string selector = "";        private string where = "";        private Type _PreType = null;        private Type _ElementType = null;        public IQueryable<T> CreateQuery<T>(Expression expression)        {            _ElementType = typeof(T);            SetQueryText(expression);            count++;            return new CustomTable<T>(expression, this);        }        public IQueryable CreateQuery(Expression expression)        {            _ElementType = expression.Type.GetGenericArguments()[0];            SetQueryText(expression);            count++;            object[] args = new object[] { expression, this };            return (IQueryable)Activator.CreateInstance(typeof(CustomTable<>).MakeGenericType(_ElementType), args);        }        public T Execute<T>(Expression expression)        {            return (T)ExecuteSql(expression);        }        public object Execute(Expression expression)        {            return ExecuteSql(expression);        }        private void SetQueryText(Expression expression)        {            MethodCallExpression call = (MethodCallExpression)expression;            Expression first = call.Arguments[0];            Expression second = call.Arguments[1];            SetTableName(first);            if (call.Method.Name == "Select")            {                where = " ";            }            else if (call.Method.Name == "Where")            {                selector = "select " + "t" + count + ".*  ";            }            ProcessExpression(second);            sql = selector + " from " + tableName + " " + where;        }        private void SetTableName(Expression expression)        {            if (expression is ConstantExpression)            {                _PreType = expression.Type.GetGenericArguments()[0];                tableName = MapHelper.GetTableName(_PreType) + " as t" + count + " ";            }            if (expression is MethodCallExpression)            {                _PreType = expression.Type.GetGenericArguments()[0];                tableName = "( " + sql + " ) as t" + count + " ";            }        }        void ProcessExpression(Expression expression)        {            if (expression is UnaryExpression)            {                UnaryExpression tmp = (UnaryExpression)expression;                ProcessExpression(tmp.Operand);            }            if (expression is LambdaExpression)            {                ProcessExpression(((LambdaExpression)expression).Body);            }            if (expression is BinaryExpression)            {                ProcessBinary((BinaryExpression)expression);            }            if (expression is NewExpression)            {                ProcessNew((NewExpression)expression);            }        }        void ProcessBinary(BinaryExpression expression)        {            string membername = "";            string propertyname = "";            object value = "";            string ope = "";            if (expression.Left is BinaryExpression || expression.Right is BinaryExpression)            {                throw new Exception("only be one binary");            }            if (expression.Left is MemberExpression)            {                MemberExpression tmp = (MemberExpression)expression.Left;                propertyname = tmp.Member.Name;                membername = MapHelper.GetColumnName(_PreType, propertyname);            }            if (expression.Right is ConstantExpression)            {                ConstantExpression tmp = (ConstantExpression)expression.Right;                value = tmp.Value;            }            if (expression.NodeType == ExpressionType.Equal)            {                ope = " = ";                     }            if (expression.NodeType == ExpressionType.LessThan)            {                ope = " < ";                   }            if (expression.NodeType == ExpressionType.GreaterThan)            {                ope = " > ";            }            Type type = MapHelper.GetColumnType(_PreType, propertyname);            switch (type.Name)            {                case "Int32":                case "Single":                case "Double":                    where += " where t" + count + "." + membername + ope + value;                    break;                case "String":                case "DateTime":                    where += " where t" + count + "." + membername + ope + "'" + value + "'";                    break;            }        }        void ProcessNew(NewExpression expression)        {            selector = "select ";            List<string> newName = new List<String>();            List<string> oldName = new List<string>();            foreach (MemberInfo mi in expression.Members)            {                newName.Add(mi.Name);            }            foreach (MemberExpression arg in expression.Arguments)            {                oldName.Add(arg.Member.Name);            }            for (int i = 0; i < oldName.Count; i++)            {                if (newName[i] == oldName[i])                {                    selector += "t" + count + "." + MapHelper.GetColumnName(_PreType, oldName[i]) + ",";                }                else                {                    selector += "t" + count + "." + MapHelper.GetColumnName(_PreType, oldName[i]) + " as " + newName[i] + " ,";                }            }            selector = selector.Substring(0, selector.Length - 1);        }        private object ExecuteSql(Expression expression)        {            DataSet ds = new DataSet();            using (SqlConnection connection = new SqlConnection("Data Source=.;Initial Catalog=TestLinq;Integrated Security=True")) //这里写死了数据库连接            {                connection.Open();                SqlCommand cmd = new SqlCommand(sql, connection);                SqlDataAdapter da = new SqlDataAdapter(cmd);                da.Fill(ds);            }            return Table2Entity.ConvertFromTable(ds.Tables[0], _ElementType); ;        }        public override string ToString()        {            return sql;        }    }}


using System;using System.Collections.Generic;using System.Linq;using System.Text;using System.Reflection;namespace SimpleLinq2Sql{    public static class MapHelper    {        public static string GetTableName(Type type)        {            if (!type.IsDefined(typeof(TableAttribute), false)) throw new Exception("");            TableAttribute ta = Attribute.GetCustomAttribute(type, typeof(TableAttribute)) as TableAttribute;            return ta.TableName;        }        public static string GetColumnName(Type type, string propertyName)        {            PropertyInfo pi = type.GetProperty(propertyName, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);            if (pi == null) throw new Exception("");            if (!pi.IsDefined(typeof(ColumnAttribute), false)) return propertyName;            ColumnAttribute ca = Attribute.GetCustomAttribute(pi, typeof(ColumnAttribute)) as ColumnAttribute;            return ca.ColumnName;        }        public static Type GetColumnType(Type type, string propertyName)        {            PropertyInfo pi = type.GetProperty(propertyName, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance);            if (pi == null) throw new Exception("");            if (!pi.IsDefined(typeof(ColumnAttribute), false)) return pi.PropertyType;            ColumnAttribute ca = Attribute.GetCustomAttribute(pi, typeof(ColumnAttribute)) as ColumnAttribute;            return SwithType(ca.ColumnType);        }        static Type SwithType(DataType dtype)        {            Type type = null;            switch (dtype)            {                case DataType.String:                    type = typeof(String);                    break;                case DataType.Int:                    type = typeof(Int32);                    break;                case DataType.DateTime:                    type = typeof(DateTime);                    break;                case DataType.Float:                    type = typeof(float);                    break;                case DataType.Double:                    type = typeof(double);                    break;            }            return type;        }    }}


using System;using System.Collections.Generic;using System.Linq;using System.Text;namespace SimpleLinq2Sql{    [AttributeUsage(AttributeTargets.Class)]    internal class TableAttribute : Attribute    {        private string _TableName;        public string TableName { get { return _TableName; } }        public TableAttribute(string tableName)        {            _TableName = tableName;        }    }}


using System;using System.Collections.Generic;using System.Linq;using System.Text;namespace SimpleLinq2Sql{    [AttributeUsage(AttributeTargets.Property)]    internal class ColumnAttribute : Attribute    {        private string _ColumnName;        private DataType _ColumnType = DataType.String;        public string ColumnName { get { return _ColumnName; } }        public DataType ColumnType        {            get { return _ColumnType; }            set { _ColumnType = value; }        }        public ColumnAttribute(string columnName)        {            _ColumnName = columnName;        }    }}


using System;using System.Collections.Generic;using System.Linq;using System.Text;namespace SimpleLinq2Sql{    public enum DataType    {        Int,        String,        Float,        Double,        DateTime    }}


using System;using System.Collections.Generic;using System.Linq;using System.Text;using System.Data;using System.Reflection;namespace SimpleLinq2Sql{    internal static class Table2Entity    {        static object ConvertFromDataRow(DataRow dr, Type type)        {            object o = null;            if (!type.IsDefined(typeof(TableAttribute), false))            {                List<object> paralist = new List<object>();                PropertyInfo[] pi = type.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance);                foreach (PropertyInfo p in pi)                {                    if (!dr.Table.Columns.Contains(p.Name))                        throw new Exception("");                    object value = Convert.ChangeType(dr[p.Name], p.PropertyType);                    paralist.Add(value);                }                o = Activator.CreateInstance(type, paralist.ToArray());            }            else            {                o = Activator.CreateInstance(type);                PropertyInfo[] pi = type.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance);                foreach (PropertyInfo p in pi)                {                    if (!dr.Table.Columns.Contains(MapHelper.GetColumnName(type, p.Name)))                        throw new Exception("");                    object value = Convert.ChangeType(dr[MapHelper.GetColumnName(type, p.Name)], p.PropertyType);                    p.SetValue(o, value, null);                }            }            return o;        }        public static object ConvertFromTable(DataTable dt, Type type)        {            var t = typeof(List<>).MakeGenericType(type);            object obj = Activator.CreateInstance(t);            MethodInfo add = t.GetMethod("Add");            foreach (DataRow dr in dt.Rows)            {                add.Invoke(obj, new object[] { ConvertFromDataRow(dr, type) });            }            return obj;        }    }}


using System;using System.Collections.Generic;using System.Linq;using System.Text;namespace SimpleLinq2Sql{    [Table("Student")]    public class Student    {        private int _ID;        private string _StuName;        private string _Address;        private int _Sex;        private int _CollegeID;        [Column("ID", ColumnType = DataType.Int)]        public int ID        {            get { return _ID; }            set { _ID = value; }        }        [Column("StuName")]        public string Name        {            get { return _StuName; }            set { _StuName = value; }        }        [Column("Address")]        public string Address        {            get { return _Address; }            set { _Address = value; }        }        [Column("Sex", ColumnType = DataType.Int)]        public int Sex        {            get { return _Sex; }            set { _Sex = value; }        }        [Column("CollegeID", ColumnType = DataType.Int)]        public int CollegeID        {            get { return _CollegeID; }            set { _CollegeID = value; }        }    }}


using System;using System.Collections.Generic;using System.Linq;using System.Text;namespace SimpleLinq2Sql{    class Program    {        static void Main(string[] args)        {            var o = new CustomTable<Student>().Where(r => r.Address == "china").Select(r => new { NewName = r.Name, Country = r.Address, r.CollegeID, r.Sex })                .Where(r => r.Sex == 1);            Console.WriteLine(o.ToString());            foreach (var i in o)            {                Console.WriteLine(i.NewName + "," + i.Country + "," + i.CollegeID + "," + i.Sex);            }            Console.Read();        }    }}


