migrator.go 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. package sqlite
  2. import (
  3. "database/sql"
  4. "fmt"
  5. "strings"
  6. "gorm.io/gorm"
  7. "gorm.io/gorm/clause"
  8. "gorm.io/gorm/migrator"
  9. "gorm.io/gorm/schema"
  10. )
  11. type Migrator struct {
  12. migrator.Migrator
  13. }
  14. func (m *Migrator) RunWithoutForeignKey(fc func() error) error {
  15. var enabled int
  16. m.DB.Raw("PRAGMA foreign_keys").Scan(&enabled)
  17. if enabled == 1 {
  18. m.DB.Exec("PRAGMA foreign_keys = OFF")
  19. defer m.DB.Exec("PRAGMA foreign_keys = ON")
  20. }
  21. return fc()
  22. }
  23. func (m Migrator) HasTable(value interface{}) bool {
  24. var count int
  25. m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
  26. return m.DB.Raw("SELECT count(*) FROM sqlite_master WHERE type='table' AND name=?", stmt.Table).Row().Scan(&count)
  27. })
  28. return count > 0
  29. }
  30. func (m Migrator) DropTable(values ...interface{}) error {
  31. return m.RunWithoutForeignKey(func() error {
  32. values = m.ReorderModels(values, false)
  33. tx := m.DB.Session(&gorm.Session{})
  34. for i := len(values) - 1; i >= 0; i-- {
  35. if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error {
  36. return tx.Exec("DROP TABLE IF EXISTS ?", clause.Table{Name: stmt.Table}).Error
  37. }); err != nil {
  38. return err
  39. }
  40. }
  41. return nil
  42. })
  43. }
  44. func (m Migrator) GetTables() (tableList []string, err error) {
  45. return tableList, m.DB.Raw("SELECT name FROM sqlite_master where type=?", "table").Scan(&tableList).Error
  46. }
  47. func (m Migrator) HasColumn(value interface{}, name string) bool {
  48. var count int
  49. m.Migrator.RunWithValue(value, func(stmt *gorm.Statement) error {
  50. if stmt.Schema != nil {
  51. if field := stmt.Schema.LookUpField(name); field != nil {
  52. name = field.DBName
  53. }
  54. }
  55. if name != "" {
  56. m.DB.Raw(
  57. "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
  58. "table", stmt.Table, `%"`+name+`" %`, `%`+name+` %`, "%`"+name+"`%", "%["+name+"]%", "%\t"+name+"\t%",
  59. ).Row().Scan(&count)
  60. }
  61. return nil
  62. })
  63. return count > 0
  64. }
  65. func (m Migrator) AlterColumn(value interface{}, name string) error {
  66. return m.RunWithoutForeignKey(func() error {
  67. return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
  68. if field := stmt.Schema.LookUpField(name); field != nil {
  69. var sqlArgs []interface{}
  70. for i, f := range ddl.fields {
  71. if matches := columnRegexp.FindStringSubmatch(f); len(matches) > 1 && matches[1] == field.DBName {
  72. ddl.fields[i] = fmt.Sprintf("`%v` ?", field.DBName)
  73. sqlArgs = []interface{}{m.FullDataTypeOf(field)}
  74. // table created by old version might look like `CREATE TABLE ? (? varchar(10) UNIQUE)`.
  75. // FullDataTypeOf doesn't contain UNIQUE, so we need to add unique constraint.
  76. if strings.Contains(strings.ToUpper(matches[3]), " UNIQUE") {
  77. uniName := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName)
  78. uni, _ := m.GuessConstraintInterfaceAndTable(stmt, uniName)
  79. if uni != nil {
  80. uniSQL, uniArgs := uni.Build()
  81. ddl.addConstraint(uniName, uniSQL)
  82. sqlArgs = append(sqlArgs, uniArgs...)
  83. }
  84. }
  85. break
  86. }
  87. }
  88. return ddl, sqlArgs, nil
  89. }
  90. return nil, nil, fmt.Errorf("failed to alter field with name %v", name)
  91. })
  92. })
  93. }
  94. // ColumnTypes return columnTypes []gorm.ColumnType and execErr error
  95. func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) {
  96. columnTypes := make([]gorm.ColumnType, 0)
  97. execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) {
  98. var (
  99. sqls []string
  100. sqlDDL *ddl
  101. )
  102. if err := m.DB.Raw("SELECT sql FROM sqlite_master WHERE type IN ? AND tbl_name = ? AND sql IS NOT NULL order by type = ? desc", []string{"table", "index"}, stmt.Table, "table").Scan(&sqls).Error; err != nil {
  103. return err
  104. }
  105. if sqlDDL, err = parseDDL(sqls...); err != nil {
  106. return err
  107. }
  108. rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows()
  109. if err != nil {
  110. return err
  111. }
  112. defer func() {
  113. err = rows.Close()
  114. }()
  115. var rawColumnTypes []*sql.ColumnType
  116. rawColumnTypes, err = rows.ColumnTypes()
  117. if err != nil {
  118. return err
  119. }
  120. for _, c := range rawColumnTypes {
  121. columnType := migrator.ColumnType{SQLColumnType: c}
  122. for _, column := range sqlDDL.columns {
  123. if column.NameValue.String == c.Name() {
  124. column.SQLColumnType = c
  125. columnType = column
  126. break
  127. }
  128. }
  129. columnTypes = append(columnTypes, columnType)
  130. }
  131. return err
  132. })
  133. return columnTypes, execErr
  134. }
  135. func (m Migrator) DropColumn(value interface{}, name string) error {
  136. return m.recreateTable(value, nil, func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
  137. if field := stmt.Schema.LookUpField(name); field != nil {
  138. name = field.DBName
  139. }
  140. ddl.removeColumn(name)
  141. return ddl, nil, nil
  142. })
  143. }
  144. func (m Migrator) CreateConstraint(value interface{}, name string) error {
  145. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  146. constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
  147. return m.recreateTable(value, &table,
  148. func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
  149. var (
  150. constraintName string
  151. constraintSql string
  152. constraintValues []interface{}
  153. )
  154. if constraint != nil {
  155. constraintName = constraint.GetName()
  156. constraintSql, constraintValues = constraint.Build()
  157. } else {
  158. return nil, nil, nil
  159. }
  160. ddl.addConstraint(constraintName, constraintSql)
  161. return ddl, constraintValues, nil
  162. })
  163. })
  164. }
  165. func (m Migrator) DropConstraint(value interface{}, name string) error {
  166. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  167. constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
  168. if constraint != nil {
  169. name = constraint.GetName()
  170. }
  171. return m.recreateTable(value, &table,
  172. func(ddl *ddl, stmt *gorm.Statement) (*ddl, []interface{}, error) {
  173. ddl.removeConstraint(name)
  174. return ddl, nil, nil
  175. })
  176. })
  177. }
  178. func (m Migrator) HasConstraint(value interface{}, name string) bool {
  179. var count int64
  180. m.RunWithValue(value, func(stmt *gorm.Statement) error {
  181. constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name)
  182. if constraint != nil {
  183. name = constraint.GetName()
  184. }
  185. m.DB.Raw(
  186. "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND (sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ? OR sql LIKE ?)",
  187. "table", table, `%CONSTRAINT "`+name+`" %`, `%CONSTRAINT `+name+` %`, "%CONSTRAINT `"+name+"`%", "%CONSTRAINT ["+name+"]%", "%CONSTRAINT \t"+name+"\t%",
  188. ).Row().Scan(&count)
  189. return nil
  190. })
  191. return count > 0
  192. }
  193. func (m Migrator) CurrentDatabase() (name string) {
  194. var null interface{}
  195. m.DB.Raw("PRAGMA database_list").Row().Scan(&null, &name, &null)
  196. return
  197. }
  198. func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) {
  199. for _, opt := range opts {
  200. str := stmt.Quote(opt.DBName)
  201. if opt.Expression != "" {
  202. str = opt.Expression
  203. }
  204. if opt.Collate != "" {
  205. str += " COLLATE " + opt.Collate
  206. }
  207. if opt.Sort != "" {
  208. str += " " + opt.Sort
  209. }
  210. results = append(results, clause.Expr{SQL: str})
  211. }
  212. return
  213. }
  214. func (m Migrator) CreateIndex(value interface{}, name string) error {
  215. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  216. if stmt.Schema != nil {
  217. if idx := stmt.Schema.LookIndex(name); idx != nil {
  218. opts := m.BuildIndexOptions(idx.Fields, stmt)
  219. values := []interface{}{clause.Column{Name: idx.Name}, clause.Table{Name: stmt.Table}, opts}
  220. createIndexSQL := "CREATE "
  221. if idx.Class != "" {
  222. createIndexSQL += idx.Class + " "
  223. }
  224. createIndexSQL += "INDEX ?"
  225. if idx.Type != "" {
  226. createIndexSQL += " USING " + idx.Type
  227. }
  228. createIndexSQL += " ON ??"
  229. if idx.Where != "" {
  230. createIndexSQL += " WHERE " + idx.Where
  231. }
  232. return m.DB.Exec(createIndexSQL, values...).Error
  233. }
  234. }
  235. return fmt.Errorf("failed to create index with name %v", name)
  236. })
  237. }
  238. func (m Migrator) HasIndex(value interface{}, name string) bool {
  239. var count int
  240. m.RunWithValue(value, func(stmt *gorm.Statement) error {
  241. if stmt.Schema != nil {
  242. if idx := stmt.Schema.LookIndex(name); idx != nil {
  243. name = idx.Name
  244. }
  245. }
  246. if name != "" {
  247. m.DB.Raw(
  248. "SELECT count(*) FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, name,
  249. ).Row().Scan(&count)
  250. }
  251. return nil
  252. })
  253. return count > 0
  254. }
  255. func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error {
  256. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  257. var sql string
  258. m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "index", stmt.Table, oldName).Row().Scan(&sql)
  259. if sql != "" {
  260. if err := m.DropIndex(value, oldName); err != nil {
  261. return err
  262. }
  263. return m.DB.Exec(strings.Replace(sql, oldName, newName, 1)).Error
  264. }
  265. return fmt.Errorf("failed to find index with name %v", oldName)
  266. })
  267. }
  268. func (m Migrator) DropIndex(value interface{}, name string) error {
  269. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  270. if stmt.Schema != nil {
  271. if idx := stmt.Schema.LookIndex(name); idx != nil {
  272. name = idx.Name
  273. }
  274. }
  275. return m.DB.Exec("DROP INDEX ?", clause.Column{Name: name}).Error
  276. })
  277. }
  278. type Index struct {
  279. Seq int
  280. Name string
  281. Unique bool
  282. Origin string
  283. Partial bool
  284. }
  285. // GetIndexes return Indexes []gorm.Index and execErr error,
  286. // See the [doc]
  287. //
  288. // [doc]: https://www.sqlite.org/pragma.html#pragma_index_list
  289. func (m Migrator) GetIndexes(value interface{}) ([]gorm.Index, error) {
  290. indexes := make([]gorm.Index, 0)
  291. err := m.RunWithValue(value, func(stmt *gorm.Statement) error {
  292. rst := make([]*Index, 0)
  293. if err := m.DB.Debug().Raw("SELECT * FROM PRAGMA_index_list(?)", stmt.Table).Scan(&rst).Error; err != nil { // alias `PRAGMA index_list(?)`
  294. return err
  295. }
  296. for _, index := range rst {
  297. if index.Origin == "u" { // skip the index was created by a UNIQUE constraint
  298. continue
  299. }
  300. var columns []string
  301. if err := m.DB.Raw("SELECT name FROM PRAGMA_index_info(?)", index.Name).Scan(&columns).Error; err != nil { // alias `PRAGMA index_info(?)`
  302. return err
  303. }
  304. indexes = append(indexes, &migrator.Index{
  305. TableName: stmt.Table,
  306. NameValue: index.Name,
  307. ColumnList: columns,
  308. PrimaryKeyValue: sql.NullBool{Bool: index.Origin == "pk", Valid: true}, // The exceptions are INTEGER PRIMARY KEY
  309. UniqueValue: sql.NullBool{Bool: index.Unique, Valid: true},
  310. })
  311. }
  312. return nil
  313. })
  314. return indexes, err
  315. }
  316. func (m Migrator) getRawDDL(table string) (string, error) {
  317. var createSQL string
  318. m.DB.Raw("SELECT sql FROM sqlite_master WHERE type = ? AND tbl_name = ? AND name = ?", "table", table, table).Row().Scan(&createSQL)
  319. if m.DB.Error != nil {
  320. return "", m.DB.Error
  321. }
  322. return createSQL, nil
  323. }
  324. func (m Migrator) recreateTable(
  325. value interface{}, tablePtr *string,
  326. getCreateSQL func(ddl *ddl, stmt *gorm.Statement) (sql *ddl, sqlArgs []interface{}, err error),
  327. ) error {
  328. return m.RunWithValue(value, func(stmt *gorm.Statement) error {
  329. table := stmt.Table
  330. if tablePtr != nil {
  331. table = *tablePtr
  332. }
  333. rawDDL, err := m.getRawDDL(table)
  334. if err != nil {
  335. return err
  336. }
  337. originDDL, err := parseDDL(rawDDL)
  338. if err != nil {
  339. return err
  340. }
  341. createDDL, sqlArgs, err := getCreateSQL(originDDL.clone(), stmt)
  342. if err != nil {
  343. return err
  344. }
  345. if createDDL == nil {
  346. return nil
  347. }
  348. newTableName := table + "__temp"
  349. if err := createDDL.renameTable(newTableName, table); err != nil {
  350. return err
  351. }
  352. columns := createDDL.getColumns()
  353. createSQL := createDDL.compile()
  354. return m.DB.Transaction(func(tx *gorm.DB) error {
  355. if err := tx.Exec(createSQL, sqlArgs...).Error; err != nil {
  356. return err
  357. }
  358. queries := []string{
  359. fmt.Sprintf("INSERT INTO `%v`(%v) SELECT %v FROM `%v`", newTableName, strings.Join(columns, ","), strings.Join(columns, ","), table),
  360. fmt.Sprintf("DROP TABLE `%v`", table),
  361. fmt.Sprintf("ALTER TABLE `%v` RENAME TO `%v`", newTableName, table),
  362. }
  363. for _, query := range queries {
  364. if err := tx.Exec(query).Error; err != nil {
  365. return err
  366. }
  367. }
  368. return nil
  369. })
  370. })
  371. }