A clean, Markdown-based publishing platform made for writers. Write together, and build a community. https://writefreely.org
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

2542 line
76 KiB

  1. /*
  2. * Copyright © 2018 A Bunch Tell LLC.
  3. *
  4. * This file is part of WriteFreely.
  5. *
  6. * WriteFreely is free software: you can redistribute it and/or modify
  7. * it under the terms of the GNU Affero General Public License, included
  8. * in the LICENSE file in this source code package.
  9. */
  10. package writefreely
  11. import (
  12. "context"
  13. "database/sql"
  14. "fmt"
  15. "net/http"
  16. "strings"
  17. "time"
  18. "github.com/guregu/null"
  19. "github.com/guregu/null/zero"
  20. uuid "github.com/nu7hatch/gouuid"
  21. "github.com/writeas/impart"
  22. "github.com/writeas/nerds/store"
  23. "github.com/writeas/web-core/activitypub"
  24. "github.com/writeas/web-core/auth"
  25. "github.com/writeas/web-core/data"
  26. "github.com/writeas/web-core/id"
  27. "github.com/writeas/web-core/log"
  28. "github.com/writeas/web-core/query"
  29. "github.com/writeas/writefreely/author"
  30. "github.com/writeas/writefreely/config"
  31. "github.com/writeas/writefreely/key"
  32. )
  33. const (
  34. mySQLErrDuplicateKey = 1062
  35. driverMySQL = "mysql"
  36. driverSQLite = "sqlite3"
  37. )
  38. var (
  39. SQLiteEnabled bool
  40. )
  41. type writestore interface {
  42. CreateUser(*config.Config, *User, string) error
  43. UpdateUserEmail(keys *key.Keychain, userID int64, email string) error
  44. UpdateEncryptedUserEmail(int64, []byte) error
  45. GetUserByID(int64) (*User, error)
  46. GetUserForAuth(string) (*User, error)
  47. GetUserForAuthByID(int64) (*User, error)
  48. GetUserNameFromToken(string) (string, error)
  49. GetUserDataFromToken(string) (int64, string, error)
  50. GetAPIUser(header string) (*User, error)
  51. GetUserID(accessToken string) int64
  52. GetUserIDPrivilege(accessToken string) (userID int64, sudo bool)
  53. DeleteToken(accessToken []byte) error
  54. FetchLastAccessToken(userID int64) string
  55. GetAccessToken(userID int64) (string, error)
  56. GetTemporaryAccessToken(userID int64, validSecs int) (string, error)
  57. GetTemporaryOneTimeAccessToken(userID int64, validSecs int, oneTime bool) (string, error)
  58. DeleteAccount(userID int64) (l *string, err error)
  59. ChangeSettings(app *App, u *User, s *userSettings) error
  60. ChangePassphrase(userID int64, sudo bool, curPass string, hashedPass []byte) error
  61. GetCollections(u *User, hostName string) (*[]Collection, error)
  62. GetPublishableCollections(u *User, hostName string) (*[]Collection, error)
  63. GetMeStats(u *User) userMeStats
  64. GetTotalCollections() (int64, error)
  65. GetTotalPosts() (int64, error)
  66. GetTopPosts(u *User, alias string) (*[]PublicPost, error)
  67. GetAnonymousPosts(u *User) (*[]PublicPost, error)
  68. GetUserPosts(u *User) (*[]PublicPost, error)
  69. CreateOwnedPost(post *SubmittedPost, accessToken, collAlias, hostName string) (*PublicPost, error)
  70. CreatePost(userID, collID int64, post *SubmittedPost) (*Post, error)
  71. UpdateOwnedPost(post *AuthenticatedPost, userID int64) error
  72. GetEditablePost(id, editToken string) (*PublicPost, error)
  73. PostIDExists(id string) bool
  74. GetPost(id string, collectionID int64) (*PublicPost, error)
  75. GetOwnedPost(id string, ownerID int64) (*PublicPost, error)
  76. GetPostProperty(id string, collectionID int64, property string) (interface{}, error)
  77. CreateCollectionFromToken(*config.Config, string, string, string) (*Collection, error)
  78. CreateCollection(*config.Config, string, string, int64) (*Collection, error)
  79. GetCollectionBy(condition string, value interface{}) (*Collection, error)
  80. GetCollection(alias string) (*Collection, error)
  81. GetCollectionForPad(alias string) (*Collection, error)
  82. GetCollectionByID(id int64) (*Collection, error)
  83. UpdateCollection(c *SubmittedCollection, alias string) error
  84. DeleteCollection(alias string, userID int64) error
  85. UpdatePostPinState(pinned bool, postID string, collID, ownerID, pos int64) error
  86. GetLastPinnedPostPos(collID int64) int64
  87. GetPinnedPosts(coll *CollectionObj, includeFuture bool) (*[]PublicPost, error)
  88. RemoveCollectionRedirect(t *sql.Tx, alias string) error
  89. GetCollectionRedirect(alias string) (new string)
  90. IsCollectionAttributeOn(id int64, attr string) bool
  91. CollectionHasAttribute(id int64, attr string) bool
  92. CanCollect(cpr *ClaimPostRequest, userID int64) bool
  93. AttemptClaim(p *ClaimPostRequest, query string, params []interface{}, slugIdx int) (sql.Result, error)
  94. DispersePosts(userID int64, postIDs []string) (*[]ClaimPostResult, error)
  95. ClaimPosts(cfg *config.Config, userID int64, collAlias string, posts *[]ClaimPostRequest) (*[]ClaimPostResult, error)
  96. GetPostsCount(c *CollectionObj, includeFuture bool)
  97. GetPosts(cfg *config.Config, c *Collection, page int, includeFuture, forceRecentFirst, includePinned bool) (*[]PublicPost, error)
  98. GetPostsTagged(cfg *config.Config, c *Collection, tag string, page int, includeFuture bool) (*[]PublicPost, error)
  99. GetAPFollowers(c *Collection) (*[]RemoteUser, error)
  100. GetAPActorKeys(collectionID int64) ([]byte, []byte)
  101. CreateUserInvite(id string, userID int64, maxUses int, expires *time.Time) error
  102. GetUserInvites(userID int64) (*[]Invite, error)
  103. GetUserInvite(id string) (*Invite, error)
  104. GetUsersInvitedCount(id string) int64
  105. CreateInvitedUser(inviteID string, userID int64) error
  106. GetDynamicContent(id string) (*instanceContent, error)
  107. UpdateDynamicContent(id, title, content, contentType string) error
  108. GetAllUsers(page uint) (*[]User, error)
  109. GetAllUsersCount() int64
  110. GetUserLastPostTime(id int64) (*time.Time, error)
  111. GetCollectionLastPostTime(id int64) (*time.Time, error)
  112. GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error)
  113. RecordRemoteUserID(ctx context.Context, localUserID, remoteUserID int64) error
  114. ValidateOAuthState(ctx context.Context, state string) error
  115. GenerateOAuthState(ctx context.Context) (string, error)
  116. DatabaseInitialized() bool
  117. }
  118. type datastore struct {
  119. *sql.DB
  120. driverName string
  121. }
  122. func (db *datastore) now() string {
  123. if db.driverName == driverSQLite {
  124. return "strftime('%Y-%m-%d %H:%M:%S','now')"
  125. }
  126. return "NOW()"
  127. }
  128. func (db *datastore) clip(field string, l int) string {
  129. if db.driverName == driverSQLite {
  130. return fmt.Sprintf("SUBSTR(%s, 0, %d)", field, l)
  131. }
  132. return fmt.Sprintf("LEFT(%s, %d)", field, l)
  133. }
  134. func (db *datastore) upsert(indexedCols ...string) string {
  135. if db.driverName == driverSQLite {
  136. // NOTE: SQLite UPSERT syntax only works in v3.24.0 (2018-06-04) or later
  137. // Leaving this for whenever we can upgrade and include it in our binary
  138. cc := strings.Join(indexedCols, ", ")
  139. return "ON CONFLICT(" + cc + ") DO UPDATE SET"
  140. }
  141. return "ON DUPLICATE KEY UPDATE"
  142. }
  143. func (db *datastore) dateSub(l int, unit string) string {
  144. if db.driverName == driverSQLite {
  145. return fmt.Sprintf("DATETIME('now', '-%d %s')", l, unit)
  146. }
  147. return fmt.Sprintf("DATE_SUB(NOW(), INTERVAL %d %s)", l, unit)
  148. }
  149. func (db *datastore) CreateUser(cfg *config.Config, u *User, collectionTitle string) error {
  150. if db.PostIDExists(u.Username) {
  151. return impart.HTTPError{http.StatusConflict, "Invalid collection name."}
  152. }
  153. // New users get a `users` and `collections` row.
  154. t, err := db.Begin()
  155. if err != nil {
  156. return err
  157. }
  158. // 1. Add to `users` table
  159. // NOTE: Assumes User's Password is already hashed!
  160. res, err := t.Exec("INSERT INTO users (username, password, email) VALUES (?, ?, ?)", u.Username, u.HashedPass, u.Email)
  161. if err != nil {
  162. t.Rollback()
  163. if db.isDuplicateKeyErr(err) {
  164. return impart.HTTPError{http.StatusConflict, "Username is already taken."}
  165. }
  166. log.Error("Rolling back users INSERT: %v\n", err)
  167. return err
  168. }
  169. u.ID, err = res.LastInsertId()
  170. if err != nil {
  171. t.Rollback()
  172. log.Error("Rolling back after LastInsertId: %v\n", err)
  173. return err
  174. }
  175. // 2. Create user's Collection
  176. if collectionTitle == "" {
  177. collectionTitle = u.Username
  178. }
  179. res, err = t.Exec("INSERT INTO collections (alias, title, description, privacy, owner_id, view_count) VALUES (?, ?, ?, ?, ?, ?)", u.Username, collectionTitle, "", defaultVisibility(cfg), u.ID, 0)
  180. if err != nil {
  181. t.Rollback()
  182. if db.isDuplicateKeyErr(err) {
  183. return impart.HTTPError{http.StatusConflict, "Username is already taken."}
  184. }
  185. log.Error("Rolling back collections INSERT: %v\n", err)
  186. return err
  187. }
  188. db.RemoveCollectionRedirect(t, u.Username)
  189. err = t.Commit()
  190. if err != nil {
  191. t.Rollback()
  192. log.Error("Rolling back after Commit(): %v\n", err)
  193. return err
  194. }
  195. return nil
  196. }
  197. // FIXME: We're returning errors inconsistently in this file. Do we use Errorf
  198. // for returned value, or impart?
  199. func (db *datastore) UpdateUserEmail(keys *key.Keychain, userID int64, email string) error {
  200. encEmail, err := data.Encrypt(keys.EmailKey, email)
  201. if err != nil {
  202. return fmt.Errorf("Couldn't encrypt email %s: %s\n", email, err)
  203. }
  204. return db.UpdateEncryptedUserEmail(userID, encEmail)
  205. }
  206. func (db *datastore) UpdateEncryptedUserEmail(userID int64, encEmail []byte) error {
  207. _, err := db.Exec("UPDATE users SET email = ? WHERE id = ?", encEmail, userID)
  208. if err != nil {
  209. return fmt.Errorf("Unable to update user email: %s", err)
  210. }
  211. return nil
  212. }
  213. func (db *datastore) CreateCollectionFromToken(cfg *config.Config, alias, title, accessToken string) (*Collection, error) {
  214. userID := db.GetUserID(accessToken)
  215. if userID == -1 {
  216. return nil, ErrBadAccessToken
  217. }
  218. return db.CreateCollection(cfg, alias, title, userID)
  219. }
  220. func (db *datastore) GetUserCollectionCount(userID int64) (uint64, error) {
  221. var collCount uint64
  222. err := db.QueryRow("SELECT COUNT(*) FROM collections WHERE owner_id = ?", userID).Scan(&collCount)
  223. switch {
  224. case err == sql.ErrNoRows:
  225. return 0, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user from database."}
  226. case err != nil:
  227. log.Error("Couldn't get collections count for user %d: %v", userID, err)
  228. return 0, err
  229. }
  230. return collCount, nil
  231. }
  232. func (db *datastore) CreateCollection(cfg *config.Config, alias, title string, userID int64) (*Collection, error) {
  233. if db.PostIDExists(alias) {
  234. return nil, impart.HTTPError{http.StatusConflict, "Invalid collection name."}
  235. }
  236. // All good, so create new collection
  237. res, err := db.Exec("INSERT INTO collections (alias, title, description, privacy, owner_id, view_count) VALUES (?, ?, ?, ?, ?, ?)", alias, title, "", defaultVisibility(cfg), userID, 0)
  238. if err != nil {
  239. if db.isDuplicateKeyErr(err) {
  240. return nil, impart.HTTPError{http.StatusConflict, "Collection already exists."}
  241. }
  242. log.Error("Couldn't add to collections: %v\n", err)
  243. return nil, err
  244. }
  245. c := &Collection{
  246. Alias: alias,
  247. Title: title,
  248. OwnerID: userID,
  249. PublicOwner: false,
  250. Public: defaultVisibility(cfg) == CollPublic,
  251. }
  252. c.ID, err = res.LastInsertId()
  253. if err != nil {
  254. log.Error("Couldn't get collection LastInsertId: %v\n", err)
  255. }
  256. return c, nil
  257. }
  258. func (db *datastore) GetUserByID(id int64) (*User, error) {
  259. u := &User{ID: id}
  260. err := db.QueryRow("SELECT username, password, email, created, status FROM users WHERE id = ?", id).Scan(&u.Username, &u.HashedPass, &u.Email, &u.Created, &u.Status)
  261. switch {
  262. case err == sql.ErrNoRows:
  263. return nil, ErrUserNotFound
  264. case err != nil:
  265. log.Error("Couldn't SELECT user password: %v", err)
  266. return nil, err
  267. }
  268. return u, nil
  269. }
  270. // IsUserSuspended returns true if the user account associated with id is
  271. // currently suspended.
  272. func (db *datastore) IsUserSuspended(id int64) (bool, error) {
  273. u := &User{ID: id}
  274. err := db.QueryRow("SELECT status FROM users WHERE id = ?", id).Scan(&u.Status)
  275. switch {
  276. case err == sql.ErrNoRows:
  277. return false, fmt.Errorf("is user suspended: %v", ErrUserNotFound)
  278. case err != nil:
  279. log.Error("Couldn't SELECT user password: %v", err)
  280. return false, fmt.Errorf("is user suspended: %v", err)
  281. }
  282. return u.IsSilenced(), nil
  283. }
  284. // DoesUserNeedAuth returns true if the user hasn't provided any methods for
  285. // authenticating with the account, such a passphrase or email address.
  286. // Any errors are reported to admin and silently quashed, returning false as the
  287. // result.
  288. func (db *datastore) DoesUserNeedAuth(id int64) bool {
  289. var pass, email []byte
  290. // Find out if user has an email set first
  291. err := db.QueryRow("SELECT password, email FROM users WHERE id = ?", id).Scan(&pass, &email)
  292. switch {
  293. case err == sql.ErrNoRows:
  294. // ERROR. Don't give false positives on needing auth methods
  295. return false
  296. case err != nil:
  297. // ERROR. Don't give false positives on needing auth methods
  298. log.Error("Couldn't SELECT user %d from users: %v", id, err)
  299. return false
  300. }
  301. // User doesn't need auth if there's an email
  302. return len(email) == 0 && len(pass) == 0
  303. }
  304. func (db *datastore) IsUserPassSet(id int64) (bool, error) {
  305. var pass []byte
  306. err := db.QueryRow("SELECT password FROM users WHERE id = ?", id).Scan(&pass)
  307. switch {
  308. case err == sql.ErrNoRows:
  309. return false, nil
  310. case err != nil:
  311. log.Error("Couldn't SELECT user %d from users: %v", id, err)
  312. return false, err
  313. }
  314. return len(pass) > 0, nil
  315. }
  316. func (db *datastore) GetUserForAuth(username string) (*User, error) {
  317. u := &User{Username: username}
  318. err := db.QueryRow("SELECT id, password, email, created, status FROM users WHERE username = ?", username).Scan(&u.ID, &u.HashedPass, &u.Email, &u.Created, &u.Status)
  319. switch {
  320. case err == sql.ErrNoRows:
  321. // Check if they've entered the wrong, unnormalized username
  322. username = getSlug(username, "")
  323. if username != u.Username {
  324. err = db.QueryRow("SELECT id FROM users WHERE username = ? LIMIT 1", username).Scan(&u.ID)
  325. if err == nil {
  326. return db.GetUserForAuth(username)
  327. }
  328. }
  329. return nil, ErrUserNotFound
  330. case err != nil:
  331. log.Error("Couldn't SELECT user password: %v", err)
  332. return nil, err
  333. }
  334. return u, nil
  335. }
  336. func (db *datastore) GetUserForAuthByID(userID int64) (*User, error) {
  337. u := &User{ID: userID}
  338. err := db.QueryRow("SELECT id, password, email, created, status FROM users WHERE id = ?", u.ID).Scan(&u.ID, &u.HashedPass, &u.Email, &u.Created, &u.Status)
  339. switch {
  340. case err == sql.ErrNoRows:
  341. return nil, ErrUserNotFound
  342. case err != nil:
  343. log.Error("Couldn't SELECT userForAuthByID: %v", err)
  344. return nil, err
  345. }
  346. return u, nil
  347. }
  348. func (db *datastore) GetUserNameFromToken(accessToken string) (string, error) {
  349. t := auth.GetToken(accessToken)
  350. if len(t) == 0 {
  351. return "", ErrNoAccessToken
  352. }
  353. var oneTime bool
  354. var username string
  355. err := db.QueryRow("SELECT username, one_time FROM accesstokens LEFT JOIN users ON user_id = id WHERE token LIKE ? AND (expires IS NULL OR expires > "+db.now()+")", t).Scan(&username, &oneTime)
  356. switch {
  357. case err == sql.ErrNoRows:
  358. return "", ErrBadAccessToken
  359. case err != nil:
  360. return "", ErrInternalGeneral
  361. }
  362. // Delete token if it was one-time
  363. if oneTime {
  364. db.DeleteToken(t[:])
  365. }
  366. return username, nil
  367. }
  368. func (db *datastore) GetUserDataFromToken(accessToken string) (int64, string, error) {
  369. t := auth.GetToken(accessToken)
  370. if len(t) == 0 {
  371. return 0, "", ErrNoAccessToken
  372. }
  373. var userID int64
  374. var oneTime bool
  375. var username string
  376. err := db.QueryRow("SELECT user_id, username, one_time FROM accesstokens LEFT JOIN users ON user_id = id WHERE token LIKE ? AND (expires IS NULL OR expires > "+db.now()+")", t).Scan(&userID, &username, &oneTime)
  377. switch {
  378. case err == sql.ErrNoRows:
  379. return 0, "", ErrBadAccessToken
  380. case err != nil:
  381. return 0, "", ErrInternalGeneral
  382. }
  383. // Delete token if it was one-time
  384. if oneTime {
  385. db.DeleteToken(t[:])
  386. }
  387. return userID, username, nil
  388. }
  389. func (db *datastore) GetAPIUser(header string) (*User, error) {
  390. uID := db.GetUserID(header)
  391. if uID == -1 {
  392. return nil, fmt.Errorf(ErrUserNotFound.Error())
  393. }
  394. return db.GetUserByID(uID)
  395. }
  396. // GetUserID takes a hexadecimal accessToken, parses it into its binary
  397. // representation, and gets any user ID associated with the token. If no user
  398. // is associated, -1 is returned.
  399. func (db *datastore) GetUserID(accessToken string) int64 {
  400. i, _ := db.GetUserIDPrivilege(accessToken)
  401. return i
  402. }
  403. func (db *datastore) GetUserIDPrivilege(accessToken string) (userID int64, sudo bool) {
  404. t := auth.GetToken(accessToken)
  405. if len(t) == 0 {
  406. return -1, false
  407. }
  408. var oneTime bool
  409. err := db.QueryRow("SELECT user_id, sudo, one_time FROM accesstokens WHERE token LIKE ? AND (expires IS NULL OR expires > "+db.now()+")", t).Scan(&userID, &sudo, &oneTime)
  410. switch {
  411. case err == sql.ErrNoRows:
  412. return -1, false
  413. case err != nil:
  414. return -1, false
  415. }
  416. // Delete token if it was one-time
  417. if oneTime {
  418. db.DeleteToken(t[:])
  419. }
  420. return
  421. }
  422. func (db *datastore) DeleteToken(accessToken []byte) error {
  423. res, err := db.Exec("DELETE FROM accesstokens WHERE token LIKE ?", accessToken)
  424. if err != nil {
  425. return err
  426. }
  427. rowsAffected, _ := res.RowsAffected()
  428. if rowsAffected == 0 {
  429. return impart.HTTPError{http.StatusNotFound, "Token is invalid or doesn't exist"}
  430. }
  431. return nil
  432. }
  433. // FetchLastAccessToken creates a new non-expiring, valid access token for the given
  434. // userID.
  435. func (db *datastore) FetchLastAccessToken(userID int64) string {
  436. var t []byte
  437. err := db.QueryRow("SELECT token FROM accesstokens WHERE user_id = ? AND (expires IS NULL OR expires > "+db.now()+") ORDER BY created DESC LIMIT 1", userID).Scan(&t)
  438. switch {
  439. case err == sql.ErrNoRows:
  440. return ""
  441. case err != nil:
  442. log.Error("Failed selecting from accesstoken: %v", err)
  443. return ""
  444. }
  445. u, err := uuid.Parse(t)
  446. if err != nil {
  447. return ""
  448. }
  449. return u.String()
  450. }
  451. // GetAccessToken creates a new non-expiring, valid access token for the given
  452. // userID.
  453. func (db *datastore) GetAccessToken(userID int64) (string, error) {
  454. return db.GetTemporaryOneTimeAccessToken(userID, 0, false)
  455. }
  456. // GetTemporaryAccessToken creates a new valid access token for the given
  457. // userID that remains valid for the given time in seconds. If validSecs is 0,
  458. // the access token doesn't automatically expire.
  459. func (db *datastore) GetTemporaryAccessToken(userID int64, validSecs int) (string, error) {
  460. return db.GetTemporaryOneTimeAccessToken(userID, validSecs, false)
  461. }
  462. // GetTemporaryOneTimeAccessToken creates a new valid access token for the given
  463. // userID that remains valid for the given time in seconds and can only be used
  464. // once if oneTime is true. If validSecs is 0, the access token doesn't
  465. // automatically expire.
  466. func (db *datastore) GetTemporaryOneTimeAccessToken(userID int64, validSecs int, oneTime bool) (string, error) {
  467. u, err := uuid.NewV4()
  468. if err != nil {
  469. log.Error("Unable to generate token: %v", err)
  470. return "", err
  471. }
  472. // Insert UUID to `accesstokens`
  473. binTok := u[:]
  474. expirationVal := "NULL"
  475. if validSecs > 0 {
  476. expirationVal = fmt.Sprintf("DATE_ADD("+db.now()+", INTERVAL %d SECOND)", validSecs)
  477. }
  478. _, err = db.Exec("INSERT INTO accesstokens (token, user_id, one_time, expires) VALUES (?, ?, ?, "+expirationVal+")", string(binTok), userID, oneTime)
  479. if err != nil {
  480. log.Error("Couldn't INSERT accesstoken: %v", err)
  481. return "", err
  482. }
  483. return u.String(), nil
  484. }
  485. func (db *datastore) CreateOwnedPost(post *SubmittedPost, accessToken, collAlias, hostName string) (*PublicPost, error) {
  486. var userID, collID int64 = -1, -1
  487. var coll *Collection
  488. var err error
  489. if accessToken != "" {
  490. userID = db.GetUserID(accessToken)
  491. if userID == -1 {
  492. return nil, ErrBadAccessToken
  493. }
  494. if collAlias != "" {
  495. coll, err = db.GetCollection(collAlias)
  496. if err != nil {
  497. return nil, err
  498. }
  499. coll.hostName = hostName
  500. if coll.OwnerID != userID {
  501. return nil, ErrForbiddenCollection
  502. }
  503. collID = coll.ID
  504. }
  505. }
  506. rp := &PublicPost{}
  507. rp.Post, err = db.CreatePost(userID, collID, post)
  508. if err != nil {
  509. return rp, err
  510. }
  511. if coll != nil {
  512. coll.ForPublic()
  513. rp.Collection = &CollectionObj{Collection: *coll}
  514. }
  515. return rp, nil
  516. }
  517. func (db *datastore) CreatePost(userID, collID int64, post *SubmittedPost) (*Post, error) {
  518. idLen := postIDLen
  519. friendlyID := store.GenerateFriendlyRandomString(idLen)
  520. // Handle appearance / font face
  521. appearance := post.Font
  522. if !post.isFontValid() {
  523. appearance = "norm"
  524. }
  525. var err error
  526. ownerID := sql.NullInt64{
  527. Valid: false,
  528. }
  529. ownerCollID := sql.NullInt64{
  530. Valid: false,
  531. }
  532. slug := sql.NullString{"", false}
  533. // If an alias was supplied, we'll add this to the collection as well.
  534. if userID > 0 {
  535. ownerID.Int64 = userID
  536. ownerID.Valid = true
  537. if collID > 0 {
  538. ownerCollID.Int64 = collID
  539. ownerCollID.Valid = true
  540. var slugVal string
  541. if post.Title != nil && *post.Title != "" {
  542. slugVal = getSlug(*post.Title, post.Language.String)
  543. if slugVal == "" {
  544. slugVal = getSlug(*post.Content, post.Language.String)
  545. }
  546. } else {
  547. slugVal = getSlug(*post.Content, post.Language.String)
  548. }
  549. if slugVal == "" {
  550. slugVal = friendlyID
  551. }
  552. slug = sql.NullString{slugVal, true}
  553. }
  554. }
  555. created := time.Now()
  556. if db.driverName == driverSQLite {
  557. // SQLite stores datetimes in UTC, so convert time.Now() to it here
  558. created = created.UTC()
  559. }
  560. if post.Created != nil {
  561. created, err = time.Parse("2006-01-02T15:04:05Z", *post.Created)
  562. if err != nil {
  563. log.Error("Unable to parse Created time '%s': %v", *post.Created, err)
  564. created = time.Now()
  565. if db.driverName == driverSQLite {
  566. // SQLite stores datetimes in UTC, so convert time.Now() to it here
  567. created = created.UTC()
  568. }
  569. }
  570. }
  571. stmt, err := db.Prepare("INSERT INTO posts (id, slug, title, content, text_appearance, language, rtl, privacy, owner_id, collection_id, created, updated, view_count) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, " + db.now() + ", ?)")
  572. if err != nil {
  573. return nil, err
  574. }
  575. defer stmt.Close()
  576. _, err = stmt.Exec(friendlyID, slug, post.Title, post.Content, appearance, post.Language, post.IsRTL, 0, ownerID, ownerCollID, created, 0)
  577. if err != nil {
  578. if db.isDuplicateKeyErr(err) {
  579. // Duplicate entry error; try a new slug
  580. // TODO: make this a little more robust
  581. slug = sql.NullString{id.GenSafeUniqueSlug(slug.String), true}
  582. _, err = stmt.Exec(friendlyID, slug, post.Title, post.Content, appearance, post.Language, post.IsRTL, 0, ownerID, ownerCollID, created, 0)
  583. if err != nil {
  584. return nil, handleFailedPostInsert(fmt.Errorf("Retried slug generation, still failed: %v", err))
  585. }
  586. } else {
  587. return nil, handleFailedPostInsert(err)
  588. }
  589. }
  590. // TODO: return Created field in proper format
  591. return &Post{
  592. ID: friendlyID,
  593. Slug: null.NewString(slug.String, slug.Valid),
  594. Font: appearance,
  595. Language: zero.NewString(post.Language.String, post.Language.Valid),
  596. RTL: zero.NewBool(post.IsRTL.Bool, post.IsRTL.Valid),
  597. OwnerID: null.NewInt(userID, true),
  598. CollectionID: null.NewInt(userID, true),
  599. Created: created.Truncate(time.Second).UTC(),
  600. Updated: time.Now().Truncate(time.Second).UTC(),
  601. Title: zero.NewString(*(post.Title), true),
  602. Content: *(post.Content),
  603. }, nil
  604. }
  605. // UpdateOwnedPost updates an existing post with only the given fields in the
  606. // supplied AuthenticatedPost.
  607. func (db *datastore) UpdateOwnedPost(post *AuthenticatedPost, userID int64) error {
  608. params := []interface{}{}
  609. var queryUpdates, sep, authCondition string
  610. if post.Slug != nil && *post.Slug != "" {
  611. queryUpdates += sep + "slug = ?"
  612. sep = ", "
  613. params = append(params, getSlug(*post.Slug, ""))
  614. }
  615. if post.Content != nil {
  616. queryUpdates += sep + "content = ?"
  617. sep = ", "
  618. params = append(params, post.Content)
  619. }
  620. if post.Title != nil {
  621. queryUpdates += sep + "title = ?"
  622. sep = ", "
  623. params = append(params, post.Title)
  624. }
  625. if post.Language.Valid {
  626. queryUpdates += sep + "language = ?"
  627. sep = ", "
  628. params = append(params, post.Language.String)
  629. }
  630. if post.IsRTL.Valid {
  631. queryUpdates += sep + "rtl = ?"
  632. sep = ", "
  633. params = append(params, post.IsRTL.Bool)
  634. }
  635. if post.Font != "" {
  636. queryUpdates += sep + "text_appearance = ?"
  637. sep = ", "
  638. params = append(params, post.Font)
  639. }
  640. if post.Created != nil {
  641. createTime, err := time.Parse(postMetaDateFormat, *post.Created)
  642. if err != nil {
  643. log.Error("Unable to parse Created date: %v", err)
  644. return fmt.Errorf("That's the incorrect format for Created date.")
  645. }
  646. queryUpdates += sep + "created = ?"
  647. sep = ", "
  648. params = append(params, createTime)
  649. }
  650. // WHERE parameters...
  651. // id = ?
  652. params = append(params, post.ID)
  653. // AND owner_id = ?
  654. authCondition = "(owner_id = ?)"
  655. params = append(params, userID)
  656. if queryUpdates == "" {
  657. return ErrPostNoUpdatableVals
  658. }
  659. queryUpdates += sep + "updated = " + db.now()
  660. res, err := db.Exec("UPDATE posts SET "+queryUpdates+" WHERE id = ? AND "+authCondition, params...)
  661. if err != nil {
  662. log.Error("Unable to update owned post: %v", err)
  663. return err
  664. }
  665. rowsAffected, _ := res.RowsAffected()
  666. if rowsAffected == 0 {
  667. // Show the correct error message if nothing was updated
  668. var dummy int
  669. err := db.QueryRow("SELECT 1 FROM posts WHERE id = ? AND "+authCondition, post.ID, params[len(params)-1]).Scan(&dummy)
  670. switch {
  671. case err == sql.ErrNoRows:
  672. return ErrUnauthorizedEditPost
  673. case err != nil:
  674. log.Error("Failed selecting from posts: %v", err)
  675. }
  676. return nil
  677. }
  678. return nil
  679. }
  680. func (db *datastore) GetCollectionBy(condition string, value interface{}) (*Collection, error) {
  681. c := &Collection{}
  682. // FIXME: change Collection to reflect database values. Add helper functions to get actual values
  683. var styleSheet, script, format zero.String
  684. row := db.QueryRow("SELECT id, alias, title, description, style_sheet, script, format, owner_id, privacy, view_count FROM collections WHERE "+condition, value)
  685. err := row.Scan(&c.ID, &c.Alias, &c.Title, &c.Description, &styleSheet, &script, &format, &c.OwnerID, &c.Visibility, &c.Views)
  686. switch {
  687. case err == sql.ErrNoRows:
  688. return nil, impart.HTTPError{http.StatusNotFound, "Collection doesn't exist."}
  689. case err != nil:
  690. log.Error("Failed selecting from collections: %v", err)
  691. return nil, err
  692. }
  693. c.StyleSheet = styleSheet.String
  694. c.Script = script.String
  695. c.Format = format.String
  696. c.Public = c.IsPublic()
  697. c.db = db
  698. return c, nil
  699. }
  700. func (db *datastore) GetCollection(alias string) (*Collection, error) {
  701. return db.GetCollectionBy("alias = ?", alias)
  702. }
  703. func (db *datastore) GetCollectionForPad(alias string) (*Collection, error) {
  704. c := &Collection{Alias: alias}
  705. row := db.QueryRow("SELECT id, alias, title, description, privacy FROM collections WHERE alias = ?", alias)
  706. err := row.Scan(&c.ID, &c.Alias, &c.Title, &c.Description, &c.Visibility)
  707. switch {
  708. case err == sql.ErrNoRows:
  709. return c, impart.HTTPError{http.StatusNotFound, "Collection doesn't exist."}
  710. case err != nil:
  711. log.Error("Failed selecting from collections: %v", err)
  712. return c, ErrInternalGeneral
  713. }
  714. c.Public = c.IsPublic()
  715. return c, nil
  716. }
  717. func (db *datastore) GetCollectionByID(id int64) (*Collection, error) {
  718. return db.GetCollectionBy("id = ?", id)
  719. }
  720. func (db *datastore) GetCollectionFromDomain(host string) (*Collection, error) {
  721. return db.GetCollectionBy("host = ?", host)
  722. }
  723. func (db *datastore) UpdateCollection(c *SubmittedCollection, alias string) error {
  724. q := query.NewUpdate().
  725. SetStringPtr(c.Title, "title").
  726. SetStringPtr(c.Description, "description").
  727. SetNullString(c.StyleSheet, "style_sheet").
  728. SetNullString(c.Script, "script")
  729. if c.Format != nil {
  730. cf := &CollectionFormat{Format: c.Format.String}
  731. if cf.Valid() {
  732. q.SetNullString(c.Format, "format")
  733. }
  734. }
  735. var updatePass bool
  736. if c.Visibility != nil && (collVisibility(*c.Visibility)&CollProtected == 0 || c.Pass != "") {
  737. q.SetIntPtr(c.Visibility, "privacy")
  738. if c.Pass != "" {
  739. updatePass = true
  740. }
  741. }
  742. // WHERE values
  743. q.Where("alias = ? AND owner_id = ?", alias, c.OwnerID)
  744. if q.Updates == "" {
  745. return ErrPostNoUpdatableVals
  746. }
  747. // Find any current domain
  748. var collID int64
  749. var rowsAffected int64
  750. var changed bool
  751. var res sql.Result
  752. err := db.QueryRow("SELECT id FROM collections WHERE alias = ?", alias).Scan(&collID)
  753. if err != nil {
  754. log.Error("Failed selecting from collections: %v. Some things won't work.", err)
  755. }
  756. // Update MathJax value
  757. if c.MathJax {
  758. if db.driverName == driverSQLite {
  759. _, err = db.Exec("INSERT OR REPLACE INTO collectionattributes (collection_id, attribute, value) VALUES (?, ?, ?)", collID, "render_mathjax", "1")
  760. } else {
  761. _, err = db.Exec("INSERT INTO collectionattributes (collection_id, attribute, value) VALUES (?, ?, ?) "+db.upsert("collection_id", "attribute")+" value = ?", collID, "render_mathjax", "1", "1")
  762. }
  763. if err != nil {
  764. log.Error("Unable to insert render_mathjax value: %v", err)
  765. return err
  766. }
  767. } else {
  768. _, err = db.Exec("DELETE FROM collectionattributes WHERE collection_id = ? AND attribute = ?", collID, "render_mathjax")
  769. if err != nil {
  770. log.Error("Unable to delete render_mathjax value: %v", err)
  771. return err
  772. }
  773. }
  774. // Update rest of the collection data
  775. res, err = db.Exec("UPDATE collections SET "+q.Updates+" WHERE "+q.Conditions, q.Params...)
  776. if err != nil {
  777. log.Error("Unable to update collection: %v", err)
  778. return err
  779. }
  780. rowsAffected, _ = res.RowsAffected()
  781. if !changed || rowsAffected == 0 {
  782. // Show the correct error message if nothing was updated
  783. var dummy int
  784. err := db.QueryRow("SELECT 1 FROM collections WHERE alias = ? AND owner_id = ?", alias, c.OwnerID).Scan(&dummy)
  785. switch {
  786. case err == sql.ErrNoRows:
  787. return ErrUnauthorizedEditPost
  788. case err != nil:
  789. log.Error("Failed selecting from collections: %v", err)
  790. }
  791. if !updatePass {
  792. return nil
  793. }
  794. }
  795. if updatePass {
  796. hashedPass, err := auth.HashPass([]byte(c.Pass))
  797. if err != nil {
  798. log.Error("Unable to create hash: %s", err)
  799. return impart.HTTPError{http.StatusInternalServerError, "Could not create password hash."}
  800. }
  801. if db.driverName == driverSQLite {
  802. _, err = db.Exec("INSERT OR REPLACE INTO collectionpasswords (collection_id, password) VALUES ((SELECT id FROM collections WHERE alias = ?), ?)", alias, hashedPass)
  803. } else {
  804. _, err = db.Exec("INSERT INTO collectionpasswords (collection_id, password) VALUES ((SELECT id FROM collections WHERE alias = ?), ?) "+db.upsert("collection_id")+" password = ?", alias, hashedPass, hashedPass)
  805. }
  806. if err != nil {
  807. return err
  808. }
  809. }
  810. return nil
  811. }
  812. const postCols = "id, slug, text_appearance, language, rtl, privacy, owner_id, collection_id, pinned_position, created, updated, view_count, title, content"
  813. // getEditablePost returns a PublicPost with the given ID only if the given
  814. // edit token is valid for the post.
  815. func (db *datastore) GetEditablePost(id, editToken string) (*PublicPost, error) {
  816. // FIXME: code duplicated from getPost()
  817. // TODO: add slight logic difference to getPost / one func
  818. var ownerName sql.NullString
  819. p := &Post{}
  820. row := db.QueryRow("SELECT "+postCols+", (SELECT username FROM users WHERE users.id = posts.owner_id) AS username FROM posts WHERE id = ? LIMIT 1", id)
  821. err := row.Scan(&p.ID, &p.Slug, &p.Font, &p.Language, &p.RTL, &p.Privacy, &p.OwnerID, &p.CollectionID, &p.PinnedPosition, &p.Created, &p.Updated, &p.ViewCount, &p.Title, &p.Content, &ownerName)
  822. switch {
  823. case err == sql.ErrNoRows:
  824. return nil, ErrPostNotFound
  825. case err != nil:
  826. log.Error("Failed selecting from collections: %v", err)
  827. return nil, err
  828. }
  829. if p.Content == "" && p.Title.String == "" {
  830. return nil, ErrPostUnpublished
  831. }
  832. res := p.processPost()
  833. if ownerName.Valid {
  834. res.Owner = &PublicUser{Username: ownerName.String}
  835. }
  836. return &res, nil
  837. }
  838. func (db *datastore) PostIDExists(id string) bool {
  839. var dummy bool
  840. err := db.QueryRow("SELECT 1 FROM posts WHERE id = ?", id).Scan(&dummy)
  841. return err == nil && dummy
  842. }
  843. // GetPost gets a public-facing post object from the database. If collectionID
  844. // is > 0, the post will be retrieved by slug and collection ID, rather than
  845. // post ID.
  846. // TODO: break this into two functions:
  847. // - GetPost(id string)
  848. // - GetCollectionPost(slug string, collectionID int64)
  849. func (db *datastore) GetPost(id string, collectionID int64) (*PublicPost, error) {
  850. var ownerName sql.NullString
  851. p := &Post{}
  852. var row *sql.Row
  853. var where string
  854. params := []interface{}{id}
  855. if collectionID > 0 {
  856. where = "slug = ? AND collection_id = ?"
  857. params = append(params, collectionID)
  858. } else {
  859. where = "id = ?"
  860. }
  861. row = db.QueryRow("SELECT "+postCols+", (SELECT username FROM users WHERE users.id = posts.owner_id) AS username FROM posts WHERE "+where+" LIMIT 1", params...)
  862. err := row.Scan(&p.ID, &p.Slug, &p.Font, &p.Language, &p.RTL, &p.Privacy, &p.OwnerID, &p.CollectionID, &p.PinnedPosition, &p.Created, &p.Updated, &p.ViewCount, &p.Title, &p.Content, &ownerName)
  863. switch {
  864. case err == sql.ErrNoRows:
  865. if collectionID > 0 {
  866. return nil, ErrCollectionPageNotFound
  867. }
  868. return nil, ErrPostNotFound
  869. case err != nil:
  870. log.Error("Failed selecting from collections: %v", err)
  871. return nil, err
  872. }
  873. if p.Content == "" && p.Title.String == "" {
  874. return nil, ErrPostUnpublished
  875. }
  876. res := p.processPost()
  877. if ownerName.Valid {
  878. res.Owner = &PublicUser{Username: ownerName.String}
  879. }
  880. return &res, nil
  881. }
  882. // TODO: don't duplicate getPost() functionality
  883. func (db *datastore) GetOwnedPost(id string, ownerID int64) (*PublicPost, error) {
  884. p := &Post{}
  885. var row *sql.Row
  886. where := "id = ? AND owner_id = ?"
  887. params := []interface{}{id, ownerID}
  888. row = db.QueryRow("SELECT "+postCols+" FROM posts WHERE "+where+" LIMIT 1", params...)
  889. err := row.Scan(&p.ID, &p.Slug, &p.Font, &p.Language, &p.RTL, &p.Privacy, &p.OwnerID, &p.CollectionID, &p.PinnedPosition, &p.Created, &p.Updated, &p.ViewCount, &p.Title, &p.Content)
  890. switch {
  891. case err == sql.ErrNoRows:
  892. return nil, ErrPostNotFound
  893. case err != nil:
  894. log.Error("Failed selecting from collections: %v", err)
  895. return nil, err
  896. }
  897. if p.Content == "" && p.Title.String == "" {
  898. return nil, ErrPostUnpublished
  899. }
  900. res := p.processPost()
  901. return &res, nil
  902. }
  903. func (db *datastore) GetPostProperty(id string, collectionID int64, property string) (interface{}, error) {
  904. propSelects := map[string]string{
  905. "views": "view_count AS views",
  906. }
  907. selectQuery, ok := propSelects[property]
  908. if !ok {
  909. return nil, impart.HTTPError{http.StatusBadRequest, fmt.Sprintf("Invalid property: %s.", property)}
  910. }
  911. var res interface{}
  912. var row *sql.Row
  913. if collectionID != 0 {
  914. row = db.QueryRow("SELECT "+selectQuery+" FROM posts WHERE slug = ? AND collection_id = ? LIMIT 1", id, collectionID)
  915. } else {
  916. row = db.QueryRow("SELECT "+selectQuery+" FROM posts WHERE id = ? LIMIT 1", id)
  917. }
  918. err := row.Scan(&res)
  919. switch {
  920. case err == sql.ErrNoRows:
  921. return nil, impart.HTTPError{http.StatusNotFound, "Post not found."}
  922. case err != nil:
  923. log.Error("Failed selecting post: %v", err)
  924. return nil, err
  925. }
  926. return res, nil
  927. }
  928. // GetPostsCount modifies the CollectionObj to include the correct number of
  929. // standard (non-pinned) posts. It will return future posts if `includeFuture`
  930. // is true.
  931. func (db *datastore) GetPostsCount(c *CollectionObj, includeFuture bool) {
  932. var count int64
  933. timeCondition := ""
  934. if !includeFuture {
  935. timeCondition = "AND created <= " + db.now()
  936. }
  937. err := db.QueryRow("SELECT COUNT(*) FROM posts WHERE collection_id = ? AND pinned_position IS NULL "+timeCondition, c.ID).Scan(&count)
  938. switch {
  939. case err == sql.ErrNoRows:
  940. c.TotalPosts = 0
  941. case err != nil:
  942. log.Error("Failed selecting from collections: %v", err)
  943. c.TotalPosts = 0
  944. }
  945. c.TotalPosts = int(count)
  946. }
  947. // GetPosts retrieves all posts for the given Collection.
  948. // It will return future posts if `includeFuture` is true.
  949. // It will include only standard (non-pinned) posts unless `includePinned` is true.
  950. // TODO: change includeFuture to isOwner, since that's how it's used
  951. func (db *datastore) GetPosts(cfg *config.Config, c *Collection, page int, includeFuture, forceRecentFirst, includePinned bool) (*[]PublicPost, error) {
  952. collID := c.ID
  953. cf := c.NewFormat()
  954. order := "DESC"
  955. if cf.Ascending() && !forceRecentFirst {
  956. order = "ASC"
  957. }
  958. pagePosts := cf.PostsPerPage()
  959. start := page*pagePosts - pagePosts
  960. if page == 0 {
  961. start = 0
  962. pagePosts = 1000
  963. }
  964. limitStr := ""
  965. if page > 0 {
  966. limitStr = fmt.Sprintf(" LIMIT %d, %d", start, pagePosts)
  967. }
  968. timeCondition := ""
  969. if !includeFuture {
  970. timeCondition = "AND created <= " + db.now()
  971. }
  972. pinnedCondition := ""
  973. if !includePinned {
  974. pinnedCondition = "AND pinned_position IS NULL"
  975. }
  976. rows, err := db.Query("SELECT "+postCols+" FROM posts WHERE collection_id = ? "+pinnedCondition+" "+timeCondition+" ORDER BY created "+order+limitStr, collID)
  977. if err != nil {
  978. log.Error("Failed selecting from posts: %v", err)
  979. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve collection posts."}
  980. }
  981. defer rows.Close()
  982. // TODO: extract this common row scanning logic for queries using `postCols`
  983. posts := []PublicPost{}
  984. for rows.Next() {
  985. p := &Post{}
  986. err = rows.Scan(&p.ID, &p.Slug, &p.Font, &p.Language, &p.RTL, &p.Privacy, &p.OwnerID, &p.CollectionID, &p.PinnedPosition, &p.Created, &p.Updated, &p.ViewCount, &p.Title, &p.Content)
  987. if err != nil {
  988. log.Error("Failed scanning row: %v", err)
  989. break
  990. }
  991. p.extractData()
  992. p.formatContent(cfg, c, includeFuture)
  993. posts = append(posts, p.processPost())
  994. }
  995. err = rows.Err()
  996. if err != nil {
  997. log.Error("Error after Next() on rows: %v", err)
  998. }
  999. return &posts, nil
  1000. }
  1001. // GetPostsTagged retrieves all posts on the given Collection that contain the
  1002. // given tag.
  1003. // It will return future posts if `includeFuture` is true.
  1004. // TODO: change includeFuture to isOwner, since that's how it's used
  1005. func (db *datastore) GetPostsTagged(cfg *config.Config, c *Collection, tag string, page int, includeFuture bool) (*[]PublicPost, error) {
  1006. collID := c.ID
  1007. cf := c.NewFormat()
  1008. order := "DESC"
  1009. if cf.Ascending() {
  1010. order = "ASC"
  1011. }
  1012. pagePosts := cf.PostsPerPage()
  1013. start := page*pagePosts - pagePosts
  1014. if page == 0 {
  1015. start = 0
  1016. pagePosts = 1000
  1017. }
  1018. limitStr := ""
  1019. if page > 0 {
  1020. limitStr = fmt.Sprintf(" LIMIT %d, %d", start, pagePosts)
  1021. }
  1022. timeCondition := ""
  1023. if !includeFuture {
  1024. timeCondition = "AND created <= " + db.now()
  1025. }
  1026. var rows *sql.Rows
  1027. var err error
  1028. if db.driverName == driverSQLite {
  1029. rows, err = db.Query("SELECT "+postCols+" FROM posts WHERE collection_id = ? AND LOWER(content) regexp ? "+timeCondition+" ORDER BY created "+order+limitStr, collID, `.*#`+strings.ToLower(tag)+`\b.*`)
  1030. } else {
  1031. rows, err = db.Query("SELECT "+postCols+" FROM posts WHERE collection_id = ? AND LOWER(content) RLIKE ? "+timeCondition+" ORDER BY created "+order+limitStr, collID, "#"+strings.ToLower(tag)+"[[:>:]]")
  1032. }
  1033. if err != nil {
  1034. log.Error("Failed selecting from posts: %v", err)
  1035. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve collection posts."}
  1036. }
  1037. defer rows.Close()
  1038. // TODO: extract this common row scanning logic for queries using `postCols`
  1039. posts := []PublicPost{}
  1040. for rows.Next() {
  1041. p := &Post{}
  1042. err = rows.Scan(&p.ID, &p.Slug, &p.Font, &p.Language, &p.RTL, &p.Privacy, &p.OwnerID, &p.CollectionID, &p.PinnedPosition, &p.Created, &p.Updated, &p.ViewCount, &p.Title, &p.Content)
  1043. if err != nil {
  1044. log.Error("Failed scanning row: %v", err)
  1045. break
  1046. }
  1047. p.extractData()
  1048. p.formatContent(cfg, c, includeFuture)
  1049. posts = append(posts, p.processPost())
  1050. }
  1051. err = rows.Err()
  1052. if err != nil {
  1053. log.Error("Error after Next() on rows: %v", err)
  1054. }
  1055. return &posts, nil
  1056. }
  1057. func (db *datastore) GetAPFollowers(c *Collection) (*[]RemoteUser, error) {
  1058. rows, err := db.Query("SELECT actor_id, inbox, shared_inbox FROM remotefollows f INNER JOIN remoteusers u ON f.remote_user_id = u.id WHERE collection_id = ?", c.ID)
  1059. if err != nil {
  1060. log.Error("Failed selecting from followers: %v", err)
  1061. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve followers."}
  1062. }
  1063. defer rows.Close()
  1064. followers := []RemoteUser{}
  1065. for rows.Next() {
  1066. f := RemoteUser{}
  1067. err = rows.Scan(&f.ActorID, &f.Inbox, &f.SharedInbox)
  1068. followers = append(followers, f)
  1069. }
  1070. return &followers, nil
  1071. }
  1072. // CanCollect returns whether or not the given user can add the given post to a
  1073. // collection. This is true when a post is already owned by the user.
  1074. // NOTE: this is currently only used to potentially add owned posts to a
  1075. // collection. This has the SIDE EFFECT of also generating a slug for the post.
  1076. // FIXME: make this side effect more explicit (or extract it)
  1077. func (db *datastore) CanCollect(cpr *ClaimPostRequest, userID int64) bool {
  1078. var title, content string
  1079. var lang sql.NullString
  1080. err := db.QueryRow("SELECT title, content, language FROM posts WHERE id = ? AND owner_id = ?", cpr.ID, userID).Scan(&title, &content, &lang)
  1081. switch {
  1082. case err == sql.ErrNoRows:
  1083. return false
  1084. case err != nil:
  1085. log.Error("Failed on post CanCollect(%s, %d): %v", cpr.ID, userID, err)
  1086. return false
  1087. }
  1088. // Since we have the post content and the post is collectable, generate the
  1089. // post's slug now.
  1090. cpr.Slug = getSlugFromPost(title, content, lang.String)
  1091. return true
  1092. }
  1093. func (db *datastore) AttemptClaim(p *ClaimPostRequest, query string, params []interface{}, slugIdx int) (sql.Result, error) {
  1094. qRes, err := db.Exec(query, params...)
  1095. if err != nil {
  1096. if db.isDuplicateKeyErr(err) && slugIdx > -1 {
  1097. s := id.GenSafeUniqueSlug(p.Slug)
  1098. if s == p.Slug {
  1099. // Sanity check to prevent infinite recursion
  1100. return qRes, fmt.Errorf("GenSafeUniqueSlug generated nothing unique: %s", s)
  1101. }
  1102. p.Slug = s
  1103. params[slugIdx] = p.Slug
  1104. return db.AttemptClaim(p, query, params, slugIdx)
  1105. }
  1106. return qRes, fmt.Errorf("attemptClaim: %s", err)
  1107. }
  1108. return qRes, nil
  1109. }
  1110. func (db *datastore) DispersePosts(userID int64, postIDs []string) (*[]ClaimPostResult, error) {
  1111. postClaimReqs := map[string]bool{}
  1112. res := []ClaimPostResult{}
  1113. for i := range postIDs {
  1114. postID := postIDs[i]
  1115. r := ClaimPostResult{Code: 0, ErrorMessage: ""}
  1116. // Perform post validation
  1117. if postID == "" {
  1118. r.ErrorMessage = "Missing post ID. "
  1119. }
  1120. if _, ok := postClaimReqs[postID]; ok {
  1121. r.Code = 429
  1122. r.ErrorMessage = "You've already tried anonymizing this post."
  1123. r.ID = postID
  1124. res = append(res, r)
  1125. continue
  1126. }
  1127. postClaimReqs[postID] = true
  1128. var err error
  1129. // Get full post information to return
  1130. var fullPost *PublicPost
  1131. fullPost, err = db.GetPost(postID, 0)
  1132. if err != nil {
  1133. if err, ok := err.(impart.HTTPError); ok {
  1134. r.Code = err.Status
  1135. r.ErrorMessage = err.Message
  1136. r.ID = postID
  1137. res = append(res, r)
  1138. continue
  1139. } else {
  1140. log.Error("Error getting post in dispersePosts: %v", err)
  1141. }
  1142. }
  1143. if fullPost.OwnerID.Int64 != userID {
  1144. r.Code = http.StatusConflict
  1145. r.ErrorMessage = "Post is already owned by someone else."
  1146. r.ID = postID
  1147. res = append(res, r)
  1148. continue
  1149. }
  1150. var qRes sql.Result
  1151. var query string
  1152. var params []interface{}
  1153. // Do AND owner_id = ? for sanity.
  1154. // This should've been caught and returned with a good error message
  1155. // just above.
  1156. query = "UPDATE posts SET collection_id = NULL WHERE id = ? AND owner_id = ?"
  1157. params = []interface{}{postID, userID}
  1158. qRes, err = db.Exec(query, params...)
  1159. if err != nil {
  1160. r.Code = http.StatusInternalServerError
  1161. r.ErrorMessage = "A glitch happened on our end."
  1162. r.ID = postID
  1163. res = append(res, r)
  1164. log.Error("dispersePosts (post %s): %v", postID, err)
  1165. continue
  1166. }
  1167. // Post was successfully dispersed
  1168. r.Code = http.StatusOK
  1169. r.Post = fullPost
  1170. rowsAffected, _ := qRes.RowsAffected()
  1171. if rowsAffected == 0 {
  1172. // This was already claimed, but return 200
  1173. r.Code = http.StatusOK
  1174. }
  1175. res = append(res, r)
  1176. }
  1177. return &res, nil
  1178. }
  1179. func (db *datastore) ClaimPosts(cfg *config.Config, userID int64, collAlias string, posts *[]ClaimPostRequest) (*[]ClaimPostResult, error) {
  1180. postClaimReqs := map[string]bool{}
  1181. res := []ClaimPostResult{}
  1182. postCollAlias := collAlias
  1183. for i := range *posts {
  1184. p := (*posts)[i]
  1185. if &p == nil {
  1186. continue
  1187. }
  1188. r := ClaimPostResult{Code: 0, ErrorMessage: ""}
  1189. // Perform post validation
  1190. if p.ID == "" {
  1191. r.ErrorMessage = "Missing post ID `id`. "
  1192. }
  1193. if _, ok := postClaimReqs[p.ID]; ok {
  1194. r.Code = 429
  1195. r.ErrorMessage = "You've already tried claiming this post."
  1196. r.ID = p.ID
  1197. res = append(res, r)
  1198. continue
  1199. }
  1200. postClaimReqs[p.ID] = true
  1201. canCollect := db.CanCollect(&p, userID)
  1202. if !canCollect && p.Token == "" {
  1203. // TODO: ensure post isn't owned by anyone else when a valid modify
  1204. // token is given.
  1205. r.ErrorMessage += "Missing post Edit Token `token`."
  1206. }
  1207. if r.ErrorMessage != "" {
  1208. // Post validate failed
  1209. r.Code = http.StatusBadRequest
  1210. r.ID = p.ID
  1211. res = append(res, r)
  1212. continue
  1213. }
  1214. var err error
  1215. var qRes sql.Result
  1216. var query string
  1217. var params []interface{}
  1218. var slugIdx int = -1
  1219. var coll *Collection
  1220. if collAlias == "" {
  1221. // Posts are being claimed at /posts/claim, not
  1222. // /collections/{alias}/collect, so use given individual collection
  1223. // to associate post with.
  1224. postCollAlias = p.CollectionAlias
  1225. }
  1226. if postCollAlias != "" {
  1227. // Associate this post with a collection
  1228. if p.CreateCollection {
  1229. // This is a new collection
  1230. // TODO: consider removing this. This seriously complicates this
  1231. // method and adds another (unnecessary?) logic path.
  1232. coll, err = db.CreateCollection(cfg, postCollAlias, "", userID)
  1233. if err != nil {
  1234. if err, ok := err.(impart.HTTPError); ok {
  1235. r.Code = err.Status
  1236. r.ErrorMessage = err.Message
  1237. } else {
  1238. r.Code = http.StatusInternalServerError
  1239. r.ErrorMessage = "Unknown error occurred creating collection"
  1240. }
  1241. r.ID = p.ID
  1242. res = append(res, r)
  1243. continue
  1244. }
  1245. } else {
  1246. // Attempt to add to existing collection
  1247. coll, err = db.GetCollection(postCollAlias)
  1248. if err != nil {
  1249. if err, ok := err.(impart.HTTPError); ok {
  1250. if err.Status == http.StatusNotFound {
  1251. // Show obfuscated "forbidden" response, as if attempting to add to an
  1252. // unowned blog.
  1253. r.Code = ErrForbiddenCollection.Status
  1254. r.ErrorMessage = ErrForbiddenCollection.Message
  1255. } else {
  1256. r.Code = err.Status
  1257. r.ErrorMessage = err.Message
  1258. }
  1259. } else {
  1260. r.Code = http.StatusInternalServerError
  1261. r.ErrorMessage = "Unknown error occurred claiming post with collection"
  1262. }
  1263. r.ID = p.ID
  1264. res = append(res, r)
  1265. continue
  1266. }
  1267. if coll.OwnerID != userID {
  1268. r.Code = ErrForbiddenCollection.Status
  1269. r.ErrorMessage = ErrForbiddenCollection.Message
  1270. r.ID = p.ID
  1271. res = append(res, r)
  1272. continue
  1273. }
  1274. }
  1275. if p.Slug == "" {
  1276. p.Slug = p.ID
  1277. }
  1278. if canCollect {
  1279. // User already owns this post, so just add it to the given
  1280. // collection.
  1281. query = "UPDATE posts SET collection_id = ?, slug = ? WHERE id = ? AND owner_id = ?"
  1282. params = []interface{}{coll.ID, p.Slug, p.ID, userID}
  1283. slugIdx = 1
  1284. } else {
  1285. query = "UPDATE posts SET owner_id = ?, collection_id = ?, slug = ? WHERE id = ? AND modify_token = ? AND owner_id IS NULL"
  1286. params = []interface{}{userID, coll.ID, p.Slug, p.ID, p.Token}
  1287. slugIdx = 2
  1288. }
  1289. } else {
  1290. query = "UPDATE posts SET owner_id = ? WHERE id = ? AND modify_token = ? AND owner_id IS NULL"
  1291. params = []interface{}{userID, p.ID, p.Token}
  1292. }
  1293. qRes, err = db.AttemptClaim(&p, query, params, slugIdx)
  1294. if err != nil {
  1295. r.Code = http.StatusInternalServerError
  1296. r.ErrorMessage = "An unknown error occurred."
  1297. r.ID = p.ID
  1298. res = append(res, r)
  1299. log.Error("claimPosts (post %s): %v", p.ID, err)
  1300. continue
  1301. }
  1302. // Get full post information to return
  1303. var fullPost *PublicPost
  1304. if p.Token != "" {
  1305. fullPost, err = db.GetEditablePost(p.ID, p.Token)
  1306. } else {
  1307. fullPost, err = db.GetPost(p.ID, 0)
  1308. }
  1309. if err != nil {
  1310. if err, ok := err.(impart.HTTPError); ok {
  1311. r.Code = err.Status
  1312. r.ErrorMessage = err.Message
  1313. r.ID = p.ID
  1314. res = append(res, r)
  1315. continue
  1316. }
  1317. }
  1318. if fullPost.OwnerID.Int64 != userID {
  1319. r.Code = http.StatusConflict
  1320. r.ErrorMessage = "Post is already owned by someone else."
  1321. r.ID = p.ID
  1322. res = append(res, r)
  1323. continue
  1324. }
  1325. // Post was successfully claimed
  1326. r.Code = http.StatusOK
  1327. r.Post = fullPost
  1328. if coll != nil {
  1329. r.Post.Collection = &CollectionObj{Collection: *coll}
  1330. }
  1331. rowsAffected, _ := qRes.RowsAffected()
  1332. if rowsAffected == 0 {
  1333. // This was already claimed, but return 200
  1334. r.Code = http.StatusOK
  1335. }
  1336. res = append(res, r)
  1337. }
  1338. return &res, nil
  1339. }
  1340. func (db *datastore) UpdatePostPinState(pinned bool, postID string, collID, ownerID, pos int64) error {
  1341. if pos <= 0 || pos > 20 {
  1342. pos = db.GetLastPinnedPostPos(collID) + 1
  1343. if pos == -1 {
  1344. pos = 1
  1345. }
  1346. }
  1347. var err error
  1348. if pinned {
  1349. _, err = db.Exec("UPDATE posts SET pinned_position = ? WHERE id = ?", pos, postID)
  1350. } else {
  1351. _, err = db.Exec("UPDATE posts SET pinned_position = NULL WHERE id = ?", postID)
  1352. }
  1353. if err != nil {
  1354. log.Error("Unable to update pinned post: %v", err)
  1355. return err
  1356. }
  1357. return nil
  1358. }
  1359. func (db *datastore) GetLastPinnedPostPos(collID int64) int64 {
  1360. var lastPos sql.NullInt64
  1361. err := db.QueryRow("SELECT MAX(pinned_position) FROM posts WHERE collection_id = ? AND pinned_position IS NOT NULL", collID).Scan(&lastPos)
  1362. switch {
  1363. case err == sql.ErrNoRows:
  1364. return -1
  1365. case err != nil:
  1366. log.Error("Failed selecting from posts: %v", err)
  1367. return -1
  1368. }
  1369. if !lastPos.Valid {
  1370. return -1
  1371. }
  1372. return lastPos.Int64
  1373. }
  1374. func (db *datastore) GetPinnedPosts(coll *CollectionObj, includeFuture bool) (*[]PublicPost, error) {
  1375. // FIXME: sqlite-backed instances don't include ellipsis on truncated titles
  1376. timeCondition := ""
  1377. if !includeFuture {
  1378. timeCondition = "AND created <= " + db.now()
  1379. }
  1380. rows, err := db.Query("SELECT id, slug, title, "+db.clip("content", 80)+", pinned_position FROM posts WHERE collection_id = ? AND pinned_position IS NOT NULL "+timeCondition+" ORDER BY pinned_position ASC", coll.ID)
  1381. if err != nil {
  1382. log.Error("Failed selecting pinned posts: %v", err)
  1383. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve pinned posts."}
  1384. }
  1385. defer rows.Close()
  1386. posts := []PublicPost{}
  1387. for rows.Next() {
  1388. p := &Post{}
  1389. err = rows.Scan(&p.ID, &p.Slug, &p.Title, &p.Content, &p.PinnedPosition)
  1390. if err != nil {
  1391. log.Error("Failed scanning row: %v", err)
  1392. break
  1393. }
  1394. p.extractData()
  1395. pp := p.processPost()
  1396. pp.Collection = coll
  1397. posts = append(posts, pp)
  1398. }
  1399. return &posts, nil
  1400. }
  1401. func (db *datastore) GetCollections(u *User, hostName string) (*[]Collection, error) {
  1402. rows, err := db.Query("SELECT id, alias, title, description, privacy, view_count FROM collections WHERE owner_id = ? ORDER BY id ASC", u.ID)
  1403. if err != nil {
  1404. log.Error("Failed selecting from collections: %v", err)
  1405. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user collections."}
  1406. }
  1407. defer rows.Close()
  1408. colls := []Collection{}
  1409. for rows.Next() {
  1410. c := Collection{}
  1411. err = rows.Scan(&c.ID, &c.Alias, &c.Title, &c.Description, &c.Visibility, &c.Views)
  1412. if err != nil {
  1413. log.Error("Failed scanning row: %v", err)
  1414. break
  1415. }
  1416. c.hostName = hostName
  1417. c.URL = c.CanonicalURL()
  1418. c.Public = c.IsPublic()
  1419. colls = append(colls, c)
  1420. }
  1421. err = rows.Err()
  1422. if err != nil {
  1423. log.Error("Error after Next() on rows: %v", err)
  1424. }
  1425. return &colls, nil
  1426. }
  1427. func (db *datastore) GetPublishableCollections(u *User, hostName string) (*[]Collection, error) {
  1428. c, err := db.GetCollections(u, hostName)
  1429. if err != nil {
  1430. return nil, err
  1431. }
  1432. if len(*c) == 0 {
  1433. return nil, impart.HTTPError{http.StatusInternalServerError, "You don't seem to have any blogs; they might've moved to another account. Try logging out and logging into your other account."}
  1434. }
  1435. return c, nil
  1436. }
  1437. func (db *datastore) GetMeStats(u *User) userMeStats {
  1438. s := userMeStats{}
  1439. // User counts
  1440. colls, _ := db.GetUserCollectionCount(u.ID)
  1441. s.TotalCollections = colls
  1442. var articles, collPosts uint64
  1443. err := db.QueryRow("SELECT COUNT(*) FROM posts WHERE owner_id = ? AND collection_id IS NULL", u.ID).Scan(&articles)
  1444. if err != nil && err != sql.ErrNoRows {
  1445. log.Error("Couldn't get articles count for user %d: %v", u.ID, err)
  1446. }
  1447. s.TotalArticles = articles
  1448. err = db.QueryRow("SELECT COUNT(*) FROM posts WHERE owner_id = ? AND collection_id IS NOT NULL", u.ID).Scan(&collPosts)
  1449. if err != nil && err != sql.ErrNoRows {
  1450. log.Error("Couldn't get coll posts count for user %d: %v", u.ID, err)
  1451. }
  1452. s.CollectionPosts = collPosts
  1453. return s
  1454. }
  1455. func (db *datastore) GetTotalCollections() (collCount int64, err error) {
  1456. err = db.QueryRow(`
  1457. SELECT COUNT(*)
  1458. FROM collections c
  1459. LEFT JOIN users u ON u.id = c.owner_id
  1460. WHERE u.status = 0`).Scan(&collCount)
  1461. if err != nil {
  1462. log.Error("Unable to fetch collections count: %v", err)
  1463. }
  1464. return
  1465. }
  1466. func (db *datastore) GetTotalPosts() (postCount int64, err error) {
  1467. err = db.QueryRow(`
  1468. SELECT COUNT(*)
  1469. FROM posts p
  1470. LEFT JOIN users u ON u.id = p.owner_id
  1471. WHERE u.status = 0`).Scan(&postCount)
  1472. if err != nil {
  1473. log.Error("Unable to fetch posts count: %v", err)
  1474. }
  1475. return
  1476. }
  1477. func (db *datastore) GetTopPosts(u *User, alias string) (*[]PublicPost, error) {
  1478. params := []interface{}{u.ID}
  1479. where := ""
  1480. if alias != "" {
  1481. where = " AND alias = ?"
  1482. params = append(params, alias)
  1483. }
  1484. rows, err := db.Query("SELECT p.id, p.slug, p.view_count, p.title, c.alias, c.title, c.description, c.view_count FROM posts p LEFT JOIN collections c ON p.collection_id = c.id WHERE p.owner_id = ?"+where+" ORDER BY p.view_count DESC, created DESC LIMIT 25", params...)
  1485. if err != nil {
  1486. log.Error("Failed selecting from posts: %v", err)
  1487. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user top posts."}
  1488. }
  1489. defer rows.Close()
  1490. posts := []PublicPost{}
  1491. var gotErr bool
  1492. for rows.Next() {
  1493. p := Post{}
  1494. c := Collection{}
  1495. var alias, title, description sql.NullString
  1496. var views sql.NullInt64
  1497. err = rows.Scan(&p.ID, &p.Slug, &p.ViewCount, &p.Title, &alias, &title, &description, &views)
  1498. if err != nil {
  1499. log.Error("Failed scanning User.getPosts() row: %v", err)
  1500. gotErr = true
  1501. break
  1502. }
  1503. p.extractData()
  1504. pubPost := p.processPost()
  1505. if alias.Valid && alias.String != "" {
  1506. c.Alias = alias.String
  1507. c.Title = title.String
  1508. c.Description = description.String
  1509. c.Views = views.Int64
  1510. pubPost.Collection = &CollectionObj{Collection: c}
  1511. }
  1512. posts = append(posts, pubPost)
  1513. }
  1514. err = rows.Err()
  1515. if err != nil {
  1516. log.Error("Error after Next() on rows: %v", err)
  1517. }
  1518. if gotErr && len(posts) == 0 {
  1519. // There were a lot of errors
  1520. return nil, impart.HTTPError{http.StatusInternalServerError, "Unable to get data."}
  1521. }
  1522. return &posts, nil
  1523. }
  1524. func (db *datastore) GetAnonymousPosts(u *User) (*[]PublicPost, error) {
  1525. rows, err := db.Query("SELECT id, view_count, title, created, updated, content FROM posts WHERE owner_id = ? AND collection_id IS NULL ORDER BY created DESC", u.ID)
  1526. if err != nil {
  1527. log.Error("Failed selecting from posts: %v", err)
  1528. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user anonymous posts."}
  1529. }
  1530. defer rows.Close()
  1531. posts := []PublicPost{}
  1532. for rows.Next() {
  1533. p := Post{}
  1534. err = rows.Scan(&p.ID, &p.ViewCount, &p.Title, &p.Created, &p.Updated, &p.Content)
  1535. if err != nil {
  1536. log.Error("Failed scanning row: %v", err)
  1537. break
  1538. }
  1539. p.extractData()
  1540. posts = append(posts, p.processPost())
  1541. }
  1542. err = rows.Err()
  1543. if err != nil {
  1544. log.Error("Error after Next() on rows: %v", err)
  1545. }
  1546. return &posts, nil
  1547. }
  1548. func (db *datastore) GetUserPosts(u *User) (*[]PublicPost, error) {
  1549. rows, err := db.Query("SELECT p.id, p.slug, p.view_count, p.title, p.created, p.updated, p.content, p.text_appearance, p.language, p.rtl, c.alias, c.title, c.description, c.view_count FROM posts p LEFT JOIN collections c ON collection_id = c.id WHERE p.owner_id = ? ORDER BY created ASC", u.ID)
  1550. if err != nil {
  1551. log.Error("Failed selecting from posts: %v", err)
  1552. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user posts."}
  1553. }
  1554. defer rows.Close()
  1555. posts := []PublicPost{}
  1556. var gotErr bool
  1557. for rows.Next() {
  1558. p := Post{}
  1559. c := Collection{}
  1560. var alias, title, description sql.NullString
  1561. var views sql.NullInt64
  1562. err = rows.Scan(&p.ID, &p.Slug, &p.ViewCount, &p.Title, &p.Created, &p.Updated, &p.Content, &p.Font, &p.Language, &p.RTL, &alias, &title, &description, &views)
  1563. if err != nil {
  1564. log.Error("Failed scanning User.getPosts() row: %v", err)
  1565. gotErr = true
  1566. break
  1567. }
  1568. p.extractData()
  1569. pubPost := p.processPost()
  1570. if alias.Valid && alias.String != "" {
  1571. c.Alias = alias.String
  1572. c.Title = title.String
  1573. c.Description = description.String
  1574. c.Views = views.Int64
  1575. pubPost.Collection = &CollectionObj{Collection: c}
  1576. }
  1577. posts = append(posts, pubPost)
  1578. }
  1579. err = rows.Err()
  1580. if err != nil {
  1581. log.Error("Error after Next() on rows: %v", err)
  1582. }
  1583. if gotErr && len(posts) == 0 {
  1584. // There were a lot of errors
  1585. return nil, impart.HTTPError{http.StatusInternalServerError, "Unable to get data."}
  1586. }
  1587. return &posts, nil
  1588. }
  1589. func (db *datastore) GetUserPostsCount(userID int64) int64 {
  1590. var count int64
  1591. err := db.QueryRow("SELECT COUNT(*) FROM posts WHERE owner_id = ?", userID).Scan(&count)
  1592. switch {
  1593. case err == sql.ErrNoRows:
  1594. return 0
  1595. case err != nil:
  1596. log.Error("Failed selecting posts count for user %d: %v", userID, err)
  1597. return 0
  1598. }
  1599. return count
  1600. }
  1601. // ChangeSettings takes a User and applies the changes in the given
  1602. // userSettings, MODIFYING THE USER with successful changes.
  1603. func (db *datastore) ChangeSettings(app *App, u *User, s *userSettings) error {
  1604. var errPass error
  1605. q := query.NewUpdate()
  1606. // Update email if given
  1607. if s.Email != "" {
  1608. encEmail, err := data.Encrypt(app.keys.EmailKey, s.Email)
  1609. if err != nil {
  1610. log.Error("Couldn't encrypt email %s: %s\n", s.Email, err)
  1611. return impart.HTTPError{http.StatusInternalServerError, "Unable to encrypt email address."}
  1612. }
  1613. q.SetBytes(encEmail, "email")
  1614. // Update the email if something goes awry updating the password
  1615. defer func() {
  1616. if errPass != nil {
  1617. db.UpdateEncryptedUserEmail(u.ID, encEmail)
  1618. }
  1619. }()
  1620. u.Email = zero.StringFrom(s.Email)
  1621. }
  1622. // Update username if given
  1623. var newUsername string
  1624. if s.Username != "" {
  1625. var ie *impart.HTTPError
  1626. newUsername, ie = getValidUsername(app, s.Username, u.Username)
  1627. if ie != nil {
  1628. // Username is invalid
  1629. return *ie
  1630. }
  1631. if !author.IsValidUsername(app.cfg, newUsername) {
  1632. // Ensure the username is syntactically correct.
  1633. return impart.HTTPError{http.StatusPreconditionFailed, "Username isn't valid."}
  1634. }
  1635. t, err := db.Begin()
  1636. if err != nil {
  1637. log.Error("Couldn't start username change transaction: %v", err)
  1638. return err
  1639. }
  1640. _, err = t.Exec("UPDATE users SET username = ? WHERE id = ?", newUsername, u.ID)
  1641. if err != nil {
  1642. t.Rollback()
  1643. if db.isDuplicateKeyErr(err) {
  1644. return impart.HTTPError{http.StatusConflict, "Username is already taken."}
  1645. }
  1646. log.Error("Unable to update users table: %v", err)
  1647. return ErrInternalGeneral
  1648. }
  1649. _, err = t.Exec("UPDATE collections SET alias = ? WHERE alias = ? AND owner_id = ?", newUsername, u.Username, u.ID)
  1650. if err != nil {
  1651. t.Rollback()
  1652. if db.isDuplicateKeyErr(err) {
  1653. return impart.HTTPError{http.StatusConflict, "Username is already taken."}
  1654. }
  1655. log.Error("Unable to update collection: %v", err)
  1656. return ErrInternalGeneral
  1657. }
  1658. // Keep track of name changes for redirection
  1659. db.RemoveCollectionRedirect(t, newUsername)
  1660. _, err = t.Exec("UPDATE collectionredirects SET new_alias = ? WHERE new_alias = ?", newUsername, u.Username)
  1661. if err != nil {
  1662. log.Error("Unable to update collectionredirects: %v", err)
  1663. }
  1664. _, err = t.Exec("INSERT INTO collectionredirects (prev_alias, new_alias) VALUES (?, ?)", u.Username, newUsername)
  1665. if err != nil {
  1666. log.Error("Unable to add new collectionredirect: %v", err)
  1667. }
  1668. err = t.Commit()
  1669. if err != nil {
  1670. t.Rollback()
  1671. log.Error("Rolling back after Commit(): %v\n", err)
  1672. return err
  1673. }
  1674. u.Username = newUsername
  1675. }
  1676. // Update passphrase if given
  1677. if s.NewPass != "" {
  1678. // Check if user has already set a password
  1679. var err error
  1680. u.HasPass, err = db.IsUserPassSet(u.ID)
  1681. if err != nil {
  1682. errPass = impart.HTTPError{http.StatusInternalServerError, "Unable to retrieve user data."}
  1683. return errPass
  1684. }
  1685. if u.HasPass {
  1686. // Check if currently-set password is correct
  1687. hashedPass := u.HashedPass
  1688. if len(hashedPass) == 0 {
  1689. authUser, err := db.GetUserForAuthByID(u.ID)
  1690. if err != nil {
  1691. errPass = err
  1692. return errPass
  1693. }
  1694. hashedPass = authUser.HashedPass
  1695. }
  1696. if !auth.Authenticated(hashedPass, []byte(s.OldPass)) {
  1697. errPass = impart.HTTPError{http.StatusUnauthorized, "Incorrect password."}
  1698. return errPass
  1699. }
  1700. }
  1701. hashedPass, err := auth.HashPass([]byte(s.NewPass))
  1702. if err != nil {
  1703. errPass = impart.HTTPError{http.StatusInternalServerError, "Could not create password hash."}
  1704. return errPass
  1705. }
  1706. q.SetBytes(hashedPass, "password")
  1707. }
  1708. // WHERE values
  1709. q.Append(u.ID)
  1710. if q.Updates == "" {
  1711. if s.Username == "" {
  1712. return ErrPostNoUpdatableVals
  1713. }
  1714. // Nothing to update except username. That was successful, so return now.
  1715. return nil
  1716. }
  1717. res, err := db.Exec("UPDATE users SET "+q.Updates+" WHERE id = ?", q.Params...)
  1718. if err != nil {
  1719. log.Error("Unable to update collection: %v", err)
  1720. return err
  1721. }
  1722. rowsAffected, _ := res.RowsAffected()
  1723. if rowsAffected == 0 {
  1724. // Show the correct error message if nothing was updated
  1725. var dummy int
  1726. err := db.QueryRow("SELECT 1 FROM users WHERE id = ?", u.ID).Scan(&dummy)
  1727. switch {
  1728. case err == sql.ErrNoRows:
  1729. return ErrUnauthorizedGeneral
  1730. case err != nil:
  1731. log.Error("Failed selecting from users: %v", err)
  1732. }
  1733. return nil
  1734. }
  1735. if s.NewPass != "" && !u.HasPass {
  1736. u.HasPass = true
  1737. }
  1738. return nil
  1739. }
  1740. func (db *datastore) ChangePassphrase(userID int64, sudo bool, curPass string, hashedPass []byte) error {
  1741. var dbPass []byte
  1742. err := db.QueryRow("SELECT password FROM users WHERE id = ?", userID).Scan(&dbPass)
  1743. switch {
  1744. case err == sql.ErrNoRows:
  1745. return ErrUserNotFound
  1746. case err != nil:
  1747. log.Error("Couldn't SELECT user password for change: %v", err)
  1748. return err
  1749. }
  1750. if !sudo && !auth.Authenticated(dbPass, []byte(curPass)) {
  1751. return impart.HTTPError{http.StatusUnauthorized, "Incorrect password."}
  1752. }
  1753. _, err = db.Exec("UPDATE users SET password = ? WHERE id = ?", hashedPass, userID)
  1754. if err != nil {
  1755. log.Error("Could not update passphrase: %v", err)
  1756. return err
  1757. }
  1758. return nil
  1759. }
  1760. func (db *datastore) RemoveCollectionRedirect(t *sql.Tx, alias string) error {
  1761. _, err := t.Exec("DELETE FROM collectionredirects WHERE prev_alias = ?", alias)
  1762. if err != nil {
  1763. log.Error("Unable to delete from collectionredirects: %v", err)
  1764. return err
  1765. }
  1766. return nil
  1767. }
  1768. func (db *datastore) GetCollectionRedirect(alias string) (new string) {
  1769. row := db.QueryRow("SELECT new_alias FROM collectionredirects WHERE prev_alias = ?", alias)
  1770. err := row.Scan(&new)
  1771. if err != nil && err != sql.ErrNoRows {
  1772. log.Error("Failed selecting from collectionredirects: %v", err)
  1773. }
  1774. return
  1775. }
  1776. func (db *datastore) DeleteCollection(alias string, userID int64) error {
  1777. c := &Collection{Alias: alias}
  1778. var username string
  1779. row := db.QueryRow("SELECT username FROM users WHERE id = ?", userID)
  1780. err := row.Scan(&username)
  1781. if err != nil {
  1782. return err
  1783. }
  1784. // Ensure user isn't deleting their main blog
  1785. if alias == username {
  1786. return impart.HTTPError{http.StatusForbidden, "You cannot currently delete your primary blog."}
  1787. }
  1788. row = db.QueryRow("SELECT id FROM collections WHERE alias = ? AND owner_id = ?", alias, userID)
  1789. err = row.Scan(&c.ID)
  1790. switch {
  1791. case err == sql.ErrNoRows:
  1792. return impart.HTTPError{http.StatusNotFound, "Collection doesn't exist or you're not allowed to delete it."}
  1793. case err != nil:
  1794. log.Error("Failed selecting from collections: %v", err)
  1795. return ErrInternalGeneral
  1796. }
  1797. t, err := db.Begin()
  1798. if err != nil {
  1799. return err
  1800. }
  1801. // Float all collection's posts
  1802. _, err = t.Exec("UPDATE posts SET collection_id = NULL WHERE collection_id = ? AND owner_id = ?", c.ID, userID)
  1803. if err != nil {
  1804. t.Rollback()
  1805. return err
  1806. }
  1807. // Remove redirects to or from this collection
  1808. _, err = t.Exec("DELETE FROM collectionredirects WHERE prev_alias = ? OR new_alias = ?", alias, alias)
  1809. if err != nil {
  1810. t.Rollback()
  1811. return err
  1812. }
  1813. // Remove any optional collection password
  1814. _, err = t.Exec("DELETE FROM collectionpasswords WHERE collection_id = ?", c.ID)
  1815. if err != nil {
  1816. t.Rollback()
  1817. return err
  1818. }
  1819. // Finally, delete collection itself
  1820. _, err = t.Exec("DELETE FROM collections WHERE id = ?", c.ID)
  1821. if err != nil {
  1822. t.Rollback()
  1823. return err
  1824. }
  1825. err = t.Commit()
  1826. if err != nil {
  1827. t.Rollback()
  1828. return err
  1829. }
  1830. return nil
  1831. }
  1832. func (db *datastore) IsCollectionAttributeOn(id int64, attr string) bool {
  1833. var v string
  1834. err := db.QueryRow("SELECT value FROM collectionattributes WHERE collection_id = ? AND attribute = ?", id, attr).Scan(&v)
  1835. switch {
  1836. case err == sql.ErrNoRows:
  1837. return false
  1838. case err != nil:
  1839. log.Error("Couldn't SELECT value in isCollectionAttributeOn for attribute '%s': %v", attr, err)
  1840. return false
  1841. }
  1842. return v == "1"
  1843. }
  1844. func (db *datastore) CollectionHasAttribute(id int64, attr string) bool {
  1845. var dummy string
  1846. err := db.QueryRow("SELECT value FROM collectionattributes WHERE collection_id = ? AND attribute = ?", id, attr).Scan(&dummy)
  1847. switch {
  1848. case err == sql.ErrNoRows:
  1849. return false
  1850. case err != nil:
  1851. log.Error("Couldn't SELECT value in collectionHasAttribute for attribute '%s': %v", attr, err)
  1852. return false
  1853. }
  1854. return true
  1855. }
  1856. func (db *datastore) DeleteAccount(userID int64) (l *string, err error) {
  1857. debug := ""
  1858. l = &debug
  1859. t, err := db.Begin()
  1860. if err != nil {
  1861. stringLogln(l, "Unable to begin: %v", err)
  1862. return
  1863. }
  1864. // Get all collections
  1865. rows, err := db.Query("SELECT id, alias FROM collections WHERE owner_id = ?", userID)
  1866. if err != nil {
  1867. t.Rollback()
  1868. stringLogln(l, "Unable to get collections: %v", err)
  1869. return
  1870. }
  1871. defer rows.Close()
  1872. colls := []Collection{}
  1873. var c Collection
  1874. for rows.Next() {
  1875. err = rows.Scan(&c.ID, &c.Alias)
  1876. if err != nil {
  1877. t.Rollback()
  1878. stringLogln(l, "Unable to scan collection cols: %v", err)
  1879. return
  1880. }
  1881. colls = append(colls, c)
  1882. }
  1883. var res sql.Result
  1884. for _, c := range colls {
  1885. // TODO: user deleteCollection() func
  1886. // Delete tokens
  1887. res, err = t.Exec("DELETE FROM collectionattributes WHERE collection_id = ?", c.ID)
  1888. if err != nil {
  1889. t.Rollback()
  1890. stringLogln(l, "Unable to delete attributes on %s: %v", c.Alias, err)
  1891. return
  1892. }
  1893. rs, _ := res.RowsAffected()
  1894. stringLogln(l, "Deleted %d for %s from collectionattributes", rs, c.Alias)
  1895. // Remove any optional collection password
  1896. res, err = t.Exec("DELETE FROM collectionpasswords WHERE collection_id = ?", c.ID)
  1897. if err != nil {
  1898. t.Rollback()
  1899. stringLogln(l, "Unable to delete passwords on %s: %v", c.Alias, err)
  1900. return
  1901. }
  1902. rs, _ = res.RowsAffected()
  1903. stringLogln(l, "Deleted %d for %s from collectionpasswords", rs, c.Alias)
  1904. // Remove redirects to this collection
  1905. res, err = t.Exec("DELETE FROM collectionredirects WHERE new_alias = ?", c.Alias)
  1906. if err != nil {
  1907. t.Rollback()
  1908. stringLogln(l, "Unable to delete redirects on %s: %v", c.Alias, err)
  1909. return
  1910. }
  1911. rs, _ = res.RowsAffected()
  1912. stringLogln(l, "Deleted %d for %s from collectionredirects", rs, c.Alias)
  1913. }
  1914. // Delete collections
  1915. res, err = t.Exec("DELETE FROM collections WHERE owner_id = ?", userID)
  1916. if err != nil {
  1917. t.Rollback()
  1918. stringLogln(l, "Unable to delete collections: %v", err)
  1919. return
  1920. }
  1921. rs, _ := res.RowsAffected()
  1922. stringLogln(l, "Deleted %d from collections", rs)
  1923. // Delete tokens
  1924. res, err = t.Exec("DELETE FROM accesstokens WHERE user_id = ?", userID)
  1925. if err != nil {
  1926. t.Rollback()
  1927. stringLogln(l, "Unable to delete access tokens: %v", err)
  1928. return
  1929. }
  1930. rs, _ = res.RowsAffected()
  1931. stringLogln(l, "Deleted %d from accesstokens", rs)
  1932. // Delete posts
  1933. res, err = t.Exec("DELETE FROM posts WHERE owner_id = ?", userID)
  1934. if err != nil {
  1935. t.Rollback()
  1936. stringLogln(l, "Unable to delete posts: %v", err)
  1937. return
  1938. }
  1939. rs, _ = res.RowsAffected()
  1940. stringLogln(l, "Deleted %d from posts", rs)
  1941. res, err = t.Exec("DELETE FROM userattributes WHERE user_id = ?", userID)
  1942. if err != nil {
  1943. t.Rollback()
  1944. stringLogln(l, "Unable to delete attributes: %v", err)
  1945. return
  1946. }
  1947. rs, _ = res.RowsAffected()
  1948. stringLogln(l, "Deleted %d from userattributes", rs)
  1949. res, err = t.Exec("DELETE FROM users WHERE id = ?", userID)
  1950. if err != nil {
  1951. t.Rollback()
  1952. stringLogln(l, "Unable to delete user: %v", err)
  1953. return
  1954. }
  1955. rs, _ = res.RowsAffected()
  1956. stringLogln(l, "Deleted %d from users", rs)
  1957. err = t.Commit()
  1958. if err != nil {
  1959. t.Rollback()
  1960. stringLogln(l, "Unable to commit: %v", err)
  1961. return
  1962. }
  1963. return
  1964. }
  1965. func (db *datastore) GetAPActorKeys(collectionID int64) ([]byte, []byte) {
  1966. var pub, priv []byte
  1967. err := db.QueryRow("SELECT public_key, private_key FROM collectionkeys WHERE collection_id = ?", collectionID).Scan(&pub, &priv)
  1968. switch {
  1969. case err == sql.ErrNoRows:
  1970. // Generate keys
  1971. pub, priv = activitypub.GenerateKeys()
  1972. _, err = db.Exec("INSERT INTO collectionkeys (collection_id, public_key, private_key) VALUES (?, ?, ?)", collectionID, pub, priv)
  1973. if err != nil {
  1974. log.Error("Unable to INSERT new activitypub keypair: %v", err)
  1975. return nil, nil
  1976. }
  1977. case err != nil:
  1978. log.Error("Couldn't SELECT collectionkeys: %v", err)
  1979. return nil, nil
  1980. }
  1981. return pub, priv
  1982. }
  1983. func (db *datastore) CreateUserInvite(id string, userID int64, maxUses int, expires *time.Time) error {
  1984. _, err := db.Exec("INSERT INTO userinvites (id, owner_id, max_uses, created, expires, inactive) VALUES (?, ?, ?, "+db.now()+", ?, 0)", id, userID, maxUses, expires)
  1985. return err
  1986. }
  1987. func (db *datastore) GetUserInvites(userID int64) (*[]Invite, error) {
  1988. rows, err := db.Query("SELECT id, max_uses, created, expires, inactive FROM userinvites WHERE owner_id = ? ORDER BY created DESC", userID)
  1989. if err != nil {
  1990. log.Error("Failed selecting from userinvites: %v", err)
  1991. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user invites."}
  1992. }
  1993. defer rows.Close()
  1994. is := []Invite{}
  1995. for rows.Next() {
  1996. i := Invite{}
  1997. err = rows.Scan(&i.ID, &i.MaxUses, &i.Created, &i.Expires, &i.Inactive)
  1998. is = append(is, i)
  1999. }
  2000. return &is, nil
  2001. }
  2002. func (db *datastore) GetUserInvite(id string) (*Invite, error) {
  2003. var i Invite
  2004. err := db.QueryRow("SELECT id, max_uses, created, expires, inactive FROM userinvites WHERE id = ?", id).Scan(&i.ID, &i.MaxUses, &i.Created, &i.Expires, &i.Inactive)
  2005. switch {
  2006. case err == sql.ErrNoRows:
  2007. return nil, impart.HTTPError{http.StatusNotFound, "Invite doesn't exist."}
  2008. case err != nil:
  2009. log.Error("Failed selecting invite: %v", err)
  2010. return nil, err
  2011. }
  2012. return &i, nil
  2013. }
  2014. // IsUsersInvite returns true if the user with ID created the invite with code
  2015. // and an error other than sql no rows, if any. Will return false in the event
  2016. // of an error.
  2017. func (db *datastore) IsUsersInvite(code string, userID int64) (bool, error) {
  2018. var id string
  2019. err := db.QueryRow("SELECT id FROM userinvites WHERE id = ? AND owner_id = ?", code, userID).Scan(&id)
  2020. if err != nil && err != sql.ErrNoRows {
  2021. log.Error("Failed selecting invite: %v", err)
  2022. return false, err
  2023. }
  2024. return id != "", nil
  2025. }
  2026. func (db *datastore) GetUsersInvitedCount(id string) int64 {
  2027. var count int64
  2028. err := db.QueryRow("SELECT COUNT(*) FROM usersinvited WHERE invite_id = ?", id).Scan(&count)
  2029. switch {
  2030. case err == sql.ErrNoRows:
  2031. return 0
  2032. case err != nil:
  2033. log.Error("Failed selecting users invited count: %v", err)
  2034. return 0
  2035. }
  2036. return count
  2037. }
  2038. func (db *datastore) CreateInvitedUser(inviteID string, userID int64) error {
  2039. _, err := db.Exec("INSERT INTO usersinvited (invite_id, user_id) VALUES (?, ?)", inviteID, userID)
  2040. return err
  2041. }
  2042. func (db *datastore) GetInstancePages() ([]*instanceContent, error) {
  2043. return db.GetAllDynamicContent("page")
  2044. }
  2045. func (db *datastore) GetAllDynamicContent(t string) ([]*instanceContent, error) {
  2046. where := ""
  2047. params := []interface{}{}
  2048. if t != "" {
  2049. where = " WHERE content_type = ?"
  2050. params = append(params, t)
  2051. }
  2052. rows, err := db.Query("SELECT id, title, content, updated, content_type FROM appcontent"+where, params...)
  2053. if err != nil {
  2054. log.Error("Failed selecting from appcontent: %v", err)
  2055. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve instance pages."}
  2056. }
  2057. defer rows.Close()
  2058. pages := []*instanceContent{}
  2059. for rows.Next() {
  2060. c := &instanceContent{}
  2061. err = rows.Scan(&c.ID, &c.Title, &c.Content, &c.Updated, &c.Type)
  2062. if err != nil {
  2063. log.Error("Failed scanning row: %v", err)
  2064. break
  2065. }
  2066. pages = append(pages, c)
  2067. }
  2068. err = rows.Err()
  2069. if err != nil {
  2070. log.Error("Error after Next() on rows: %v", err)
  2071. }
  2072. return pages, nil
  2073. }
  2074. func (db *datastore) GetDynamicContent(id string) (*instanceContent, error) {
  2075. c := &instanceContent{
  2076. ID: id,
  2077. }
  2078. err := db.QueryRow("SELECT title, content, updated, content_type FROM appcontent WHERE id = ?", id).Scan(&c.Title, &c.Content, &c.Updated, &c.Type)
  2079. switch {
  2080. case err == sql.ErrNoRows:
  2081. return nil, nil
  2082. case err != nil:
  2083. log.Error("Couldn't SELECT FROM appcontent for id '%s': %v", id, err)
  2084. return nil, err
  2085. }
  2086. return c, nil
  2087. }
  2088. func (db *datastore) UpdateDynamicContent(id, title, content, contentType string) error {
  2089. var err error
  2090. if db.driverName == driverSQLite {
  2091. _, err = db.Exec("INSERT OR REPLACE INTO appcontent (id, title, content, updated, content_type) VALUES (?, ?, ?, "+db.now()+", ?)", id, title, content, contentType)
  2092. } else {
  2093. _, err = db.Exec("INSERT INTO appcontent (id, title, content, updated, content_type) VALUES (?, ?, ?, "+db.now()+", ?) "+db.upsert("id")+" title = ?, content = ?, updated = "+db.now(), id, title, content, contentType, title, content)
  2094. }
  2095. if err != nil {
  2096. log.Error("Unable to INSERT appcontent for '%s': %v", id, err)
  2097. }
  2098. return err
  2099. }
  2100. func (db *datastore) GetAllUsers(page uint) (*[]User, error) {
  2101. limitStr := fmt.Sprintf("0, %d", adminUsersPerPage)
  2102. if page > 1 {
  2103. limitStr = fmt.Sprintf("%d, %d", (page-1)*adminUsersPerPage, adminUsersPerPage)
  2104. }
  2105. rows, err := db.Query("SELECT id, username, created, status FROM users ORDER BY created DESC LIMIT " + limitStr)
  2106. if err != nil {
  2107. log.Error("Failed selecting from users: %v", err)
  2108. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve all users."}
  2109. }
  2110. defer rows.Close()
  2111. users := []User{}
  2112. for rows.Next() {
  2113. u := User{}
  2114. err = rows.Scan(&u.ID, &u.Username, &u.Created, &u.Status)
  2115. if err != nil {
  2116. log.Error("Failed scanning GetAllUsers() row: %v", err)
  2117. break
  2118. }
  2119. users = append(users, u)
  2120. }
  2121. return &users, nil
  2122. }
  2123. func (db *datastore) GetAllUsersCount() int64 {
  2124. var count int64
  2125. err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)
  2126. switch {
  2127. case err == sql.ErrNoRows:
  2128. return 0
  2129. case err != nil:
  2130. log.Error("Failed selecting all users count: %v", err)
  2131. return 0
  2132. }
  2133. return count
  2134. }
  2135. func (db *datastore) GetUserLastPostTime(id int64) (*time.Time, error) {
  2136. var t time.Time
  2137. err := db.QueryRow("SELECT created FROM posts WHERE owner_id = ? ORDER BY created DESC LIMIT 1", id).Scan(&t)
  2138. switch {
  2139. case err == sql.ErrNoRows:
  2140. return nil, nil
  2141. case err != nil:
  2142. log.Error("Failed selecting last post time from posts: %v", err)
  2143. return nil, err
  2144. }
  2145. return &t, nil
  2146. }
  2147. // SetUserStatus changes a user's status in the database. see Users.UserStatus
  2148. func (db *datastore) SetUserStatus(id int64, status UserStatus) error {
  2149. _, err := db.Exec("UPDATE users SET status = ? WHERE id = ?", status, id)
  2150. if err != nil {
  2151. return fmt.Errorf("failed to update user status: %v", err)
  2152. }
  2153. return nil
  2154. }
  2155. func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) {
  2156. var t time.Time
  2157. err := db.QueryRow("SELECT created FROM posts WHERE collection_id = ? ORDER BY created DESC LIMIT 1", id).Scan(&t)
  2158. switch {
  2159. case err == sql.ErrNoRows:
  2160. return nil, nil
  2161. case err != nil:
  2162. log.Error("Failed selecting last post time from posts: %v", err)
  2163. return nil, err
  2164. }
  2165. return &t, nil
  2166. }
  2167. func (db *datastore) GenerateOAuthState(ctx context.Context) (string, error) {
  2168. state := store.Generate62RandomString(24)
  2169. _, err := db.ExecContext(ctx, "INSERT INTO oauth_client_states (state, used, created_at) VALUES (?, FALSE, NOW())", state)
  2170. if err != nil {
  2171. return "", fmt.Errorf("unable to record oauth client state: %w", err)
  2172. }
  2173. return state, nil
  2174. }
  2175. func (db *datastore) ValidateOAuthState(ctx context.Context, state string) error {
  2176. res, err := db.ExecContext(ctx, "UPDATE oauth_client_states SET used = TRUE WHERE state = ?", state)
  2177. if err != nil {
  2178. return err
  2179. }
  2180. rowsAffected, err := res.RowsAffected()
  2181. if err != nil {
  2182. return err
  2183. }
  2184. if rowsAffected != 1 {
  2185. return fmt.Errorf("state not found")
  2186. }
  2187. return nil
  2188. }
  2189. func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID, remoteUserID int64) error {
  2190. var err error
  2191. if db.driverName == driverSQLite {
  2192. _, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO oauth_users (user_id, remote_user_id) VALUES (?, ?)", localUserID, remoteUserID)
  2193. } else {
  2194. _, err = db.ExecContext(ctx, "INSERT INTO oauth_users (user_id, remote_user_id) VALUES (?, ?) "+db.upsert("user_id")+" user_id = ?", localUserID, remoteUserID, localUserID)
  2195. }
  2196. if err != nil {
  2197. log.Error("Unable to INSERT oauth_users for '%d': %v", localUserID, err)
  2198. }
  2199. return err
  2200. }
  2201. // GetIDForRemoteUser returns a user ID associated with a remote user ID.
  2202. func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error) {
  2203. var userID int64 = -1
  2204. err := db.
  2205. QueryRowContext(ctx, "SELECT user_id FROM oauth_users WHERE remote_user_id = ?", remoteUserID).
  2206. Scan(&userID)
  2207. // Not finding a record is OK.
  2208. if err != nil && err != sql.ErrNoRows {
  2209. return -1, err
  2210. }
  2211. return userID, nil
  2212. }
  2213. // DatabaseInitialized returns whether or not the current datastore has been
  2214. // initialized with the correct schema.
  2215. // Currently, it checks to see if the `users` table exists.
  2216. func (db *datastore) DatabaseInitialized() bool {
  2217. var dummy string
  2218. var err error
  2219. if db.driverName == driverSQLite {
  2220. err = db.QueryRow("SELECT name FROM sqlite_master WHERE type = 'table' AND name = 'users'").Scan(&dummy)
  2221. } else {
  2222. err = db.QueryRow("SHOW TABLES LIKE 'users'").Scan(&dummy)
  2223. }
  2224. switch {
  2225. case err == sql.ErrNoRows:
  2226. return false
  2227. case err != nil:
  2228. log.Error("Couldn't SHOW TABLES: %v", err)
  2229. return false
  2230. }
  2231. return true
  2232. }
  2233. func stringLogln(log *string, s string, v ...interface{}) {
  2234. *log += fmt.Sprintf(s+"\n", v...)
  2235. }
  2236. func handleFailedPostInsert(err error) error {
  2237. log.Error("Couldn't insert into posts: %v", err)
  2238. return err
  2239. }