fork download
  1. //imaginarydevelopment.blogspot.com
  2. using System;
  3. using System.Collections;
  4. using System.Collections.Generic;
  5. using System.Diagnostics;
  6. using System.Linq;
  7. using System.Linq.Expressions;
  8. using System.Reflection;
  9. using System.Text.RegularExpressions;
  10.  
  11. namespace Adapter
  12. {
  13.  
  14. /// <remarks>http://stackoverflow.com/a/9120931/57883</remarks>
  15. public class TypeChangeVisitor : ExpressionVisitor
  16. {
  17. readonly IDictionary<Type, Type> _typeReplacements;
  18.  
  19. int visitStack = 0;
  20.  
  21. public TypeChangeVisitor(IDictionary<Type, Type> typeReplacements)
  22. {
  23. _typeReplacements = typeReplacements;
  24. var addItems = new Dictionary<Type,Type>();
  25. foreach (var item in typeReplacements.Keys)
  26. {
  27. if (item.IsInterface == false)
  28. continue;
  29. var interfaces = item.GetInterfaces();
  30. foreach (var i in interfaces)
  31. {
  32. addItems.Add(i,_typeReplacements[ item]);
  33. }
  34. }
  35. foreach (var item in addItems)
  36. _typeReplacements.Add(item.Key, item.Value);
  37.  
  38. }
  39.  
  40. IEnumerable<Type> TransformMethodArgs(MethodBase method)
  41. {
  42. //if(method.IsGenericMethod)
  43. //only generic methods should land here.
  44. foreach (var t in method.GetGenericArguments())
  45. {
  46. yield return VisitType(t);
  47. }
  48.  
  49.  
  50. }
  51.  
  52. static bool NeedsTypeChange(Type t, IDictionary<Type,Type> typeChanges)
  53. {
  54. if (t.FullName.Contains("["))
  55. {
  56. var typeNames = Regex.Matches(t.FullName, @"\[([A-Z].*?),");
  57. foreach (var item in typeNames.Cast<Match>().Select(m => m.Groups[1].Value))
  58. if (typeChanges.Keys.Any(k => item == k.FullName))
  59. return true;
  60. }
  61. return typeChanges.Keys.Any(k => t.FullName == k.FullName);
  62. }
  63. bool NeedsTypeChange(Type t)
  64. {
  65.  
  66. return NeedsTypeChange(t, _typeReplacements);
  67. }
  68. Type VisitType(Type t)
  69. {
  70. if (_typeReplacements.ContainsKey(t))
  71. {
  72. return _typeReplacements[t];
  73. }
  74. if (t.IsGenericType & t.GetGenericArguments().Any(NeedsTypeChange))
  75. {
  76. var types = t.GetGenericArguments().Select(VisitType).ToArray();
  77. var newType = t.GetGenericTypeDefinition().MakeGenericType(types);
  78. Debug.Assert(NeedsTypeChange(newType) == false);
  79. return newType;
  80. }
  81. if (t.FullName.Contains('+'))
  82. {
  83. var members = t.GetMembers();
  84. if(members.Any(m=>m is FieldInfo && NeedsTypeChange(((FieldInfo)m).FieldType)))
  85. //if (members.Any(m => NeedsTypeChange(m.MemberType)))
  86. {
  87. var c=t.GetConstructors(BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic| BindingFlags.Instance);
  88. //anonymous non-generic type
  89. Debug.Assert(false,"Anonymous type with bad field(s)");
  90.  
  91. }
  92.  
  93. }
  94. Debug.Assert(NeedsTypeChange(t) == false);
  95. return t;
  96. }
  97. NewExpression TransformNewCall(NewExpression node)
  98. {
  99. Debug.Assert(node.Constructor != null);
  100.  
  101. var argParams = node.Arguments.Select(n => Visit(n));
  102. Debug.Assert(argParams.Any(a => NeedsTypeChange(a.Type)) == false);
  103. var constructor = node.Constructor;
  104. //generic class constructor
  105. if (constructor.DeclaringType.IsGenericType && constructor.DeclaringType.GetGenericArguments().Any(NeedsTypeChange))
  106. {
  107.  
  108. var newType = VisitType(constructor.DeclaringType);
  109.  
  110. Debug.Assert(NeedsTypeChange(newType) == false);
  111. var constructorTypes = constructor.GetParameters().Select(s => s.ParameterType).Select(VisitType).ToArray();
  112. Debug.Assert(constructorTypes.Any(NeedsTypeChange) == false);
  113. constructor = newType.GetConstructor(constructorTypes);
  114. Debug.Assert(NeedsTypeChange(constructor.DeclaringType) == false);
  115. }
  116. var members = from fMember in node.Members
  117. join nMember in constructor.DeclaringType.GetMembers()
  118. on fMember.Name equals nMember.Name
  119. select nMember;
  120.  
  121. var membersTransformed = members.ToArray();
  122.  
  123. //Safe only because the from type is assumed to be an interface, and a new would be an anonymous type
  124. var visited = NewExpression.New(constructor, argParams, membersTransformed);
  125. Debug.Assert(visited.Members.Count == node.Members.Count);
  126. Debug.Assert(NeedsTypeChange(visited.Type) == false);
  127. return visited;
  128.  
  129. }
  130.  
  131. MethodCallExpression TransformMethodCall(MethodCallExpression node)
  132. {
  133. Debug.Assert(node.Method != null);
  134. var argTypes = TransformMethodArgs(node.Method).ToArray();
  135. Debug.Assert(argTypes.Any(NeedsTypeChange) == false);
  136. var argParams = node.Arguments.Select(n => Visit(n)).ToArray();
  137. Debug.Assert(argParams.Any(a => NeedsTypeChange(a.Type)) == false);
  138. MethodInfo methodInfo=node.Method;
  139. var target=node.Object;
  140.  
  141. if (node.Method.IsGenericMethodDefinition || node.Method.IsGenericMethod)
  142. {
  143. methodInfo = node.Method.GetGenericMethodDefinition().MakeGenericMethod(argTypes);
  144. Debug.Assert(NeedsTypeChange(methodInfo.DeclaringType) == false);
  145. }
  146. else if(node.Method.DeclaringType.IsGenericType)
  147. {
  148. var type=VisitType(node.Method.DeclaringType);
  149. target = Visit(node.Object);
  150. var newArgTypes = node.Method.GetGenericArguments().Select(s => _typeReplacements.ContainsKey(s) ? _typeReplacements[s] : s).ToArray();
  151.  
  152. methodInfo = type.GetMethod(node.Method.Name, newArgTypes);
  153.  
  154. }
  155.  
  156. var visited = MethodCallExpression.Call(node.Object, methodInfo, argParams);
  157. Debug.Assert(NeedsTypeChange(visited.Type) == false);
  158. return visited;
  159. }
  160. protected override Expression VisitMethodCall(MethodCallExpression node)
  161. {
  162. Expression visited=null;
  163. if (node.Method != null && node.Method.Name == "Query" && NeedsTypeChange(node.Method.ReturnType))
  164. { //Query against another domain model
  165.  
  166. var targetType = VisitType(node.Method.ReturnType);
  167. var instance = Expression.Lambda(node.Object).Compile().DynamicInvoke();
  168. var pluralizer=System.Data.Entity.Design.PluralizationServices.PluralizationService.CreateService(System.Threading.Thread.CurrentThread.CurrentCulture);
  169. var targetName=pluralizer.Pluralize(targetType.GetGenericArguments()[0].Name);
  170. var instanceConstant=Expression.Constant(instance);
  171. var entitySet = Expression.Property(instanceConstant,targetName);
  172. return entitySet;
  173.  
  174. }
  175. else if (node.Object != null && node.Object.NodeType == ExpressionType.MemberAccess && node.Method.DeclaringType.FullName.Contains("Domain"))
  176. {
  177. //handle method calls on local instances/objects
  178. var result = Expression.Lambda(node).Compile().DynamicInvoke();
  179. var constant = Expression.Constant(result);
  180. return constant;
  181.  
  182. }
  183. else if (node.Method != null && node.Method.ReturnType != null && NeedsTypeChange(node.Method.ReturnType))
  184. {
  185. var transformed = TransformMethodCall(node);
  186. //Debug.WriteLine("Transformed methodcall");
  187. visited = base.VisitMethodCall(transformed);
  188. }
  189. else
  190.  
  191. if (node.Method != null && node.Arguments != null && node.Arguments.Any(t => NeedsTypeChange(t.Type)))
  192. {
  193.  
  194. var transformed = TransformMethodCall(node);
  195. //Debug.WriteLine("Transformed methodcall");
  196. visited = base.VisitMethodCall(transformed);
  197. }
  198. else
  199.  
  200. visited = base.VisitMethodCall(node);
  201. return visited;
  202. }
  203.  
  204. /// <remarks>http://stackoverflow.com/a/9120931/57883</remarks>
  205. protected override Expression VisitUnary(UnaryExpression node)
  206. {
  207. Expression visited;
  208. if (NeedsTypeChange(node.Type))
  209. {
  210. var operand = Visit(node.Operand);
  211. var newType = VisitType(node.Type);
  212. visited = Expression.MakeUnary(node.NodeType, operand, newType);
  213. }
  214. else visited = base.VisitUnary(node);
  215.  
  216. Debug.Assert(NeedsTypeChange(visited.Type) == false);
  217. return visited;
  218.  
  219.  
  220. }
  221.  
  222. protected override Expression VisitInvocation(InvocationExpression node)
  223. {
  224. var call = base.VisitInvocation(node);
  225. return call;
  226. }
  227.  
  228. public NewExpression LastNew { get; private set; }
  229. public NewExpression LastNewResult { get; private set; }
  230.  
  231. protected override Expression VisitNew(NewExpression node)
  232. {
  233. LastNew=node;
  234.  
  235. Expression visited;
  236. if (NeedsTypeChange(node.Type) || node.Arguments.Any(a => NeedsTypeChange(a.Type)))
  237. {
  238. var transformed = TransformNewCall(node);
  239.  
  240. Debug.Assert(NeedsTypeChange(transformed.Type) == false);
  241.  
  242. visited = base.VisitNew(transformed);
  243. }
  244. else
  245. visited = base.VisitNew(node);
  246.  
  247. LastNewResult =(NewExpression) visited;
  248. return visited;
  249. }
  250. protected override Expression VisitConstant(ConstantExpression node)
  251. {
  252. Expression visited;
  253. if (NeedsTypeChange(node.Type))
  254. {
  255. if (node.Value == null)
  256. return ConstantExpression.Constant(null, VisitType(node.Type));
  257. var valueType = node.Value.GetType();
  258. //var newType=VisitType(node.Type);
  259. if (valueType.IsArray)
  260. {
  261.  
  262.  
  263. var value = node.Value;
  264. if (NeedsTypeChange(valueType)) //array is of bad type even though elements may not be
  265. {
  266. var elementType = valueType.GetElementType();
  267. var newElementType = VisitType(elementType);
  268.  
  269. var oldArray = (node.Value as Array);
  270. //var oArray = (object[])node.Value;
  271.  
  272. var newArray = Array.CreateInstance(newElementType, oldArray.Length);
  273. oldArray.CopyTo(newArray, 0);
  274. value = newArray;
  275. }
  276. else
  277. {
  278. //types in the array need changed, but the array itself does not
  279. //should never happen
  280. Debug.Assert(false);
  281. }
  282.  
  283. var transformed = ConstantExpression.Constant(value);
  284. visited = base.VisitConstant(transformed);
  285.  
  286.  
  287.  
  288. }else
  289.  
  290. visited = ConstantExpression.Constant(node.Value);
  291. }
  292. else visited = base.VisitConstant(node);
  293. Debug.Assert(NeedsTypeChange(visited.Type) == false);
  294. return visited;
  295. }
  296. public override Expression Visit(Expression node)
  297. { // general substitutions (for example, parameter swaps)
  298. if (node == null)
  299. return null;
  300. Expression found = null;
  301.  
  302. visitStack++;
  303.  
  304. found = base.Visit(node);
  305.  
  306. Debug.Assert(found == null || NeedsTypeChange(found.Type) == false);
  307.  
  308. visitStack--;
  309.  
  310. return found;
  311.  
  312. }
  313. protected override Expression VisitBinary(BinaryExpression node)
  314. {
  315. var newBinary = base.VisitBinary(node);
  316. return newBinary;
  317. }
  318.  
  319. /// <summary>
  320. /// In a given query the params must be the same instance, not just the same name/type
  321. /// </summary>
  322. /// <remarks>
  323. /// The parameter xxx was not bound in the specified LINQ to Entities query expression
  324. /// http://social.msdn.microsoft.com/Forums/en/adodotnetentityframework/thread/8c2b0b1c-01bb-4de2-af46-0b4ea866cf8f
  325. /// </remarks>
  326. readonly Dictionary<ParameterExpression, ParameterExpression> paramMappings = new Dictionary<ParameterExpression, ParameterExpression>();
  327.  
  328.  
  329. protected override Expression VisitLambda<T>(Expression<T> node)
  330. {
  331.  
  332. Expression visited;
  333.  
  334.  
  335. if (NeedsTypeChange(node.ReturnType) || node.Parameters.Any(p => NeedsTypeChange(p.Type)))
  336. {
  337. var visitedBody = Visit(node.Body);
  338. Debug.Assert(NeedsTypeChange(visitedBody.Type) == false);
  339.  
  340. IList<ParameterExpression> transformedParams = new List<ParameterExpression>();
  341. foreach (var p in node.Parameters)
  342. {
  343. var transformedParam = VisitParameter(p);
  344. Debug.Assert(transformedParam is ParameterExpression);
  345. transformedParams.Add((ParameterExpression)transformedParam);
  346.  
  347. }
  348.  
  349. Debug.Assert(transformedParams.Any(t => NeedsTypeChange(t.Type)) == false);
  350.  
  351. var transformed = Expression.Lambda(visitedBody, transformedParams.ToArray());
  352. Debug.Assert(NeedsTypeChange(transformed.Type) == false);
  353.  
  354. Debug.Assert(NeedsTypeChange(transformed.ReturnType) == false);
  355. if (transformed is Expression<T>)
  356. {
  357. var transformedCasted = transformed as Expression<T>;
  358. visited = base.VisitLambda<T>(transformedCasted);
  359. }
  360. else
  361. visited = transformed;
  362. //return base.VisitLambda<T>(newLambda);
  363. }
  364. else
  365. visited = base.VisitLambda<T>(node);
  366. Debug.Assert(NeedsTypeChange(visited.Type) == false);
  367.  
  368.  
  369.  
  370. return visited;
  371. }
  372. protected override Expression VisitParameter(ParameterExpression node)
  373. {
  374. Expression visited = node;
  375.  
  376. if (NeedsTypeChange(node.Type))
  377. {
  378. ParameterExpression visitedParam = node;
  379. if (paramMappings.ContainsKey(node))
  380. {
  381. visited = base.VisitParameter(paramMappings[node]);
  382. }
  383. else
  384. {
  385. var newType = VisitType(visitedParam.Type);
  386. var newParam = Expression.Parameter(newType, node.Name);
  387. paramMappings.Add(node, newParam);
  388. visited = base.VisitParameter(newParam);
  389. }
  390.  
  391. }
  392. else
  393. visited = base.VisitParameter(node);
  394. Debug.Assert(NeedsTypeChange(visited.Type) == false);
  395. return visited;
  396. }
  397.  
  398. protected override Expression VisitMember(MemberExpression node)
  399. { // if we see x.Name on the old type, substitute for new type
  400. Expression visited = node;
  401.  
  402. if (NeedsTypeChange(node.Type) || NeedsTypeChange(node.Member.DeclaringType))
  403. {
  404. var newtype = VisitType(node.Member.DeclaringType);
  405.  
  406. var visitedExpression = Visit(node.Expression);
  407. Debug.Assert(NeedsTypeChange(visitedExpression.Type) == false);
  408.  
  409. //either it is not a member expression, or the visit made sure it doesn't need a change now
  410. Debug.Assert(!(visitedExpression is MemberExpression) || NeedsTypeChange((visitedExpression as MemberExpression).Member.DeclaringType)==false);
  411.  
  412. if (node.Member.MemberType == MemberTypes.Property)
  413. {
  414. var targetProperty = newtype.GetProperty(node.Member.Name).GetGetMethod();
  415. Debug.Assert(NeedsTypeChange(newtype) == false);
  416. visited = Expression.Property(visitedExpression, targetProperty);
  417. }
  418. else if (node.Member.MemberType == MemberTypes.Field)
  419. {
  420. Debug.Assert(NeedsTypeChange(newtype) == false);
  421.  
  422. var targetField = newtype.GetField(node.Member.Name,BindingFlags.NonPublic| BindingFlags.Public | BindingFlags.Instance);
  423. Debug.Assert(targetField != null);
  424. Debug.Assert(NeedsTypeChange(targetField.DeclaringType) == false);
  425. if (NeedsTypeChange(targetField.FieldType))
  426. {
  427.  
  428. var result = Expression.Lambda(node).Compile().DynamicInvoke();
  429. return Expression.Constant(result);
  430. } else
  431. visited = Expression.Field(visitedExpression, targetField);
  432. }
  433.  
  434.  
  435.  
  436. //BindingFlags.Instance | BindingFlags.Static | BindingFlags.Public | BindingFlags.NonPublic).Single();
  437.  
  438. Debug.Assert(NeedsTypeChange(visited.Type) == false);
  439.  
  440. visited = base.VisitMember((MemberExpression)visited);
  441. }
  442. else
  443. visited = base.VisitMember(node);
  444. Debug.Assert(NeedsTypeChange(visited.Type) == false);
  445. return visited;
  446. }
  447. }
  448.  
  449. }
  450.  
Not running #stdin #stdout 0s 0KB
stdin
Standard input is empty
stdout
Standard output is empty