//imaginarydevelopment.blogspot.com
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Text.RegularExpressions;
namespace Adapter
{
/// <remarks>http://stackoverflow.com/a/9120931/57883</remarks>
public class TypeChangeVisitor : ExpressionVisitor
{
readonly IDictionary<Type, Type> _typeReplacements;
int visitStack = 0;
public TypeChangeVisitor(IDictionary<Type, Type> typeReplacements)
{
_typeReplacements = typeReplacements;
var addItems = new Dictionary<Type,Type>();
foreach (var item in typeReplacements.Keys)
{
if (item.IsInterface == false)
continue;
var interfaces = item.GetInterfaces();
foreach (var i in interfaces)
{
addItems.Add(i,_typeReplacements[ item]);
}
}
foreach (var item in addItems)
_typeReplacements.Add(item.Key, item.Value);
}
IEnumerable<Type> TransformMethodArgs(MethodBase method)
{
//if(method.IsGenericMethod)
//only generic methods should land here.
foreach (var t in method.GetGenericArguments())
{
yield return VisitType(t);
}
}
static bool NeedsTypeChange(Type t, IDictionary<Type,Type> typeChanges)
{
if (t.FullName.Contains("["))
{
var typeNames = Regex.Matches(t.FullName, @"\[([A-Z].*?),");
foreach (var item in typeNames.Cast<Match>().Select(m => m.Groups[1].Value))
if (typeChanges.Keys.Any(k => item == k.FullName))
return true;
}
return typeChanges.Keys.Any(k => t.FullName == k.FullName);
}
bool NeedsTypeChange(Type t)
{
return NeedsTypeChange(t, _typeReplacements);
}
Type VisitType(Type t)
{
if (_typeReplacements.ContainsKey(t))
{
return _typeReplacements[t];
}
if (t.IsGenericType & t.GetGenericArguments().Any(NeedsTypeChange))
{
var types = t.GetGenericArguments().Select(VisitType).ToArray();
var newType = t.GetGenericTypeDefinition().MakeGenericType(types);
Debug.Assert(NeedsTypeChange(newType) == false);
return newType;
}
if (t.FullName.Contains('+'))
{
var members = t.GetMembers();
if(members.Any(m=>m is FieldInfo && NeedsTypeChange(((FieldInfo)m).FieldType)))
//if (members.Any(m => NeedsTypeChange(m.MemberType)))
{
var c=t.GetConstructors(BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic| BindingFlags.Instance);
//anonymous non-generic type
Debug.Assert(false,"Anonymous type with bad field(s)");
}
}
Debug.Assert(NeedsTypeChange(t) == false);
return t;
}
NewExpression TransformNewCall(NewExpression node)
{
Debug.Assert(node.Constructor != null);
var argParams = node.Arguments.Select(n => Visit(n));
Debug.Assert(argParams.Any(a => NeedsTypeChange(a.Type)) == false);
var constructor = node.Constructor;
//generic class constructor
if (constructor.DeclaringType.IsGenericType && constructor.DeclaringType.GetGenericArguments().Any(NeedsTypeChange))
{
var newType = VisitType(constructor.DeclaringType);
Debug.Assert(NeedsTypeChange(newType) == false);
var constructorTypes = constructor.GetParameters().Select(s => s.ParameterType).Select(VisitType).ToArray();
Debug.Assert(constructorTypes.Any(NeedsTypeChange) == false);
constructor = newType.GetConstructor(constructorTypes);
Debug.Assert(NeedsTypeChange(constructor.DeclaringType) == false);
}
var members = from fMember in node.Members
join nMember in constructor.DeclaringType.GetMembers()
on fMember.Name equals nMember.Name
select nMember;
var membersTransformed = members.ToArray();
//Safe only because the from type is assumed to be an interface, and a new would be an anonymous type
var visited = NewExpression.New(constructor, argParams, membersTransformed);
Debug.Assert(visited.Members.Count == node.Members.Count);
Debug.Assert(NeedsTypeChange(visited.Type) == false);
return visited;
}
MethodCallExpression TransformMethodCall(MethodCallExpression node)
{
Debug.Assert(node.Method != null);
var argTypes = TransformMethodArgs(node.Method).ToArray();
Debug.Assert(argTypes.Any(NeedsTypeChange) == false);
var argParams = node.Arguments.Select(n => Visit(n)).ToArray();
Debug.Assert(argParams.Any(a => NeedsTypeChange(a.Type)) == false);
MethodInfo methodInfo=node.Method;
var target=node.Object;
if (node.Method.IsGenericMethodDefinition || node.Method.IsGenericMethod)
{
methodInfo = node.Method.GetGenericMethodDefinition().MakeGenericMethod(argTypes);
Debug.Assert(NeedsTypeChange(methodInfo.DeclaringType) == false);
}
else if(node.Method.DeclaringType.IsGenericType)
{
var type=VisitType(node.Method.DeclaringType);
target = Visit(node.Object);
var newArgTypes = node.Method.GetGenericArguments().Select(s => _typeReplacements.ContainsKey(s) ? _typeReplacements[s] : s).ToArray();
methodInfo = type.GetMethod(node.Method.Name, newArgTypes);
}
var visited = MethodCallExpression.Call(node.Object, methodInfo, argParams);
Debug.Assert(NeedsTypeChange(visited.Type) == false);
return visited;
}
protected override Expression VisitMethodCall(MethodCallExpression node)
{
Expression visited=null;
if (node.Method != null && node.Method.Name == "Query" && NeedsTypeChange(node.Method.ReturnType))
{ //Query against another domain model
var targetType = VisitType(node.Method.ReturnType);
var instance = Expression.Lambda(node.Object).Compile().DynamicInvoke();
var pluralizer=System.Data.Entity.Design.PluralizationServices.PluralizationService.CreateService(System.Threading.Thread.CurrentThread.CurrentCulture);
var targetName=pluralizer.Pluralize(targetType.GetGenericArguments()[0].Name);
var instanceConstant=Expression.Constant(instance);
var entitySet = Expression.Property(instanceConstant,targetName);
return entitySet;
}
else if (node.Object != null && node.Object.NodeType == ExpressionType.MemberAccess && node.Method.DeclaringType.FullName.Contains("Domain"))
{
//handle method calls on local instances/objects
var result = Expression.Lambda(node).Compile().DynamicInvoke();
var constant = Expression.Constant(result);
return constant;
}
else if (node.Method != null && node.Method.ReturnType != null && NeedsTypeChange(node.Method.ReturnType))
{
var transformed = TransformMethodCall(node);
//Debug.WriteLine("Transformed methodcall");
visited = base.VisitMethodCall(transformed);
}
else
if (node.Method != null && node.Arguments != null && node.Arguments.Any(t => NeedsTypeChange(t.Type)))
{
var transformed = TransformMethodCall(node);
//Debug.WriteLine("Transformed methodcall");
visited = base.VisitMethodCall(transformed);
}
else
visited = base.VisitMethodCall(node);
return visited;
}
/// <remarks>http://stackoverflow.com/a/9120931/57883</remarks>
protected override Expression VisitUnary(UnaryExpression node)
{
Expression visited;
if (NeedsTypeChange(node.Type))
{
var operand = Visit(node.Operand);
var newType = VisitType(node.Type);
visited = Expression.MakeUnary(node.NodeType, operand, newType);
}
else visited = base.VisitUnary(node);
Debug.Assert(NeedsTypeChange(visited.Type) == false);
return visited;
}
protected override Expression VisitInvocation(InvocationExpression node)
{
var call = base.VisitInvocation(node);
return call;
}
public NewExpression LastNew { get; private set; }
public NewExpression LastNewResult { get; private set; }
protected override Expression VisitNew(NewExpression node)
{
LastNew=node;
Expression visited;
if (NeedsTypeChange(node.Type) || node.Arguments.Any(a => NeedsTypeChange(a.Type)))
{
var transformed = TransformNewCall(node);
Debug.Assert(NeedsTypeChange(transformed.Type) == false);
visited = base.VisitNew(transformed);
}
else
visited = base.VisitNew(node);
LastNewResult =(NewExpression) visited;
return visited;
}
protected override Expression VisitConstant(ConstantExpression node)
{
Expression visited;
if (NeedsTypeChange(node.Type))
{
if (node.Value == null)
return ConstantExpression.Constant(null, VisitType(node.Type));
var valueType = node.Value.GetType();
//var newType=VisitType(node.Type);
if (valueType.IsArray)
{
var value = node.Value;
if (NeedsTypeChange(valueType)) //array is of bad type even though elements may not be
{
var elementType = valueType.GetElementType();
var newElementType = VisitType(elementType);
var oldArray = (node.Value as Array);
//var oArray = (object[])node.Value;
var newArray = Array.CreateInstance(newElementType, oldArray.Length);
oldArray.CopyTo(newArray, 0);
value = newArray;
}
else
{
//types in the array need changed, but the array itself does not
//should never happen
Debug.Assert(false);
}
var transformed = ConstantExpression.Constant(value);
visited = base.VisitConstant(transformed);
}else
visited = ConstantExpression.Constant(node.Value);
}
else visited = base.VisitConstant(node);
Debug.Assert(NeedsTypeChange(visited.Type) == false);
return visited;
}
public override Expression Visit(Expression node)
{ // general substitutions (for example, parameter swaps)
if (node == null)
return null;
Expression found = null;
visitStack++;
found = base.Visit(node);
Debug.Assert(found == null || NeedsTypeChange(found.Type) == false);
visitStack--;
return found;
}
protected override Expression VisitBinary(BinaryExpression node)
{
var newBinary = base.VisitBinary(node);
return newBinary;
}
/// <summary>
/// In a given query the params must be the same instance, not just the same name/type
/// </summary>
/// <remarks>
/// The parameter xxx was not bound in the specified LINQ to Entities query expression
/// http://social.msdn.microsoft.com/Forums/en/adodotnetentityframework/thread/8c2b0b1c-01bb-4de2-af46-0b4ea866cf8f
/// </remarks>
readonly Dictionary<ParameterExpression, ParameterExpression> paramMappings = new Dictionary<ParameterExpression, ParameterExpression>();
protected override Expression VisitLambda<T>(Expression<T> node)
{
Expression visited;
if (NeedsTypeChange(node.ReturnType) || node.Parameters.Any(p => NeedsTypeChange(p.Type)))
{
var visitedBody = Visit(node.Body);
Debug.Assert(NeedsTypeChange(visitedBody.Type) == false);
IList<ParameterExpression> transformedParams = new List<ParameterExpression>();
foreach (var p in node.Parameters)
{
var transformedParam = VisitParameter(p);
Debug.Assert(transformedParam is ParameterExpression);
transformedParams.Add((ParameterExpression)transformedParam);
}
Debug.Assert(transformedParams.Any(t => NeedsTypeChange(t.Type)) == false);
var transformed = Expression.Lambda(visitedBody, transformedParams.ToArray());
Debug.Assert(NeedsTypeChange(transformed.Type) == false);
Debug.Assert(NeedsTypeChange(transformed.ReturnType) == false);
if (transformed is Expression<T>)
{
var transformedCasted = transformed as Expression<T>;
visited = base.VisitLambda<T>(transformedCasted);
}
else
visited = transformed;
//return base.VisitLambda<T>(newLambda);
}
else
visited = base.VisitLambda<T>(node);
Debug.Assert(NeedsTypeChange(visited.Type) == false);
return visited;
}
protected override Expression VisitParameter(ParameterExpression node)
{
Expression visited = node;
if (NeedsTypeChange(node.Type))
{
ParameterExpression visitedParam = node;
if (paramMappings.ContainsKey(node))
{
visited = base.VisitParameter(paramMappings[node]);
}
else
{
var newType = VisitType(visitedParam.Type);
var newParam = Expression.Parameter(newType, node.Name);
paramMappings.Add(node, newParam);
visited = base.VisitParameter(newParam);
}
}
else
visited = base.VisitParameter(node);
Debug.Assert(NeedsTypeChange(visited.Type) == false);
return visited;
}
protected override Expression VisitMember(MemberExpression node)
{ // if we see x.Name on the old type, substitute for new type
Expression visited = node;
if (NeedsTypeChange(node.Type) || NeedsTypeChange(node.Member.DeclaringType))
{
var newtype = VisitType(node.Member.DeclaringType);
var visitedExpression = Visit(node.Expression);
Debug.Assert(NeedsTypeChange(visitedExpression.Type) == false);
//either it is not a member expression, or the visit made sure it doesn't need a change now
Debug.Assert(!(visitedExpression is MemberExpression) || NeedsTypeChange((visitedExpression as MemberExpression).Member.DeclaringType)==false);
if (node.Member.MemberType == MemberTypes.Property)
{
var targetProperty = newtype.GetProperty(node.Member.Name).GetGetMethod();
Debug.Assert(NeedsTypeChange(newtype) == false);
visited = Expression.Property(visitedExpression, targetProperty);
}
else if (node.Member.MemberType == MemberTypes.Field)
{
Debug.Assert(NeedsTypeChange(newtype) == false);
var targetField = newtype.GetField(node.Member.Name,BindingFlags.NonPublic| BindingFlags.Public | BindingFlags.Instance);
Debug.Assert(targetField != null);
Debug.Assert(NeedsTypeChange(targetField.DeclaringType) == false);
if (NeedsTypeChange(targetField.FieldType))
{
var result = Expression.Lambda(node).Compile().DynamicInvoke();
return Expression.Constant(result);
} else
visited = Expression.Field(visitedExpression, targetField);
}
//BindingFlags.Instance | BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic).Single();
Debug.Assert(NeedsTypeChange(visited.Type) == false);
visited = base.VisitMember((MemberExpression)visited);
}
else
visited = base.VisitMember(node);
Debug.Assert(NeedsTypeChange(visited.Type) == false);
return visited;
}
}
}