fork download
  1. trait Expression {
  2. def asSQL: (String, Seq[Any])
  3.  
  4. def %==(other: Expression) = new EqualToExpression(this, other)
  5. def %!=(other: Expression) = new NotEqualToExpression(this, other)
  6. def %<(other: Expression) = new SmallerThanExpression(this, other)
  7. def %>(other: Expression) = new GreaterThanExpression(this, other)
  8. def %<=(other: Expression) = new SmallerThanOrEqualToExpression(this, other)
  9. def %>=(other: Expression) = new GreaterThanOrEqualToExpression(this, other)
  10. }
  11.  
  12. class Schema(val name: String)
  13.  
  14. class Relation(val schema: Schema, val name: String)
  15.  
  16. class Attribute(val relation: Relation, val name: String) extends Expression {
  17. def asSQL =
  18. (relation.schema.name + "." + relation.name + "." + name, Seq.empty)
  19. }
  20.  
  21. abstract class BinaryExpression(a: Expression, b: Expression, operator: String) extends Expression {
  22. def asSQL = {
  23. val (aSQL, bindings1) = a.asSQL
  24. val (bSQL, bindings2) = b.asSQL
  25.  
  26. ("(" + aSQL + ") " + operator + " (" + bSQL + ")", bindings1 ++ bindings2)
  27. }
  28. }
  29.  
  30. class EqualToExpression(a: Expression, b: Expression) extends BinaryExpression(a, b, "=")
  31. class NotEqualToExpression(a: Expression, b: Expression) extends BinaryExpression(a, b, "<>")
  32. class SmallerThanExpression(a: Expression, b: Expression) extends BinaryExpression(a, b, "<")
  33. class GreaterThanExpression(a: Expression, b: Expression) extends BinaryExpression(a, b, ">")
  34. class SmallerThanOrEqualToExpression(a: Expression, b: Expression) extends BinaryExpression(a, b, "<=")
  35. class GreaterThanOrEqualToExpression(a: Expression, b: Expression) extends BinaryExpression(a, b, ">=")
  36.  
  37. class IntExpression(value: Int) extends Expression {
  38. def asSQL =
  39. ("?", Array(value))
  40. }
  41.  
  42. class Query(relation: Relation, fields: Seq[Expression], filters: Seq[Expression]) extends Expression {
  43. def asSQL = {
  44. val (fieldSQLs, bindings1) = fields.map(_.asSQL).unzip
  45. val selectStr = "SELECT (" + fieldSQLs.mkString("), (") + ")\n"
  46.  
  47. val fromStr = s" FROM ${relation.schema.name}.${relation.name}\n"
  48.  
  49. val (filterSQLs, bindings2) = filters.map(_.asSQL).unzip
  50. val whereStr =
  51. if (filterSQLs.length > 0)
  52. " WHERE (" + filterSQLs.mkString(")\n AND (") + ")\n"
  53. ""
  54.  
  55. (selectStr + fromStr + whereStr, bindings1 ++ bindings2)
  56. }
  57.  
  58. def select(newFields: Expression*) =
  59. new Query(relation, fields ++ newFields, filters)
  60.  
  61. def where(newFilters: Expression*) =
  62. new Query(relation, fields, filters ++ newFilters)
  63. }
  64.  
  65. object Query {
  66. def from(relation: Relation) = new Query(relation, Array[Expression](), Array[Expression]())
  67. }
  68.  
  69. object Main {
  70. def main(args: Array[String]) {
  71. val public = new Schema("public")
  72. val users = new Relation(public, "users")
  73. val userId = new Attribute(users, "id")
  74. val userName = new Attribute(users, "name")
  75. val userEmailAddress = new Attribute(users, "email_address")
  76. val userAge = new Attribute(users, "age")
  77.  
  78. val query =
  79. Query.from(users)
  80. .where(userAge %> new IntExpression(20))
  81. .select(userId, userName, userEmailAddress)
  82.  
  83. val (sql, bindings) = query.asSQL
  84. println(sql)
  85. bindings.foreach(println(_))
  86. }
  87. }
  88.  
Success #stdin #stdout 0.39s 382144KB
stdin
Standard input is empty
stdout
SELECT (public.users.id), (public.users.name), (public.users.email_address)
  FROM public.users
 WHERE ((public.users.age) > (?))

List()
List()
List()
List(20)