SqlGenerator_Helper.cs 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. using Ant.Core;
  2. using Ant.DbExpressions;
  3. using System;
  4. using System.Collections;
  5. using System.Collections.Generic;
  6. using System.Collections.ObjectModel;
  7. using System.Linq;
  8. using System.Reflection;
  9. using System.Text;
  10. namespace Ant.SqlServer
  11. {
  12. /// <summary>
  13. /// 解析SQL语句帮助类
  14. /// </summary>
  15. partial class SqlGenerator : DbExpressionVisitor<DbExpression>
  16. {
  17. static string GenParameterName(int ordinal)
  18. {
  19. if (ordinal < CacheParameterNames.Count)
  20. {
  21. return CacheParameterNames[ordinal];
  22. }
  23. return ParameterPrefix + ordinal.ToString();
  24. }
  25. static string CreateRowNumberName(List<DbColumnSegment> columns)
  26. {
  27. int ROW_NUMBER_INDEX = 1;
  28. string row_numberName = "ROW_NUMBER_0";
  29. while (columns.Any(a => string.Equals(a.Alias, row_numberName, StringComparison.OrdinalIgnoreCase)))
  30. {
  31. row_numberName = "ROW_NUMBER_" + ROW_NUMBER_INDEX.ToString();
  32. ROW_NUMBER_INDEX++;
  33. }
  34. return row_numberName;
  35. }
  36. static DbExpression EnsureDbExpressionReturnCSharpBoolean(DbExpression exp)
  37. {
  38. if (exp.Type != UtilConstants.TypeOfBoolean && exp.Type != UtilConstants.TypeOfBoolean_Nullable)
  39. return exp;
  40. if (SafeDbExpressionTypes.Contains(exp.NodeType))
  41. {
  42. return exp;
  43. }
  44. //将且认为不符合上述条件的都是诸如 a.Id>1,a.Name=="name" 等不能作为 bool 返回值的表达式
  45. //构建 case when
  46. return ConstructReturnCSharpBooleanCaseWhenExpression(exp);
  47. }
  48. public static DbCaseWhenExpression ConstructReturnCSharpBooleanCaseWhenExpression(DbExpression exp)
  49. {
  50. // case when 1>0 then 1 when not (1>0) then 0 else Null end
  51. DbCaseWhenExpression.WhenThenExpressionPair whenThenPair = new DbCaseWhenExpression.WhenThenExpressionPair(exp, DbConstantExpression.True);
  52. DbCaseWhenExpression.WhenThenExpressionPair whenThenPair1 = new DbCaseWhenExpression.WhenThenExpressionPair(DbExpression.Not(exp), DbConstantExpression.False);
  53. List<DbCaseWhenExpression.WhenThenExpressionPair> whenThenExps = new List<DbCaseWhenExpression.WhenThenExpressionPair>(2);
  54. whenThenExps.Add(whenThenPair);
  55. whenThenExps.Add(whenThenPair1);
  56. DbCaseWhenExpression caseWhenExpression = DbExpression.CaseWhen(whenThenExps, DbConstantExpression.Null, UtilConstants.TypeOfBoolean);
  57. return caseWhenExpression;
  58. }
  59. static Stack<DbExpression> GatherBinaryExpressionOperand(DbBinaryExpression exp)
  60. {
  61. DbExpressionType nodeType = exp.NodeType;
  62. Stack<DbExpression> items = new Stack<DbExpression>();
  63. items.Push(exp.Right);
  64. DbExpression left = exp.Left;
  65. while (left.NodeType == nodeType)
  66. {
  67. exp = (DbBinaryExpression)left;
  68. items.Push(exp.Right);
  69. left = exp.Left;
  70. }
  71. items.Push(left);
  72. return items;
  73. }
  74. static void EnsureMethodDeclaringType(DbMethodCallExpression exp, Type ensureType)
  75. {
  76. if (exp.Method.DeclaringType != ensureType)
  77. throw UtilExceptions.NotSupportedMethod(exp.Method);
  78. }
  79. static void EnsureMethod(DbMethodCallExpression exp, MethodInfo methodInfo)
  80. {
  81. if (exp.Method != methodInfo)
  82. throw UtilExceptions.NotSupportedMethod(exp.Method);
  83. }
  84. static void EnsureTrimCharArgumentIsSpaces(DbExpression exp)
  85. {
  86. var m = exp as DbMemberExpression;
  87. if (m == null)
  88. throw new NotSupportedException();
  89. DbParameterExpression p;
  90. if (!DbExpressionExtensions.TryParseToParameterExpression(m, out p))
  91. {
  92. throw new NotSupportedException();
  93. }
  94. var arg = p.Value;
  95. if (arg == null)
  96. throw new NotSupportedException();
  97. var chars = arg as char[];
  98. if (chars.Length != 1 || chars[0] != ' ')
  99. {
  100. throw new NotSupportedException();
  101. }
  102. }
  103. static bool TryGetCastTargetDbTypeString(Type sourceType, Type targetType, out string dbTypeString, bool throwNotSupportedException = true)
  104. {
  105. dbTypeString = null;
  106. sourceType = Utils.GetUnderlyingType(sourceType);
  107. targetType = Utils.GetUnderlyingType(targetType);
  108. if (sourceType == targetType)
  109. return false;
  110. if (targetType == UtilConstants.TypeOfDecimal)
  111. {
  112. //Casting to Decimal is not supported when missing the precision and scale information.I have no idea to deal with this case now.
  113. if (sourceType != UtilConstants.TypeOfInt16 && sourceType != UtilConstants.TypeOfInt32 && sourceType != UtilConstants.TypeOfInt64 && sourceType != UtilConstants.TypeOfByte)
  114. {
  115. if (throwNotSupportedException)
  116. throw new NotSupportedException(AppendNotSupportedCastErrorMsg(sourceType, targetType));
  117. else
  118. return false;
  119. }
  120. }
  121. if (CastTypeMap.TryGetValue(targetType, out dbTypeString))
  122. {
  123. return true;
  124. }
  125. if (throwNotSupportedException)
  126. throw new NotSupportedException(AppendNotSupportedCastErrorMsg(sourceType, targetType));
  127. else
  128. return false;
  129. }
  130. static string AppendNotSupportedCastErrorMsg(Type sourceType, Type targetType)
  131. {
  132. return string.Format("Does not support the type '{0}' converted to type '{1}'.", sourceType.FullName, targetType.FullName);
  133. }
  134. static void DbFunction_DATEADD(SqlGenerator generator, string interval, DbMethodCallExpression exp)
  135. {
  136. generator._sqlBuilder.Append("DATEADD(");
  137. generator._sqlBuilder.Append(interval);
  138. generator._sqlBuilder.Append(",");
  139. exp.Arguments[0].Accept(generator);
  140. generator._sqlBuilder.Append(",");
  141. exp.Object.Accept(generator);
  142. generator._sqlBuilder.Append(")");
  143. }
  144. static void DbFunction_DATEPART(SqlGenerator generator, string interval, DbExpression exp)
  145. {
  146. generator._sqlBuilder.Append("DATEPART(");
  147. generator._sqlBuilder.Append(interval);
  148. generator._sqlBuilder.Append(",");
  149. exp.Accept(generator);
  150. generator._sqlBuilder.Append(")");
  151. }
  152. static void DbFunction_DATEDIFF(SqlGenerator generator, string interval, DbExpression startDateTimeExp, DbExpression endDateTimeExp)
  153. {
  154. generator._sqlBuilder.Append("DATEDIFF(");
  155. generator._sqlBuilder.Append(interval);
  156. generator._sqlBuilder.Append(",");
  157. startDateTimeExp.Accept(generator);
  158. generator._sqlBuilder.Append(",");
  159. endDateTimeExp.Accept(generator);
  160. generator._sqlBuilder.Append(")");
  161. }
  162. #region AggregateFunction
  163. static void Aggregate_Count(SqlGenerator generator)
  164. {
  165. generator._sqlBuilder.Append("COUNT(1)");
  166. }
  167. static void Aggregate_Equals(DbMethodCallExpression exp, SqlGenerator generator)
  168. {
  169. generator._sqlBuilder.Append("=");
  170. }
  171. static void Aggregate_LongCount(SqlGenerator generator)
  172. {
  173. generator._sqlBuilder.Append("COUNT_BIG(1)");
  174. }
  175. static void Aggregate_Max(SqlGenerator generator, DbExpression exp, Type retType)
  176. {
  177. AppendAggregateFunction(generator, exp, retType, "MAX", false);
  178. }
  179. static void Aggregate_Min(SqlGenerator generator, DbExpression exp, Type retType)
  180. {
  181. AppendAggregateFunction(generator, exp, retType, "MIN", false);
  182. }
  183. static void Aggregate_Sum(SqlGenerator generator, DbExpression exp, Type retType)
  184. {
  185. AppendAggregateFunction(generator, exp, retType, "SUM", true);
  186. }
  187. static void Aggregate_Average(SqlGenerator generator, DbExpression exp, Type retType)
  188. {
  189. AppendAggregateFunction(generator, exp, retType, "AVG", true);
  190. }
  191. /// <summary>
  192. /// 实现Case When操作
  193. /// </summary>
  194. /// <param name="generator"></param>
  195. /// <param name="exp"></param>
  196. /// <param name="retType"></param>
  197. /// <param name="functionName"></param>
  198. /// <param name="withCast"></param>
  199. static void AppendAggregateFunction(SqlGenerator generator, DbExpression exp, Type retType, string functionName, bool withCast)
  200. {
  201. string dbTypeString = null;
  202. if (withCast == true)
  203. {
  204. Type unType = Utils.GetUnderlyingType(retType);
  205. if (unType != UtilConstants.TypeOfDecimal/* We don't know the precision and scale,so,we can not cast exp to decimal,otherwise cause problems. */ && CastTypeMap.TryGetValue(unType, out dbTypeString))
  206. {
  207. generator._sqlBuilder.Append("CAST(");
  208. }
  209. }
  210. generator._sqlBuilder.Append(functionName, "(");
  211. exp.Accept(generator);
  212. generator._sqlBuilder.Append(")");
  213. if (dbTypeString != null)
  214. {
  215. generator._sqlBuilder.Append(" AS ", dbTypeString, ")");
  216. }
  217. }
  218. #endregion
  219. }
  220. }