A clean, Markdown-based publishing platform made for writers. Write together, and build a community. https://writefreely.org
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

2566 lines
77 KiB

  1. /*
  2. * Copyright © 2018 A Bunch Tell LLC.
  3. *
  4. * This file is part of WriteFreely.
  5. *
  6. * WriteFreely is free software: you can redistribute it and/or modify
  7. * it under the terms of the GNU Affero General Public License, included
  8. * in the LICENSE file in this source code package.
  9. */
  10. package writefreely
  11. import (
  12. "context"
  13. "crypto/rand"
  14. "database/sql"
  15. "fmt"
  16. "github.com/pkg/errors"
  17. "math/big"
  18. "net/http"
  19. "strings"
  20. "time"
  21. "github.com/guregu/null"
  22. "github.com/guregu/null/zero"
  23. uuid "github.com/nu7hatch/gouuid"
  24. "github.com/writeas/impart"
  25. "github.com/writeas/nerds/store"
  26. "github.com/writeas/web-core/activitypub"
  27. "github.com/writeas/web-core/auth"
  28. "github.com/writeas/web-core/data"
  29. "github.com/writeas/web-core/id"
  30. "github.com/writeas/web-core/log"
  31. "github.com/writeas/web-core/query"
  32. "github.com/writeas/writefreely/author"
  33. "github.com/writeas/writefreely/config"
  34. "github.com/writeas/writefreely/key"
  35. )
  36. const (
  37. mySQLErrDuplicateKey = 1062
  38. driverMySQL = "mysql"
  39. driverSQLite = "sqlite3"
  40. )
  41. var (
  42. SQLiteEnabled bool
  43. )
  44. type writestore interface {
  45. CreateUser(*config.Config, *User, string) error
  46. UpdateUserEmail(keys *key.Keychain, userID int64, email string) error
  47. UpdateEncryptedUserEmail(int64, []byte) error
  48. GetUserByID(int64) (*User, error)
  49. GetUserForAuth(string) (*User, error)
  50. GetUserForAuthByID(int64) (*User, error)
  51. GetUserNameFromToken(string) (string, error)
  52. GetUserDataFromToken(string) (int64, string, error)
  53. GetAPIUser(header string) (*User, error)
  54. GetUserID(accessToken string) int64
  55. GetUserIDPrivilege(accessToken string) (userID int64, sudo bool)
  56. DeleteToken(accessToken []byte) error
  57. FetchLastAccessToken(userID int64) string
  58. GetAccessToken(userID int64) (string, error)
  59. GetTemporaryAccessToken(userID int64, validSecs int) (string, error)
  60. GetTemporaryOneTimeAccessToken(userID int64, validSecs int, oneTime bool) (string, error)
  61. DeleteAccount(userID int64) (l *string, err error)
  62. ChangeSettings(app *App, u *User, s *userSettings) error
  63. ChangePassphrase(userID int64, sudo bool, curPass string, hashedPass []byte) error
  64. GetCollections(u *User, hostName string) (*[]Collection, error)
  65. GetPublishableCollections(u *User, hostName string) (*[]Collection, error)
  66. GetMeStats(u *User) userMeStats
  67. GetTotalCollections() (int64, error)
  68. GetTotalPosts() (int64, error)
  69. GetTopPosts(u *User, alias string) (*[]PublicPost, error)
  70. GetAnonymousPosts(u *User) (*[]PublicPost, error)
  71. GetUserPosts(u *User) (*[]PublicPost, error)
  72. CreateOwnedPost(post *SubmittedPost, accessToken, collAlias, hostName string) (*PublicPost, error)
  73. CreatePost(userID, collID int64, post *SubmittedPost) (*Post, error)
  74. UpdateOwnedPost(post *AuthenticatedPost, userID int64) error
  75. GetEditablePost(id, editToken string) (*PublicPost, error)
  76. PostIDExists(id string) bool
  77. GetPost(id string, collectionID int64) (*PublicPost, error)
  78. GetOwnedPost(id string, ownerID int64) (*PublicPost, error)
  79. GetPostProperty(id string, collectionID int64, property string) (interface{}, error)
  80. CreateCollectionFromToken(*config.Config, string, string, string) (*Collection, error)
  81. CreateCollection(*config.Config, string, string, int64) (*Collection, error)
  82. GetCollectionBy(condition string, value interface{}) (*Collection, error)
  83. GetCollection(alias string) (*Collection, error)
  84. GetCollectionForPad(alias string) (*Collection, error)
  85. GetCollectionByID(id int64) (*Collection, error)
  86. UpdateCollection(c *SubmittedCollection, alias string) error
  87. DeleteCollection(alias string, userID int64) error
  88. UpdatePostPinState(pinned bool, postID string, collID, ownerID, pos int64) error
  89. GetLastPinnedPostPos(collID int64) int64
  90. GetPinnedPosts(coll *CollectionObj, includeFuture bool) (*[]PublicPost, error)
  91. RemoveCollectionRedirect(t *sql.Tx, alias string) error
  92. GetCollectionRedirect(alias string) (new string)
  93. IsCollectionAttributeOn(id int64, attr string) bool
  94. CollectionHasAttribute(id int64, attr string) bool
  95. CanCollect(cpr *ClaimPostRequest, userID int64) bool
  96. AttemptClaim(p *ClaimPostRequest, query string, params []interface{}, slugIdx int) (sql.Result, error)
  97. DispersePosts(userID int64, postIDs []string) (*[]ClaimPostResult, error)
  98. ClaimPosts(cfg *config.Config, userID int64, collAlias string, posts *[]ClaimPostRequest) (*[]ClaimPostResult, error)
  99. GetPostsCount(c *CollectionObj, includeFuture bool)
  100. GetPosts(cfg *config.Config, c *Collection, page int, includeFuture, forceRecentFirst, includePinned bool) (*[]PublicPost, error)
  101. GetPostsTagged(cfg *config.Config, c *Collection, tag string, page int, includeFuture bool) (*[]PublicPost, error)
  102. GetAPFollowers(c *Collection) (*[]RemoteUser, error)
  103. GetAPActorKeys(collectionID int64) ([]byte, []byte)
  104. CreateUserInvite(id string, userID int64, maxUses int, expires *time.Time) error
  105. GetUserInvites(userID int64) (*[]Invite, error)
  106. GetUserInvite(id string) (*Invite, error)
  107. GetUsersInvitedCount(id string) int64
  108. CreateInvitedUser(inviteID string, userID int64) error
  109. GetDynamicContent(id string) (*instanceContent, error)
  110. UpdateDynamicContent(id, title, content, contentType string) error
  111. GetAllUsers(page uint) (*[]User, error)
  112. GetAllUsersCount() int64
  113. GetUserLastPostTime(id int64) (*time.Time, error)
  114. GetCollectionLastPostTime(id int64) (*time.Time, error)
  115. DatabaseInitialized() bool
  116. }
  117. type datastore struct {
  118. *sql.DB
  119. driverName string
  120. }
  121. func (db *datastore) now() string {
  122. if db.driverName == driverSQLite {
  123. return "strftime('%Y-%m-%d %H:%M:%S','now')"
  124. }
  125. return "NOW()"
  126. }
  127. func (db *datastore) clip(field string, l int) string {
  128. if db.driverName == driverSQLite {
  129. return fmt.Sprintf("SUBSTR(%s, 0, %d)", field, l)
  130. }
  131. return fmt.Sprintf("LEFT(%s, %d)", field, l)
  132. }
  133. func (db *datastore) upsert(indexedCols ...string) string {
  134. if db.driverName == driverSQLite {
  135. // NOTE: SQLite UPSERT syntax only works in v3.24.0 (2018-06-04) or later
  136. // Leaving this for whenever we can upgrade and include it in our binary
  137. cc := strings.Join(indexedCols, ", ")
  138. return "ON CONFLICT(" + cc + ") DO UPDATE SET"
  139. }
  140. return "ON DUPLICATE KEY UPDATE"
  141. }
  142. func (db *datastore) dateSub(l int, unit string) string {
  143. if db.driverName == driverSQLite {
  144. return fmt.Sprintf("DATETIME('now', '-%d %s')", l, unit)
  145. }
  146. return fmt.Sprintf("DATE_SUB(NOW(), INTERVAL %d %s)", l, unit)
  147. }
  148. func (db *datastore) CreateUser(cfg *config.Config, u *User, collectionTitle string) error {
  149. if db.PostIDExists(u.Username) {
  150. return impart.HTTPError{http.StatusConflict, "Invalid collection name."}
  151. }
  152. // New users get a `users` and `collections` row.
  153. t, err := db.Begin()
  154. if err != nil {
  155. return err
  156. }
  157. // 1. Add to `users` table
  158. // NOTE: Assumes User's Password is already hashed!
  159. res, err := t.Exec("INSERT INTO users (username, password, email) VALUES (?, ?, ?)", u.Username, u.HashedPass, u.Email)
  160. if err != nil {
  161. t.Rollback()
  162. if db.isDuplicateKeyErr(err) {
  163. return impart.HTTPError{http.StatusConflict, "Username is already taken."}
  164. }
  165. log.Error("Rolling back users INSERT: %v\n", err)
  166. return err
  167. }
  168. u.ID, err = res.LastInsertId()
  169. if err != nil {
  170. t.Rollback()
  171. log.Error("Rolling back after LastInsertId: %v\n", err)
  172. return err
  173. }
  174. // 2. Create user's Collection
  175. if collectionTitle == "" {
  176. collectionTitle = u.Username
  177. }
  178. res, err = t.Exec("INSERT INTO collections (alias, title, description, privacy, owner_id, view_count) VALUES (?, ?, ?, ?, ?, ?)", u.Username, collectionTitle, "", defaultVisibility(cfg), u.ID, 0)
  179. if err != nil {
  180. t.Rollback()
  181. if db.isDuplicateKeyErr(err) {
  182. return impart.HTTPError{http.StatusConflict, "Username is already taken."}
  183. }
  184. log.Error("Rolling back collections INSERT: %v\n", err)
  185. return err
  186. }
  187. db.RemoveCollectionRedirect(t, u.Username)
  188. err = t.Commit()
  189. if err != nil {
  190. t.Rollback()
  191. log.Error("Rolling back after Commit(): %v\n", err)
  192. return err
  193. }
  194. return nil
  195. }
  196. // FIXME: We're returning errors inconsistently in this file. Do we use Errorf
  197. // for returned value, or impart?
  198. func (db *datastore) UpdateUserEmail(keys *key.Keychain, userID int64, email string) error {
  199. encEmail, err := data.Encrypt(keys.EmailKey, email)
  200. if err != nil {
  201. return fmt.Errorf("Couldn't encrypt email %s: %s\n", email, err)
  202. }
  203. return db.UpdateEncryptedUserEmail(userID, encEmail)
  204. }
  205. func (db *datastore) UpdateEncryptedUserEmail(userID int64, encEmail []byte) error {
  206. _, err := db.Exec("UPDATE users SET email = ? WHERE id = ?", encEmail, userID)
  207. if err != nil {
  208. return fmt.Errorf("Unable to update user email: %s", err)
  209. }
  210. return nil
  211. }
  212. func (db *datastore) CreateCollectionFromToken(cfg *config.Config, alias, title, accessToken string) (*Collection, error) {
  213. userID := db.GetUserID(accessToken)
  214. if userID == -1 {
  215. return nil, ErrBadAccessToken
  216. }
  217. return db.CreateCollection(cfg, alias, title, userID)
  218. }
  219. func (db *datastore) GetUserCollectionCount(userID int64) (uint64, error) {
  220. var collCount uint64
  221. err := db.QueryRow("SELECT COUNT(*) FROM collections WHERE owner_id = ?", userID).Scan(&collCount)
  222. switch {
  223. case err == sql.ErrNoRows:
  224. return 0, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user from database."}
  225. case err != nil:
  226. log.Error("Couldn't get collections count for user %d: %v", userID, err)
  227. return 0, err
  228. }
  229. return collCount, nil
  230. }
  231. func (db *datastore) CreateCollection(cfg *config.Config, alias, title string, userID int64) (*Collection, error) {
  232. if db.PostIDExists(alias) {
  233. return nil, impart.HTTPError{http.StatusConflict, "Invalid collection name."}
  234. }
  235. // All good, so create new collection
  236. res, err := db.Exec("INSERT INTO collections (alias, title, description, privacy, owner_id, view_count) VALUES (?, ?, ?, ?, ?, ?)", alias, title, "", defaultVisibility(cfg), userID, 0)
  237. if err != nil {
  238. if db.isDuplicateKeyErr(err) {
  239. return nil, impart.HTTPError{http.StatusConflict, "Collection already exists."}
  240. }
  241. log.Error("Couldn't add to collections: %v\n", err)
  242. return nil, err
  243. }
  244. c := &Collection{
  245. Alias: alias,
  246. Title: title,
  247. OwnerID: userID,
  248. PublicOwner: false,
  249. Public: defaultVisibility(cfg) == CollPublic,
  250. }
  251. c.ID, err = res.LastInsertId()
  252. if err != nil {
  253. log.Error("Couldn't get collection LastInsertId: %v\n", err)
  254. }
  255. return c, nil
  256. }
  257. func (db *datastore) GetUserByID(id int64) (*User, error) {
  258. u := &User{ID: id}
  259. err := db.QueryRow("SELECT username, password, email, created, status FROM users WHERE id = ?", id).Scan(&u.Username, &u.HashedPass, &u.Email, &u.Created, &u.Status)
  260. switch {
  261. case err == sql.ErrNoRows:
  262. return nil, ErrUserNotFound
  263. case err != nil:
  264. log.Error("Couldn't SELECT user password: %v", err)
  265. return nil, err
  266. }
  267. return u, nil
  268. }
  269. // IsUserSuspended returns true if the user account associated with id is
  270. // currently suspended.
  271. func (db *datastore) IsUserSuspended(id int64) (bool, error) {
  272. u := &User{ID: id}
  273. err := db.QueryRow("SELECT status FROM users WHERE id = ?", id).Scan(&u.Status)
  274. switch {
  275. case err == sql.ErrNoRows:
  276. return false, fmt.Errorf("is user suspended: %v", ErrUserNotFound)
  277. case err != nil:
  278. log.Error("Couldn't SELECT user password: %v", err)
  279. return false, fmt.Errorf("is user suspended: %v", err)
  280. }
  281. return u.IsSilenced(), nil
  282. }
  283. // DoesUserNeedAuth returns true if the user hasn't provided any methods for
  284. // authenticating with the account, such a passphrase or email address.
  285. // Any errors are reported to admin and silently quashed, returning false as the
  286. // result.
  287. func (db *datastore) DoesUserNeedAuth(id int64) bool {
  288. var pass, email []byte
  289. // Find out if user has an email set first
  290. err := db.QueryRow("SELECT password, email FROM users WHERE id = ?", id).Scan(&pass, &email)
  291. switch {
  292. case err == sql.ErrNoRows:
  293. // ERROR. Don't give false positives on needing auth methods
  294. return false
  295. case err != nil:
  296. // ERROR. Don't give false positives on needing auth methods
  297. log.Error("Couldn't SELECT user %d from users: %v", id, err)
  298. return false
  299. }
  300. // User doesn't need auth if there's an email
  301. return len(email) == 0 && len(pass) == 0
  302. }
  303. func (db *datastore) IsUserPassSet(id int64) (bool, error) {
  304. var pass []byte
  305. err := db.QueryRow("SELECT password FROM users WHERE id = ?", id).Scan(&pass)
  306. switch {
  307. case err == sql.ErrNoRows:
  308. return false, nil
  309. case err != nil:
  310. log.Error("Couldn't SELECT user %d from users: %v", id, err)
  311. return false, err
  312. }
  313. return len(pass) > 0, nil
  314. }
  315. func (db *datastore) GetUserForAuth(username string) (*User, error) {
  316. u := &User{Username: username}
  317. err := db.QueryRow("SELECT id, password, email, created, status FROM users WHERE username = ?", username).Scan(&u.ID, &u.HashedPass, &u.Email, &u.Created, &u.Status)
  318. switch {
  319. case err == sql.ErrNoRows:
  320. // Check if they've entered the wrong, unnormalized username
  321. username = getSlug(username, "")
  322. if username != u.Username {
  323. err = db.QueryRow("SELECT id FROM users WHERE username = ? LIMIT 1", username).Scan(&u.ID)
  324. if err == nil {
  325. return db.GetUserForAuth(username)
  326. }
  327. }
  328. return nil, ErrUserNotFound
  329. case err != nil:
  330. log.Error("Couldn't SELECT user password: %v", err)
  331. return nil, err
  332. }
  333. return u, nil
  334. }
  335. func (db *datastore) GetUserForAuthByID(userID int64) (*User, error) {
  336. u := &User{ID: userID}
  337. err := db.QueryRow("SELECT id, password, email, created, status FROM users WHERE id = ?", u.ID).Scan(&u.ID, &u.HashedPass, &u.Email, &u.Created, &u.Status)
  338. switch {
  339. case err == sql.ErrNoRows:
  340. return nil, ErrUserNotFound
  341. case err != nil:
  342. log.Error("Couldn't SELECT userForAuthByID: %v", err)
  343. return nil, err
  344. }
  345. return u, nil
  346. }
  347. func (db *datastore) GetUserNameFromToken(accessToken string) (string, error) {
  348. t := auth.GetToken(accessToken)
  349. if len(t) == 0 {
  350. return "", ErrNoAccessToken
  351. }
  352. var oneTime bool
  353. var username string
  354. err := db.QueryRow("SELECT username, one_time FROM accesstokens LEFT JOIN users ON user_id = id WHERE token LIKE ? AND (expires IS NULL OR expires > "+db.now()+")", t).Scan(&username, &oneTime)
  355. switch {
  356. case err == sql.ErrNoRows:
  357. return "", ErrBadAccessToken
  358. case err != nil:
  359. return "", ErrInternalGeneral
  360. }
  361. // Delete token if it was one-time
  362. if oneTime {
  363. db.DeleteToken(t[:])
  364. }
  365. return username, nil
  366. }
  367. func (db *datastore) GetUserDataFromToken(accessToken string) (int64, string, error) {
  368. t := auth.GetToken(accessToken)
  369. if len(t) == 0 {
  370. return 0, "", ErrNoAccessToken
  371. }
  372. var userID int64
  373. var oneTime bool
  374. var username string
  375. err := db.QueryRow("SELECT user_id, username, one_time FROM accesstokens LEFT JOIN users ON user_id = id WHERE token LIKE ? AND (expires IS NULL OR expires > "+db.now()+")", t).Scan(&userID, &username, &oneTime)
  376. switch {
  377. case err == sql.ErrNoRows:
  378. return 0, "", ErrBadAccessToken
  379. case err != nil:
  380. return 0, "", ErrInternalGeneral
  381. }
  382. // Delete token if it was one-time
  383. if oneTime {
  384. db.DeleteToken(t[:])
  385. }
  386. return userID, username, nil
  387. }
  388. func (db *datastore) GetAPIUser(header string) (*User, error) {
  389. uID := db.GetUserID(header)
  390. if uID == -1 {
  391. return nil, fmt.Errorf(ErrUserNotFound.Error())
  392. }
  393. return db.GetUserByID(uID)
  394. }
  395. // GetUserID takes a hexadecimal accessToken, parses it into its binary
  396. // representation, and gets any user ID associated with the token. If no user
  397. // is associated, -1 is returned.
  398. func (db *datastore) GetUserID(accessToken string) int64 {
  399. i, _ := db.GetUserIDPrivilege(accessToken)
  400. return i
  401. }
  402. func (db *datastore) GetUserIDPrivilege(accessToken string) (userID int64, sudo bool) {
  403. t := auth.GetToken(accessToken)
  404. if len(t) == 0 {
  405. return -1, false
  406. }
  407. var oneTime bool
  408. err := db.QueryRow("SELECT user_id, sudo, one_time FROM accesstokens WHERE token LIKE ? AND (expires IS NULL OR expires > "+db.now()+")", t).Scan(&userID, &sudo, &oneTime)
  409. switch {
  410. case err == sql.ErrNoRows:
  411. return -1, false
  412. case err != nil:
  413. return -1, false
  414. }
  415. // Delete token if it was one-time
  416. if oneTime {
  417. db.DeleteToken(t[:])
  418. }
  419. return
  420. }
  421. func (db *datastore) DeleteToken(accessToken []byte) error {
  422. res, err := db.Exec("DELETE FROM accesstokens WHERE token LIKE ?", accessToken)
  423. if err != nil {
  424. return err
  425. }
  426. rowsAffected, _ := res.RowsAffected()
  427. if rowsAffected == 0 {
  428. return impart.HTTPError{http.StatusNotFound, "Token is invalid or doesn't exist"}
  429. }
  430. return nil
  431. }
  432. // FetchLastAccessToken creates a new non-expiring, valid access token for the given
  433. // userID.
  434. func (db *datastore) FetchLastAccessToken(userID int64) string {
  435. var t []byte
  436. err := db.QueryRow("SELECT token FROM accesstokens WHERE user_id = ? AND (expires IS NULL OR expires > "+db.now()+") ORDER BY created DESC LIMIT 1", userID).Scan(&t)
  437. switch {
  438. case err == sql.ErrNoRows:
  439. return ""
  440. case err != nil:
  441. log.Error("Failed selecting from accesstoken: %v", err)
  442. return ""
  443. }
  444. u, err := uuid.Parse(t)
  445. if err != nil {
  446. return ""
  447. }
  448. return u.String()
  449. }
  450. // GetAccessToken creates a new non-expiring, valid access token for the given
  451. // userID.
  452. func (db *datastore) GetAccessToken(userID int64) (string, error) {
  453. return db.GetTemporaryOneTimeAccessToken(userID, 0, false)
  454. }
  455. // GetTemporaryAccessToken creates a new valid access token for the given
  456. // userID that remains valid for the given time in seconds. If validSecs is 0,
  457. // the access token doesn't automatically expire.
  458. func (db *datastore) GetTemporaryAccessToken(userID int64, validSecs int) (string, error) {
  459. return db.GetTemporaryOneTimeAccessToken(userID, validSecs, false)
  460. }
  461. // GetTemporaryOneTimeAccessToken creates a new valid access token for the given
  462. // userID that remains valid for the given time in seconds and can only be used
  463. // once if oneTime is true. If validSecs is 0, the access token doesn't
  464. // automatically expire.
  465. func (db *datastore) GetTemporaryOneTimeAccessToken(userID int64, validSecs int, oneTime bool) (string, error) {
  466. u, err := uuid.NewV4()
  467. if err != nil {
  468. log.Error("Unable to generate token: %v", err)
  469. return "", err
  470. }
  471. // Insert UUID to `accesstokens`
  472. binTok := u[:]
  473. expirationVal := "NULL"
  474. if validSecs > 0 {
  475. expirationVal = fmt.Sprintf("DATE_ADD("+db.now()+", INTERVAL %d SECOND)", validSecs)
  476. }
  477. _, err = db.Exec("INSERT INTO accesstokens (token, user_id, one_time, expires) VALUES (?, ?, ?, "+expirationVal+")", string(binTok), userID, oneTime)
  478. if err != nil {
  479. log.Error("Couldn't INSERT accesstoken: %v", err)
  480. return "", err
  481. }
  482. return u.String(), nil
  483. }
  484. func (db *datastore) CreateOwnedPost(post *SubmittedPost, accessToken, collAlias, hostName string) (*PublicPost, error) {
  485. var userID, collID int64 = -1, -1
  486. var coll *Collection
  487. var err error
  488. if accessToken != "" {
  489. userID = db.GetUserID(accessToken)
  490. if userID == -1 {
  491. return nil, ErrBadAccessToken
  492. }
  493. if collAlias != "" {
  494. coll, err = db.GetCollection(collAlias)
  495. if err != nil {
  496. return nil, err
  497. }
  498. coll.hostName = hostName
  499. if coll.OwnerID != userID {
  500. return nil, ErrForbiddenCollection
  501. }
  502. collID = coll.ID
  503. }
  504. }
  505. rp := &PublicPost{}
  506. rp.Post, err = db.CreatePost(userID, collID, post)
  507. if err != nil {
  508. return rp, err
  509. }
  510. if coll != nil {
  511. coll.ForPublic()
  512. rp.Collection = &CollectionObj{Collection: *coll}
  513. }
  514. return rp, nil
  515. }
  516. func (db *datastore) CreatePost(userID, collID int64, post *SubmittedPost) (*Post, error) {
  517. idLen := postIDLen
  518. friendlyID := store.GenerateFriendlyRandomString(idLen)
  519. // Handle appearance / font face
  520. appearance := post.Font
  521. if !post.isFontValid() {
  522. appearance = "norm"
  523. }
  524. var err error
  525. ownerID := sql.NullInt64{
  526. Valid: false,
  527. }
  528. ownerCollID := sql.NullInt64{
  529. Valid: false,
  530. }
  531. slug := sql.NullString{"", false}
  532. // If an alias was supplied, we'll add this to the collection as well.
  533. if userID > 0 {
  534. ownerID.Int64 = userID
  535. ownerID.Valid = true
  536. if collID > 0 {
  537. ownerCollID.Int64 = collID
  538. ownerCollID.Valid = true
  539. var slugVal string
  540. if post.Title != nil && *post.Title != "" {
  541. slugVal = getSlug(*post.Title, post.Language.String)
  542. if slugVal == "" {
  543. slugVal = getSlug(*post.Content, post.Language.String)
  544. }
  545. } else {
  546. slugVal = getSlug(*post.Content, post.Language.String)
  547. }
  548. if slugVal == "" {
  549. slugVal = friendlyID
  550. }
  551. slug = sql.NullString{slugVal, true}
  552. }
  553. }
  554. created := time.Now()
  555. if db.driverName == driverSQLite {
  556. // SQLite stores datetimes in UTC, so convert time.Now() to it here
  557. created = created.UTC()
  558. }
  559. if post.Created != nil {
  560. created, err = time.Parse("2006-01-02T15:04:05Z", *post.Created)
  561. if err != nil {
  562. log.Error("Unable to parse Created time '%s': %v", *post.Created, err)
  563. created = time.Now()
  564. if db.driverName == driverSQLite {
  565. // SQLite stores datetimes in UTC, so convert time.Now() to it here
  566. created = created.UTC()
  567. }
  568. }
  569. }
  570. stmt, err := db.Prepare("INSERT INTO posts (id, slug, title, content, text_appearance, language, rtl, privacy, owner_id, collection_id, created, updated, view_count) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, " + db.now() + ", ?)")
  571. if err != nil {
  572. return nil, err
  573. }
  574. defer stmt.Close()
  575. _, err = stmt.Exec(friendlyID, slug, post.Title, post.Content, appearance, post.Language, post.IsRTL, 0, ownerID, ownerCollID, created, 0)
  576. if err != nil {
  577. if db.isDuplicateKeyErr(err) {
  578. // Duplicate entry error; try a new slug
  579. // TODO: make this a little more robust
  580. slug = sql.NullString{id.GenSafeUniqueSlug(slug.String), true}
  581. _, err = stmt.Exec(friendlyID, slug, post.Title, post.Content, appearance, post.Language, post.IsRTL, 0, ownerID, ownerCollID, created, 0)
  582. if err != nil {
  583. return nil, handleFailedPostInsert(fmt.Errorf("Retried slug generation, still failed: %v", err))
  584. }
  585. } else {
  586. return nil, handleFailedPostInsert(err)
  587. }
  588. }
  589. // TODO: return Created field in proper format
  590. return &Post{
  591. ID: friendlyID,
  592. Slug: null.NewString(slug.String, slug.Valid),
  593. Font: appearance,
  594. Language: zero.NewString(post.Language.String, post.Language.Valid),
  595. RTL: zero.NewBool(post.IsRTL.Bool, post.IsRTL.Valid),
  596. OwnerID: null.NewInt(userID, true),
  597. CollectionID: null.NewInt(userID, true),
  598. Created: created.Truncate(time.Second).UTC(),
  599. Updated: time.Now().Truncate(time.Second).UTC(),
  600. Title: zero.NewString(*(post.Title), true),
  601. Content: *(post.Content),
  602. }, nil
  603. }
  604. // UpdateOwnedPost updates an existing post with only the given fields in the
  605. // supplied AuthenticatedPost.
  606. func (db *datastore) UpdateOwnedPost(post *AuthenticatedPost, userID int64) error {
  607. params := []interface{}{}
  608. var queryUpdates, sep, authCondition string
  609. if post.Slug != nil && *post.Slug != "" {
  610. queryUpdates += sep + "slug = ?"
  611. sep = ", "
  612. params = append(params, getSlug(*post.Slug, ""))
  613. }
  614. if post.Content != nil {
  615. queryUpdates += sep + "content = ?"
  616. sep = ", "
  617. params = append(params, post.Content)
  618. }
  619. if post.Title != nil {
  620. queryUpdates += sep + "title = ?"
  621. sep = ", "
  622. params = append(params, post.Title)
  623. }
  624. if post.Language.Valid {
  625. queryUpdates += sep + "language = ?"
  626. sep = ", "
  627. params = append(params, post.Language.String)
  628. }
  629. if post.IsRTL.Valid {
  630. queryUpdates += sep + "rtl = ?"
  631. sep = ", "
  632. params = append(params, post.IsRTL.Bool)
  633. }
  634. if post.Font != "" {
  635. queryUpdates += sep + "text_appearance = ?"
  636. sep = ", "
  637. params = append(params, post.Font)
  638. }
  639. if post.Created != nil {
  640. createTime, err := time.Parse(postMetaDateFormat, *post.Created)
  641. if err != nil {
  642. log.Error("Unable to parse Created date: %v", err)
  643. return fmt.Errorf("That's the incorrect format for Created date.")
  644. }
  645. queryUpdates += sep + "created = ?"
  646. sep = ", "
  647. params = append(params, createTime)
  648. }
  649. // WHERE parameters...
  650. // id = ?
  651. params = append(params, post.ID)
  652. // AND owner_id = ?
  653. authCondition = "(owner_id = ?)"
  654. params = append(params, userID)
  655. if queryUpdates == "" {
  656. return ErrPostNoUpdatableVals
  657. }
  658. queryUpdates += sep + "updated = " + db.now()
  659. res, err := db.Exec("UPDATE posts SET "+queryUpdates+" WHERE id = ? AND "+authCondition, params...)
  660. if err != nil {
  661. log.Error("Unable to update owned post: %v", err)
  662. return err
  663. }
  664. rowsAffected, _ := res.RowsAffected()
  665. if rowsAffected == 0 {
  666. // Show the correct error message if nothing was updated
  667. var dummy int
  668. err := db.QueryRow("SELECT 1 FROM posts WHERE id = ? AND "+authCondition, post.ID, params[len(params)-1]).Scan(&dummy)
  669. switch {
  670. case err == sql.ErrNoRows:
  671. return ErrUnauthorizedEditPost
  672. case err != nil:
  673. log.Error("Failed selecting from posts: %v", err)
  674. }
  675. return nil
  676. }
  677. return nil
  678. }
  679. func (db *datastore) GetCollectionBy(condition string, value interface{}) (*Collection, error) {
  680. c := &Collection{}
  681. // FIXME: change Collection to reflect database values. Add helper functions to get actual values
  682. var styleSheet, script, format zero.String
  683. row := db.QueryRow("SELECT id, alias, title, description, style_sheet, script, format, owner_id, privacy, view_count FROM collections WHERE "+condition, value)
  684. err := row.Scan(&c.ID, &c.Alias, &c.Title, &c.Description, &styleSheet, &script, &format, &c.OwnerID, &c.Visibility, &c.Views)
  685. switch {
  686. case err == sql.ErrNoRows:
  687. return nil, impart.HTTPError{http.StatusNotFound, "Collection doesn't exist."}
  688. case err != nil:
  689. log.Error("Failed selecting from collections: %v", err)
  690. return nil, err
  691. }
  692. c.StyleSheet = styleSheet.String
  693. c.Script = script.String
  694. c.Format = format.String
  695. c.Public = c.IsPublic()
  696. c.db = db
  697. return c, nil
  698. }
  699. func (db *datastore) GetCollection(alias string) (*Collection, error) {
  700. return db.GetCollectionBy("alias = ?", alias)
  701. }
  702. func (db *datastore) GetCollectionForPad(alias string) (*Collection, error) {
  703. c := &Collection{Alias: alias}
  704. row := db.QueryRow("SELECT id, alias, title, description, privacy FROM collections WHERE alias = ?", alias)
  705. err := row.Scan(&c.ID, &c.Alias, &c.Title, &c.Description, &c.Visibility)
  706. switch {
  707. case err == sql.ErrNoRows:
  708. return c, impart.HTTPError{http.StatusNotFound, "Collection doesn't exist."}
  709. case err != nil:
  710. log.Error("Failed selecting from collections: %v", err)
  711. return c, ErrInternalGeneral
  712. }
  713. c.Public = c.IsPublic()
  714. return c, nil
  715. }
  716. func (db *datastore) GetCollectionByID(id int64) (*Collection, error) {
  717. return db.GetCollectionBy("id = ?", id)
  718. }
  719. func (db *datastore) GetCollectionFromDomain(host string) (*Collection, error) {
  720. return db.GetCollectionBy("host = ?", host)
  721. }
  722. func (db *datastore) UpdateCollection(c *SubmittedCollection, alias string) error {
  723. q := query.NewUpdate().
  724. SetStringPtr(c.Title, "title").
  725. SetStringPtr(c.Description, "description").
  726. SetNullString(c.StyleSheet, "style_sheet").
  727. SetNullString(c.Script, "script")
  728. if c.Format != nil {
  729. cf := &CollectionFormat{Format: c.Format.String}
  730. if cf.Valid() {
  731. q.SetNullString(c.Format, "format")
  732. }
  733. }
  734. var updatePass bool
  735. if c.Visibility != nil && (collVisibility(*c.Visibility)&CollProtected == 0 || c.Pass != "") {
  736. q.SetIntPtr(c.Visibility, "privacy")
  737. if c.Pass != "" {
  738. updatePass = true
  739. }
  740. }
  741. // WHERE values
  742. q.Where("alias = ? AND owner_id = ?", alias, c.OwnerID)
  743. if q.Updates == "" {
  744. return ErrPostNoUpdatableVals
  745. }
  746. // Find any current domain
  747. var collID int64
  748. var rowsAffected int64
  749. var changed bool
  750. var res sql.Result
  751. err := db.QueryRow("SELECT id FROM collections WHERE alias = ?", alias).Scan(&collID)
  752. if err != nil {
  753. log.Error("Failed selecting from collections: %v. Some things won't work.", err)
  754. }
  755. // Update MathJax value
  756. if c.MathJax {
  757. if db.driverName == driverSQLite {
  758. _, err = db.Exec("INSERT OR REPLACE INTO collectionattributes (collection_id, attribute, value) VALUES (?, ?, ?)", collID, "render_mathjax", "1")
  759. } else {
  760. _, err = db.Exec("INSERT INTO collectionattributes (collection_id, attribute, value) VALUES (?, ?, ?) "+db.upsert("collection_id", "attribute")+" value = ?", collID, "render_mathjax", "1", "1")
  761. }
  762. if err != nil {
  763. log.Error("Unable to insert render_mathjax value: %v", err)
  764. return err
  765. }
  766. } else {
  767. _, err = db.Exec("DELETE FROM collectionattributes WHERE collection_id = ? AND attribute = ?", collID, "render_mathjax")
  768. if err != nil {
  769. log.Error("Unable to delete render_mathjax value: %v", err)
  770. return err
  771. }
  772. }
  773. // Update rest of the collection data
  774. res, err = db.Exec("UPDATE collections SET "+q.Updates+" WHERE "+q.Conditions, q.Params...)
  775. if err != nil {
  776. log.Error("Unable to update collection: %v", err)
  777. return err
  778. }
  779. rowsAffected, _ = res.RowsAffected()
  780. if !changed || rowsAffected == 0 {
  781. // Show the correct error message if nothing was updated
  782. var dummy int
  783. err := db.QueryRow("SELECT 1 FROM collections WHERE alias = ? AND owner_id = ?", alias, c.OwnerID).Scan(&dummy)
  784. switch {
  785. case err == sql.ErrNoRows:
  786. return ErrUnauthorizedEditPost
  787. case err != nil:
  788. log.Error("Failed selecting from collections: %v", err)
  789. }
  790. if !updatePass {
  791. return nil
  792. }
  793. }
  794. if updatePass {
  795. hashedPass, err := auth.HashPass([]byte(c.Pass))
  796. if err != nil {
  797. log.Error("Unable to create hash: %s", err)
  798. return impart.HTTPError{http.StatusInternalServerError, "Could not create password hash."}
  799. }
  800. if db.driverName == driverSQLite {
  801. _, err = db.Exec("INSERT OR REPLACE INTO collectionpasswords (collection_id, password) VALUES ((SELECT id FROM collections WHERE alias = ?), ?)", alias, hashedPass)
  802. } else {
  803. _, err = db.Exec("INSERT INTO collectionpasswords (collection_id, password) VALUES ((SELECT id FROM collections WHERE alias = ?), ?) "+db.upsert("collection_id")+" password = ?", alias, hashedPass, hashedPass)
  804. }
  805. if err != nil {
  806. return err
  807. }
  808. }
  809. return nil
  810. }
  811. const postCols = "id, slug, text_appearance, language, rtl, privacy, owner_id, collection_id, pinned_position, created, updated, view_count, title, content"
  812. // getEditablePost returns a PublicPost with the given ID only if the given
  813. // edit token is valid for the post.
  814. func (db *datastore) GetEditablePost(id, editToken string) (*PublicPost, error) {
  815. // FIXME: code duplicated from getPost()
  816. // TODO: add slight logic difference to getPost / one func
  817. var ownerName sql.NullString
  818. p := &Post{}
  819. row := db.QueryRow("SELECT "+postCols+", (SELECT username FROM users WHERE users.id = posts.owner_id) AS username FROM posts WHERE id = ? LIMIT 1", id)
  820. err := row.Scan(&p.ID, &p.Slug, &p.Font, &p.Language, &p.RTL, &p.Privacy, &p.OwnerID, &p.CollectionID, &p.PinnedPosition, &p.Created, &p.Updated, &p.ViewCount, &p.Title, &p.Content, &ownerName)
  821. switch {
  822. case err == sql.ErrNoRows:
  823. return nil, ErrPostNotFound
  824. case err != nil:
  825. log.Error("Failed selecting from collections: %v", err)
  826. return nil, err
  827. }
  828. if p.Content == "" && p.Title.String == "" {
  829. return nil, ErrPostUnpublished
  830. }
  831. res := p.processPost()
  832. if ownerName.Valid {
  833. res.Owner = &PublicUser{Username: ownerName.String}
  834. }
  835. return &res, nil
  836. }
  837. func (db *datastore) PostIDExists(id string) bool {
  838. var dummy bool
  839. err := db.QueryRow("SELECT 1 FROM posts WHERE id = ?", id).Scan(&dummy)
  840. return err == nil && dummy
  841. }
  842. // GetPost gets a public-facing post object from the database. If collectionID
  843. // is > 0, the post will be retrieved by slug and collection ID, rather than
  844. // post ID.
  845. // TODO: break this into two functions:
  846. // - GetPost(id string)
  847. // - GetCollectionPost(slug string, collectionID int64)
  848. func (db *datastore) GetPost(id string, collectionID int64) (*PublicPost, error) {
  849. var ownerName sql.NullString
  850. p := &Post{}
  851. var row *sql.Row
  852. var where string
  853. params := []interface{}{id}
  854. if collectionID > 0 {
  855. where = "slug = ? AND collection_id = ?"
  856. params = append(params, collectionID)
  857. } else {
  858. where = "id = ?"
  859. }
  860. row = db.QueryRow("SELECT "+postCols+", (SELECT username FROM users WHERE users.id = posts.owner_id) AS username FROM posts WHERE "+where+" LIMIT 1", params...)
  861. err := row.Scan(&p.ID, &p.Slug, &p.Font, &p.Language, &p.RTL, &p.Privacy, &p.OwnerID, &p.CollectionID, &p.PinnedPosition, &p.Created, &p.Updated, &p.ViewCount, &p.Title, &p.Content, &ownerName)
  862. switch {
  863. case err == sql.ErrNoRows:
  864. if collectionID > 0 {
  865. return nil, ErrCollectionPageNotFound
  866. }
  867. return nil, ErrPostNotFound
  868. case err != nil:
  869. log.Error("Failed selecting from collections: %v", err)
  870. return nil, err
  871. }
  872. if p.Content == "" && p.Title.String == "" {
  873. return nil, ErrPostUnpublished
  874. }
  875. res := p.processPost()
  876. if ownerName.Valid {
  877. res.Owner = &PublicUser{Username: ownerName.String}
  878. }
  879. return &res, nil
  880. }
  881. // TODO: don't duplicate getPost() functionality
  882. func (db *datastore) GetOwnedPost(id string, ownerID int64) (*PublicPost, error) {
  883. p := &Post{}
  884. var row *sql.Row
  885. where := "id = ? AND owner_id = ?"
  886. params := []interface{}{id, ownerID}
  887. row = db.QueryRow("SELECT "+postCols+" FROM posts WHERE "+where+" LIMIT 1", params...)
  888. err := row.Scan(&p.ID, &p.Slug, &p.Font, &p.Language, &p.RTL, &p.Privacy, &p.OwnerID, &p.CollectionID, &p.PinnedPosition, &p.Created, &p.Updated, &p.ViewCount, &p.Title, &p.Content)
  889. switch {
  890. case err == sql.ErrNoRows:
  891. return nil, ErrPostNotFound
  892. case err != nil:
  893. log.Error("Failed selecting from collections: %v", err)
  894. return nil, err
  895. }
  896. if p.Content == "" && p.Title.String == "" {
  897. return nil, ErrPostUnpublished
  898. }
  899. res := p.processPost()
  900. return &res, nil
  901. }
  902. func (db *datastore) GetPostProperty(id string, collectionID int64, property string) (interface{}, error) {
  903. propSelects := map[string]string{
  904. "views": "view_count AS views",
  905. }
  906. selectQuery, ok := propSelects[property]
  907. if !ok {
  908. return nil, impart.HTTPError{http.StatusBadRequest, fmt.Sprintf("Invalid property: %s.", property)}
  909. }
  910. var res interface{}
  911. var row *sql.Row
  912. if collectionID != 0 {
  913. row = db.QueryRow("SELECT "+selectQuery+" FROM posts WHERE slug = ? AND collection_id = ? LIMIT 1", id, collectionID)
  914. } else {
  915. row = db.QueryRow("SELECT "+selectQuery+" FROM posts WHERE id = ? LIMIT 1", id)
  916. }
  917. err := row.Scan(&res)
  918. switch {
  919. case err == sql.ErrNoRows:
  920. return nil, impart.HTTPError{http.StatusNotFound, "Post not found."}
  921. case err != nil:
  922. log.Error("Failed selecting post: %v", err)
  923. return nil, err
  924. }
  925. return res, nil
  926. }
  927. // GetPostsCount modifies the CollectionObj to include the correct number of
  928. // standard (non-pinned) posts. It will return future posts if `includeFuture`
  929. // is true.
  930. func (db *datastore) GetPostsCount(c *CollectionObj, includeFuture bool) {
  931. var count int64
  932. timeCondition := ""
  933. if !includeFuture {
  934. timeCondition = "AND created <= " + db.now()
  935. }
  936. err := db.QueryRow("SELECT COUNT(*) FROM posts WHERE collection_id = ? AND pinned_position IS NULL "+timeCondition, c.ID).Scan(&count)
  937. switch {
  938. case err == sql.ErrNoRows:
  939. c.TotalPosts = 0
  940. case err != nil:
  941. log.Error("Failed selecting from collections: %v", err)
  942. c.TotalPosts = 0
  943. }
  944. c.TotalPosts = int(count)
  945. }
  946. // GetPosts retrieves all posts for the given Collection.
  947. // It will return future posts if `includeFuture` is true.
  948. // It will include only standard (non-pinned) posts unless `includePinned` is true.
  949. // TODO: change includeFuture to isOwner, since that's how it's used
  950. func (db *datastore) GetPosts(cfg *config.Config, c *Collection, page int, includeFuture, forceRecentFirst, includePinned bool) (*[]PublicPost, error) {
  951. collID := c.ID
  952. cf := c.NewFormat()
  953. order := "DESC"
  954. if cf.Ascending() && !forceRecentFirst {
  955. order = "ASC"
  956. }
  957. pagePosts := cf.PostsPerPage()
  958. start := page*pagePosts - pagePosts
  959. if page == 0 {
  960. start = 0
  961. pagePosts = 1000
  962. }
  963. limitStr := ""
  964. if page > 0 {
  965. limitStr = fmt.Sprintf(" LIMIT %d, %d", start, pagePosts)
  966. }
  967. timeCondition := ""
  968. if !includeFuture {
  969. timeCondition = "AND created <= " + db.now()
  970. }
  971. pinnedCondition := ""
  972. if !includePinned {
  973. pinnedCondition = "AND pinned_position IS NULL"
  974. }
  975. rows, err := db.Query("SELECT "+postCols+" FROM posts WHERE collection_id = ? "+pinnedCondition+" "+timeCondition+" ORDER BY created "+order+limitStr, collID)
  976. if err != nil {
  977. log.Error("Failed selecting from posts: %v", err)
  978. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve collection posts."}
  979. }
  980. defer rows.Close()
  981. // TODO: extract this common row scanning logic for queries using `postCols`
  982. posts := []PublicPost{}
  983. for rows.Next() {
  984. p := &Post{}
  985. err = rows.Scan(&p.ID, &p.Slug, &p.Font, &p.Language, &p.RTL, &p.Privacy, &p.OwnerID, &p.CollectionID, &p.PinnedPosition, &p.Created, &p.Updated, &p.ViewCount, &p.Title, &p.Content)
  986. if err != nil {
  987. log.Error("Failed scanning row: %v", err)
  988. break
  989. }
  990. p.extractData()
  991. p.formatContent(cfg, c, includeFuture)
  992. posts = append(posts, p.processPost())
  993. }
  994. err = rows.Err()
  995. if err != nil {
  996. log.Error("Error after Next() on rows: %v", err)
  997. }
  998. return &posts, nil
  999. }
  1000. // GetPostsTagged retrieves all posts on the given Collection that contain the
  1001. // given tag.
  1002. // It will return future posts if `includeFuture` is true.
  1003. // TODO: change includeFuture to isOwner, since that's how it's used
  1004. func (db *datastore) GetPostsTagged(cfg *config.Config, c *Collection, tag string, page int, includeFuture bool) (*[]PublicPost, error) {
  1005. collID := c.ID
  1006. cf := c.NewFormat()
  1007. order := "DESC"
  1008. if cf.Ascending() {
  1009. order = "ASC"
  1010. }
  1011. pagePosts := cf.PostsPerPage()
  1012. start := page*pagePosts - pagePosts
  1013. if page == 0 {
  1014. start = 0
  1015. pagePosts = 1000
  1016. }
  1017. limitStr := ""
  1018. if page > 0 {
  1019. limitStr = fmt.Sprintf(" LIMIT %d, %d", start, pagePosts)
  1020. }
  1021. timeCondition := ""
  1022. if !includeFuture {
  1023. timeCondition = "AND created <= " + db.now()
  1024. }
  1025. var rows *sql.Rows
  1026. var err error
  1027. if db.driverName == driverSQLite {
  1028. rows, err = db.Query("SELECT "+postCols+" FROM posts WHERE collection_id = ? AND LOWER(content) regexp ? "+timeCondition+" ORDER BY created "+order+limitStr, collID, `.*#`+strings.ToLower(tag)+`\b.*`)
  1029. } else {
  1030. rows, err = db.Query("SELECT "+postCols+" FROM posts WHERE collection_id = ? AND LOWER(content) RLIKE ? "+timeCondition+" ORDER BY created "+order+limitStr, collID, "#"+strings.ToLower(tag)+"[[:>:]]")
  1031. }
  1032. if err != nil {
  1033. log.Error("Failed selecting from posts: %v", err)
  1034. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve collection posts."}
  1035. }
  1036. defer rows.Close()
  1037. // TODO: extract this common row scanning logic for queries using `postCols`
  1038. posts := []PublicPost{}
  1039. for rows.Next() {
  1040. p := &Post{}
  1041. err = rows.Scan(&p.ID, &p.Slug, &p.Font, &p.Language, &p.RTL, &p.Privacy, &p.OwnerID, &p.CollectionID, &p.PinnedPosition, &p.Created, &p.Updated, &p.ViewCount, &p.Title, &p.Content)
  1042. if err != nil {
  1043. log.Error("Failed scanning row: %v", err)
  1044. break
  1045. }
  1046. p.extractData()
  1047. p.formatContent(cfg, c, includeFuture)
  1048. posts = append(posts, p.processPost())
  1049. }
  1050. err = rows.Err()
  1051. if err != nil {
  1052. log.Error("Error after Next() on rows: %v", err)
  1053. }
  1054. return &posts, nil
  1055. }
  1056. func (db *datastore) GetAPFollowers(c *Collection) (*[]RemoteUser, error) {
  1057. rows, err := db.Query("SELECT actor_id, inbox, shared_inbox FROM remotefollows f INNER JOIN remoteusers u ON f.remote_user_id = u.id WHERE collection_id = ?", c.ID)
  1058. if err != nil {
  1059. log.Error("Failed selecting from followers: %v", err)
  1060. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve followers."}
  1061. }
  1062. defer rows.Close()
  1063. followers := []RemoteUser{}
  1064. for rows.Next() {
  1065. f := RemoteUser{}
  1066. err = rows.Scan(&f.ActorID, &f.Inbox, &f.SharedInbox)
  1067. followers = append(followers, f)
  1068. }
  1069. return &followers, nil
  1070. }
  1071. // CanCollect returns whether or not the given user can add the given post to a
  1072. // collection. This is true when a post is already owned by the user.
  1073. // NOTE: this is currently only used to potentially add owned posts to a
  1074. // collection. This has the SIDE EFFECT of also generating a slug for the post.
  1075. // FIXME: make this side effect more explicit (or extract it)
  1076. func (db *datastore) CanCollect(cpr *ClaimPostRequest, userID int64) bool {
  1077. var title, content string
  1078. var lang sql.NullString
  1079. err := db.QueryRow("SELECT title, content, language FROM posts WHERE id = ? AND owner_id = ?", cpr.ID, userID).Scan(&title, &content, &lang)
  1080. switch {
  1081. case err == sql.ErrNoRows:
  1082. return false
  1083. case err != nil:
  1084. log.Error("Failed on post CanCollect(%s, %d): %v", cpr.ID, userID, err)
  1085. return false
  1086. }
  1087. // Since we have the post content and the post is collectable, generate the
  1088. // post's slug now.
  1089. cpr.Slug = getSlugFromPost(title, content, lang.String)
  1090. return true
  1091. }
  1092. func (db *datastore) AttemptClaim(p *ClaimPostRequest, query string, params []interface{}, slugIdx int) (sql.Result, error) {
  1093. qRes, err := db.Exec(query, params...)
  1094. if err != nil {
  1095. if db.isDuplicateKeyErr(err) && slugIdx > -1 {
  1096. s := id.GenSafeUniqueSlug(p.Slug)
  1097. if s == p.Slug {
  1098. // Sanity check to prevent infinite recursion
  1099. return qRes, fmt.Errorf("GenSafeUniqueSlug generated nothing unique: %s", s)
  1100. }
  1101. p.Slug = s
  1102. params[slugIdx] = p.Slug
  1103. return db.AttemptClaim(p, query, params, slugIdx)
  1104. }
  1105. return qRes, fmt.Errorf("attemptClaim: %s", err)
  1106. }
  1107. return qRes, nil
  1108. }
  1109. func (db *datastore) DispersePosts(userID int64, postIDs []string) (*[]ClaimPostResult, error) {
  1110. postClaimReqs := map[string]bool{}
  1111. res := []ClaimPostResult{}
  1112. for i := range postIDs {
  1113. postID := postIDs[i]
  1114. r := ClaimPostResult{Code: 0, ErrorMessage: ""}
  1115. // Perform post validation
  1116. if postID == "" {
  1117. r.ErrorMessage = "Missing post ID. "
  1118. }
  1119. if _, ok := postClaimReqs[postID]; ok {
  1120. r.Code = 429
  1121. r.ErrorMessage = "You've already tried anonymizing this post."
  1122. r.ID = postID
  1123. res = append(res, r)
  1124. continue
  1125. }
  1126. postClaimReqs[postID] = true
  1127. var err error
  1128. // Get full post information to return
  1129. var fullPost *PublicPost
  1130. fullPost, err = db.GetPost(postID, 0)
  1131. if err != nil {
  1132. if err, ok := err.(impart.HTTPError); ok {
  1133. r.Code = err.Status
  1134. r.ErrorMessage = err.Message
  1135. r.ID = postID
  1136. res = append(res, r)
  1137. continue
  1138. } else {
  1139. log.Error("Error getting post in dispersePosts: %v", err)
  1140. }
  1141. }
  1142. if fullPost.OwnerID.Int64 != userID {
  1143. r.Code = http.StatusConflict
  1144. r.ErrorMessage = "Post is already owned by someone else."
  1145. r.ID = postID
  1146. res = append(res, r)
  1147. continue
  1148. }
  1149. var qRes sql.Result
  1150. var query string
  1151. var params []interface{}
  1152. // Do AND owner_id = ? for sanity.
  1153. // This should've been caught and returned with a good error message
  1154. // just above.
  1155. query = "UPDATE posts SET collection_id = NULL WHERE id = ? AND owner_id = ?"
  1156. params = []interface{}{postID, userID}
  1157. qRes, err = db.Exec(query, params...)
  1158. if err != nil {
  1159. r.Code = http.StatusInternalServerError
  1160. r.ErrorMessage = "A glitch happened on our end."
  1161. r.ID = postID
  1162. res = append(res, r)
  1163. log.Error("dispersePosts (post %s): %v", postID, err)
  1164. continue
  1165. }
  1166. // Post was successfully dispersed
  1167. r.Code = http.StatusOK
  1168. r.Post = fullPost
  1169. rowsAffected, _ := qRes.RowsAffected()
  1170. if rowsAffected == 0 {
  1171. // This was already claimed, but return 200
  1172. r.Code = http.StatusOK
  1173. }
  1174. res = append(res, r)
  1175. }
  1176. return &res, nil
  1177. }
  1178. func (db *datastore) ClaimPosts(cfg *config.Config, userID int64, collAlias string, posts *[]ClaimPostRequest) (*[]ClaimPostResult, error) {
  1179. postClaimReqs := map[string]bool{}
  1180. res := []ClaimPostResult{}
  1181. postCollAlias := collAlias
  1182. for i := range *posts {
  1183. p := (*posts)[i]
  1184. if &p == nil {
  1185. continue
  1186. }
  1187. r := ClaimPostResult{Code: 0, ErrorMessage: ""}
  1188. // Perform post validation
  1189. if p.ID == "" {
  1190. r.ErrorMessage = "Missing post ID `id`. "
  1191. }
  1192. if _, ok := postClaimReqs[p.ID]; ok {
  1193. r.Code = 429
  1194. r.ErrorMessage = "You've already tried claiming this post."
  1195. r.ID = p.ID
  1196. res = append(res, r)
  1197. continue
  1198. }
  1199. postClaimReqs[p.ID] = true
  1200. canCollect := db.CanCollect(&p, userID)
  1201. if !canCollect && p.Token == "" {
  1202. // TODO: ensure post isn't owned by anyone else when a valid modify
  1203. // token is given.
  1204. r.ErrorMessage += "Missing post Edit Token `token`."
  1205. }
  1206. if r.ErrorMessage != "" {
  1207. // Post validate failed
  1208. r.Code = http.StatusBadRequest
  1209. r.ID = p.ID
  1210. res = append(res, r)
  1211. continue
  1212. }
  1213. var err error
  1214. var qRes sql.Result
  1215. var query string
  1216. var params []interface{}
  1217. var slugIdx int = -1
  1218. var coll *Collection
  1219. if collAlias == "" {
  1220. // Posts are being claimed at /posts/claim, not
  1221. // /collections/{alias}/collect, so use given individual collection
  1222. // to associate post with.
  1223. postCollAlias = p.CollectionAlias
  1224. }
  1225. if postCollAlias != "" {
  1226. // Associate this post with a collection
  1227. if p.CreateCollection {
  1228. // This is a new collection
  1229. // TODO: consider removing this. This seriously complicates this
  1230. // method and adds another (unnecessary?) logic path.
  1231. coll, err = db.CreateCollection(cfg, postCollAlias, "", userID)
  1232. if err != nil {
  1233. if err, ok := err.(impart.HTTPError); ok {
  1234. r.Code = err.Status
  1235. r.ErrorMessage = err.Message
  1236. } else {
  1237. r.Code = http.StatusInternalServerError
  1238. r.ErrorMessage = "Unknown error occurred creating collection"
  1239. }
  1240. r.ID = p.ID
  1241. res = append(res, r)
  1242. continue
  1243. }
  1244. } else {
  1245. // Attempt to add to existing collection
  1246. coll, err = db.GetCollection(postCollAlias)
  1247. if err != nil {
  1248. if err, ok := err.(impart.HTTPError); ok {
  1249. if err.Status == http.StatusNotFound {
  1250. // Show obfuscated "forbidden" response, as if attempting to add to an
  1251. // unowned blog.
  1252. r.Code = ErrForbiddenCollection.Status
  1253. r.ErrorMessage = ErrForbiddenCollection.Message
  1254. } else {
  1255. r.Code = err.Status
  1256. r.ErrorMessage = err.Message
  1257. }
  1258. } else {
  1259. r.Code = http.StatusInternalServerError
  1260. r.ErrorMessage = "Unknown error occurred claiming post with collection"
  1261. }
  1262. r.ID = p.ID
  1263. res = append(res, r)
  1264. continue
  1265. }
  1266. if coll.OwnerID != userID {
  1267. r.Code = ErrForbiddenCollection.Status
  1268. r.ErrorMessage = ErrForbiddenCollection.Message
  1269. r.ID = p.ID
  1270. res = append(res, r)
  1271. continue
  1272. }
  1273. }
  1274. if p.Slug == "" {
  1275. p.Slug = p.ID
  1276. }
  1277. if canCollect {
  1278. // User already owns this post, so just add it to the given
  1279. // collection.
  1280. query = "UPDATE posts SET collection_id = ?, slug = ? WHERE id = ? AND owner_id = ?"
  1281. params = []interface{}{coll.ID, p.Slug, p.ID, userID}
  1282. slugIdx = 1
  1283. } else {
  1284. query = "UPDATE posts SET owner_id = ?, collection_id = ?, slug = ? WHERE id = ? AND modify_token = ? AND owner_id IS NULL"
  1285. params = []interface{}{userID, coll.ID, p.Slug, p.ID, p.Token}
  1286. slugIdx = 2
  1287. }
  1288. } else {
  1289. query = "UPDATE posts SET owner_id = ? WHERE id = ? AND modify_token = ? AND owner_id IS NULL"
  1290. params = []interface{}{userID, p.ID, p.Token}
  1291. }
  1292. qRes, err = db.AttemptClaim(&p, query, params, slugIdx)
  1293. if err != nil {
  1294. r.Code = http.StatusInternalServerError
  1295. r.ErrorMessage = "An unknown error occurred."
  1296. r.ID = p.ID
  1297. res = append(res, r)
  1298. log.Error("claimPosts (post %s): %v", p.ID, err)
  1299. continue
  1300. }
  1301. // Get full post information to return
  1302. var fullPost *PublicPost
  1303. if p.Token != "" {
  1304. fullPost, err = db.GetEditablePost(p.ID, p.Token)
  1305. } else {
  1306. fullPost, err = db.GetPost(p.ID, 0)
  1307. }
  1308. if err != nil {
  1309. if err, ok := err.(impart.HTTPError); ok {
  1310. r.Code = err.Status
  1311. r.ErrorMessage = err.Message
  1312. r.ID = p.ID
  1313. res = append(res, r)
  1314. continue
  1315. }
  1316. }
  1317. if fullPost.OwnerID.Int64 != userID {
  1318. r.Code = http.StatusConflict
  1319. r.ErrorMessage = "Post is already owned by someone else."
  1320. r.ID = p.ID
  1321. res = append(res, r)
  1322. continue
  1323. }
  1324. // Post was successfully claimed
  1325. r.Code = http.StatusOK
  1326. r.Post = fullPost
  1327. if coll != nil {
  1328. r.Post.Collection = &CollectionObj{Collection: *coll}
  1329. }
  1330. rowsAffected, _ := qRes.RowsAffected()
  1331. if rowsAffected == 0 {
  1332. // This was already claimed, but return 200
  1333. r.Code = http.StatusOK
  1334. }
  1335. res = append(res, r)
  1336. }
  1337. return &res, nil
  1338. }
  1339. func (db *datastore) UpdatePostPinState(pinned bool, postID string, collID, ownerID, pos int64) error {
  1340. if pos <= 0 || pos > 20 {
  1341. pos = db.GetLastPinnedPostPos(collID) + 1
  1342. if pos == -1 {
  1343. pos = 1
  1344. }
  1345. }
  1346. var err error
  1347. if pinned {
  1348. _, err = db.Exec("UPDATE posts SET pinned_position = ? WHERE id = ?", pos, postID)
  1349. } else {
  1350. _, err = db.Exec("UPDATE posts SET pinned_position = NULL WHERE id = ?", postID)
  1351. }
  1352. if err != nil {
  1353. log.Error("Unable to update pinned post: %v", err)
  1354. return err
  1355. }
  1356. return nil
  1357. }
  1358. func (db *datastore) GetLastPinnedPostPos(collID int64) int64 {
  1359. var lastPos sql.NullInt64
  1360. err := db.QueryRow("SELECT MAX(pinned_position) FROM posts WHERE collection_id = ? AND pinned_position IS NOT NULL", collID).Scan(&lastPos)
  1361. switch {
  1362. case err == sql.ErrNoRows:
  1363. return -1
  1364. case err != nil:
  1365. log.Error("Failed selecting from posts: %v", err)
  1366. return -1
  1367. }
  1368. if !lastPos.Valid {
  1369. return -1
  1370. }
  1371. return lastPos.Int64
  1372. }
  1373. func (db *datastore) GetPinnedPosts(coll *CollectionObj, includeFuture bool) (*[]PublicPost, error) {
  1374. // FIXME: sqlite-backed instances don't include ellipsis on truncated titles
  1375. timeCondition := ""
  1376. if !includeFuture {
  1377. timeCondition = "AND created <= " + db.now()
  1378. }
  1379. rows, err := db.Query("SELECT id, slug, title, "+db.clip("content", 80)+", pinned_position FROM posts WHERE collection_id = ? AND pinned_position IS NOT NULL "+timeCondition+" ORDER BY pinned_position ASC", coll.ID)
  1380. if err != nil {
  1381. log.Error("Failed selecting pinned posts: %v", err)
  1382. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve pinned posts."}
  1383. }
  1384. defer rows.Close()
  1385. posts := []PublicPost{}
  1386. for rows.Next() {
  1387. p := &Post{}
  1388. err = rows.Scan(&p.ID, &p.Slug, &p.Title, &p.Content, &p.PinnedPosition)
  1389. if err != nil {
  1390. log.Error("Failed scanning row: %v", err)
  1391. break
  1392. }
  1393. p.extractData()
  1394. pp := p.processPost()
  1395. pp.Collection = coll
  1396. posts = append(posts, pp)
  1397. }
  1398. return &posts, nil
  1399. }
  1400. func (db *datastore) GetCollections(u *User, hostName string) (*[]Collection, error) {
  1401. rows, err := db.Query("SELECT id, alias, title, description, privacy, view_count FROM collections WHERE owner_id = ? ORDER BY id ASC", u.ID)
  1402. if err != nil {
  1403. log.Error("Failed selecting from collections: %v", err)
  1404. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user collections."}
  1405. }
  1406. defer rows.Close()
  1407. colls := []Collection{}
  1408. for rows.Next() {
  1409. c := Collection{}
  1410. err = rows.Scan(&c.ID, &c.Alias, &c.Title, &c.Description, &c.Visibility, &c.Views)
  1411. if err != nil {
  1412. log.Error("Failed scanning row: %v", err)
  1413. break
  1414. }
  1415. c.hostName = hostName
  1416. c.URL = c.CanonicalURL()
  1417. c.Public = c.IsPublic()
  1418. colls = append(colls, c)
  1419. }
  1420. err = rows.Err()
  1421. if err != nil {
  1422. log.Error("Error after Next() on rows: %v", err)
  1423. }
  1424. return &colls, nil
  1425. }
  1426. func (db *datastore) GetPublishableCollections(u *User, hostName string) (*[]Collection, error) {
  1427. c, err := db.GetCollections(u, hostName)
  1428. if err != nil {
  1429. return nil, err
  1430. }
  1431. if len(*c) == 0 {
  1432. return nil, impart.HTTPError{http.StatusInternalServerError, "You don't seem to have any blogs; they might've moved to another account. Try logging out and logging into your other account."}
  1433. }
  1434. return c, nil
  1435. }
  1436. func (db *datastore) GetMeStats(u *User) userMeStats {
  1437. s := userMeStats{}
  1438. // User counts
  1439. colls, _ := db.GetUserCollectionCount(u.ID)
  1440. s.TotalCollections = colls
  1441. var articles, collPosts uint64
  1442. err := db.QueryRow("SELECT COUNT(*) FROM posts WHERE owner_id = ? AND collection_id IS NULL", u.ID).Scan(&articles)
  1443. if err != nil && err != sql.ErrNoRows {
  1444. log.Error("Couldn't get articles count for user %d: %v", u.ID, err)
  1445. }
  1446. s.TotalArticles = articles
  1447. err = db.QueryRow("SELECT COUNT(*) FROM posts WHERE owner_id = ? AND collection_id IS NOT NULL", u.ID).Scan(&collPosts)
  1448. if err != nil && err != sql.ErrNoRows {
  1449. log.Error("Couldn't get coll posts count for user %d: %v", u.ID, err)
  1450. }
  1451. s.CollectionPosts = collPosts
  1452. return s
  1453. }
  1454. func (db *datastore) GetTotalCollections() (collCount int64, err error) {
  1455. err = db.QueryRow(`
  1456. SELECT COUNT(*)
  1457. FROM collections c
  1458. LEFT JOIN users u ON u.id = c.owner_id
  1459. WHERE u.status = 0`).Scan(&collCount)
  1460. if err != nil {
  1461. log.Error("Unable to fetch collections count: %v", err)
  1462. }
  1463. return
  1464. }
  1465. func (db *datastore) GetTotalPosts() (postCount int64, err error) {
  1466. err = db.QueryRow(`
  1467. SELECT COUNT(*)
  1468. FROM posts p
  1469. LEFT JOIN users u ON u.id = p.owner_id
  1470. WHERE u.status = 0`).Scan(&postCount)
  1471. if err != nil {
  1472. log.Error("Unable to fetch posts count: %v", err)
  1473. }
  1474. return
  1475. }
  1476. func (db *datastore) GetTopPosts(u *User, alias string) (*[]PublicPost, error) {
  1477. params := []interface{}{u.ID}
  1478. where := ""
  1479. if alias != "" {
  1480. where = " AND alias = ?"
  1481. params = append(params, alias)
  1482. }
  1483. rows, err := db.Query("SELECT p.id, p.slug, p.view_count, p.title, c.alias, c.title, c.description, c.view_count FROM posts p LEFT JOIN collections c ON p.collection_id = c.id WHERE p.owner_id = ?"+where+" ORDER BY p.view_count DESC, created DESC LIMIT 25", params...)
  1484. if err != nil {
  1485. log.Error("Failed selecting from posts: %v", err)
  1486. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user top posts."}
  1487. }
  1488. defer rows.Close()
  1489. posts := []PublicPost{}
  1490. var gotErr bool
  1491. for rows.Next() {
  1492. p := Post{}
  1493. c := Collection{}
  1494. var alias, title, description sql.NullString
  1495. var views sql.NullInt64
  1496. err = rows.Scan(&p.ID, &p.Slug, &p.ViewCount, &p.Title, &alias, &title, &description, &views)
  1497. if err != nil {
  1498. log.Error("Failed scanning User.getPosts() row: %v", err)
  1499. gotErr = true
  1500. break
  1501. }
  1502. p.extractData()
  1503. pubPost := p.processPost()
  1504. if alias.Valid && alias.String != "" {
  1505. c.Alias = alias.String
  1506. c.Title = title.String
  1507. c.Description = description.String
  1508. c.Views = views.Int64
  1509. pubPost.Collection = &CollectionObj{Collection: c}
  1510. }
  1511. posts = append(posts, pubPost)
  1512. }
  1513. err = rows.Err()
  1514. if err != nil {
  1515. log.Error("Error after Next() on rows: %v", err)
  1516. }
  1517. if gotErr && len(posts) == 0 {
  1518. // There were a lot of errors
  1519. return nil, impart.HTTPError{http.StatusInternalServerError, "Unable to get data."}
  1520. }
  1521. return &posts, nil
  1522. }
  1523. func (db *datastore) GetAnonymousPosts(u *User) (*[]PublicPost, error) {
  1524. rows, err := db.Query("SELECT id, view_count, title, created, updated, content FROM posts WHERE owner_id = ? AND collection_id IS NULL ORDER BY created DESC", u.ID)
  1525. if err != nil {
  1526. log.Error("Failed selecting from posts: %v", err)
  1527. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user anonymous posts."}
  1528. }
  1529. defer rows.Close()
  1530. posts := []PublicPost{}
  1531. for rows.Next() {
  1532. p := Post{}
  1533. err = rows.Scan(&p.ID, &p.ViewCount, &p.Title, &p.Created, &p.Updated, &p.Content)
  1534. if err != nil {
  1535. log.Error("Failed scanning row: %v", err)
  1536. break
  1537. }
  1538. p.extractData()
  1539. posts = append(posts, p.processPost())
  1540. }
  1541. err = rows.Err()
  1542. if err != nil {
  1543. log.Error("Error after Next() on rows: %v", err)
  1544. }
  1545. return &posts, nil
  1546. }
  1547. func (db *datastore) GetUserPosts(u *User) (*[]PublicPost, error) {
  1548. rows, err := db.Query("SELECT p.id, p.slug, p.view_count, p.title, p.created, p.updated, p.content, p.text_appearance, p.language, p.rtl, c.alias, c.title, c.description, c.view_count FROM posts p LEFT JOIN collections c ON collection_id = c.id WHERE p.owner_id = ? ORDER BY created ASC", u.ID)
  1549. if err != nil {
  1550. log.Error("Failed selecting from posts: %v", err)
  1551. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user posts."}
  1552. }
  1553. defer rows.Close()
  1554. posts := []PublicPost{}
  1555. var gotErr bool
  1556. for rows.Next() {
  1557. p := Post{}
  1558. c := Collection{}
  1559. var alias, title, description sql.NullString
  1560. var views sql.NullInt64
  1561. err = rows.Scan(&p.ID, &p.Slug, &p.ViewCount, &p.Title, &p.Created, &p.Updated, &p.Content, &p.Font, &p.Language, &p.RTL, &alias, &title, &description, &views)
  1562. if err != nil {
  1563. log.Error("Failed scanning User.getPosts() row: %v", err)
  1564. gotErr = true
  1565. break
  1566. }
  1567. p.extractData()
  1568. pubPost := p.processPost()
  1569. if alias.Valid && alias.String != "" {
  1570. c.Alias = alias.String
  1571. c.Title = title.String
  1572. c.Description = description.String
  1573. c.Views = views.Int64
  1574. pubPost.Collection = &CollectionObj{Collection: c}
  1575. }
  1576. posts = append(posts, pubPost)
  1577. }
  1578. err = rows.Err()
  1579. if err != nil {
  1580. log.Error("Error after Next() on rows: %v", err)
  1581. }
  1582. if gotErr && len(posts) == 0 {
  1583. // There were a lot of errors
  1584. return nil, impart.HTTPError{http.StatusInternalServerError, "Unable to get data."}
  1585. }
  1586. return &posts, nil
  1587. }
  1588. func (db *datastore) GetUserPostsCount(userID int64) int64 {
  1589. var count int64
  1590. err := db.QueryRow("SELECT COUNT(*) FROM posts WHERE owner_id = ?", userID).Scan(&count)
  1591. switch {
  1592. case err == sql.ErrNoRows:
  1593. return 0
  1594. case err != nil:
  1595. log.Error("Failed selecting posts count for user %d: %v", userID, err)
  1596. return 0
  1597. }
  1598. return count
  1599. }
  1600. // ChangeSettings takes a User and applies the changes in the given
  1601. // userSettings, MODIFYING THE USER with successful changes.
  1602. func (db *datastore) ChangeSettings(app *App, u *User, s *userSettings) error {
  1603. var errPass error
  1604. q := query.NewUpdate()
  1605. // Update email if given
  1606. if s.Email != "" {
  1607. encEmail, err := data.Encrypt(app.keys.EmailKey, s.Email)
  1608. if err != nil {
  1609. log.Error("Couldn't encrypt email %s: %s\n", s.Email, err)
  1610. return impart.HTTPError{http.StatusInternalServerError, "Unable to encrypt email address."}
  1611. }
  1612. q.SetBytes(encEmail, "email")
  1613. // Update the email if something goes awry updating the password
  1614. defer func() {
  1615. if errPass != nil {
  1616. db.UpdateEncryptedUserEmail(u.ID, encEmail)
  1617. }
  1618. }()
  1619. u.Email = zero.StringFrom(s.Email)
  1620. }
  1621. // Update username if given
  1622. var newUsername string
  1623. if s.Username != "" {
  1624. var ie *impart.HTTPError
  1625. newUsername, ie = getValidUsername(app, s.Username, u.Username)
  1626. if ie != nil {
  1627. // Username is invalid
  1628. return *ie
  1629. }
  1630. if !author.IsValidUsername(app.cfg, newUsername) {
  1631. // Ensure the username is syntactically correct.
  1632. return impart.HTTPError{http.StatusPreconditionFailed, "Username isn't valid."}
  1633. }
  1634. t, err := db.Begin()
  1635. if err != nil {
  1636. log.Error("Couldn't start username change transaction: %v", err)
  1637. return err
  1638. }
  1639. _, err = t.Exec("UPDATE users SET username = ? WHERE id = ?", newUsername, u.ID)
  1640. if err != nil {
  1641. t.Rollback()
  1642. if db.isDuplicateKeyErr(err) {
  1643. return impart.HTTPError{http.StatusConflict, "Username is already taken."}
  1644. }
  1645. log.Error("Unable to update users table: %v", err)
  1646. return ErrInternalGeneral
  1647. }
  1648. _, err = t.Exec("UPDATE collections SET alias = ? WHERE alias = ? AND owner_id = ?", newUsername, u.Username, u.ID)
  1649. if err != nil {
  1650. t.Rollback()
  1651. if db.isDuplicateKeyErr(err) {
  1652. return impart.HTTPError{http.StatusConflict, "Username is already taken."}
  1653. }
  1654. log.Error("Unable to update collection: %v", err)
  1655. return ErrInternalGeneral
  1656. }
  1657. // Keep track of name changes for redirection
  1658. db.RemoveCollectionRedirect(t, newUsername)
  1659. _, err = t.Exec("UPDATE collectionredirects SET new_alias = ? WHERE new_alias = ?", newUsername, u.Username)
  1660. if err != nil {
  1661. log.Error("Unable to update collectionredirects: %v", err)
  1662. }
  1663. _, err = t.Exec("INSERT INTO collectionredirects (prev_alias, new_alias) VALUES (?, ?)", u.Username, newUsername)
  1664. if err != nil {
  1665. log.Error("Unable to add new collectionredirect: %v", err)
  1666. }
  1667. err = t.Commit()
  1668. if err != nil {
  1669. t.Rollback()
  1670. log.Error("Rolling back after Commit(): %v\n", err)
  1671. return err
  1672. }
  1673. u.Username = newUsername
  1674. }
  1675. // Update passphrase if given
  1676. if s.NewPass != "" {
  1677. // Check if user has already set a password
  1678. var err error
  1679. u.HasPass, err = db.IsUserPassSet(u.ID)
  1680. if err != nil {
  1681. errPass = impart.HTTPError{http.StatusInternalServerError, "Unable to retrieve user data."}
  1682. return errPass
  1683. }
  1684. if u.HasPass {
  1685. // Check if currently-set password is correct
  1686. hashedPass := u.HashedPass
  1687. if len(hashedPass) == 0 {
  1688. authUser, err := db.GetUserForAuthByID(u.ID)
  1689. if err != nil {
  1690. errPass = err
  1691. return errPass
  1692. }
  1693. hashedPass = authUser.HashedPass
  1694. }
  1695. if !auth.Authenticated(hashedPass, []byte(s.OldPass)) {
  1696. errPass = impart.HTTPError{http.StatusUnauthorized, "Incorrect password."}
  1697. return errPass
  1698. }
  1699. }
  1700. hashedPass, err := auth.HashPass([]byte(s.NewPass))
  1701. if err != nil {
  1702. errPass = impart.HTTPError{http.StatusInternalServerError, "Could not create password hash."}
  1703. return errPass
  1704. }
  1705. q.SetBytes(hashedPass, "password")
  1706. }
  1707. // WHERE values
  1708. q.Append(u.ID)
  1709. if q.Updates == "" {
  1710. if s.Username == "" {
  1711. return ErrPostNoUpdatableVals
  1712. }
  1713. // Nothing to update except username. That was successful, so return now.
  1714. return nil
  1715. }
  1716. res, err := db.Exec("UPDATE users SET "+q.Updates+" WHERE id = ?", q.Params...)
  1717. if err != nil {
  1718. log.Error("Unable to update collection: %v", err)
  1719. return err
  1720. }
  1721. rowsAffected, _ := res.RowsAffected()
  1722. if rowsAffected == 0 {
  1723. // Show the correct error message if nothing was updated
  1724. var dummy int
  1725. err := db.QueryRow("SELECT 1 FROM users WHERE id = ?", u.ID).Scan(&dummy)
  1726. switch {
  1727. case err == sql.ErrNoRows:
  1728. return ErrUnauthorizedGeneral
  1729. case err != nil:
  1730. log.Error("Failed selecting from users: %v", err)
  1731. }
  1732. return nil
  1733. }
  1734. if s.NewPass != "" && !u.HasPass {
  1735. u.HasPass = true
  1736. }
  1737. return nil
  1738. }
  1739. func (db *datastore) ChangePassphrase(userID int64, sudo bool, curPass string, hashedPass []byte) error {
  1740. var dbPass []byte
  1741. err := db.QueryRow("SELECT password FROM users WHERE id = ?", userID).Scan(&dbPass)
  1742. switch {
  1743. case err == sql.ErrNoRows:
  1744. return ErrUserNotFound
  1745. case err != nil:
  1746. log.Error("Couldn't SELECT user password for change: %v", err)
  1747. return err
  1748. }
  1749. if !sudo && !auth.Authenticated(dbPass, []byte(curPass)) {
  1750. return impart.HTTPError{http.StatusUnauthorized, "Incorrect password."}
  1751. }
  1752. _, err = db.Exec("UPDATE users SET password = ? WHERE id = ?", hashedPass, userID)
  1753. if err != nil {
  1754. log.Error("Could not update passphrase: %v", err)
  1755. return err
  1756. }
  1757. return nil
  1758. }
  1759. func (db *datastore) RemoveCollectionRedirect(t *sql.Tx, alias string) error {
  1760. _, err := t.Exec("DELETE FROM collectionredirects WHERE prev_alias = ?", alias)
  1761. if err != nil {
  1762. log.Error("Unable to delete from collectionredirects: %v", err)
  1763. return err
  1764. }
  1765. return nil
  1766. }
  1767. func (db *datastore) GetCollectionRedirect(alias string) (new string) {
  1768. row := db.QueryRow("SELECT new_alias FROM collectionredirects WHERE prev_alias = ?", alias)
  1769. err := row.Scan(&new)
  1770. if err != nil && err != sql.ErrNoRows {
  1771. log.Error("Failed selecting from collectionredirects: %v", err)
  1772. }
  1773. return
  1774. }
  1775. func (db *datastore) DeleteCollection(alias string, userID int64) error {
  1776. c := &Collection{Alias: alias}
  1777. var username string
  1778. row := db.QueryRow("SELECT username FROM users WHERE id = ?", userID)
  1779. err := row.Scan(&username)
  1780. if err != nil {
  1781. return err
  1782. }
  1783. // Ensure user isn't deleting their main blog
  1784. if alias == username {
  1785. return impart.HTTPError{http.StatusForbidden, "You cannot currently delete your primary blog."}
  1786. }
  1787. row = db.QueryRow("SELECT id FROM collections WHERE alias = ? AND owner_id = ?", alias, userID)
  1788. err = row.Scan(&c.ID)
  1789. switch {
  1790. case err == sql.ErrNoRows:
  1791. return impart.HTTPError{http.StatusNotFound, "Collection doesn't exist or you're not allowed to delete it."}
  1792. case err != nil:
  1793. log.Error("Failed selecting from collections: %v", err)
  1794. return ErrInternalGeneral
  1795. }
  1796. t, err := db.Begin()
  1797. if err != nil {
  1798. return err
  1799. }
  1800. // Float all collection's posts
  1801. _, err = t.Exec("UPDATE posts SET collection_id = NULL WHERE collection_id = ? AND owner_id = ?", c.ID, userID)
  1802. if err != nil {
  1803. t.Rollback()
  1804. return err
  1805. }
  1806. // Remove redirects to or from this collection
  1807. _, err = t.Exec("DELETE FROM collectionredirects WHERE prev_alias = ? OR new_alias = ?", alias, alias)
  1808. if err != nil {
  1809. t.Rollback()
  1810. return err
  1811. }
  1812. // Remove any optional collection password
  1813. _, err = t.Exec("DELETE FROM collectionpasswords WHERE collection_id = ?", c.ID)
  1814. if err != nil {
  1815. t.Rollback()
  1816. return err
  1817. }
  1818. // Finally, delete collection itself
  1819. _, err = t.Exec("DELETE FROM collections WHERE id = ?", c.ID)
  1820. if err != nil {
  1821. t.Rollback()
  1822. return err
  1823. }
  1824. err = t.Commit()
  1825. if err != nil {
  1826. t.Rollback()
  1827. return err
  1828. }
  1829. return nil
  1830. }
  1831. func (db *datastore) IsCollectionAttributeOn(id int64, attr string) bool {
  1832. var v string
  1833. err := db.QueryRow("SELECT value FROM collectionattributes WHERE collection_id = ? AND attribute = ?", id, attr).Scan(&v)
  1834. switch {
  1835. case err == sql.ErrNoRows:
  1836. return false
  1837. case err != nil:
  1838. log.Error("Couldn't SELECT value in isCollectionAttributeOn for attribute '%s': %v", attr, err)
  1839. return false
  1840. }
  1841. return v == "1"
  1842. }
  1843. func (db *datastore) CollectionHasAttribute(id int64, attr string) bool {
  1844. var dummy string
  1845. err := db.QueryRow("SELECT value FROM collectionattributes WHERE collection_id = ? AND attribute = ?", id, attr).Scan(&dummy)
  1846. switch {
  1847. case err == sql.ErrNoRows:
  1848. return false
  1849. case err != nil:
  1850. log.Error("Couldn't SELECT value in collectionHasAttribute for attribute '%s': %v", attr, err)
  1851. return false
  1852. }
  1853. return true
  1854. }
  1855. func (db *datastore) DeleteAccount(userID int64) (l *string, err error) {
  1856. debug := ""
  1857. l = &debug
  1858. t, err := db.Begin()
  1859. if err != nil {
  1860. stringLogln(l, "Unable to begin: %v", err)
  1861. return
  1862. }
  1863. // Get all collections
  1864. rows, err := db.Query("SELECT id, alias FROM collections WHERE owner_id = ?", userID)
  1865. if err != nil {
  1866. t.Rollback()
  1867. stringLogln(l, "Unable to get collections: %v", err)
  1868. return
  1869. }
  1870. defer rows.Close()
  1871. colls := []Collection{}
  1872. var c Collection
  1873. for rows.Next() {
  1874. err = rows.Scan(&c.ID, &c.Alias)
  1875. if err != nil {
  1876. t.Rollback()
  1877. stringLogln(l, "Unable to scan collection cols: %v", err)
  1878. return
  1879. }
  1880. colls = append(colls, c)
  1881. }
  1882. var res sql.Result
  1883. for _, c := range colls {
  1884. // TODO: user deleteCollection() func
  1885. // Delete tokens
  1886. res, err = t.Exec("DELETE FROM collectionattributes WHERE collection_id = ?", c.ID)
  1887. if err != nil {
  1888. t.Rollback()
  1889. stringLogln(l, "Unable to delete attributes on %s: %v", c.Alias, err)
  1890. return
  1891. }
  1892. rs, _ := res.RowsAffected()
  1893. stringLogln(l, "Deleted %d for %s from collectionattributes", rs, c.Alias)
  1894. // Remove any optional collection password
  1895. res, err = t.Exec("DELETE FROM collectionpasswords WHERE collection_id = ?", c.ID)
  1896. if err != nil {
  1897. t.Rollback()
  1898. stringLogln(l, "Unable to delete passwords on %s: %v", c.Alias, err)
  1899. return
  1900. }
  1901. rs, _ = res.RowsAffected()
  1902. stringLogln(l, "Deleted %d for %s from collectionpasswords", rs, c.Alias)
  1903. // Remove redirects to this collection
  1904. res, err = t.Exec("DELETE FROM collectionredirects WHERE new_alias = ?", c.Alias)
  1905. if err != nil {
  1906. t.Rollback()
  1907. stringLogln(l, "Unable to delete redirects on %s: %v", c.Alias, err)
  1908. return
  1909. }
  1910. rs, _ = res.RowsAffected()
  1911. stringLogln(l, "Deleted %d for %s from collectionredirects", rs, c.Alias)
  1912. }
  1913. // Delete collections
  1914. res, err = t.Exec("DELETE FROM collections WHERE owner_id = ?", userID)
  1915. if err != nil {
  1916. t.Rollback()
  1917. stringLogln(l, "Unable to delete collections: %v", err)
  1918. return
  1919. }
  1920. rs, _ := res.RowsAffected()
  1921. stringLogln(l, "Deleted %d from collections", rs)
  1922. // Delete tokens
  1923. res, err = t.Exec("DELETE FROM accesstokens WHERE user_id = ?", userID)
  1924. if err != nil {
  1925. t.Rollback()
  1926. stringLogln(l, "Unable to delete access tokens: %v", err)
  1927. return
  1928. }
  1929. rs, _ = res.RowsAffected()
  1930. stringLogln(l, "Deleted %d from accesstokens", rs)
  1931. // Delete posts
  1932. res, err = t.Exec("DELETE FROM posts WHERE owner_id = ?", userID)
  1933. if err != nil {
  1934. t.Rollback()
  1935. stringLogln(l, "Unable to delete posts: %v", err)
  1936. return
  1937. }
  1938. rs, _ = res.RowsAffected()
  1939. stringLogln(l, "Deleted %d from posts", rs)
  1940. res, err = t.Exec("DELETE FROM userattributes WHERE user_id = ?", userID)
  1941. if err != nil {
  1942. t.Rollback()
  1943. stringLogln(l, "Unable to delete attributes: %v", err)
  1944. return
  1945. }
  1946. rs, _ = res.RowsAffected()
  1947. stringLogln(l, "Deleted %d from userattributes", rs)
  1948. res, err = t.Exec("DELETE FROM users WHERE id = ?", userID)
  1949. if err != nil {
  1950. t.Rollback()
  1951. stringLogln(l, "Unable to delete user: %v", err)
  1952. return
  1953. }
  1954. rs, _ = res.RowsAffected()
  1955. stringLogln(l, "Deleted %d from users", rs)
  1956. err = t.Commit()
  1957. if err != nil {
  1958. t.Rollback()
  1959. stringLogln(l, "Unable to commit: %v", err)
  1960. return
  1961. }
  1962. return
  1963. }
  1964. func (db *datastore) GetAPActorKeys(collectionID int64) ([]byte, []byte) {
  1965. var pub, priv []byte
  1966. err := db.QueryRow("SELECT public_key, private_key FROM collectionkeys WHERE collection_id = ?", collectionID).Scan(&pub, &priv)
  1967. switch {
  1968. case err == sql.ErrNoRows:
  1969. // Generate keys
  1970. pub, priv = activitypub.GenerateKeys()
  1971. _, err = db.Exec("INSERT INTO collectionkeys (collection_id, public_key, private_key) VALUES (?, ?, ?)", collectionID, pub, priv)
  1972. if err != nil {
  1973. log.Error("Unable to INSERT new activitypub keypair: %v", err)
  1974. return nil, nil
  1975. }
  1976. case err != nil:
  1977. log.Error("Couldn't SELECT collectionkeys: %v", err)
  1978. return nil, nil
  1979. }
  1980. return pub, priv
  1981. }
  1982. func (db *datastore) CreateUserInvite(id string, userID int64, maxUses int, expires *time.Time) error {
  1983. _, err := db.Exec("INSERT INTO userinvites (id, owner_id, max_uses, created, expires, inactive) VALUES (?, ?, ?, "+db.now()+", ?, 0)", id, userID, maxUses, expires)
  1984. return err
  1985. }
  1986. func (db *datastore) GetUserInvites(userID int64) (*[]Invite, error) {
  1987. rows, err := db.Query("SELECT id, max_uses, created, expires, inactive FROM userinvites WHERE owner_id = ? ORDER BY created DESC", userID)
  1988. if err != nil {
  1989. log.Error("Failed selecting from userinvites: %v", err)
  1990. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user invites."}
  1991. }
  1992. defer rows.Close()
  1993. is := []Invite{}
  1994. for rows.Next() {
  1995. i := Invite{}
  1996. err = rows.Scan(&i.ID, &i.MaxUses, &i.Created, &i.Expires, &i.Inactive)
  1997. is = append(is, i)
  1998. }
  1999. return &is, nil
  2000. }
  2001. func (db *datastore) GetUserInvite(id string) (*Invite, error) {
  2002. var i Invite
  2003. err := db.QueryRow("SELECT id, max_uses, created, expires, inactive FROM userinvites WHERE id = ?", id).Scan(&i.ID, &i.MaxUses, &i.Created, &i.Expires, &i.Inactive)
  2004. switch {
  2005. case err == sql.ErrNoRows:
  2006. return nil, impart.HTTPError{http.StatusNotFound, "Invite doesn't exist."}
  2007. case err != nil:
  2008. log.Error("Failed selecting invite: %v", err)
  2009. return nil, err
  2010. }
  2011. return &i, nil
  2012. }
  2013. // IsUsersInvite returns true if the user with ID created the invite with code
  2014. // and an error other than sql no rows, if any. Will return false in the event
  2015. // of an error.
  2016. func (db *datastore) IsUsersInvite(code string, userID int64) (bool, error) {
  2017. var id string
  2018. err := db.QueryRow("SELECT id FROM userinvites WHERE id = ? AND owner_id = ?", code, userID).Scan(&id)
  2019. if err != nil && err != sql.ErrNoRows {
  2020. log.Error("Failed selecting invite: %v", err)
  2021. return false, err
  2022. }
  2023. return id != "", nil
  2024. }
  2025. func (db *datastore) GetUsersInvitedCount(id string) int64 {
  2026. var count int64
  2027. err := db.QueryRow("SELECT COUNT(*) FROM usersinvited WHERE invite_id = ?", id).Scan(&count)
  2028. switch {
  2029. case err == sql.ErrNoRows:
  2030. return 0
  2031. case err != nil:
  2032. log.Error("Failed selecting users invited count: %v", err)
  2033. return 0
  2034. }
  2035. return count
  2036. }
  2037. func (db *datastore) CreateInvitedUser(inviteID string, userID int64) error {
  2038. _, err := db.Exec("INSERT INTO usersinvited (invite_id, user_id) VALUES (?, ?)", inviteID, userID)
  2039. return err
  2040. }
  2041. func (db *datastore) GetInstancePages() ([]*instanceContent, error) {
  2042. return db.GetAllDynamicContent("page")
  2043. }
  2044. func (db *datastore) GetAllDynamicContent(t string) ([]*instanceContent, error) {
  2045. where := ""
  2046. params := []interface{}{}
  2047. if t != "" {
  2048. where = " WHERE content_type = ?"
  2049. params = append(params, t)
  2050. }
  2051. rows, err := db.Query("SELECT id, title, content, updated, content_type FROM appcontent"+where, params...)
  2052. if err != nil {
  2053. log.Error("Failed selecting from appcontent: %v", err)
  2054. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve instance pages."}
  2055. }
  2056. defer rows.Close()
  2057. pages := []*instanceContent{}
  2058. for rows.Next() {
  2059. c := &instanceContent{}
  2060. err = rows.Scan(&c.ID, &c.Title, &c.Content, &c.Updated, &c.Type)
  2061. if err != nil {
  2062. log.Error("Failed scanning row: %v", err)
  2063. break
  2064. }
  2065. pages = append(pages, c)
  2066. }
  2067. err = rows.Err()
  2068. if err != nil {
  2069. log.Error("Error after Next() on rows: %v", err)
  2070. }
  2071. return pages, nil
  2072. }
  2073. func (db *datastore) GetDynamicContent(id string) (*instanceContent, error) {
  2074. c := &instanceContent{
  2075. ID: id,
  2076. }
  2077. err := db.QueryRow("SELECT title, content, updated, content_type FROM appcontent WHERE id = ?", id).Scan(&c.Title, &c.Content, &c.Updated, &c.Type)
  2078. switch {
  2079. case err == sql.ErrNoRows:
  2080. return nil, nil
  2081. case err != nil:
  2082. log.Error("Couldn't SELECT FROM appcontent for id '%s': %v", id, err)
  2083. return nil, err
  2084. }
  2085. return c, nil
  2086. }
  2087. func (db *datastore) UpdateDynamicContent(id, title, content, contentType string) error {
  2088. var err error
  2089. if db.driverName == driverSQLite {
  2090. _, err = db.Exec("INSERT OR REPLACE INTO appcontent (id, title, content, updated, content_type) VALUES (?, ?, ?, "+db.now()+", ?)", id, title, content, contentType)
  2091. } else {
  2092. _, err = db.Exec("INSERT INTO appcontent (id, title, content, updated, content_type) VALUES (?, ?, ?, "+db.now()+", ?) "+db.upsert("id")+" title = ?, content = ?, updated = "+db.now(), id, title, content, contentType, title, content)
  2093. }
  2094. if err != nil {
  2095. log.Error("Unable to INSERT appcontent for '%s': %v", id, err)
  2096. }
  2097. return err
  2098. }
  2099. func (db *datastore) GetAllUsers(page uint) (*[]User, error) {
  2100. limitStr := fmt.Sprintf("0, %d", adminUsersPerPage)
  2101. if page > 1 {
  2102. limitStr = fmt.Sprintf("%d, %d", (page-1)*adminUsersPerPage, adminUsersPerPage)
  2103. }
  2104. rows, err := db.Query("SELECT id, username, created, status FROM users ORDER BY created DESC LIMIT " + limitStr)
  2105. if err != nil {
  2106. log.Error("Failed selecting from users: %v", err)
  2107. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve all users."}
  2108. }
  2109. defer rows.Close()
  2110. users := []User{}
  2111. for rows.Next() {
  2112. u := User{}
  2113. err = rows.Scan(&u.ID, &u.Username, &u.Created, &u.Status)
  2114. if err != nil {
  2115. log.Error("Failed scanning GetAllUsers() row: %v", err)
  2116. break
  2117. }
  2118. users = append(users, u)
  2119. }
  2120. return &users, nil
  2121. }
  2122. func (db *datastore) GetAllUsersCount() int64 {
  2123. var count int64
  2124. err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)
  2125. switch {
  2126. case err == sql.ErrNoRows:
  2127. return 0
  2128. case err != nil:
  2129. log.Error("Failed selecting all users count: %v", err)
  2130. return 0
  2131. }
  2132. return count
  2133. }
  2134. func (db *datastore) GetUserLastPostTime(id int64) (*time.Time, error) {
  2135. var t time.Time
  2136. err := db.QueryRow("SELECT created FROM posts WHERE owner_id = ? ORDER BY created DESC LIMIT 1", id).Scan(&t)
  2137. switch {
  2138. case err == sql.ErrNoRows:
  2139. return nil, nil
  2140. case err != nil:
  2141. log.Error("Failed selecting last post time from posts: %v", err)
  2142. return nil, err
  2143. }
  2144. return &t, nil
  2145. }
  2146. // SetUserStatus changes a user's status in the database. see Users.UserStatus
  2147. func (db *datastore) SetUserStatus(id int64, status UserStatus) error {
  2148. _, err := db.Exec("UPDATE users SET status = ? WHERE id = ?", status, id)
  2149. if err != nil {
  2150. return fmt.Errorf("failed to update user status: %v", err)
  2151. }
  2152. return nil
  2153. }
  2154. func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) {
  2155. var t time.Time
  2156. err := db.QueryRow("SELECT created FROM posts WHERE collection_id = ? ORDER BY created DESC LIMIT 1", id).Scan(&t)
  2157. switch {
  2158. case err == sql.ErrNoRows:
  2159. return nil, nil
  2160. case err != nil:
  2161. log.Error("Failed selecting last post time from posts: %v", err)
  2162. return nil, err
  2163. }
  2164. return &t, nil
  2165. }
  2166. func (db *datastore) GenerateOAuthState(ctx context.Context) (string, error) {
  2167. state, err := randString(24)
  2168. if err != nil {
  2169. return "", err
  2170. }
  2171. _, err = db.ExecContext(ctx, "INSERT INTO oauth_client_state (state, used, created_at) VALUES (?, FALSE, NOW())", state)
  2172. if err != nil {
  2173. return "", fmt.Errorf("unable to record oauth client state: %w", err)
  2174. }
  2175. return state, nil
  2176. }
  2177. func (db *datastore) ValidateOAuthState(ctx context.Context, state string) error {
  2178. res, err := db.ExecContext(ctx, "UPDATE oauth_client_state SET used = TRUE WHERE state = ?", state)
  2179. if err != nil {
  2180. return err
  2181. }
  2182. rowsAffected, err := res.RowsAffected()
  2183. if err != nil {
  2184. return err
  2185. }
  2186. if rowsAffected != 1 {
  2187. return fmt.Errorf("state not found")
  2188. }
  2189. return nil
  2190. }
  2191. func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID, remoteUserID int64) error {
  2192. var err error
  2193. if db.driverName == driverSQLite {
  2194. _, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO users_oauth (user_id, remote_user_id) VALUES (?, ?)", localUserID, remoteUserID)
  2195. } else {
  2196. _, err = db.ExecContext(ctx, "INSERT INTO users_oauth (user_id, remote_user_id) VALUES (?, ?) "+db.upsert("user_id"), localUserID, remoteUserID)
  2197. }
  2198. if err != nil {
  2199. log.Error("Unable to INSERT users_oauth for '%d': %v", localUserID, err)
  2200. }
  2201. return err
  2202. }
  2203. // GetIDForRemoteUser returns a user ID associated with a remote user ID.
  2204. func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error) {
  2205. var userID int64 = -1
  2206. err := db.
  2207. QueryRowContext(ctx, "SELECT user_id FROM users_oauth WHERE remote_user_id = ?", remoteUserID).
  2208. Scan(&userID)
  2209. // Not finding a record is OK.
  2210. if err != nil && err != sql.ErrNoRows {
  2211. return -1, err
  2212. }
  2213. return userID, nil
  2214. }
  2215. // DatabaseInitialized returns whether or not the current datastore has been
  2216. // initialized with the correct schema.
  2217. // Currently, it checks to see if the `users` table exists.
  2218. func (db *datastore) DatabaseInitialized() bool {
  2219. var dummy string
  2220. var err error
  2221. if db.driverName == driverSQLite {
  2222. err = db.QueryRow("SELECT name FROM sqlite_master WHERE type = 'table' AND name = 'users'").Scan(&dummy)
  2223. } else {
  2224. err = db.QueryRow("SHOW TABLES LIKE 'users'").Scan(&dummy)
  2225. }
  2226. switch {
  2227. case err == sql.ErrNoRows:
  2228. return false
  2229. case err != nil:
  2230. log.Error("Couldn't SHOW TABLES: %v", err)
  2231. return false
  2232. }
  2233. return true
  2234. }
  2235. func stringLogln(log *string, s string, v ...interface{}) {
  2236. *log += fmt.Sprintf(s+"\n", v...)
  2237. }
  2238. func handleFailedPostInsert(err error) error {
  2239. log.Error("Couldn't insert into posts: %v", err)
  2240. return err
  2241. }
  2242. func randString(length int) (string, error) {
  2243. // every printable character on a US keyboard
  2244. charset := []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789")
  2245. out := make([]rune, length)
  2246. setLen := big.NewInt(int64(len(charset)))
  2247. for idx := 0; idx < length; idx++ {
  2248. offset, err := rand.Int(rand.Reader, setLen)
  2249. if err != nil {
  2250. return "", err
  2251. }
  2252. if !offset.IsUint64() {
  2253. // this should (in theory) never happen
  2254. return "", errors.Errorf("Non-Uint64 offset returned from rand.Int")
  2255. }
  2256. out[idx] = charset[offset.Uint64()]
  2257. }
  2258. return string(out), nil
  2259. }