A clean, Markdown-based publishing platform made for writers. Write together, and build a community. https://writefreely.org
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

245 lines
4.8 KiB

  1. package db
  2. import (
  3. "fmt"
  4. "strings"
  5. )
  6. type ColumnType int
  7. type OptionalInt struct {
  8. Set bool
  9. Value int
  10. }
  11. type OptionalString struct {
  12. Set bool
  13. Value string
  14. }
  15. type SQLBuilder interface {
  16. ToSQL() (string, error)
  17. }
  18. type Column struct {
  19. Dialect DialectType
  20. Name string
  21. Nullable bool
  22. Default OptionalString
  23. Type ColumnType
  24. Size OptionalInt
  25. PrimaryKey bool
  26. }
  27. type CreateTableSqlBuilder struct {
  28. Dialect DialectType
  29. Name string
  30. IfNotExists bool
  31. ColumnOrder []string
  32. Columns map[string]*Column
  33. Constraints []string
  34. }
  35. const (
  36. ColumnTypeBool ColumnType = iota
  37. ColumnTypeSmallInt ColumnType = iota
  38. ColumnTypeInteger ColumnType = iota
  39. ColumnTypeChar ColumnType = iota
  40. ColumnTypeVarChar ColumnType = iota
  41. ColumnTypeText ColumnType = iota
  42. ColumnTypeDateTime ColumnType = iota
  43. )
  44. var _ SQLBuilder = &CreateTableSqlBuilder{}
  45. var UnsetSize OptionalInt = OptionalInt{Set: false, Value: 0}
  46. var UnsetDefault OptionalString = OptionalString{Set: false, Value: ""}
  47. func (d ColumnType) Format(dialect DialectType, size OptionalInt) (string, error) {
  48. if dialect != DialectMySQL && dialect != DialectSQLite {
  49. return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
  50. }
  51. switch d {
  52. case ColumnTypeSmallInt:
  53. {
  54. if dialect == DialectSQLite {
  55. return "INTEGER", nil
  56. }
  57. mod := ""
  58. if size.Set {
  59. mod = fmt.Sprintf("(%d)", size.Value)
  60. }
  61. return "SMALLINT" + mod, nil
  62. }
  63. case ColumnTypeInteger:
  64. {
  65. if dialect == DialectSQLite {
  66. return "INTEGER", nil
  67. }
  68. mod := ""
  69. if size.Set {
  70. mod = fmt.Sprintf("(%d)", size.Value)
  71. }
  72. return "INT" + mod, nil
  73. }
  74. case ColumnTypeChar:
  75. {
  76. if dialect == DialectSQLite {
  77. return "TEXT", nil
  78. }
  79. mod := ""
  80. if size.Set {
  81. mod = fmt.Sprintf("(%d)", size.Value)
  82. }
  83. return "CHAR" + mod, nil
  84. }
  85. case ColumnTypeVarChar:
  86. {
  87. if dialect == DialectSQLite {
  88. return "TEXT", nil
  89. }
  90. mod := ""
  91. if size.Set {
  92. mod = fmt.Sprintf("(%d)", size.Value)
  93. }
  94. return "VARCHAR" + mod, nil
  95. }
  96. case ColumnTypeBool:
  97. {
  98. if dialect == DialectSQLite {
  99. return "INTEGER", nil
  100. }
  101. return "TINYINT(1)", nil
  102. }
  103. case ColumnTypeDateTime:
  104. return "DATETIME", nil
  105. case ColumnTypeText:
  106. return "TEXT", nil
  107. }
  108. return "", fmt.Errorf("unsupported column type %d for dialect %d and size %v", d, dialect, size)
  109. }
  110. func (c *Column) SetName(name string) *Column {
  111. c.Name = name
  112. return c
  113. }
  114. func (c *Column) SetNullable(nullable bool) *Column {
  115. c.Nullable = nullable
  116. return c
  117. }
  118. func (c *Column) SetPrimaryKey(pk bool) *Column {
  119. c.PrimaryKey = pk
  120. return c
  121. }
  122. func (c *Column) SetDefault(value string) *Column {
  123. c.Default = OptionalString{Set: true, Value: value}
  124. return c
  125. }
  126. func (c *Column) SetType(t ColumnType) *Column {
  127. c.Type = t
  128. return c
  129. }
  130. func (c *Column) SetSize(size int) *Column {
  131. c.Size = OptionalInt{Set: true, Value: size}
  132. return c
  133. }
  134. func (c *Column) String() (string, error) {
  135. var str strings.Builder
  136. str.WriteString(c.Name)
  137. str.WriteString(" ")
  138. typeStr, err := c.Type.Format(c.Dialect, c.Size)
  139. if err != nil {
  140. return "", err
  141. }
  142. str.WriteString(typeStr)
  143. if !c.Nullable {
  144. str.WriteString(" NOT NULL")
  145. }
  146. if c.Default.Set {
  147. str.WriteString(" DEFAULT ")
  148. str.WriteString(c.Default.Value)
  149. }
  150. if c.PrimaryKey {
  151. str.WriteString(" PRIMARY KEY")
  152. }
  153. return str.String(), nil
  154. }
  155. func (b *CreateTableSqlBuilder) Column(column *Column) *CreateTableSqlBuilder {
  156. if b.Columns == nil {
  157. b.Columns = make(map[string]*Column)
  158. }
  159. b.Columns[column.Name] = column
  160. b.ColumnOrder = append(b.ColumnOrder, column.Name)
  161. return b
  162. }
  163. func (b *CreateTableSqlBuilder) UniqueConstraint(columns ...string) *CreateTableSqlBuilder {
  164. for _, column := range columns {
  165. if _, ok := b.Columns[column]; !ok {
  166. // This fails silently.
  167. return b
  168. }
  169. }
  170. b.Constraints = append(b.Constraints, fmt.Sprintf("UNIQUE(%s)", strings.Join(columns, ",")))
  171. return b
  172. }
  173. func (b *CreateTableSqlBuilder) SetIfNotExists(ine bool) *CreateTableSqlBuilder {
  174. b.IfNotExists = ine
  175. return b
  176. }
  177. func (b *CreateTableSqlBuilder) ToSQL() (string, error) {
  178. var str strings.Builder
  179. str.WriteString("CREATE TABLE ")
  180. if b.IfNotExists {
  181. str.WriteString("IF NOT EXISTS ")
  182. }
  183. str.WriteString(b.Name)
  184. var things []string
  185. for _, columnName := range b.ColumnOrder {
  186. column, ok := b.Columns[columnName]
  187. if !ok {
  188. return "", fmt.Errorf("column not found: %s", columnName)
  189. }
  190. columnStr, err := column.String()
  191. if err != nil {
  192. return "", err
  193. }
  194. things = append(things, columnStr)
  195. }
  196. for _, constraint := range b.Constraints {
  197. things = append(things, constraint)
  198. }
  199. if thingLen := len(things); thingLen > 0 {
  200. str.WriteString(" ( ")
  201. for i, thing := range things {
  202. str.WriteString(thing)
  203. if i < thingLen-1 {
  204. str.WriteString(", ")
  205. }
  206. }
  207. str.WriteString(" )")
  208. }
  209. return str.String(), nil
  210. }