A clean, Markdown-based publishing platform made for writers. Write together, and build a community. https://writefreely.org
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

272 lines
5.6 KiB

  1. package db
  2. import (
  3. "fmt"
  4. "strings"
  5. )
  6. type DialectType int
  7. type ColumnType int
  8. type OptionalInt struct {
  9. Set bool
  10. Value int
  11. }
  12. type OptionalString struct {
  13. Set bool
  14. Value string
  15. }
  16. type SQLBuilder interface {
  17. ToSQL() (string, error)
  18. }
  19. type Column struct {
  20. Dialect DialectType
  21. Name string
  22. Nullable bool
  23. Default OptionalString
  24. Type ColumnType
  25. Size OptionalInt
  26. PrimaryKey bool
  27. }
  28. type CreateTableSqlBuilder struct {
  29. Dialect DialectType
  30. Name string
  31. IfNotExists bool
  32. ColumnOrder []string
  33. Columns map[string]*Column
  34. Constraints []string
  35. }
  36. const (
  37. DialectSQLite DialectType = iota
  38. DialectMySQL DialectType = iota
  39. )
  40. const (
  41. ColumnTypeBool ColumnType = iota
  42. ColumnTypeSmallInt ColumnType = iota
  43. ColumnTypeInteger ColumnType = iota
  44. ColumnTypeChar ColumnType = iota
  45. ColumnTypeVarChar ColumnType = iota
  46. ColumnTypeText ColumnType = iota
  47. ColumnTypeDateTime ColumnType = iota
  48. )
  49. var _ SQLBuilder = &CreateTableSqlBuilder{}
  50. var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0}
  51. var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""}
  52. func (d DialectType) Column(name string, t ColumnType, size OptionalInt) *Column {
  53. switch d {
  54. case DialectSQLite:
  55. return &Column{Dialect: DialectSQLite, Name: name, Type: t, Size: size}
  56. case DialectMySQL:
  57. return &Column{Dialect: DialectMySQL, Name: name, Type: t, Size: size}
  58. default:
  59. panic(fmt.Sprintf("unexpected dialect: %d", d))
  60. }
  61. }
  62. func (d DialectType) Table(name string) *CreateTableSqlBuilder {
  63. switch d {
  64. case DialectSQLite:
  65. return &CreateTableSqlBuilder{Dialect: DialectSQLite, Name: name}
  66. case DialectMySQL:
  67. return &CreateTableSqlBuilder{Dialect: DialectMySQL, Name: name}
  68. default:
  69. panic(fmt.Sprintf("unexpected dialect: %d", d))
  70. }
  71. }
  72. func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) {
  73. if dialect != DialectMySQL && dialect != DialectSQLite {
  74. return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
  75. }
  76. switch d {
  77. case ColumnTypeSmallInt:
  78. {
  79. if dialect == DialectSQLite {
  80. return "INTEGER", nil
  81. }
  82. mod := ""
  83. if size.Set {
  84. mod = fmt.Sprintf("(%d)", size.Value)
  85. }
  86. return "SMALLINT" + mod, nil
  87. }
  88. case ColumnTypeInteger:
  89. {
  90. if dialect == DialectSQLite {
  91. return "INTEGER", nil
  92. }
  93. mod := ""
  94. if size.Set {
  95. mod = fmt.Sprintf("(%d)", size.Value)
  96. }
  97. return "INT" + mod, nil
  98. }
  99. case ColumnTypeChar:
  100. {
  101. if dialect == DialectSQLite {
  102. return "TEXT", nil
  103. }
  104. mod := ""
  105. if size.Set {
  106. mod = fmt.Sprintf("(%d)", size.Value)
  107. }
  108. return "CHAR" + mod, nil
  109. }
  110. case ColumnTypeVarChar:
  111. {
  112. if dialect == DialectSQLite {
  113. return "TEXT", nil
  114. }
  115. mod := ""
  116. if size.Set {
  117. mod = fmt.Sprintf("(%d)", size.Value)
  118. }
  119. return "VARCHAR" + mod, nil
  120. }
  121. case ColumnTypeBool:
  122. {
  123. if dialect == DialectSQLite {
  124. return "INTEGER", nil
  125. }
  126. return "TINYINT(1)", nil
  127. }
  128. case ColumnTypeDateTime:
  129. return "DATETIME", nil
  130. case ColumnTypeText:
  131. return "TEXT", nil
  132. }
  133. return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
  134. }
  135. func (c *Column) SetName(name string) *Column {
  136. c.Name = name
  137. return c
  138. }
  139. func (c *Column) SetNullable(nullable bool) *Column {
  140. c.Nullable = nullable
  141. return c
  142. }
  143. func (c *Column) SetPrimaryKey(pk bool) *Column {
  144. c.PrimaryKey = pk
  145. return c
  146. }
  147. func (c *Column) SetDefault(value string) *Column {
  148. c.Default = OptionalString{Set: true, Value: value}
  149. return c
  150. }
  151. func (c *Column) SetType(t ColumnType) *Column {
  152. c.Type = t
  153. return c
  154. }
  155. func (c *Column) SetSize(size int) *Column {
  156. c.Size = OptionalInt{Set: true, Value: size}
  157. return c
  158. }
  159. func (c *Column) String() (string, error) {
  160. var str strings.Builder
  161. str.WriteString(c.Name)
  162. str.WriteString(" ")
  163. typeStr, err := c.Type.Format(c.Dialect, c.Size)
  164. if err != nil {
  165. return "", err
  166. }
  167. str.WriteString(typeStr)
  168. if !c.Nullable {
  169. str.WriteString(" NOT NULL")
  170. }
  171. if c.Default.Set {
  172. str.WriteString(" DEFAULT ")
  173. str.WriteString(c.Default.Value)
  174. }
  175. if c.PrimaryKey {
  176. str.WriteString(" PRIMARY KEY")
  177. }
  178. return str.String(), nil
  179. }
  180. func (b *CreateTableSqlBuilder) Column(column *Column) *CreateTableSqlBuilder {
  181. if b.Columns == nil {
  182. b.Columns = make(map[string]*Column)
  183. }
  184. b.Columns[column.Name] = column
  185. b.ColumnOrder = append(b.ColumnOrder, column.Name)
  186. return b
  187. }
  188. func (b *CreateTableSqlBuilder) UniqueConstraint(columns ...string) *CreateTableSqlBuilder {
  189. for _, column := range columns {
  190. if _, ok := b.Columns[column]; !ok {
  191. // This fails silently.
  192. return b
  193. }
  194. }
  195. b.Constraints = append(b.Constraints, fmt.Sprintf("UNIQUE(%s)", strings.Join(columns, ",")))
  196. return b
  197. }
  198. func (b *CreateTableSqlBuilder) SetIfNotExists(ine bool) *CreateTableSqlBuilder {
  199. b.IfNotExists = ine
  200. return b
  201. }
  202. func (b *CreateTableSqlBuilder) ToSQL() (string, error) {
  203. var str strings.Builder
  204. str.WriteString("CREATE TABLE ")
  205. if b.IfNotExists {
  206. str.WriteString("IF NOT EXISTS ")
  207. }
  208. str.WriteString(b.Name)
  209. var things []string
  210. for _, columnName := range b.ColumnOrder {
  211. column, ok := b.Columns[columnName]
  212. if !ok {
  213. return "", fmt.Errorf("column not found: %s", columnName)
  214. }
  215. columnStr, err := column.String()
  216. if err != nil {
  217. return "", err
  218. }
  219. things = append(things, columnStr)
  220. }
  221. for _, constraint := range b.Constraints {
  222. things = append(things, constraint)
  223. }
  224. if thingLen := len(things); thingLen > 0 {
  225. str.WriteString(" ( ")
  226. for i, thing := range things {
  227. str.WriteString(thing)
  228. if i < thingLen-1 {
  229. str.WriteString(", ")
  230. }
  231. }
  232. str.WriteString(" )")
  233. }
  234. return str.String(), nil
  235. }