using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Linq.Expressions;
using System.Reflection;
using VisitorDelegate = System.Func<System.Linq.Expressions.Expression, System.Linq.Expressions.Expression>;
using System.Diagnostics;
///
///
/// <remarks>http://b...content-available-to-author-only...n.com/b/alexj/archive/2010/03/01/tip-55-how-to-extend-an-iqueryable-by-wrapping-it.aspx</remarks>
internal class InterceptingProvider : IQueryProvider
{
readonly IQueryProvider _underlyingProvider;
readonly VisitorDelegate[] _visitors;
readonly VisitorDelegate _afterUnderlyingVisitor;
private InterceptingProvider(VisitorDelegate afterUnderlyingVisitor,
IQueryProvider underlyingQueryProvider,
params VisitorDelegate[] visitors)
{
this._underlyingProvider = underlyingQueryProvider;
this._afterUnderlyingVisitor = afterUnderlyingVisitor;
this._visitors = visitors;
}
public static IQueryable<T> Intercept<T>(
ExpressionVisitor afterUnderlyingVisitor,
IQueryable<T> underlyingQuery,
params ExpressionVisitor[] visitors)
{
Func<Expression, Expression>[] visitFuncs =
visitors
.Select(v => (Func<Expression, Expression>)v.Visit)
.ToArray();
VisitorDelegate afterDelegate = afterUnderlyingVisitor != null ? (VisitorDelegate)afterUnderlyingVisitor.Visit : null;
return Intercept<T>(afterDelegate, underlyingQuery, visitFuncs);
}
public static IQueryable<T> Intercept<T>(
IQueryable<T> underlyingQuery,
params ExpressionVisitor[] visitors)
{
Func<Expression, Expression>[] visitFuncs =
visitors
.Select(v => (Func<Expression, Expression>)v.Visit)
.ToArray();
return Intercept<T>(null, underlyingQuery, visitFuncs);
}
public static IQueryable<T> Intercept<T>(Func<Expression, Expression> afterUnderlyingVisitor,
IQueryable<T> underlyingQuery,
params Func<Expression, Expression>[] visitors)
{
var provider = new InterceptingProvider(afterUnderlyingVisitor,
underlyingQuery.Provider,
visitors
);
return provider.CreateQuery<T>(
underlyingQuery.Expression);
}
public static bool DoTrace {get;set;}
public IEnumerator<TElement> ExecuteQuery<TElement>(
Expression expression)
{
Expression intercepted;
using(var step=Profiler.Step("intercepting query")){
intercepted = InterceptExpr(expression);
}
IQueryable newExpression;
using(var step=Profiler.Step("Ef Translating query")){
newExpression = _underlyingProvider.CreateQuery(intercepted);
}
System.Diagnostics.Debug.Assert(intercepted.Type.FullName.Contains("Shared") == false);
if(DoTrace)
using(var step=Profiler.Step("ToTraceString")){
Trace.WriteLine(((System.Data.Objects.ObjectQuery)newExpression).ToTraceString());
}
if (_afterUnderlyingVisitor != null)
{
var afterResult = _afterUnderlyingVisitor(newExpression.Expression);
}
using(var step=Profiler.Step("enumerating query"))
{
var enumerator = newExpression.GetEnumerator();
var enumeratorType = enumerator.GetType();
//get the type the enumerator contains
var sourceArgumentType = enumeratorType.GetGenericArguments().Single();
var targetType = typeof(TElement);
if (typeof(IEnumerator<TElement>).IsAssignableFrom(enumeratorType))
return (IEnumerator<TElement>)enumerator;
if (targetType.IsAssignableFrom(sourceArgumentType))
{
var items = new List<TElement>();
while (enumerator.MoveNext())
{
var current = enumerator.Current;
items.Add((TElement)current);
}
if (enumerator is IDisposable)
{
(enumerator as IDisposable).Dispose();
}
return items.GetEnumerator();
}
//needs to translate one anonymous type to another
var targetConstructor = targetType.GetConstructor(targetType.GetGenericArguments());
#warning does not handle recursive anonymous types
var eProperties = from tp in targetType.GetProperties()
join ep in sourceArgumentType.GetProperties()
on tp.Name equals ep.Name
select ep.GetGetMethod();
var items2 = new List<TElement>();
while (enumerator.MoveNext())
{
var current = enumerator.Current;
var targetParams = eProperties.Select(s => s.Invoke(current, null));
var newItem = targetConstructor.Invoke(targetParams.ToArray());
items2.Add((TElement)newItem);
}
if (enumerator is IDisposable)
{
(enumerator as IDisposable).Dispose();
}
return (IEnumerator<TElement>)items2.GetEnumerator();
}
}
TResult TranslateAnonymous<TResult>(Type inputType, object source)
{
var targetType = typeof(TResult);
if (targetType.IsAssignableFrom(inputType))
return (TResult)source;
var targetConstructor = targetType.GetConstructor(targetType.GetGenericArguments());
#warning does not handle nested anonymous types
var eProperties = from tp in targetType.GetProperties()
join ep in inputType.GetProperties()
on tp.Name equals ep.Name
select ep.GetGetMethod();
var targetParams = eProperties.Select(s => s.Invoke(source, null));
var result = targetConstructor.Invoke(targetParams.ToArray());
return (TResult)result;
}
public IQueryable<TElement> CreateQuery<TElement>(
Expression expression)
{
return new InterceptedQuery<TElement>(this, expression);
}
public IQueryable CreateQuery(Expression expression)
{
Type et = TypeHelper.FindIEnumerable(expression.Type);
Type qt = typeof(InterceptedQuery<>).MakeGenericType(et);
object[] args = new object[] { this, expression };
var ci = qt.GetConstructor(
BindingFlags.NonPublic | BindingFlags.Instance,
null,
new Type[] {
typeof(InterceptingProvider),
typeof(Expression)
},
null);
return (IQueryable)ci.Invoke(args);
}
public TResult Execute<TResult>(Expression expression)
{
var intercepted = InterceptExpr(expression);
var result = this._underlyingProvider.Execute(intercepted);
if (result == null)
return default(TResult);
return TranslateAnonymous<TResult>(result.GetType(), result);
}
public object Execute(Expression expression)
{
return this._underlyingProvider.Execute(
InterceptExpr(expression)
);
}
private Expression InterceptExpr(Expression expression)
{
Expression
exp = expression
; foreach (var visitor in _visitors)
}
}
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Linq.Expressions;
using System.Reflection;
using System.Diagnostics;
using System.Collections;
/// <remarks>http://stackoverflow.com/a/9120931/57883</remarks>
public class TypeChangeVisitor : ExpressionVisitor
{
readonly IDictionary<Type, Type> _typeReplacements;
int visitStack = 0;
public TypeChangeVisitor(IDictionary<Type, Type> typeReplacements)
{
_typeReplacements = typeReplacements;
var addItems = new Dictionary<Type,Type>();
foreach (var item in typeReplacements.Keys)
{
if (item.IsInterface == false)
continue;
var interfaces = item.GetInterfaces();
foreach (var i in interfaces)
{
if(_typeReplacements.ContainsKey(i)==false)
addItems.Add(i,_typeReplacements[ item]);
}
}
foreach (var item in addItems)
_typeReplacements.Add(item.Key, item.Value);
}
IEnumerable<Type> TransformMethodArgs(MethodBase method)
{
//if(method.IsGenericMethod)
//only generic methods should land here.
foreach (var t in method.GetGenericArguments())
{
yield return VisitType(t);
}
}
bool NeedsTypeChange(Type t)
{
var hasBadType = _typeReplacements.Keys.Any(k => t.FullName.Contains(k.FullName));
return hasBadType;
}
Type VisitType(Type t)
{
if (_typeReplacements.ContainsKey(t))
{
return _typeReplacements[t];
}
if (t.IsGenericType & t.GetGenericArguments().Any(NeedsTypeChange))
{
var types = t.GetGenericArguments().Select(VisitType).ToArray();
var newType = t.GetGenericTypeDefinition().MakeGenericType(types);
Debug.Assert(NeedsTypeChange(newType) == false);
return newType;
}
Debug.Assert(NeedsTypeChange(t) == false);
return t;
}
NewExpression TransformNewCall(NewExpression node)
{
Debug.Assert(node.Constructor != null);
var argParams = node.Arguments.Select(n => Visit(n));
Debug.Assert(argParams.Any(a => NeedsTypeChange(a.Type)) == false);
var constructor = node.Constructor;
//generic class constructor
if (constructor.DeclaringType.IsGenericType && constructor.DeclaringType.GetGenericArguments().Any(NeedsTypeChange))
{
var newType = VisitType(constructor.DeclaringType);
Debug.Assert(NeedsTypeChange(newType) == false);
var constructorTypes = constructor.GetParameters().Select(s => s.ParameterType).Select(VisitType).ToArray();
Debug.Assert(constructorTypes.Any(NeedsTypeChange) == false);
constructor = newType.GetConstructor(constructorTypes);
Debug.Assert(NeedsTypeChange(constructor.DeclaringType) == false);
}
var members = from fMember in node.Members
join nMember in constructor.DeclaringType.GetMembers()
on fMember.Name equals nMember.Name
select nMember;
var membersTransformed = members.ToArray();
//Safe only because the from type is assumed to be an interface, and a new would be an anonymous type
var visited = NewExpression.New(constructor, argParams, membersTransformed);
Debug.Assert(visited.Members.Count == node.Members.Count);
Debug.Assert(NeedsTypeChange(visited.Type) == false);
return visited;
}
MethodCallExpression TransformMethodCall(MethodCallExpression node)
{
Debug.Assert(node.Method != null);
var argTypes = TransformMethodArgs(node.Method).ToArray();
Debug.Assert(argTypes.Any(NeedsTypeChange) == false);
var argParams = node.Arguments.Select(n => Visit(n)).ToArray();
Debug.Assert(argParams.Any(a => NeedsTypeChange(a.Type)) == false);
var methodInfo = node.Method.GetGenericMethodDefinition().MakeGenericMethod(argTypes);
Debug.Assert(NeedsTypeChange(methodInfo.DeclaringType) == false);
var visited = MethodCallExpression.Call(node.Object, methodInfo, argParams);
Debug.Assert(NeedsTypeChange(visited.Type) == false);
return visited;
}
protected override Expression VisitMethodCall(MethodCallExpression node)
{
Expression visited;
if (node.Method != null && node.Method.ReturnType != null && NeedsTypeChange(node.Method.ReturnType))
{
var transformed = TransformMethodCall(node);
Debug.WriteLine("Transformed methodcall");
visited = base.VisitMethodCall(transformed);
}
else
if (node.Method != null && node.Arguments != null && node.Arguments.Any(t => NeedsTypeChange(t.Type)))
{
var transformed = TransformMethodCall(node);
Debug.WriteLine("Transformed methodcall");
visited = base.VisitMethodCall(transformed);
//node.Method.ReturnType.IsAssignableFrom(node.Arguments[0].Type)
//node.Method.ReturnType.GetGenericArguments()
//visited = node.Arguments[0];
}
else
visited = base.VisitMethodCall(node);
return visited;
}
/// <remarks>http://stackoverflow.com/a/9120931/57883</remarks>
protected override Expression VisitUnary(UnaryExpression node)
{
Expression visited;
if (NeedsTypeChange(node.Type))
{
var operand = Visit(node.Operand);
var newType = VisitType(node.Type);
visited = Expression.MakeUnary(node.NodeType, operand, newType);
}
else visited = base.VisitUnary(node);
//if (node.NodeType == ExpressionType.Convert
//&& node.Type.IsAssignableFrom(node.Operand.Type))
//{
// return base.Visit(node.Operand);
//}
Debug.Assert(NeedsTypeChange(visited.Type) == false);
return visited;
}
protected override Expression VisitInvocation(InvocationExpression node)
{
var call = base.VisitInvocation(node);
return call;
}
public NewExpression LastNew { get; private set; }
public NewExpression LastNewResult { get; private set; }
protected override Expression VisitNew(NewExpression node)
{
LastNew=node;
Debug.WriteLine("Transforming NewExpression");
Debug.WriteLine(ExpressionWriter.WriteToString(node));
Debug.WriteLine(string.Empty);
Expression visited;
if (NeedsTypeChange(node.Type) || node.Arguments.Any(a => NeedsTypeChange(a.Type)))
{
var transformed = TransformNewCall(node);
Debug.Assert(NeedsTypeChange(transformed.Type) == false);
visited = base.VisitNew(transformed);
}
else
visited = base.VisitNew(node);
Debug.WriteLine("Transformed to");
Debug.WriteLine(ExpressionWriter.WriteToString(visited));
Debug.WriteLine(string.Empty);
LastNewResult =(NewExpression) visited;
return visited;
}
protected override Expression VisitConstant(ConstantExpression node)
{
Expression visited;
if (NeedsTypeChange(node.Type))
{
if (node.Value == null)
return ConstantExpression.Constant(null, VisitType(node.Type));
var valueType = node.Value.GetType();
//var newType=VisitType(node.Type);
if (valueType.IsArray)
{
var value = node.Value;
if (NeedsTypeChange(valueType)) //array is of bad type even though elements may not be
{
var elementType = valueType.GetElementType();
var newElementType = VisitType(elementType);
var oldArray = (node.Value as Array);
//var oArray = (object[])node.Value;
var newArray = Array.CreateInstance(newElementType, oldArray.Length);
oldArray.CopyTo(newArray, 0);
value = newArray;
}
else
{
//types in the array need changed, but the array itself does not
//should never happen
Debug.Assert(false);
}
var transformed = ConstantExpression.Constant(value);
visited = base.VisitConstant(transformed);
}else
visited = ConstantExpression.Constant(node.Value);
}
else visited = base.VisitConstant(node);
Debug.Assert(NeedsTypeChange(visited.Type) == false);
return visited;
}
public override Expression Visit(Expression node)
{ // general substitutions (for example, parameter swaps)
if (node == null)
return null;
Expression found = null;
visitStack++;
found = base.Visit(node);
Debug.Assert(found == null || NeedsTypeChange(found.Type) == false);
visitStack--;
return found;
}
protected override Expression VisitBinary(BinaryExpression node)
{
var newBinary = base.VisitBinary(node);
return newBinary;
}
/// <summary>
/// In a given query the params must be the same instance, not just the same name/type
/// </summary>
/// <remarks>
/// The parameter xxx was not bound in the specified LINQ to Entities query expression
/// http://social.msdn.microsoft.com/Forums/en/adodotnetentityframework/thread/8c2b0b1c-01bb-4de2-af46-0b4ea866cf8f
/// </remarks>
readonly Dictionary<ParameterExpression, ParameterExpression> paramMappings = new Dictionary<ParameterExpression, ParameterExpression>();
protected override Expression VisitLambda<T>(Expression<T> node)
{
Expression visited;
if (NeedsTypeChange(node.ReturnType) || node.Parameters.Any(p => NeedsTypeChange(p.Type)))
{
var visitedBody = Visit(node.Body);
Debug.Assert(NeedsTypeChange(visitedBody.Type) == false);
IList<ParameterExpression> transformedParams = new List<ParameterExpression>();
foreach (var p in node.Parameters)
{
var transformedParam = VisitParameter(p);
Debug.Assert(transformedParam is ParameterExpression);
transformedParams.Add((ParameterExpression)transformedParam);
}
Debug.Assert(transformedParams.Any(t => NeedsTypeChange(t.Type)) == false);
var transformed = Expression.Lambda(visitedBody, transformedParams.ToArray());
Debug.Assert(NeedsTypeChange(transformed.Type) == false);
Debug.Assert(NeedsTypeChange(transformed.ReturnType) == false);
if (transformed is Expression<T>)
{
var transformedCasted = transformed as Expression<T>;
visited = base.VisitLambda<T>(transformedCasted);
}
else
visited = transformed;
//return base.VisitLambda<T>(newLambda);
}
else
visited = base.VisitLambda<T>(node);
Debug.Assert(NeedsTypeChange(visited.Type) == false);
return visited;
}
protected override Expression VisitParameter(ParameterExpression node)
{
Expression visited = node;
if (NeedsTypeChange(node.Type))
{
ParameterExpression visitedParam = node;
if (paramMappings.ContainsKey(node))
{
visited = base.VisitParameter(paramMappings[node]);
}
else
{
var newType = VisitType(visitedParam.Type);
var newParam = Expression.Parameter(newType, node.Name);
paramMappings.Add(node, newParam);
visited = base.VisitParameter(newParam);
}
}
else
visited = base.VisitParameter(node);
Debug.Assert(NeedsTypeChange(visited.Type) == false);
return visited;
}
protected override Expression VisitMember(MemberExpression node)
{ // if we see x.Name on the old type, substitute for new type
Expression visited = node;
if (NeedsTypeChange(node.Type) || NeedsTypeChange(node.Member.DeclaringType))
{
var newtype = VisitType(node.Member.DeclaringType);
var visitedExpression = Visit(node.Expression);
Debug.Assert(NeedsTypeChange(visitedExpression.Type) == false);
var targetProperty = newtype.GetProperty(node.Member.Name).GetGetMethod();
//BindingFlags.Instance | BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic).Single();
visited = Expression.Property(visitedExpression, targetProperty);
Debug.Assert(NeedsTypeChange(visited.Type) == false);
visited = base.VisitMember((MemberExpression)visited);
}
else
visited = base.VisitMember(node);
Debug.Assert(NeedsTypeChange(visited.Type) == false);
return visited;
}
}
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Linq.Expressions;
internal class InterceptedQuery<T> : IOrderedQueryable<T>
{
private Expression _expression;
protected internal readonly InterceptingProvider _provider;
internal InterceptedQuery(
InterceptingProvider provider,
Expression expression)
{
this._provider = provider;
this._expression = expression;
}
public IEnumerator<T> GetEnumerator()
{
return this._provider.ExecuteQuery<T>(this._expression);
}
IEnumerator IEnumerable.GetEnumerator()
{
return this._provider.ExecuteQuery<T>(this._expression);
}
public Type ElementType
{
get { return typeof(T); }
}
public Expression Expression
{
get { return this._expression; }
}
public IQueryProvider Provider
{
get { return this._provider; }
}
}
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.IO;
using System.Diagnostics;
public class DebugTextWriter : StreamWriter
{
public DebugTextWriter(string category=null)
: base(new DebugOutStream(category), Encoding.Unicode, 1024)
{
this.AutoFlush = true;
}
class DebugOutStream : Stream
{
readonly string _category = null;
public DebugOutStream(string category)
{
_category = category;
}
public override void Write(byte[] buffer, int offset, int count)
{
if (String.IsNullOrEmpty(_category))
Debug.Write(Encoding.Unicode.GetString(buffer, offset, count), _category);
else
Debug.Write(Encoding.Unicode.GetString(buffer, offset, count));
}
public override bool CanRead { get { return false; } }
public override bool CanSeek { get { return false; } }
public override bool CanWrite { get { return true; } }
public override void Flush() { Debug.Flush(); }
public override long Length { get { throw new InvalidOperationException(); } }
public override int Read(byte[] buffer, int offset, int count) { throw new InvalidOperationException(); }
public override long Seek(long offset, SeekOrigin origin) { throw new InvalidOperationException(); }
public override void SetLength(long value) { throw new InvalidOperationException(); }
public override long Position
{
get { throw new InvalidOperationException(); }
set { throw new InvalidOperationException(); }
}
};
}
}