A clean, Markdown-based publishing platform made for writers. Write together, and build a community. https://writefreely.org
25개 이상의 토픽을 선택하실 수 없습니다. Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

267 lines
5.3 KiB

  1. /*
  2. * Copyright © 2019-2020 Musing Studio LLC.
  3. *
  4. * This file is part of WriteFreely.
  5. *
  6. * WriteFreely is free software: you can redistribute it and/or modify
  7. * it under the terms of the GNU Affero General Public License, included
  8. * in the LICENSE file in this source code package.
  9. */
  10. package db
  11. import (
  12. "fmt"
  13. "strings"
  14. )
  15. type ColumnType int
  16. type OptionalInt struct {
  17. Set bool
  18. Value int
  19. }
  20. type OptionalString struct {
  21. Set bool
  22. Value string
  23. }
  24. type SQLBuilder interface {
  25. ToSQL() (string, error)
  26. }
  27. type Column struct {
  28. Dialect DialectType
  29. Name string
  30. Nullable bool
  31. Default OptionalString
  32. Type ColumnType
  33. Size OptionalInt
  34. PrimaryKey bool
  35. }
  36. type CreateTableSqlBuilder struct {
  37. Dialect DialectType
  38. Name string
  39. IfNotExists bool
  40. ColumnOrder []string
  41. Columns map[string]*Column
  42. Constraints []string
  43. }
  44. const (
  45. ColumnTypeBool ColumnType = iota
  46. ColumnTypeSmallInt ColumnType = iota
  47. ColumnTypeInteger ColumnType = iota
  48. ColumnTypeChar ColumnType = iota
  49. ColumnTypeVarChar ColumnType = iota
  50. ColumnTypeText ColumnType = iota
  51. ColumnTypeDateTime ColumnType = iota
  52. )
  53. var _ SQLBuilder = &CreateTableSqlBuilder{}
  54. var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0}
  55. var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""}
  56. func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) {
  57. if dialect != DialectMySQL && dialect != DialectSQLite {
  58. return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
  59. }
  60. switch d {
  61. case ColumnTypeSmallInt:
  62. {
  63. if dialect == DialectSQLite {
  64. return "INTEGER", nil
  65. }
  66. mod := ""
  67. if size.Set {
  68. mod = fmt.Sprintf("(%d)", size.Value)
  69. }
  70. return "SMALLINT" + mod, nil
  71. }
  72. case ColumnTypeInteger:
  73. {
  74. if dialect == DialectSQLite {
  75. return "INTEGER", nil
  76. }
  77. mod := ""
  78. if size.Set {
  79. mod = fmt.Sprintf("(%d)", size.Value)
  80. }
  81. return "INT" + mod, nil
  82. }
  83. case ColumnTypeChar:
  84. {
  85. if dialect == DialectSQLite {
  86. return "TEXT", nil
  87. }
  88. mod := ""
  89. if size.Set {
  90. mod = fmt.Sprintf("(%d)", size.Value)
  91. }
  92. return "CHAR" + mod, nil
  93. }
  94. case ColumnTypeVarChar:
  95. {
  96. if dialect == DialectSQLite {
  97. return "TEXT", nil
  98. }
  99. mod := ""
  100. if size.Set {
  101. mod = fmt.Sprintf("(%d)", size.Value)
  102. }
  103. return "VARCHAR" + mod, nil
  104. }
  105. case ColumnTypeBool:
  106. {
  107. if dialect == DialectSQLite {
  108. return "INTEGER", nil
  109. }
  110. return "TINYINT(1)", nil
  111. }
  112. case ColumnTypeDateTime:
  113. return "DATETIME", nil
  114. case ColumnTypeText:
  115. return "TEXT", nil
  116. }
  117. return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
  118. }
  119. func (c *Column) SetName(name string) *Column {
  120. c.Name = name
  121. return c
  122. }
  123. func (c *Column) SetNullable(nullable bool) *Column {
  124. c.Nullable = nullable
  125. return c
  126. }
  127. func (c *Column) SetPrimaryKey(pk bool) *Column {
  128. c.PrimaryKey = pk
  129. return c
  130. }
  131. func (c *Column) SetDefault(value string) *Column {
  132. c.Default = OptionalString{Set: true, Value: value}
  133. return c
  134. }
  135. func (c *Column) SetDefaultCurrentTimestamp() *Column {
  136. def := "NOW()"
  137. if c.Dialect == DialectSQLite {
  138. def = "CURRENT_TIMESTAMP"
  139. }
  140. c.Default = OptionalString{Set: true, Value: def}
  141. return c
  142. }
  143. func (c *Column) SetType(t ColumnType) *Column {
  144. c.Type = t
  145. return c
  146. }
  147. func (c *Column) SetSize(size int) *Column {
  148. c.Size = OptionalInt{Set: true, Value: size}
  149. return c
  150. }
  151. func (c *Column) String() (string, error) {
  152. var str strings.Builder
  153. str.WriteString(c.Name)
  154. str.WriteString(" ")
  155. typeStr, err := c.Type.Format(c.Dialect, c.Size)
  156. if err != nil {
  157. return "", err
  158. }
  159. str.WriteString(typeStr)
  160. if !c.Nullable {
  161. str.WriteString(" NOT NULL")
  162. }
  163. if c.Default.Set {
  164. str.WriteString(" DEFAULT ")
  165. val := c.Default.Value
  166. if val == "" {
  167. val = "''"
  168. }
  169. str.WriteString(val)
  170. }
  171. if c.PrimaryKey {
  172. str.WriteString(" PRIMARY KEY")
  173. }
  174. return str.String(), nil
  175. }
  176. func (b *CreateTableSqlBuilder) Column(column *Column) *CreateTableSqlBuilder {
  177. if b.Columns == nil {
  178. b.Columns = make(map[string]*Column)
  179. }
  180. b.Columns[column.Name] = column
  181. b.ColumnOrder = append(b.ColumnOrder, column.Name)
  182. return b
  183. }
  184. func (b *CreateTableSqlBuilder) UniqueConstraint(columns ...string) *CreateTableSqlBuilder {
  185. for _, column := range columns {
  186. if _, ok := b.Columns[column]; !ok {
  187. // This fails silently.
  188. return b
  189. }
  190. }
  191. b.Constraints = append(b.Constraints, fmt.Sprintf("UNIQUE(%s)", strings.Join(columns, ",")))
  192. return b
  193. }
  194. func (b *CreateTableSqlBuilder) SetIfNotExists(ine bool) *CreateTableSqlBuilder {
  195. b.IfNotExists = ine
  196. return b
  197. }
  198. func (b *CreateTableSqlBuilder) ToSQL() (string, error) {
  199. var str strings.Builder
  200. str.WriteString("CREATE TABLE ")
  201. if b.IfNotExists {
  202. str.WriteString("IF NOT EXISTS ")
  203. }
  204. str.WriteString(b.Name)
  205. var things []string
  206. for _, columnName := range b.ColumnOrder {
  207. column, ok := b.Columns[columnName]
  208. if !ok {
  209. return "", fmt.Errorf("column not found: %s", columnName)
  210. }
  211. columnStr, err := column.String()
  212. if err != nil {
  213. return "", err
  214. }
  215. things = append(things, columnStr)
  216. }
  217. for _, constraint := range b.Constraints {
  218. things = append(things, constraint)
  219. }
  220. if thingLen := len(things); thingLen > 0 {
  221. str.WriteString(" ( ")
  222. for i, thing := range things {
  223. str.WriteString(thing)
  224. if i < thingLen-1 {
  225. str.WriteString(", ")
  226. }
  227. }
  228. str.WriteString(" )")
  229. }
  230. return str.String(), nil
  231. }