简单的实现一个自定义的Linq to Sql Provider
来源:互联网 发布:mac air怎样新建文件夹 编辑:程序博客网 时间:2024/06/07 11:33
这两天空闲时间研究了一下Linq 的提供器,简单的实现了一下,代码写的很乱,也没有注释,也没怎么对代码进行设计,因此有很多的临时变量和有些不必要的操作,但注重的是实现原理吧,微软的Linq to SQL实现水很深,这个例子只是简单的实现select和where,其他的没有实现,并且对于where查询,只支持有限的==、>、<,不过这个不重要,如果需要可以添加对应的实现。
先把代码记录下来吧,以后有时间再优化下代码和添加些注释。
IQueryable的实现:
using System;using System.Collections.Generic;using System.Linq;using System.Text;using System.Linq.Expressions;using System.Collections;namespace SimpleLinq2Sql{ public class CustomTable<T> : IQueryable<T> { private Type _ElementType = null; private Expression _Expression = null; private IQueryProvider _Provider = null; public Type ElementType { get { return _ElementType; } } public Expression Expression { get { return _Expression; } } public IQueryProvider Provider { get { return _Provider; } } public CustomTable(Expression expression, IQueryProvider provider) { if (provider == null) throw new Exception("provider can't be null"); _ElementType = typeof(T); _Expression = expression; _Provider = provider; } public CustomTable() : this(null, new CustomProvider()) { _Expression = Expression.Constant(this); } public IEnumerator<T> GetEnumerator() { return (Provider.Execute<IEnumerable<T>>(Expression)).GetEnumerator(); } IEnumerator IEnumerable.GetEnumerator() { return GetEnumerator(); } public override string ToString() { return _Provider.ToString(); } }}
IQueryProvider的实现:
using System;using System.Collections.Generic;using System.Linq;using System.Text;using System.Linq.Expressions;using System.Reflection;using System.Data;using System.Data.SqlClient;namespace SimpleLinq2Sql{ public class CustomProvider : IQueryProvider { private string sql = ""; private int count = 0; private string tableName = ""; private string selector = ""; private string where = ""; private Type _PreType = null; private Type _ElementType = null; public IQueryable<T> CreateQuery<T>(Expression expression) { _ElementType = typeof(T); SetQueryText(expression); count++; return new CustomTable<T>(expression, this); } public IQueryable CreateQuery(Expression expression) { _ElementType = expression.Type.GetGenericArguments()[0]; SetQueryText(expression); count++; object[] args = new object[] { expression, this }; return (IQueryable)Activator.CreateInstance(typeof(CustomTable<>).MakeGenericType(_ElementType), args); } public T Execute<T>(Expression expression) { return (T)ExecuteSql(expression); } public object Execute(Expression expression) { return ExecuteSql(expression); } private void SetQueryText(Expression expression) { MethodCallExpression call = (MethodCallExpression)expression; Expression first = call.Arguments[0]; Expression second = call.Arguments[1]; SetTableName(first); if (call.Method.Name == "Select") { where = " "; } else if (call.Method.Name == "Where") { selector = "select " + "t" + count + ".* "; } ProcessExpression(second); sql = selector + " from " + tableName + " " + where; } private void SetTableName(Expression expression) { if (expression is ConstantExpression) { _PreType = expression.Type.GetGenericArguments()[0]; tableName = MapHelper.GetTableName(_PreType) + " as t" + count + " "; } if (expression is MethodCallExpression) { _PreType = expression.Type.GetGenericArguments()[0]; tableName = "( " + sql + " ) as t" + count + " "; } } void ProcessExpression(Expression expression) { if (expression is UnaryExpression) { UnaryExpression tmp = (UnaryExpression)expression; ProcessExpression(tmp.Operand); } if (expression is LambdaExpression) { ProcessExpression(((LambdaExpression)expression).Body); } if (expression is BinaryExpression) { ProcessBinary((BinaryExpression)expression); } if (expression is NewExpression) { ProcessNew((NewExpression)expression); } } void ProcessBinary(BinaryExpression expression) { string membername = ""; string propertyname = ""; object value = ""; string ope = ""; if (expression.Left is BinaryExpression || expression.Right is BinaryExpression) { throw new Exception("only be one binary"); } if (expression.Left is MemberExpression) { MemberExpression tmp = (MemberExpression)expression.Left; propertyname = tmp.Member.Name; membername = MapHelper.GetColumnName(_PreType, propertyname); } if (expression.Right is ConstantExpression) { ConstantExpression tmp = (ConstantExpression)expression.Right; value = tmp.Value; } if (expression.NodeType == ExpressionType.Equal) { ope = " = "; } if (expression.NodeType == ExpressionType.LessThan) { ope = " < "; } if (expression.NodeType == ExpressionType.GreaterThan) { ope = " > "; } Type type = MapHelper.GetColumnType(_PreType, propertyname); switch (type.Name) { case "Int32": case "Single": case "Double": where += " where t" + count + "." + membername + ope + value; break; case "String": case "DateTime": where += " where t" + count + "." + membername + ope + "'" + value + "'"; break; } } void ProcessNew(NewExpression expression) { selector = "select "; List<string> newName = new List<String>(); List<string> oldName = new List<string>(); foreach (MemberInfo mi in expression.Members) { newName.Add(mi.Name); } foreach (MemberExpression arg in expression.Arguments) { oldName.Add(arg.Member.Name); } for (int i = 0; i < oldName.Count; i++) { if (newName[i] == oldName[i]) { selector += "t" + count + "." + MapHelper.GetColumnName(_PreType, oldName[i]) + ","; } else { selector += "t" + count + "." + MapHelper.GetColumnName(_PreType, oldName[i]) + " as " + newName[i] + " ,"; } } selector = selector.Substring(0, selector.Length - 1); } private object ExecuteSql(Expression expression) { DataSet ds = new DataSet(); using (SqlConnection connection = new SqlConnection("Data Source=.;Initial Catalog=TestLinq;Integrated Security=True")) //这里写死了数据库连接 { connection.Open(); SqlCommand cmd = new SqlCommand(sql, connection); SqlDataAdapter da = new SqlDataAdapter(cmd); da.Fill(ds); } return Table2Entity.ConvertFromTable(ds.Tables[0], _ElementType); ; } public override string ToString() { return sql; } }}
实体、属性与数据库中的表、列映射帮助类
using System;using System.Collections.Generic;using System.Linq;using System.Text;using System.Reflection;namespace SimpleLinq2Sql{ public static class MapHelper { public static string GetTableName(Type type) { if (!type.IsDefined(typeof(TableAttribute), false)) throw new Exception(""); TableAttribute ta = Attribute.GetCustomAttribute(type, typeof(TableAttribute)) as TableAttribute; return ta.TableName; } public static string GetColumnName(Type type, string propertyName) { PropertyInfo pi = type.GetProperty(propertyName, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); if (pi == null) throw new Exception(""); if (!pi.IsDefined(typeof(ColumnAttribute), false)) return propertyName; ColumnAttribute ca = Attribute.GetCustomAttribute(pi, typeof(ColumnAttribute)) as ColumnAttribute; return ca.ColumnName; } public static Type GetColumnType(Type type, string propertyName) { PropertyInfo pi = type.GetProperty(propertyName, BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Instance); if (pi == null) throw new Exception(""); if (!pi.IsDefined(typeof(ColumnAttribute), false)) return pi.PropertyType; ColumnAttribute ca = Attribute.GetCustomAttribute(pi, typeof(ColumnAttribute)) as ColumnAttribute; return SwithType(ca.ColumnType); } static Type SwithType(DataType dtype) { Type type = null; switch (dtype) { case DataType.String: type = typeof(String); break; case DataType.Int: type = typeof(Int32); break; case DataType.DateTime: type = typeof(DateTime); break; case DataType.Float: type = typeof(float); break; case DataType.Double: type = typeof(double); break; } return type; } }}
自定义TableAttribute
using System;using System.Collections.Generic;using System.Linq;using System.Text;namespace SimpleLinq2Sql{ [AttributeUsage(AttributeTargets.Class)] internal class TableAttribute : Attribute { private string _TableName; public string TableName { get { return _TableName; } } public TableAttribute(string tableName) { _TableName = tableName; } }}
自定义ColumnAttribute
using System;using System.Collections.Generic;using System.Linq;using System.Text;namespace SimpleLinq2Sql{ [AttributeUsage(AttributeTargets.Property)] internal class ColumnAttribute : Attribute { private string _ColumnName; private DataType _ColumnType = DataType.String; public string ColumnName { get { return _ColumnName; } } public DataType ColumnType { get { return _ColumnType; } set { _ColumnType = value; } } public ColumnAttribute(string columnName) { _ColumnName = columnName; } }}
数据类型枚举
using System;using System.Collections.Generic;using System.Linq;using System.Text;namespace SimpleLinq2Sql{ public enum DataType { Int, String, Float, Double, DateTime }}
Table转换对应实体类
using System;using System.Collections.Generic;using System.Linq;using System.Text;using System.Data;using System.Reflection;namespace SimpleLinq2Sql{ internal static class Table2Entity { static object ConvertFromDataRow(DataRow dr, Type type) { object o = null; if (!type.IsDefined(typeof(TableAttribute), false)) { List<object> paralist = new List<object>(); PropertyInfo[] pi = type.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance); foreach (PropertyInfo p in pi) { if (!dr.Table.Columns.Contains(p.Name)) throw new Exception(""); object value = Convert.ChangeType(dr[p.Name], p.PropertyType); paralist.Add(value); } o = Activator.CreateInstance(type, paralist.ToArray()); } else { o = Activator.CreateInstance(type); PropertyInfo[] pi = type.GetProperties(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.Static | BindingFlags.Instance); foreach (PropertyInfo p in pi) { if (!dr.Table.Columns.Contains(MapHelper.GetColumnName(type, p.Name))) throw new Exception(""); object value = Convert.ChangeType(dr[MapHelper.GetColumnName(type, p.Name)], p.PropertyType); p.SetValue(o, value, null); } } return o; } public static object ConvertFromTable(DataTable dt, Type type) { var t = typeof(List<>).MakeGenericType(type); object obj = Activator.CreateInstance(t); MethodInfo add = t.GetMethod("Add"); foreach (DataRow dr in dt.Rows) { add.Invoke(obj, new object[] { ConvertFromDataRow(dr, type) }); } return obj; } }}
自定义的实体类:
using System;using System.Collections.Generic;using System.Linq;using System.Text;namespace SimpleLinq2Sql{ [Table("Student")] public class Student { private int _ID; private string _StuName; private string _Address; private int _Sex; private int _CollegeID; [Column("ID", ColumnType = DataType.Int)] public int ID { get { return _ID; } set { _ID = value; } } [Column("StuName")] public string Name { get { return _StuName; } set { _StuName = value; } } [Column("Address")] public string Address { get { return _Address; } set { _Address = value; } } [Column("Sex", ColumnType = DataType.Int)] public int Sex { get { return _Sex; } set { _Sex = value; } } [Column("CollegeID", ColumnType = DataType.Int)] public int CollegeID { get { return _CollegeID; } set { _CollegeID = value; } } }}
Program执行
using System;using System.Collections.Generic;using System.Linq;using System.Text;namespace SimpleLinq2Sql{ class Program { static void Main(string[] args) { var o = new CustomTable<Student>().Where(r => r.Address == "china").Select(r => new { NewName = r.Name, Country = r.Address, r.CollegeID, r.Sex }) .Where(r => r.Sex == 1); Console.WriteLine(o.ToString()); foreach (var i in o) { Console.WriteLine(i.NewName + "," + i.Country + "," + i.CollegeID + "," + i.Sex); } Console.Read(); } }}
0 0
- 简单的实现一个自定义的Linq to Sql Provider
- 简单的linq to sql
- Linq to sql之简单的分页
- LINQ to SQL的一些简单用法
- LINQ TO SQL CAST方法的实现
- linq Distinct的一个简单实现方法
- 简单的linq to sql 的例子 ,实现了增删改查。
- 简单的linq to sql 的例子 ,实现了增删改查
- LINQ : 如何为LINQ TO SQL实现自定义业务逻辑
- Linq to SQL 简单的增删改操作
- 关于Linq to sql的一个更新问题
- Linq to SQL 插入数据时的一个问题
- LINQ to SQL 中一个注意的问题
- LINQ to SQL的不足
- LINQ to SQL的不足
- LINQ to SQL 的EntitySet)>)
- linq to sql 的学习
- Linq to sql 的学习体会
- javaIO输入输出流
- OpenCV_连通区域分析(Connected Component Analysis-Labeling)
- opencv提高对比度算法
- TCP协议连接建立与连接断开过程(含断开时的TCP状态图)
- 完全卸载oracle|oracle卸载|彻底卸载oracle
- 简单的实现一个自定义的Linq to Sql Provider
- myeclipse启动An internal error occurred during: "Initializing Java Tooling".
- 多线程
- 零基础写嵌入式操作系统-2
- ASIHTTPRequest 详解, ios http网络请求
- winform+wcf(netTcpBinding)双向通讯 自定义用户名密码验证
- PHP中的 抽象类(abstract class)和 接口(interface)
- extjs2.0 xtype属性
- Raspberry Pi Airplay Tutorial