A clean, Markdown-based publishing platform made for writers. Write together, and build a community. https://writefreely.org
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

2571 lines
77 KiB

  1. /*
  2. * Copyright © 2018 A Bunch Tell LLC.
  3. *
  4. * This file is part of WriteFreely.
  5. *
  6. * WriteFreely is free software: you can redistribute it and/or modify
  7. * it under the terms of the GNU Affero General Public License, included
  8. * in the LICENSE file in this source code package.
  9. */
  10. package writefreely
  11. import (
  12. "context"
  13. "crypto/rand"
  14. "database/sql"
  15. "fmt"
  16. "github.com/pkg/errors"
  17. "math/big"
  18. "net/http"
  19. "strings"
  20. "time"
  21. "github.com/guregu/null"
  22. "github.com/guregu/null/zero"
  23. uuid "github.com/nu7hatch/gouuid"
  24. "github.com/writeas/impart"
  25. "github.com/writeas/nerds/store"
  26. "github.com/writeas/web-core/activitypub"
  27. "github.com/writeas/web-core/auth"
  28. "github.com/writeas/web-core/data"
  29. "github.com/writeas/web-core/id"
  30. "github.com/writeas/web-core/log"
  31. "github.com/writeas/web-core/query"
  32. "github.com/writeas/writefreely/author"
  33. "github.com/writeas/writefreely/config"
  34. "github.com/writeas/writefreely/key"
  35. )
  36. const (
  37. mySQLErrDuplicateKey = 1062
  38. driverMySQL = "mysql"
  39. driverSQLite = "sqlite3"
  40. )
  41. var (
  42. SQLiteEnabled bool
  43. )
  44. type writestore interface {
  45. CreateUser(*config.Config, *User, string) error
  46. UpdateUserEmail(keys *key.Keychain, userID int64, email string) error
  47. UpdateEncryptedUserEmail(int64, []byte) error
  48. GetUserByID(int64) (*User, error)
  49. GetUserForAuth(string) (*User, error)
  50. GetUserForAuthByID(int64) (*User, error)
  51. GetUserNameFromToken(string) (string, error)
  52. GetUserDataFromToken(string) (int64, string, error)
  53. GetAPIUser(header string) (*User, error)
  54. GetUserID(accessToken string) int64
  55. GetUserIDPrivilege(accessToken string) (userID int64, sudo bool)
  56. DeleteToken(accessToken []byte) error
  57. FetchLastAccessToken(userID int64) string
  58. GetAccessToken(userID int64) (string, error)
  59. GetTemporaryAccessToken(userID int64, validSecs int) (string, error)
  60. GetTemporaryOneTimeAccessToken(userID int64, validSecs int, oneTime bool) (string, error)
  61. DeleteAccount(userID int64) (l *string, err error)
  62. ChangeSettings(app *App, u *User, s *userSettings) error
  63. ChangePassphrase(userID int64, sudo bool, curPass string, hashedPass []byte) error
  64. GetCollections(u *User, hostName string) (*[]Collection, error)
  65. GetPublishableCollections(u *User, hostName string) (*[]Collection, error)
  66. GetMeStats(u *User) userMeStats
  67. GetTotalCollections() (int64, error)
  68. GetTotalPosts() (int64, error)
  69. GetTopPosts(u *User, alias string) (*[]PublicPost, error)
  70. GetAnonymousPosts(u *User) (*[]PublicPost, error)
  71. GetUserPosts(u *User) (*[]PublicPost, error)
  72. CreateOwnedPost(post *SubmittedPost, accessToken, collAlias, hostName string) (*PublicPost, error)
  73. CreatePost(userID, collID int64, post *SubmittedPost) (*Post, error)
  74. UpdateOwnedPost(post *AuthenticatedPost, userID int64) error
  75. GetEditablePost(id, editToken string) (*PublicPost, error)
  76. PostIDExists(id string) bool
  77. GetPost(id string, collectionID int64) (*PublicPost, error)
  78. GetOwnedPost(id string, ownerID int64) (*PublicPost, error)
  79. GetPostProperty(id string, collectionID int64, property string) (interface{}, error)
  80. CreateCollectionFromToken(*config.Config, string, string, string) (*Collection, error)
  81. CreateCollection(*config.Config, string, string, int64) (*Collection, error)
  82. GetCollectionBy(condition string, value interface{}) (*Collection, error)
  83. GetCollection(alias string) (*Collection, error)
  84. GetCollectionForPad(alias string) (*Collection, error)
  85. GetCollectionByID(id int64) (*Collection, error)
  86. UpdateCollection(c *SubmittedCollection, alias string) error
  87. DeleteCollection(alias string, userID int64) error
  88. UpdatePostPinState(pinned bool, postID string, collID, ownerID, pos int64) error
  89. GetLastPinnedPostPos(collID int64) int64
  90. GetPinnedPosts(coll *CollectionObj, includeFuture bool) (*[]PublicPost, error)
  91. RemoveCollectionRedirect(t *sql.Tx, alias string) error
  92. GetCollectionRedirect(alias string) (new string)
  93. IsCollectionAttributeOn(id int64, attr string) bool
  94. CollectionHasAttribute(id int64, attr string) bool
  95. CanCollect(cpr *ClaimPostRequest, userID int64) bool
  96. AttemptClaim(p *ClaimPostRequest, query string, params []interface{}, slugIdx int) (sql.Result, error)
  97. DispersePosts(userID int64, postIDs []string) (*[]ClaimPostResult, error)
  98. ClaimPosts(cfg *config.Config, userID int64, collAlias string, posts *[]ClaimPostRequest) (*[]ClaimPostResult, error)
  99. GetPostsCount(c *CollectionObj, includeFuture bool)
  100. GetPosts(cfg *config.Config, c *Collection, page int, includeFuture, forceRecentFirst, includePinned bool) (*[]PublicPost, error)
  101. GetPostsTagged(cfg *config.Config, c *Collection, tag string, page int, includeFuture bool) (*[]PublicPost, error)
  102. GetAPFollowers(c *Collection) (*[]RemoteUser, error)
  103. GetAPActorKeys(collectionID int64) ([]byte, []byte)
  104. CreateUserInvite(id string, userID int64, maxUses int, expires *time.Time) error
  105. GetUserInvites(userID int64) (*[]Invite, error)
  106. GetUserInvite(id string) (*Invite, error)
  107. GetUsersInvitedCount(id string) int64
  108. CreateInvitedUser(inviteID string, userID int64) error
  109. GetDynamicContent(id string) (*instanceContent, error)
  110. UpdateDynamicContent(id, title, content, contentType string) error
  111. GetAllUsers(page uint) (*[]User, error)
  112. GetAllUsersCount() int64
  113. GetUserLastPostTime(id int64) (*time.Time, error)
  114. GetCollectionLastPostTime(id int64) (*time.Time, error)
  115. GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error)
  116. RecordRemoteUserID(ctx context.Context, localUserID, remoteUserID int64) error
  117. ValidateOAuthState(ctx context.Context, state string) error
  118. GenerateOAuthState(ctx context.Context) (string, error)
  119. DatabaseInitialized() bool
  120. }
  121. type datastore struct {
  122. *sql.DB
  123. driverName string
  124. }
  125. func (db *datastore) now() string {
  126. if db.driverName == driverSQLite {
  127. return "strftime('%Y-%m-%d %H:%M:%S','now')"
  128. }
  129. return "NOW()"
  130. }
  131. func (db *datastore) clip(field string, l int) string {
  132. if db.driverName == driverSQLite {
  133. return fmt.Sprintf("SUBSTR(%s, 0, %d)", field, l)
  134. }
  135. return fmt.Sprintf("LEFT(%s, %d)", field, l)
  136. }
  137. func (db *datastore) upsert(indexedCols ...string) string {
  138. if db.driverName == driverSQLite {
  139. // NOTE: SQLite UPSERT syntax only works in v3.24.0 (2018-06-04) or later
  140. // Leaving this for whenever we can upgrade and include it in our binary
  141. cc := strings.Join(indexedCols, ", ")
  142. return "ON CONFLICT(" + cc + ") DO UPDATE SET"
  143. }
  144. return "ON DUPLICATE KEY UPDATE"
  145. }
  146. func (db *datastore) dateSub(l int, unit string) string {
  147. if db.driverName == driverSQLite {
  148. return fmt.Sprintf("DATETIME('now', '-%d %s')", l, unit)
  149. }
  150. return fmt.Sprintf("DATE_SUB(NOW(), INTERVAL %d %s)", l, unit)
  151. }
  152. func (db *datastore) CreateUser(cfg *config.Config, u *User, collectionTitle string) error {
  153. if db.PostIDExists(u.Username) {
  154. return impart.HTTPError{http.StatusConflict, "Invalid collection name."}
  155. }
  156. // New users get a `users` and `collections` row.
  157. t, err := db.Begin()
  158. if err != nil {
  159. return err
  160. }
  161. // 1. Add to `users` table
  162. // NOTE: Assumes User's Password is already hashed!
  163. res, err := t.Exec("INSERT INTO users (username, password, email) VALUES (?, ?, ?)", u.Username, u.HashedPass, u.Email)
  164. if err != nil {
  165. t.Rollback()
  166. if db.isDuplicateKeyErr(err) {
  167. return impart.HTTPError{http.StatusConflict, "Username is already taken."}
  168. }
  169. log.Error("Rolling back users INSERT: %v\n", err)
  170. return err
  171. }
  172. u.ID, err = res.LastInsertId()
  173. if err != nil {
  174. t.Rollback()
  175. log.Error("Rolling back after LastInsertId: %v\n", err)
  176. return err
  177. }
  178. // 2. Create user's Collection
  179. if collectionTitle == "" {
  180. collectionTitle = u.Username
  181. }
  182. res, err = t.Exec("INSERT INTO collections (alias, title, description, privacy, owner_id, view_count) VALUES (?, ?, ?, ?, ?, ?)", u.Username, collectionTitle, "", defaultVisibility(cfg), u.ID, 0)
  183. if err != nil {
  184. t.Rollback()
  185. if db.isDuplicateKeyErr(err) {
  186. return impart.HTTPError{http.StatusConflict, "Username is already taken."}
  187. }
  188. log.Error("Rolling back collections INSERT: %v\n", err)
  189. return err
  190. }
  191. db.RemoveCollectionRedirect(t, u.Username)
  192. err = t.Commit()
  193. if err != nil {
  194. t.Rollback()
  195. log.Error("Rolling back after Commit(): %v\n", err)
  196. return err
  197. }
  198. return nil
  199. }
  200. // FIXME: We're returning errors inconsistently in this file. Do we use Errorf
  201. // for returned value, or impart?
  202. func (db *datastore) UpdateUserEmail(keys *key.Keychain, userID int64, email string) error {
  203. encEmail, err := data.Encrypt(keys.EmailKey, email)
  204. if err != nil {
  205. return fmt.Errorf("Couldn't encrypt email %s: %s\n", email, err)
  206. }
  207. return db.UpdateEncryptedUserEmail(userID, encEmail)
  208. }
  209. func (db *datastore) UpdateEncryptedUserEmail(userID int64, encEmail []byte) error {
  210. _, err := db.Exec("UPDATE users SET email = ? WHERE id = ?", encEmail, userID)
  211. if err != nil {
  212. return fmt.Errorf("Unable to update user email: %s", err)
  213. }
  214. return nil
  215. }
  216. func (db *datastore) CreateCollectionFromToken(cfg *config.Config, alias, title, accessToken string) (*Collection, error) {
  217. userID := db.GetUserID(accessToken)
  218. if userID == -1 {
  219. return nil, ErrBadAccessToken
  220. }
  221. return db.CreateCollection(cfg, alias, title, userID)
  222. }
  223. func (db *datastore) GetUserCollectionCount(userID int64) (uint64, error) {
  224. var collCount uint64
  225. err := db.QueryRow("SELECT COUNT(*) FROM collections WHERE owner_id = ?", userID).Scan(&collCount)
  226. switch {
  227. case err == sql.ErrNoRows:
  228. return 0, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user from database."}
  229. case err != nil:
  230. log.Error("Couldn't get collections count for user %d: %v", userID, err)
  231. return 0, err
  232. }
  233. return collCount, nil
  234. }
  235. func (db *datastore) CreateCollection(cfg *config.Config, alias, title string, userID int64) (*Collection, error) {
  236. if db.PostIDExists(alias) {
  237. return nil, impart.HTTPError{http.StatusConflict, "Invalid collection name."}
  238. }
  239. // All good, so create new collection
  240. res, err := db.Exec("INSERT INTO collections (alias, title, description, privacy, owner_id, view_count) VALUES (?, ?, ?, ?, ?, ?)", alias, title, "", defaultVisibility(cfg), userID, 0)
  241. if err != nil {
  242. if db.isDuplicateKeyErr(err) {
  243. return nil, impart.HTTPError{http.StatusConflict, "Collection already exists."}
  244. }
  245. log.Error("Couldn't add to collections: %v\n", err)
  246. return nil, err
  247. }
  248. c := &Collection{
  249. Alias: alias,
  250. Title: title,
  251. OwnerID: userID,
  252. PublicOwner: false,
  253. Public: defaultVisibility(cfg) == CollPublic,
  254. }
  255. c.ID, err = res.LastInsertId()
  256. if err != nil {
  257. log.Error("Couldn't get collection LastInsertId: %v\n", err)
  258. }
  259. return c, nil
  260. }
  261. func (db *datastore) GetUserByID(id int64) (*User, error) {
  262. u := &User{ID: id}
  263. err := db.QueryRow("SELECT username, password, email, created, status FROM users WHERE id = ?", id).Scan(&u.Username, &u.HashedPass, &u.Email, &u.Created, &u.Status)
  264. switch {
  265. case err == sql.ErrNoRows:
  266. return nil, ErrUserNotFound
  267. case err != nil:
  268. log.Error("Couldn't SELECT user password: %v", err)
  269. return nil, err
  270. }
  271. return u, nil
  272. }
  273. // IsUserSuspended returns true if the user account associated with id is
  274. // currently suspended.
  275. func (db *datastore) IsUserSuspended(id int64) (bool, error) {
  276. u := &User{ID: id}
  277. err := db.QueryRow("SELECT status FROM users WHERE id = ?", id).Scan(&u.Status)
  278. switch {
  279. case err == sql.ErrNoRows:
  280. return false, fmt.Errorf("is user suspended: %v", ErrUserNotFound)
  281. case err != nil:
  282. log.Error("Couldn't SELECT user password: %v", err)
  283. return false, fmt.Errorf("is user suspended: %v", err)
  284. }
  285. return u.IsSilenced(), nil
  286. }
  287. // DoesUserNeedAuth returns true if the user hasn't provided any methods for
  288. // authenticating with the account, such a passphrase or email address.
  289. // Any errors are reported to admin and silently quashed, returning false as the
  290. // result.
  291. func (db *datastore) DoesUserNeedAuth(id int64) bool {
  292. var pass, email []byte
  293. // Find out if user has an email set first
  294. err := db.QueryRow("SELECT password, email FROM users WHERE id = ?", id).Scan(&pass, &email)
  295. switch {
  296. case err == sql.ErrNoRows:
  297. // ERROR. Don't give false positives on needing auth methods
  298. return false
  299. case err != nil:
  300. // ERROR. Don't give false positives on needing auth methods
  301. log.Error("Couldn't SELECT user %d from users: %v", id, err)
  302. return false
  303. }
  304. // User doesn't need auth if there's an email
  305. return len(email) == 0 && len(pass) == 0
  306. }
  307. func (db *datastore) IsUserPassSet(id int64) (bool, error) {
  308. var pass []byte
  309. err := db.QueryRow("SELECT password FROM users WHERE id = ?", id).Scan(&pass)
  310. switch {
  311. case err == sql.ErrNoRows:
  312. return false, nil
  313. case err != nil:
  314. log.Error("Couldn't SELECT user %d from users: %v", id, err)
  315. return false, err
  316. }
  317. return len(pass) > 0, nil
  318. }
  319. func (db *datastore) GetUserForAuth(username string) (*User, error) {
  320. u := &User{Username: username}
  321. err := db.QueryRow("SELECT id, password, email, created, status FROM users WHERE username = ?", username).Scan(&u.ID, &u.HashedPass, &u.Email, &u.Created, &u.Status)
  322. switch {
  323. case err == sql.ErrNoRows:
  324. // Check if they've entered the wrong, unnormalized username
  325. username = getSlug(username, "")
  326. if username != u.Username {
  327. err = db.QueryRow("SELECT id FROM users WHERE username = ? LIMIT 1", username).Scan(&u.ID)
  328. if err == nil {
  329. return db.GetUserForAuth(username)
  330. }
  331. }
  332. return nil, ErrUserNotFound
  333. case err != nil:
  334. log.Error("Couldn't SELECT user password: %v", err)
  335. return nil, err
  336. }
  337. return u, nil
  338. }
  339. func (db *datastore) GetUserForAuthByID(userID int64) (*User, error) {
  340. u := &User{ID: userID}
  341. err := db.QueryRow("SELECT id, password, email, created, status FROM users WHERE id = ?", u.ID).Scan(&u.ID, &u.HashedPass, &u.Email, &u.Created, &u.Status)
  342. switch {
  343. case err == sql.ErrNoRows:
  344. return nil, ErrUserNotFound
  345. case err != nil:
  346. log.Error("Couldn't SELECT userForAuthByID: %v", err)
  347. return nil, err
  348. }
  349. return u, nil
  350. }
  351. func (db *datastore) GetUserNameFromToken(accessToken string) (string, error) {
  352. t := auth.GetToken(accessToken)
  353. if len(t) == 0 {
  354. return "", ErrNoAccessToken
  355. }
  356. var oneTime bool
  357. var username string
  358. err := db.QueryRow("SELECT username, one_time FROM accesstokens LEFT JOIN users ON user_id = id WHERE token LIKE ? AND (expires IS NULL OR expires > "+db.now()+")", t).Scan(&username, &oneTime)
  359. switch {
  360. case err == sql.ErrNoRows:
  361. return "", ErrBadAccessToken
  362. case err != nil:
  363. return "", ErrInternalGeneral
  364. }
  365. // Delete token if it was one-time
  366. if oneTime {
  367. db.DeleteToken(t[:])
  368. }
  369. return username, nil
  370. }
  371. func (db *datastore) GetUserDataFromToken(accessToken string) (int64, string, error) {
  372. t := auth.GetToken(accessToken)
  373. if len(t) == 0 {
  374. return 0, "", ErrNoAccessToken
  375. }
  376. var userID int64
  377. var oneTime bool
  378. var username string
  379. err := db.QueryRow("SELECT user_id, username, one_time FROM accesstokens LEFT JOIN users ON user_id = id WHERE token LIKE ? AND (expires IS NULL OR expires > "+db.now()+")", t).Scan(&userID, &username, &oneTime)
  380. switch {
  381. case err == sql.ErrNoRows:
  382. return 0, "", ErrBadAccessToken
  383. case err != nil:
  384. return 0, "", ErrInternalGeneral
  385. }
  386. // Delete token if it was one-time
  387. if oneTime {
  388. db.DeleteToken(t[:])
  389. }
  390. return userID, username, nil
  391. }
  392. func (db *datastore) GetAPIUser(header string) (*User, error) {
  393. uID := db.GetUserID(header)
  394. if uID == -1 {
  395. return nil, fmt.Errorf(ErrUserNotFound.Error())
  396. }
  397. return db.GetUserByID(uID)
  398. }
  399. // GetUserID takes a hexadecimal accessToken, parses it into its binary
  400. // representation, and gets any user ID associated with the token. If no user
  401. // is associated, -1 is returned.
  402. func (db *datastore) GetUserID(accessToken string) int64 {
  403. i, _ := db.GetUserIDPrivilege(accessToken)
  404. return i
  405. }
  406. func (db *datastore) GetUserIDPrivilege(accessToken string) (userID int64, sudo bool) {
  407. t := auth.GetToken(accessToken)
  408. if len(t) == 0 {
  409. return -1, false
  410. }
  411. var oneTime bool
  412. err := db.QueryRow("SELECT user_id, sudo, one_time FROM accesstokens WHERE token LIKE ? AND (expires IS NULL OR expires > "+db.now()+")", t).Scan(&userID, &sudo, &oneTime)
  413. switch {
  414. case err == sql.ErrNoRows:
  415. return -1, false
  416. case err != nil:
  417. return -1, false
  418. }
  419. // Delete token if it was one-time
  420. if oneTime {
  421. db.DeleteToken(t[:])
  422. }
  423. return
  424. }
  425. func (db *datastore) DeleteToken(accessToken []byte) error {
  426. res, err := db.Exec("DELETE FROM accesstokens WHERE token LIKE ?", accessToken)
  427. if err != nil {
  428. return err
  429. }
  430. rowsAffected, _ := res.RowsAffected()
  431. if rowsAffected == 0 {
  432. return impart.HTTPError{http.StatusNotFound, "Token is invalid or doesn't exist"}
  433. }
  434. return nil
  435. }
  436. // FetchLastAccessToken creates a new non-expiring, valid access token for the given
  437. // userID.
  438. func (db *datastore) FetchLastAccessToken(userID int64) string {
  439. var t []byte
  440. err := db.QueryRow("SELECT token FROM accesstokens WHERE user_id = ? AND (expires IS NULL OR expires > "+db.now()+") ORDER BY created DESC LIMIT 1", userID).Scan(&t)
  441. switch {
  442. case err == sql.ErrNoRows:
  443. return ""
  444. case err != nil:
  445. log.Error("Failed selecting from accesstoken: %v", err)
  446. return ""
  447. }
  448. u, err := uuid.Parse(t)
  449. if err != nil {
  450. return ""
  451. }
  452. return u.String()
  453. }
  454. // GetAccessToken creates a new non-expiring, valid access token for the given
  455. // userID.
  456. func (db *datastore) GetAccessToken(userID int64) (string, error) {
  457. return db.GetTemporaryOneTimeAccessToken(userID, 0, false)
  458. }
  459. // GetTemporaryAccessToken creates a new valid access token for the given
  460. // userID that remains valid for the given time in seconds. If validSecs is 0,
  461. // the access token doesn't automatically expire.
  462. func (db *datastore) GetTemporaryAccessToken(userID int64, validSecs int) (string, error) {
  463. return db.GetTemporaryOneTimeAccessToken(userID, validSecs, false)
  464. }
  465. // GetTemporaryOneTimeAccessToken creates a new valid access token for the given
  466. // userID that remains valid for the given time in seconds and can only be used
  467. // once if oneTime is true. If validSecs is 0, the access token doesn't
  468. // automatically expire.
  469. func (db *datastore) GetTemporaryOneTimeAccessToken(userID int64, validSecs int, oneTime bool) (string, error) {
  470. u, err := uuid.NewV4()
  471. if err != nil {
  472. log.Error("Unable to generate token: %v", err)
  473. return "", err
  474. }
  475. // Insert UUID to `accesstokens`
  476. binTok := u[:]
  477. expirationVal := "NULL"
  478. if validSecs > 0 {
  479. expirationVal = fmt.Sprintf("DATE_ADD("+db.now()+", INTERVAL %d SECOND)", validSecs)
  480. }
  481. _, err = db.Exec("INSERT INTO accesstokens (token, user_id, one_time, expires) VALUES (?, ?, ?, "+expirationVal+")", string(binTok), userID, oneTime)
  482. if err != nil {
  483. log.Error("Couldn't INSERT accesstoken: %v", err)
  484. return "", err
  485. }
  486. return u.String(), nil
  487. }
  488. func (db *datastore) CreateOwnedPost(post *SubmittedPost, accessToken, collAlias, hostName string) (*PublicPost, error) {
  489. var userID, collID int64 = -1, -1
  490. var coll *Collection
  491. var err error
  492. if accessToken != "" {
  493. userID = db.GetUserID(accessToken)
  494. if userID == -1 {
  495. return nil, ErrBadAccessToken
  496. }
  497. if collAlias != "" {
  498. coll, err = db.GetCollection(collAlias)
  499. if err != nil {
  500. return nil, err
  501. }
  502. coll.hostName = hostName
  503. if coll.OwnerID != userID {
  504. return nil, ErrForbiddenCollection
  505. }
  506. collID = coll.ID
  507. }
  508. }
  509. rp := &PublicPost{}
  510. rp.Post, err = db.CreatePost(userID, collID, post)
  511. if err != nil {
  512. return rp, err
  513. }
  514. if coll != nil {
  515. coll.ForPublic()
  516. rp.Collection = &CollectionObj{Collection: *coll}
  517. }
  518. return rp, nil
  519. }
  520. func (db *datastore) CreatePost(userID, collID int64, post *SubmittedPost) (*Post, error) {
  521. idLen := postIDLen
  522. friendlyID := store.GenerateFriendlyRandomString(idLen)
  523. // Handle appearance / font face
  524. appearance := post.Font
  525. if !post.isFontValid() {
  526. appearance = "norm"
  527. }
  528. var err error
  529. ownerID := sql.NullInt64{
  530. Valid: false,
  531. }
  532. ownerCollID := sql.NullInt64{
  533. Valid: false,
  534. }
  535. slug := sql.NullString{"", false}
  536. // If an alias was supplied, we'll add this to the collection as well.
  537. if userID > 0 {
  538. ownerID.Int64 = userID
  539. ownerID.Valid = true
  540. if collID > 0 {
  541. ownerCollID.Int64 = collID
  542. ownerCollID.Valid = true
  543. var slugVal string
  544. if post.Title != nil && *post.Title != "" {
  545. slugVal = getSlug(*post.Title, post.Language.String)
  546. if slugVal == "" {
  547. slugVal = getSlug(*post.Content, post.Language.String)
  548. }
  549. } else {
  550. slugVal = getSlug(*post.Content, post.Language.String)
  551. }
  552. if slugVal == "" {
  553. slugVal = friendlyID
  554. }
  555. slug = sql.NullString{slugVal, true}
  556. }
  557. }
  558. created := time.Now()
  559. if db.driverName == driverSQLite {
  560. // SQLite stores datetimes in UTC, so convert time.Now() to it here
  561. created = created.UTC()
  562. }
  563. if post.Created != nil {
  564. created, err = time.Parse("2006-01-02T15:04:05Z", *post.Created)
  565. if err != nil {
  566. log.Error("Unable to parse Created time '%s': %v", *post.Created, err)
  567. created = time.Now()
  568. if db.driverName == driverSQLite {
  569. // SQLite stores datetimes in UTC, so convert time.Now() to it here
  570. created = created.UTC()
  571. }
  572. }
  573. }
  574. stmt, err := db.Prepare("INSERT INTO posts (id, slug, title, content, text_appearance, language, rtl, privacy, owner_id, collection_id, created, updated, view_count) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, " + db.now() + ", ?)")
  575. if err != nil {
  576. return nil, err
  577. }
  578. defer stmt.Close()
  579. _, err = stmt.Exec(friendlyID, slug, post.Title, post.Content, appearance, post.Language, post.IsRTL, 0, ownerID, ownerCollID, created, 0)
  580. if err != nil {
  581. if db.isDuplicateKeyErr(err) {
  582. // Duplicate entry error; try a new slug
  583. // TODO: make this a little more robust
  584. slug = sql.NullString{id.GenSafeUniqueSlug(slug.String), true}
  585. _, err = stmt.Exec(friendlyID, slug, post.Title, post.Content, appearance, post.Language, post.IsRTL, 0, ownerID, ownerCollID, created, 0)
  586. if err != nil {
  587. return nil, handleFailedPostInsert(fmt.Errorf("Retried slug generation, still failed: %v", err))
  588. }
  589. } else {
  590. return nil, handleFailedPostInsert(err)
  591. }
  592. }
  593. // TODO: return Created field in proper format
  594. return &Post{
  595. ID: friendlyID,
  596. Slug: null.NewString(slug.String, slug.Valid),
  597. Font: appearance,
  598. Language: zero.NewString(post.Language.String, post.Language.Valid),
  599. RTL: zero.NewBool(post.IsRTL.Bool, post.IsRTL.Valid),
  600. OwnerID: null.NewInt(userID, true),
  601. CollectionID: null.NewInt(userID, true),
  602. Created: created.Truncate(time.Second).UTC(),
  603. Updated: time.Now().Truncate(time.Second).UTC(),
  604. Title: zero.NewString(*(post.Title), true),
  605. Content: *(post.Content),
  606. }, nil
  607. }
  608. // UpdateOwnedPost updates an existing post with only the given fields in the
  609. // supplied AuthenticatedPost.
  610. func (db *datastore) UpdateOwnedPost(post *AuthenticatedPost, userID int64) error {
  611. params := []interface{}{}
  612. var queryUpdates, sep, authCondition string
  613. if post.Slug != nil && *post.Slug != "" {
  614. queryUpdates += sep + "slug = ?"
  615. sep = ", "
  616. params = append(params, getSlug(*post.Slug, ""))
  617. }
  618. if post.Content != nil {
  619. queryUpdates += sep + "content = ?"
  620. sep = ", "
  621. params = append(params, post.Content)
  622. }
  623. if post.Title != nil {
  624. queryUpdates += sep + "title = ?"
  625. sep = ", "
  626. params = append(params, post.Title)
  627. }
  628. if post.Language.Valid {
  629. queryUpdates += sep + "language = ?"
  630. sep = ", "
  631. params = append(params, post.Language.String)
  632. }
  633. if post.IsRTL.Valid {
  634. queryUpdates += sep + "rtl = ?"
  635. sep = ", "
  636. params = append(params, post.IsRTL.Bool)
  637. }
  638. if post.Font != "" {
  639. queryUpdates += sep + "text_appearance = ?"
  640. sep = ", "
  641. params = append(params, post.Font)
  642. }
  643. if post.Created != nil {
  644. createTime, err := time.Parse(postMetaDateFormat, *post.Created)
  645. if err != nil {
  646. log.Error("Unable to parse Created date: %v", err)
  647. return fmt.Errorf("That's the incorrect format for Created date.")
  648. }
  649. queryUpdates += sep + "created = ?"
  650. sep = ", "
  651. params = append(params, createTime)
  652. }
  653. // WHERE parameters...
  654. // id = ?
  655. params = append(params, post.ID)
  656. // AND owner_id = ?
  657. authCondition = "(owner_id = ?)"
  658. params = append(params, userID)
  659. if queryUpdates == "" {
  660. return ErrPostNoUpdatableVals
  661. }
  662. queryUpdates += sep + "updated = " + db.now()
  663. res, err := db.Exec("UPDATE posts SET "+queryUpdates+" WHERE id = ? AND "+authCondition, params...)
  664. if err != nil {
  665. log.Error("Unable to update owned post: %v", err)
  666. return err
  667. }
  668. rowsAffected, _ := res.RowsAffected()
  669. if rowsAffected == 0 {
  670. // Show the correct error message if nothing was updated
  671. var dummy int
  672. err := db.QueryRow("SELECT 1 FROM posts WHERE id = ? AND "+authCondition, post.ID, params[len(params)-1]).Scan(&dummy)
  673. switch {
  674. case err == sql.ErrNoRows:
  675. return ErrUnauthorizedEditPost
  676. case err != nil:
  677. log.Error("Failed selecting from posts: %v", err)
  678. }
  679. return nil
  680. }
  681. return nil
  682. }
  683. func (db *datastore) GetCollectionBy(condition string, value interface{}) (*Collection, error) {
  684. c := &Collection{}
  685. // FIXME: change Collection to reflect database values. Add helper functions to get actual values
  686. var styleSheet, script, format zero.String
  687. row := db.QueryRow("SELECT id, alias, title, description, style_sheet, script, format, owner_id, privacy, view_count FROM collections WHERE "+condition, value)
  688. err := row.Scan(&c.ID, &c.Alias, &c.Title, &c.Description, &styleSheet, &script, &format, &c.OwnerID, &c.Visibility, &c.Views)
  689. switch {
  690. case err == sql.ErrNoRows:
  691. return nil, impart.HTTPError{http.StatusNotFound, "Collection doesn't exist."}
  692. case err != nil:
  693. log.Error("Failed selecting from collections: %v", err)
  694. return nil, err
  695. }
  696. c.StyleSheet = styleSheet.String
  697. c.Script = script.String
  698. c.Format = format.String
  699. c.Public = c.IsPublic()
  700. c.db = db
  701. return c, nil
  702. }
  703. func (db *datastore) GetCollection(alias string) (*Collection, error) {
  704. return db.GetCollectionBy("alias = ?", alias)
  705. }
  706. func (db *datastore) GetCollectionForPad(alias string) (*Collection, error) {
  707. c := &Collection{Alias: alias}
  708. row := db.QueryRow("SELECT id, alias, title, description, privacy FROM collections WHERE alias = ?", alias)
  709. err := row.Scan(&c.ID, &c.Alias, &c.Title, &c.Description, &c.Visibility)
  710. switch {
  711. case err == sql.ErrNoRows:
  712. return c, impart.HTTPError{http.StatusNotFound, "Collection doesn't exist."}
  713. case err != nil:
  714. log.Error("Failed selecting from collections: %v", err)
  715. return c, ErrInternalGeneral
  716. }
  717. c.Public = c.IsPublic()
  718. return c, nil
  719. }
  720. func (db *datastore) GetCollectionByID(id int64) (*Collection, error) {
  721. return db.GetCollectionBy("id = ?", id)
  722. }
  723. func (db *datastore) GetCollectionFromDomain(host string) (*Collection, error) {
  724. return db.GetCollectionBy("host = ?", host)
  725. }
  726. func (db *datastore) UpdateCollection(c *SubmittedCollection, alias string) error {
  727. q := query.NewUpdate().
  728. SetStringPtr(c.Title, "title").
  729. SetStringPtr(c.Description, "description").
  730. SetNullString(c.StyleSheet, "style_sheet").
  731. SetNullString(c.Script, "script")
  732. if c.Format != nil {
  733. cf := &CollectionFormat{Format: c.Format.String}
  734. if cf.Valid() {
  735. q.SetNullString(c.Format, "format")
  736. }
  737. }
  738. var updatePass bool
  739. if c.Visibility != nil && (collVisibility(*c.Visibility)&CollProtected == 0 || c.Pass != "") {
  740. q.SetIntPtr(c.Visibility, "privacy")
  741. if c.Pass != "" {
  742. updatePass = true
  743. }
  744. }
  745. // WHERE values
  746. q.Where("alias = ? AND owner_id = ?", alias, c.OwnerID)
  747. if q.Updates == "" {
  748. return ErrPostNoUpdatableVals
  749. }
  750. // Find any current domain
  751. var collID int64
  752. var rowsAffected int64
  753. var changed bool
  754. var res sql.Result
  755. err := db.QueryRow("SELECT id FROM collections WHERE alias = ?", alias).Scan(&collID)
  756. if err != nil {
  757. log.Error("Failed selecting from collections: %v. Some things won't work.", err)
  758. }
  759. // Update MathJax value
  760. if c.MathJax {
  761. if db.driverName == driverSQLite {
  762. _, err = db.Exec("INSERT OR REPLACE INTO collectionattributes (collection_id, attribute, value) VALUES (?, ?, ?)", collID, "render_mathjax", "1")
  763. } else {
  764. _, err = db.Exec("INSERT INTO collectionattributes (collection_id, attribute, value) VALUES (?, ?, ?) "+db.upsert("collection_id", "attribute")+" value = ?", collID, "render_mathjax", "1", "1")
  765. }
  766. if err != nil {
  767. log.Error("Unable to insert render_mathjax value: %v", err)
  768. return err
  769. }
  770. } else {
  771. _, err = db.Exec("DELETE FROM collectionattributes WHERE collection_id = ? AND attribute = ?", collID, "render_mathjax")
  772. if err != nil {
  773. log.Error("Unable to delete render_mathjax value: %v", err)
  774. return err
  775. }
  776. }
  777. // Update rest of the collection data
  778. res, err = db.Exec("UPDATE collections SET "+q.Updates+" WHERE "+q.Conditions, q.Params...)
  779. if err != nil {
  780. log.Error("Unable to update collection: %v", err)
  781. return err
  782. }
  783. rowsAffected, _ = res.RowsAffected()
  784. if !changed || rowsAffected == 0 {
  785. // Show the correct error message if nothing was updated
  786. var dummy int
  787. err := db.QueryRow("SELECT 1 FROM collections WHERE alias = ? AND owner_id = ?", alias, c.OwnerID).Scan(&dummy)
  788. switch {
  789. case err == sql.ErrNoRows:
  790. return ErrUnauthorizedEditPost
  791. case err != nil:
  792. log.Error("Failed selecting from collections: %v", err)
  793. }
  794. if !updatePass {
  795. return nil
  796. }
  797. }
  798. if updatePass {
  799. hashedPass, err := auth.HashPass([]byte(c.Pass))
  800. if err != nil {
  801. log.Error("Unable to create hash: %s", err)
  802. return impart.HTTPError{http.StatusInternalServerError, "Could not create password hash."}
  803. }
  804. if db.driverName == driverSQLite {
  805. _, err = db.Exec("INSERT OR REPLACE INTO collectionpasswords (collection_id, password) VALUES ((SELECT id FROM collections WHERE alias = ?), ?)", alias, hashedPass)
  806. } else {
  807. _, err = db.Exec("INSERT INTO collectionpasswords (collection_id, password) VALUES ((SELECT id FROM collections WHERE alias = ?), ?) "+db.upsert("collection_id")+" password = ?", alias, hashedPass, hashedPass)
  808. }
  809. if err != nil {
  810. return err
  811. }
  812. }
  813. return nil
  814. }
  815. const postCols = "id, slug, text_appearance, language, rtl, privacy, owner_id, collection_id, pinned_position, created, updated, view_count, title, content"
  816. // getEditablePost returns a PublicPost with the given ID only if the given
  817. // edit token is valid for the post.
  818. func (db *datastore) GetEditablePost(id, editToken string) (*PublicPost, error) {
  819. // FIXME: code duplicated from getPost()
  820. // TODO: add slight logic difference to getPost / one func
  821. var ownerName sql.NullString
  822. p := &Post{}
  823. row := db.QueryRow("SELECT "+postCols+", (SELECT username FROM users WHERE users.id = posts.owner_id) AS username FROM posts WHERE id = ? LIMIT 1", id)
  824. err := row.Scan(&p.ID, &p.Slug, &p.Font, &p.Language, &p.RTL, &p.Privacy, &p.OwnerID, &p.CollectionID, &p.PinnedPosition, &p.Created, &p.Updated, &p.ViewCount, &p.Title, &p.Content, &ownerName)
  825. switch {
  826. case err == sql.ErrNoRows:
  827. return nil, ErrPostNotFound
  828. case err != nil:
  829. log.Error("Failed selecting from collections: %v", err)
  830. return nil, err
  831. }
  832. if p.Content == "" && p.Title.String == "" {
  833. return nil, ErrPostUnpublished
  834. }
  835. res := p.processPost()
  836. if ownerName.Valid {
  837. res.Owner = &PublicUser{Username: ownerName.String}
  838. }
  839. return &res, nil
  840. }
  841. func (db *datastore) PostIDExists(id string) bool {
  842. var dummy bool
  843. err := db.QueryRow("SELECT 1 FROM posts WHERE id = ?", id).Scan(&dummy)
  844. return err == nil && dummy
  845. }
  846. // GetPost gets a public-facing post object from the database. If collectionID
  847. // is > 0, the post will be retrieved by slug and collection ID, rather than
  848. // post ID.
  849. // TODO: break this into two functions:
  850. // - GetPost(id string)
  851. // - GetCollectionPost(slug string, collectionID int64)
  852. func (db *datastore) GetPost(id string, collectionID int64) (*PublicPost, error) {
  853. var ownerName sql.NullString
  854. p := &Post{}
  855. var row *sql.Row
  856. var where string
  857. params := []interface{}{id}
  858. if collectionID > 0 {
  859. where = "slug = ? AND collection_id = ?"
  860. params = append(params, collectionID)
  861. } else {
  862. where = "id = ?"
  863. }
  864. row = db.QueryRow("SELECT "+postCols+", (SELECT username FROM users WHERE users.id = posts.owner_id) AS username FROM posts WHERE "+where+" LIMIT 1", params...)
  865. err := row.Scan(&p.ID, &p.Slug, &p.Font, &p.Language, &p.RTL, &p.Privacy, &p.OwnerID, &p.CollectionID, &p.PinnedPosition, &p.Created, &p.Updated, &p.ViewCount, &p.Title, &p.Content, &ownerName)
  866. switch {
  867. case err == sql.ErrNoRows:
  868. if collectionID > 0 {
  869. return nil, ErrCollectionPageNotFound
  870. }
  871. return nil, ErrPostNotFound
  872. case err != nil:
  873. log.Error("Failed selecting from collections: %v", err)
  874. return nil, err
  875. }
  876. if p.Content == "" && p.Title.String == "" {
  877. return nil, ErrPostUnpublished
  878. }
  879. res := p.processPost()
  880. if ownerName.Valid {
  881. res.Owner = &PublicUser{Username: ownerName.String}
  882. }
  883. return &res, nil
  884. }
  885. // TODO: don't duplicate getPost() functionality
  886. func (db *datastore) GetOwnedPost(id string, ownerID int64) (*PublicPost, error) {
  887. p := &Post{}
  888. var row *sql.Row
  889. where := "id = ? AND owner_id = ?"
  890. params := []interface{}{id, ownerID}
  891. row = db.QueryRow("SELECT "+postCols+" FROM posts WHERE "+where+" LIMIT 1", params...)
  892. err := row.Scan(&p.ID, &p.Slug, &p.Font, &p.Language, &p.RTL, &p.Privacy, &p.OwnerID, &p.CollectionID, &p.PinnedPosition, &p.Created, &p.Updated, &p.ViewCount, &p.Title, &p.Content)
  893. switch {
  894. case err == sql.ErrNoRows:
  895. return nil, ErrPostNotFound
  896. case err != nil:
  897. log.Error("Failed selecting from collections: %v", err)
  898. return nil, err
  899. }
  900. if p.Content == "" && p.Title.String == "" {
  901. return nil, ErrPostUnpublished
  902. }
  903. res := p.processPost()
  904. return &res, nil
  905. }
  906. func (db *datastore) GetPostProperty(id string, collectionID int64, property string) (interface{}, error) {
  907. propSelects := map[string]string{
  908. "views": "view_count AS views",
  909. }
  910. selectQuery, ok := propSelects[property]
  911. if !ok {
  912. return nil, impart.HTTPError{http.StatusBadRequest, fmt.Sprintf("Invalid property: %s.", property)}
  913. }
  914. var res interface{}
  915. var row *sql.Row
  916. if collectionID != 0 {
  917. row = db.QueryRow("SELECT "+selectQuery+" FROM posts WHERE slug = ? AND collection_id = ? LIMIT 1", id, collectionID)
  918. } else {
  919. row = db.QueryRow("SELECT "+selectQuery+" FROM posts WHERE id = ? LIMIT 1", id)
  920. }
  921. err := row.Scan(&res)
  922. switch {
  923. case err == sql.ErrNoRows:
  924. return nil, impart.HTTPError{http.StatusNotFound, "Post not found."}
  925. case err != nil:
  926. log.Error("Failed selecting post: %v", err)
  927. return nil, err
  928. }
  929. return res, nil
  930. }
  931. // GetPostsCount modifies the CollectionObj to include the correct number of
  932. // standard (non-pinned) posts. It will return future posts if `includeFuture`
  933. // is true.
  934. func (db *datastore) GetPostsCount(c *CollectionObj, includeFuture bool) {
  935. var count int64
  936. timeCondition := ""
  937. if !includeFuture {
  938. timeCondition = "AND created <= " + db.now()
  939. }
  940. err := db.QueryRow("SELECT COUNT(*) FROM posts WHERE collection_id = ? AND pinned_position IS NULL "+timeCondition, c.ID).Scan(&count)
  941. switch {
  942. case err == sql.ErrNoRows:
  943. c.TotalPosts = 0
  944. case err != nil:
  945. log.Error("Failed selecting from collections: %v", err)
  946. c.TotalPosts = 0
  947. }
  948. c.TotalPosts = int(count)
  949. }
  950. // GetPosts retrieves all posts for the given Collection.
  951. // It will return future posts if `includeFuture` is true.
  952. // It will include only standard (non-pinned) posts unless `includePinned` is true.
  953. // TODO: change includeFuture to isOwner, since that's how it's used
  954. func (db *datastore) GetPosts(cfg *config.Config, c *Collection, page int, includeFuture, forceRecentFirst, includePinned bool) (*[]PublicPost, error) {
  955. collID := c.ID
  956. cf := c.NewFormat()
  957. order := "DESC"
  958. if cf.Ascending() && !forceRecentFirst {
  959. order = "ASC"
  960. }
  961. pagePosts := cf.PostsPerPage()
  962. start := page*pagePosts - pagePosts
  963. if page == 0 {
  964. start = 0
  965. pagePosts = 1000
  966. }
  967. limitStr := ""
  968. if page > 0 {
  969. limitStr = fmt.Sprintf(" LIMIT %d, %d", start, pagePosts)
  970. }
  971. timeCondition := ""
  972. if !includeFuture {
  973. timeCondition = "AND created <= " + db.now()
  974. }
  975. pinnedCondition := ""
  976. if !includePinned {
  977. pinnedCondition = "AND pinned_position IS NULL"
  978. }
  979. rows, err := db.Query("SELECT "+postCols+" FROM posts WHERE collection_id = ? "+pinnedCondition+" "+timeCondition+" ORDER BY created "+order+limitStr, collID)
  980. if err != nil {
  981. log.Error("Failed selecting from posts: %v", err)
  982. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve collection posts."}
  983. }
  984. defer rows.Close()
  985. // TODO: extract this common row scanning logic for queries using `postCols`
  986. posts := []PublicPost{}
  987. for rows.Next() {
  988. p := &Post{}
  989. err = rows.Scan(&p.ID, &p.Slug, &p.Font, &p.Language, &p.RTL, &p.Privacy, &p.OwnerID, &p.CollectionID, &p.PinnedPosition, &p.Created, &p.Updated, &p.ViewCount, &p.Title, &p.Content)
  990. if err != nil {
  991. log.Error("Failed scanning row: %v", err)
  992. break
  993. }
  994. p.extractData()
  995. p.formatContent(cfg, c, includeFuture)
  996. posts = append(posts, p.processPost())
  997. }
  998. err = rows.Err()
  999. if err != nil {
  1000. log.Error("Error after Next() on rows: %v", err)
  1001. }
  1002. return &posts, nil
  1003. }
  1004. // GetPostsTagged retrieves all posts on the given Collection that contain the
  1005. // given tag.
  1006. // It will return future posts if `includeFuture` is true.
  1007. // TODO: change includeFuture to isOwner, since that's how it's used
  1008. func (db *datastore) GetPostsTagged(cfg *config.Config, c *Collection, tag string, page int, includeFuture bool) (*[]PublicPost, error) {
  1009. collID := c.ID
  1010. cf := c.NewFormat()
  1011. order := "DESC"
  1012. if cf.Ascending() {
  1013. order = "ASC"
  1014. }
  1015. pagePosts := cf.PostsPerPage()
  1016. start := page*pagePosts - pagePosts
  1017. if page == 0 {
  1018. start = 0
  1019. pagePosts = 1000
  1020. }
  1021. limitStr := ""
  1022. if page > 0 {
  1023. limitStr = fmt.Sprintf(" LIMIT %d, %d", start, pagePosts)
  1024. }
  1025. timeCondition := ""
  1026. if !includeFuture {
  1027. timeCondition = "AND created <= " + db.now()
  1028. }
  1029. var rows *sql.Rows
  1030. var err error
  1031. if db.driverName == driverSQLite {
  1032. rows, err = db.Query("SELECT "+postCols+" FROM posts WHERE collection_id = ? AND LOWER(content) regexp ? "+timeCondition+" ORDER BY created "+order+limitStr, collID, `.*#`+strings.ToLower(tag)+`\b.*`)
  1033. } else {
  1034. rows, err = db.Query("SELECT "+postCols+" FROM posts WHERE collection_id = ? AND LOWER(content) RLIKE ? "+timeCondition+" ORDER BY created "+order+limitStr, collID, "#"+strings.ToLower(tag)+"[[:>:]]")
  1035. }
  1036. if err != nil {
  1037. log.Error("Failed selecting from posts: %v", err)
  1038. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve collection posts."}
  1039. }
  1040. defer rows.Close()
  1041. // TODO: extract this common row scanning logic for queries using `postCols`
  1042. posts := []PublicPost{}
  1043. for rows.Next() {
  1044. p := &Post{}
  1045. err = rows.Scan(&p.ID, &p.Slug, &p.Font, &p.Language, &p.RTL, &p.Privacy, &p.OwnerID, &p.CollectionID, &p.PinnedPosition, &p.Created, &p.Updated, &p.ViewCount, &p.Title, &p.Content)
  1046. if err != nil {
  1047. log.Error("Failed scanning row: %v", err)
  1048. break
  1049. }
  1050. p.extractData()
  1051. p.formatContent(cfg, c, includeFuture)
  1052. posts = append(posts, p.processPost())
  1053. }
  1054. err = rows.Err()
  1055. if err != nil {
  1056. log.Error("Error after Next() on rows: %v", err)
  1057. }
  1058. return &posts, nil
  1059. }
  1060. func (db *datastore) GetAPFollowers(c *Collection) (*[]RemoteUser, error) {
  1061. rows, err := db.Query("SELECT actor_id, inbox, shared_inbox FROM remotefollows f INNER JOIN remoteusers u ON f.remote_user_id = u.id WHERE collection_id = ?", c.ID)
  1062. if err != nil {
  1063. log.Error("Failed selecting from followers: %v", err)
  1064. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve followers."}
  1065. }
  1066. defer rows.Close()
  1067. followers := []RemoteUser{}
  1068. for rows.Next() {
  1069. f := RemoteUser{}
  1070. err = rows.Scan(&f.ActorID, &f.Inbox, &f.SharedInbox)
  1071. followers = append(followers, f)
  1072. }
  1073. return &followers, nil
  1074. }
  1075. // CanCollect returns whether or not the given user can add the given post to a
  1076. // collection. This is true when a post is already owned by the user.
  1077. // NOTE: this is currently only used to potentially add owned posts to a
  1078. // collection. This has the SIDE EFFECT of also generating a slug for the post.
  1079. // FIXME: make this side effect more explicit (or extract it)
  1080. func (db *datastore) CanCollect(cpr *ClaimPostRequest, userID int64) bool {
  1081. var title, content string
  1082. var lang sql.NullString
  1083. err := db.QueryRow("SELECT title, content, language FROM posts WHERE id = ? AND owner_id = ?", cpr.ID, userID).Scan(&title, &content, &lang)
  1084. switch {
  1085. case err == sql.ErrNoRows:
  1086. return false
  1087. case err != nil:
  1088. log.Error("Failed on post CanCollect(%s, %d): %v", cpr.ID, userID, err)
  1089. return false
  1090. }
  1091. // Since we have the post content and the post is collectable, generate the
  1092. // post's slug now.
  1093. cpr.Slug = getSlugFromPost(title, content, lang.String)
  1094. return true
  1095. }
  1096. func (db *datastore) AttemptClaim(p *ClaimPostRequest, query string, params []interface{}, slugIdx int) (sql.Result, error) {
  1097. qRes, err := db.Exec(query, params...)
  1098. if err != nil {
  1099. if db.isDuplicateKeyErr(err) && slugIdx > -1 {
  1100. s := id.GenSafeUniqueSlug(p.Slug)
  1101. if s == p.Slug {
  1102. // Sanity check to prevent infinite recursion
  1103. return qRes, fmt.Errorf("GenSafeUniqueSlug generated nothing unique: %s", s)
  1104. }
  1105. p.Slug = s
  1106. params[slugIdx] = p.Slug
  1107. return db.AttemptClaim(p, query, params, slugIdx)
  1108. }
  1109. return qRes, fmt.Errorf("attemptClaim: %s", err)
  1110. }
  1111. return qRes, nil
  1112. }
  1113. func (db *datastore) DispersePosts(userID int64, postIDs []string) (*[]ClaimPostResult, error) {
  1114. postClaimReqs := map[string]bool{}
  1115. res := []ClaimPostResult{}
  1116. for i := range postIDs {
  1117. postID := postIDs[i]
  1118. r := ClaimPostResult{Code: 0, ErrorMessage: ""}
  1119. // Perform post validation
  1120. if postID == "" {
  1121. r.ErrorMessage = "Missing post ID. "
  1122. }
  1123. if _, ok := postClaimReqs[postID]; ok {
  1124. r.Code = 429
  1125. r.ErrorMessage = "You've already tried anonymizing this post."
  1126. r.ID = postID
  1127. res = append(res, r)
  1128. continue
  1129. }
  1130. postClaimReqs[postID] = true
  1131. var err error
  1132. // Get full post information to return
  1133. var fullPost *PublicPost
  1134. fullPost, err = db.GetPost(postID, 0)
  1135. if err != nil {
  1136. if err, ok := err.(impart.HTTPError); ok {
  1137. r.Code = err.Status
  1138. r.ErrorMessage = err.Message
  1139. r.ID = postID
  1140. res = append(res, r)
  1141. continue
  1142. } else {
  1143. log.Error("Error getting post in dispersePosts: %v", err)
  1144. }
  1145. }
  1146. if fullPost.OwnerID.Int64 != userID {
  1147. r.Code = http.StatusConflict
  1148. r.ErrorMessage = "Post is already owned by someone else."
  1149. r.ID = postID
  1150. res = append(res, r)
  1151. continue
  1152. }
  1153. var qRes sql.Result
  1154. var query string
  1155. var params []interface{}
  1156. // Do AND owner_id = ? for sanity.
  1157. // This should've been caught and returned with a good error message
  1158. // just above.
  1159. query = "UPDATE posts SET collection_id = NULL WHERE id = ? AND owner_id = ?"
  1160. params = []interface{}{postID, userID}
  1161. qRes, err = db.Exec(query, params...)
  1162. if err != nil {
  1163. r.Code = http.StatusInternalServerError
  1164. r.ErrorMessage = "A glitch happened on our end."
  1165. r.ID = postID
  1166. res = append(res, r)
  1167. log.Error("dispersePosts (post %s): %v", postID, err)
  1168. continue
  1169. }
  1170. // Post was successfully dispersed
  1171. r.Code = http.StatusOK
  1172. r.Post = fullPost
  1173. rowsAffected, _ := qRes.RowsAffected()
  1174. if rowsAffected == 0 {
  1175. // This was already claimed, but return 200
  1176. r.Code = http.StatusOK
  1177. }
  1178. res = append(res, r)
  1179. }
  1180. return &res, nil
  1181. }
  1182. func (db *datastore) ClaimPosts(cfg *config.Config, userID int64, collAlias string, posts *[]ClaimPostRequest) (*[]ClaimPostResult, error) {
  1183. postClaimReqs := map[string]bool{}
  1184. res := []ClaimPostResult{}
  1185. postCollAlias := collAlias
  1186. for i := range *posts {
  1187. p := (*posts)[i]
  1188. if &p == nil {
  1189. continue
  1190. }
  1191. r := ClaimPostResult{Code: 0, ErrorMessage: ""}
  1192. // Perform post validation
  1193. if p.ID == "" {
  1194. r.ErrorMessage = "Missing post ID `id`. "
  1195. }
  1196. if _, ok := postClaimReqs[p.ID]; ok {
  1197. r.Code = 429
  1198. r.ErrorMessage = "You've already tried claiming this post."
  1199. r.ID = p.ID
  1200. res = append(res, r)
  1201. continue
  1202. }
  1203. postClaimReqs[p.ID] = true
  1204. canCollect := db.CanCollect(&p, userID)
  1205. if !canCollect && p.Token == "" {
  1206. // TODO: ensure post isn't owned by anyone else when a valid modify
  1207. // token is given.
  1208. r.ErrorMessage += "Missing post Edit Token `token`."
  1209. }
  1210. if r.ErrorMessage != "" {
  1211. // Post validate failed
  1212. r.Code = http.StatusBadRequest
  1213. r.ID = p.ID
  1214. res = append(res, r)
  1215. continue
  1216. }
  1217. var err error
  1218. var qRes sql.Result
  1219. var query string
  1220. var params []interface{}
  1221. var slugIdx int = -1
  1222. var coll *Collection
  1223. if collAlias == "" {
  1224. // Posts are being claimed at /posts/claim, not
  1225. // /collections/{alias}/collect, so use given individual collection
  1226. // to associate post with.
  1227. postCollAlias = p.CollectionAlias
  1228. }
  1229. if postCollAlias != "" {
  1230. // Associate this post with a collection
  1231. if p.CreateCollection {
  1232. // This is a new collection
  1233. // TODO: consider removing this. This seriously complicates this
  1234. // method and adds another (unnecessary?) logic path.
  1235. coll, err = db.CreateCollection(cfg, postCollAlias, "", userID)
  1236. if err != nil {
  1237. if err, ok := err.(impart.HTTPError); ok {
  1238. r.Code = err.Status
  1239. r.ErrorMessage = err.Message
  1240. } else {
  1241. r.Code = http.StatusInternalServerError
  1242. r.ErrorMessage = "Unknown error occurred creating collection"
  1243. }
  1244. r.ID = p.ID
  1245. res = append(res, r)
  1246. continue
  1247. }
  1248. } else {
  1249. // Attempt to add to existing collection
  1250. coll, err = db.GetCollection(postCollAlias)
  1251. if err != nil {
  1252. if err, ok := err.(impart.HTTPError); ok {
  1253. if err.Status == http.StatusNotFound {
  1254. // Show obfuscated "forbidden" response, as if attempting to add to an
  1255. // unowned blog.
  1256. r.Code = ErrForbiddenCollection.Status
  1257. r.ErrorMessage = ErrForbiddenCollection.Message
  1258. } else {
  1259. r.Code = err.Status
  1260. r.ErrorMessage = err.Message
  1261. }
  1262. } else {
  1263. r.Code = http.StatusInternalServerError
  1264. r.ErrorMessage = "Unknown error occurred claiming post with collection"
  1265. }
  1266. r.ID = p.ID
  1267. res = append(res, r)
  1268. continue
  1269. }
  1270. if coll.OwnerID != userID {
  1271. r.Code = ErrForbiddenCollection.Status
  1272. r.ErrorMessage = ErrForbiddenCollection.Message
  1273. r.ID = p.ID
  1274. res = append(res, r)
  1275. continue
  1276. }
  1277. }
  1278. if p.Slug == "" {
  1279. p.Slug = p.ID
  1280. }
  1281. if canCollect {
  1282. // User already owns this post, so just add it to the given
  1283. // collection.
  1284. query = "UPDATE posts SET collection_id = ?, slug = ? WHERE id = ? AND owner_id = ?"
  1285. params = []interface{}{coll.ID, p.Slug, p.ID, userID}
  1286. slugIdx = 1
  1287. } else {
  1288. query = "UPDATE posts SET owner_id = ?, collection_id = ?, slug = ? WHERE id = ? AND modify_token = ? AND owner_id IS NULL"
  1289. params = []interface{}{userID, coll.ID, p.Slug, p.ID, p.Token}
  1290. slugIdx = 2
  1291. }
  1292. } else {
  1293. query = "UPDATE posts SET owner_id = ? WHERE id = ? AND modify_token = ? AND owner_id IS NULL"
  1294. params = []interface{}{userID, p.ID, p.Token}
  1295. }
  1296. qRes, err = db.AttemptClaim(&p, query, params, slugIdx)
  1297. if err != nil {
  1298. r.Code = http.StatusInternalServerError
  1299. r.ErrorMessage = "An unknown error occurred."
  1300. r.ID = p.ID
  1301. res = append(res, r)
  1302. log.Error("claimPosts (post %s): %v", p.ID, err)
  1303. continue
  1304. }
  1305. // Get full post information to return
  1306. var fullPost *PublicPost
  1307. if p.Token != "" {
  1308. fullPost, err = db.GetEditablePost(p.ID, p.Token)
  1309. } else {
  1310. fullPost, err = db.GetPost(p.ID, 0)
  1311. }
  1312. if err != nil {
  1313. if err, ok := err.(impart.HTTPError); ok {
  1314. r.Code = err.Status
  1315. r.ErrorMessage = err.Message
  1316. r.ID = p.ID
  1317. res = append(res, r)
  1318. continue
  1319. }
  1320. }
  1321. if fullPost.OwnerID.Int64 != userID {
  1322. r.Code = http.StatusConflict
  1323. r.ErrorMessage = "Post is already owned by someone else."
  1324. r.ID = p.ID
  1325. res = append(res, r)
  1326. continue
  1327. }
  1328. // Post was successfully claimed
  1329. r.Code = http.StatusOK
  1330. r.Post = fullPost
  1331. if coll != nil {
  1332. r.Post.Collection = &CollectionObj{Collection: *coll}
  1333. }
  1334. rowsAffected, _ := qRes.RowsAffected()
  1335. if rowsAffected == 0 {
  1336. // This was already claimed, but return 200
  1337. r.Code = http.StatusOK
  1338. }
  1339. res = append(res, r)
  1340. }
  1341. return &res, nil
  1342. }
  1343. func (db *datastore) UpdatePostPinState(pinned bool, postID string, collID, ownerID, pos int64) error {
  1344. if pos <= 0 || pos > 20 {
  1345. pos = db.GetLastPinnedPostPos(collID) + 1
  1346. if pos == -1 {
  1347. pos = 1
  1348. }
  1349. }
  1350. var err error
  1351. if pinned {
  1352. _, err = db.Exec("UPDATE posts SET pinned_position = ? WHERE id = ?", pos, postID)
  1353. } else {
  1354. _, err = db.Exec("UPDATE posts SET pinned_position = NULL WHERE id = ?", postID)
  1355. }
  1356. if err != nil {
  1357. log.Error("Unable to update pinned post: %v", err)
  1358. return err
  1359. }
  1360. return nil
  1361. }
  1362. func (db *datastore) GetLastPinnedPostPos(collID int64) int64 {
  1363. var lastPos sql.NullInt64
  1364. err := db.QueryRow("SELECT MAX(pinned_position) FROM posts WHERE collection_id = ? AND pinned_position IS NOT NULL", collID).Scan(&lastPos)
  1365. switch {
  1366. case err == sql.ErrNoRows:
  1367. return -1
  1368. case err != nil:
  1369. log.Error("Failed selecting from posts: %v", err)
  1370. return -1
  1371. }
  1372. if !lastPos.Valid {
  1373. return -1
  1374. }
  1375. return lastPos.Int64
  1376. }
  1377. func (db *datastore) GetPinnedPosts(coll *CollectionObj, includeFuture bool) (*[]PublicPost, error) {
  1378. // FIXME: sqlite-backed instances don't include ellipsis on truncated titles
  1379. timeCondition := ""
  1380. if !includeFuture {
  1381. timeCondition = "AND created <= " + db.now()
  1382. }
  1383. rows, err := db.Query("SELECT id, slug, title, "+db.clip("content", 80)+", pinned_position FROM posts WHERE collection_id = ? AND pinned_position IS NOT NULL "+timeCondition+" ORDER BY pinned_position ASC", coll.ID)
  1384. if err != nil {
  1385. log.Error("Failed selecting pinned posts: %v", err)
  1386. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve pinned posts."}
  1387. }
  1388. defer rows.Close()
  1389. posts := []PublicPost{}
  1390. for rows.Next() {
  1391. p := &Post{}
  1392. err = rows.Scan(&p.ID, &p.Slug, &p.Title, &p.Content, &p.PinnedPosition)
  1393. if err != nil {
  1394. log.Error("Failed scanning row: %v", err)
  1395. break
  1396. }
  1397. p.extractData()
  1398. pp := p.processPost()
  1399. pp.Collection = coll
  1400. posts = append(posts, pp)
  1401. }
  1402. return &posts, nil
  1403. }
  1404. func (db *datastore) GetCollections(u *User, hostName string) (*[]Collection, error) {
  1405. rows, err := db.Query("SELECT id, alias, title, description, privacy, view_count FROM collections WHERE owner_id = ? ORDER BY id ASC", u.ID)
  1406. if err != nil {
  1407. log.Error("Failed selecting from collections: %v", err)
  1408. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user collections."}
  1409. }
  1410. defer rows.Close()
  1411. colls := []Collection{}
  1412. for rows.Next() {
  1413. c := Collection{}
  1414. err = rows.Scan(&c.ID, &c.Alias, &c.Title, &c.Description, &c.Visibility, &c.Views)
  1415. if err != nil {
  1416. log.Error("Failed scanning row: %v", err)
  1417. break
  1418. }
  1419. c.hostName = hostName
  1420. c.URL = c.CanonicalURL()
  1421. c.Public = c.IsPublic()
  1422. colls = append(colls, c)
  1423. }
  1424. err = rows.Err()
  1425. if err != nil {
  1426. log.Error("Error after Next() on rows: %v", err)
  1427. }
  1428. return &colls, nil
  1429. }
  1430. func (db *datastore) GetPublishableCollections(u *User, hostName string) (*[]Collection, error) {
  1431. c, err := db.GetCollections(u, hostName)
  1432. if err != nil {
  1433. return nil, err
  1434. }
  1435. if len(*c) == 0 {
  1436. return nil, impart.HTTPError{http.StatusInternalServerError, "You don't seem to have any blogs; they might've moved to another account. Try logging out and logging into your other account."}
  1437. }
  1438. return c, nil
  1439. }
  1440. func (db *datastore) GetMeStats(u *User) userMeStats {
  1441. s := userMeStats{}
  1442. // User counts
  1443. colls, _ := db.GetUserCollectionCount(u.ID)
  1444. s.TotalCollections = colls
  1445. var articles, collPosts uint64
  1446. err := db.QueryRow("SELECT COUNT(*) FROM posts WHERE owner_id = ? AND collection_id IS NULL", u.ID).Scan(&articles)
  1447. if err != nil && err != sql.ErrNoRows {
  1448. log.Error("Couldn't get articles count for user %d: %v", u.ID, err)
  1449. }
  1450. s.TotalArticles = articles
  1451. err = db.QueryRow("SELECT COUNT(*) FROM posts WHERE owner_id = ? AND collection_id IS NOT NULL", u.ID).Scan(&collPosts)
  1452. if err != nil && err != sql.ErrNoRows {
  1453. log.Error("Couldn't get coll posts count for user %d: %v", u.ID, err)
  1454. }
  1455. s.CollectionPosts = collPosts
  1456. return s
  1457. }
  1458. func (db *datastore) GetTotalCollections() (collCount int64, err error) {
  1459. err = db.QueryRow(`
  1460. SELECT COUNT(*)
  1461. FROM collections c
  1462. LEFT JOIN users u ON u.id = c.owner_id
  1463. WHERE u.status = 0`).Scan(&collCount)
  1464. if err != nil {
  1465. log.Error("Unable to fetch collections count: %v", err)
  1466. }
  1467. return
  1468. }
  1469. func (db *datastore) GetTotalPosts() (postCount int64, err error) {
  1470. err = db.QueryRow(`
  1471. SELECT COUNT(*)
  1472. FROM posts p
  1473. LEFT JOIN users u ON u.id = p.owner_id
  1474. WHERE u.status = 0`).Scan(&postCount)
  1475. if err != nil {
  1476. log.Error("Unable to fetch posts count: %v", err)
  1477. }
  1478. return
  1479. }
  1480. func (db *datastore) GetTopPosts(u *User, alias string) (*[]PublicPost, error) {
  1481. params := []interface{}{u.ID}
  1482. where := ""
  1483. if alias != "" {
  1484. where = " AND alias = ?"
  1485. params = append(params, alias)
  1486. }
  1487. rows, err := db.Query("SELECT p.id, p.slug, p.view_count, p.title, c.alias, c.title, c.description, c.view_count FROM posts p LEFT JOIN collections c ON p.collection_id = c.id WHERE p.owner_id = ?"+where+" ORDER BY p.view_count DESC, created DESC LIMIT 25", params...)
  1488. if err != nil {
  1489. log.Error("Failed selecting from posts: %v", err)
  1490. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user top posts."}
  1491. }
  1492. defer rows.Close()
  1493. posts := []PublicPost{}
  1494. var gotErr bool
  1495. for rows.Next() {
  1496. p := Post{}
  1497. c := Collection{}
  1498. var alias, title, description sql.NullString
  1499. var views sql.NullInt64
  1500. err = rows.Scan(&p.ID, &p.Slug, &p.ViewCount, &p.Title, &alias, &title, &description, &views)
  1501. if err != nil {
  1502. log.Error("Failed scanning User.getPosts() row: %v", err)
  1503. gotErr = true
  1504. break
  1505. }
  1506. p.extractData()
  1507. pubPost := p.processPost()
  1508. if alias.Valid && alias.String != "" {
  1509. c.Alias = alias.String
  1510. c.Title = title.String
  1511. c.Description = description.String
  1512. c.Views = views.Int64
  1513. pubPost.Collection = &CollectionObj{Collection: c}
  1514. }
  1515. posts = append(posts, pubPost)
  1516. }
  1517. err = rows.Err()
  1518. if err != nil {
  1519. log.Error("Error after Next() on rows: %v", err)
  1520. }
  1521. if gotErr && len(posts) == 0 {
  1522. // There were a lot of errors
  1523. return nil, impart.HTTPError{http.StatusInternalServerError, "Unable to get data."}
  1524. }
  1525. return &posts, nil
  1526. }
  1527. func (db *datastore) GetAnonymousPosts(u *User) (*[]PublicPost, error) {
  1528. rows, err := db.Query("SELECT id, view_count, title, created, updated, content FROM posts WHERE owner_id = ? AND collection_id IS NULL ORDER BY created DESC", u.ID)
  1529. if err != nil {
  1530. log.Error("Failed selecting from posts: %v", err)
  1531. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user anonymous posts."}
  1532. }
  1533. defer rows.Close()
  1534. posts := []PublicPost{}
  1535. for rows.Next() {
  1536. p := Post{}
  1537. err = rows.Scan(&p.ID, &p.ViewCount, &p.Title, &p.Created, &p.Updated, &p.Content)
  1538. if err != nil {
  1539. log.Error("Failed scanning row: %v", err)
  1540. break
  1541. }
  1542. p.extractData()
  1543. posts = append(posts, p.processPost())
  1544. }
  1545. err = rows.Err()
  1546. if err != nil {
  1547. log.Error("Error after Next() on rows: %v", err)
  1548. }
  1549. return &posts, nil
  1550. }
  1551. func (db *datastore) GetUserPosts(u *User) (*[]PublicPost, error) {
  1552. rows, err := db.Query("SELECT p.id, p.slug, p.view_count, p.title, p.created, p.updated, p.content, p.text_appearance, p.language, p.rtl, c.alias, c.title, c.description, c.view_count FROM posts p LEFT JOIN collections c ON collection_id = c.id WHERE p.owner_id = ? ORDER BY created ASC", u.ID)
  1553. if err != nil {
  1554. log.Error("Failed selecting from posts: %v", err)
  1555. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user posts."}
  1556. }
  1557. defer rows.Close()
  1558. posts := []PublicPost{}
  1559. var gotErr bool
  1560. for rows.Next() {
  1561. p := Post{}
  1562. c := Collection{}
  1563. var alias, title, description sql.NullString
  1564. var views sql.NullInt64
  1565. err = rows.Scan(&p.ID, &p.Slug, &p.ViewCount, &p.Title, &p.Created, &p.Updated, &p.Content, &p.Font, &p.Language, &p.RTL, &alias, &title, &description, &views)
  1566. if err != nil {
  1567. log.Error("Failed scanning User.getPosts() row: %v", err)
  1568. gotErr = true
  1569. break
  1570. }
  1571. p.extractData()
  1572. pubPost := p.processPost()
  1573. if alias.Valid && alias.String != "" {
  1574. c.Alias = alias.String
  1575. c.Title = title.String
  1576. c.Description = description.String
  1577. c.Views = views.Int64
  1578. pubPost.Collection = &CollectionObj{Collection: c}
  1579. }
  1580. posts = append(posts, pubPost)
  1581. }
  1582. err = rows.Err()
  1583. if err != nil {
  1584. log.Error("Error after Next() on rows: %v", err)
  1585. }
  1586. if gotErr && len(posts) == 0 {
  1587. // There were a lot of errors
  1588. return nil, impart.HTTPError{http.StatusInternalServerError, "Unable to get data."}
  1589. }
  1590. return &posts, nil
  1591. }
  1592. func (db *datastore) GetUserPostsCount(userID int64) int64 {
  1593. var count int64
  1594. err := db.QueryRow("SELECT COUNT(*) FROM posts WHERE owner_id = ?", userID).Scan(&count)
  1595. switch {
  1596. case err == sql.ErrNoRows:
  1597. return 0
  1598. case err != nil:
  1599. log.Error("Failed selecting posts count for user %d: %v", userID, err)
  1600. return 0
  1601. }
  1602. return count
  1603. }
  1604. // ChangeSettings takes a User and applies the changes in the given
  1605. // userSettings, MODIFYING THE USER with successful changes.
  1606. func (db *datastore) ChangeSettings(app *App, u *User, s *userSettings) error {
  1607. var errPass error
  1608. q := query.NewUpdate()
  1609. // Update email if given
  1610. if s.Email != "" {
  1611. encEmail, err := data.Encrypt(app.keys.EmailKey, s.Email)
  1612. if err != nil {
  1613. log.Error("Couldn't encrypt email %s: %s\n", s.Email, err)
  1614. return impart.HTTPError{http.StatusInternalServerError, "Unable to encrypt email address."}
  1615. }
  1616. q.SetBytes(encEmail, "email")
  1617. // Update the email if something goes awry updating the password
  1618. defer func() {
  1619. if errPass != nil {
  1620. db.UpdateEncryptedUserEmail(u.ID, encEmail)
  1621. }
  1622. }()
  1623. u.Email = zero.StringFrom(s.Email)
  1624. }
  1625. // Update username if given
  1626. var newUsername string
  1627. if s.Username != "" {
  1628. var ie *impart.HTTPError
  1629. newUsername, ie = getValidUsername(app, s.Username, u.Username)
  1630. if ie != nil {
  1631. // Username is invalid
  1632. return *ie
  1633. }
  1634. if !author.IsValidUsername(app.cfg, newUsername) {
  1635. // Ensure the username is syntactically correct.
  1636. return impart.HTTPError{http.StatusPreconditionFailed, "Username isn't valid."}
  1637. }
  1638. t, err := db.Begin()
  1639. if err != nil {
  1640. log.Error("Couldn't start username change transaction: %v", err)
  1641. return err
  1642. }
  1643. _, err = t.Exec("UPDATE users SET username = ? WHERE id = ?", newUsername, u.ID)
  1644. if err != nil {
  1645. t.Rollback()
  1646. if db.isDuplicateKeyErr(err) {
  1647. return impart.HTTPError{http.StatusConflict, "Username is already taken."}
  1648. }
  1649. log.Error("Unable to update users table: %v", err)
  1650. return ErrInternalGeneral
  1651. }
  1652. _, err = t.Exec("UPDATE collections SET alias = ? WHERE alias = ? AND owner_id = ?", newUsername, u.Username, u.ID)
  1653. if err != nil {
  1654. t.Rollback()
  1655. if db.isDuplicateKeyErr(err) {
  1656. return impart.HTTPError{http.StatusConflict, "Username is already taken."}
  1657. }
  1658. log.Error("Unable to update collection: %v", err)
  1659. return ErrInternalGeneral
  1660. }
  1661. // Keep track of name changes for redirection
  1662. db.RemoveCollectionRedirect(t, newUsername)
  1663. _, err = t.Exec("UPDATE collectionredirects SET new_alias = ? WHERE new_alias = ?", newUsername, u.Username)
  1664. if err != nil {
  1665. log.Error("Unable to update collectionredirects: %v", err)
  1666. }
  1667. _, err = t.Exec("INSERT INTO collectionredirects (prev_alias, new_alias) VALUES (?, ?)", u.Username, newUsername)
  1668. if err != nil {
  1669. log.Error("Unable to add new collectionredirect: %v", err)
  1670. }
  1671. err = t.Commit()
  1672. if err != nil {
  1673. t.Rollback()
  1674. log.Error("Rolling back after Commit(): %v\n", err)
  1675. return err
  1676. }
  1677. u.Username = newUsername
  1678. }
  1679. // Update passphrase if given
  1680. if s.NewPass != "" {
  1681. // Check if user has already set a password
  1682. var err error
  1683. u.HasPass, err = db.IsUserPassSet(u.ID)
  1684. if err != nil {
  1685. errPass = impart.HTTPError{http.StatusInternalServerError, "Unable to retrieve user data."}
  1686. return errPass
  1687. }
  1688. if u.HasPass {
  1689. // Check if currently-set password is correct
  1690. hashedPass := u.HashedPass
  1691. if len(hashedPass) == 0 {
  1692. authUser, err := db.GetUserForAuthByID(u.ID)
  1693. if err != nil {
  1694. errPass = err
  1695. return errPass
  1696. }
  1697. hashedPass = authUser.HashedPass
  1698. }
  1699. if !auth.Authenticated(hashedPass, []byte(s.OldPass)) {
  1700. errPass = impart.HTTPError{http.StatusUnauthorized, "Incorrect password."}
  1701. return errPass
  1702. }
  1703. }
  1704. hashedPass, err := auth.HashPass([]byte(s.NewPass))
  1705. if err != nil {
  1706. errPass = impart.HTTPError{http.StatusInternalServerError, "Could not create password hash."}
  1707. return errPass
  1708. }
  1709. q.SetBytes(hashedPass, "password")
  1710. }
  1711. // WHERE values
  1712. q.Append(u.ID)
  1713. if q.Updates == "" {
  1714. if s.Username == "" {
  1715. return ErrPostNoUpdatableVals
  1716. }
  1717. // Nothing to update except username. That was successful, so return now.
  1718. return nil
  1719. }
  1720. res, err := db.Exec("UPDATE users SET "+q.Updates+" WHERE id = ?", q.Params...)
  1721. if err != nil {
  1722. log.Error("Unable to update collection: %v", err)
  1723. return err
  1724. }
  1725. rowsAffected, _ := res.RowsAffected()
  1726. if rowsAffected == 0 {
  1727. // Show the correct error message if nothing was updated
  1728. var dummy int
  1729. err := db.QueryRow("SELECT 1 FROM users WHERE id = ?", u.ID).Scan(&dummy)
  1730. switch {
  1731. case err == sql.ErrNoRows:
  1732. return ErrUnauthorizedGeneral
  1733. case err != nil:
  1734. log.Error("Failed selecting from users: %v", err)
  1735. }
  1736. return nil
  1737. }
  1738. if s.NewPass != "" && !u.HasPass {
  1739. u.HasPass = true
  1740. }
  1741. return nil
  1742. }
  1743. func (db *datastore) ChangePassphrase(userID int64, sudo bool, curPass string, hashedPass []byte) error {
  1744. var dbPass []byte
  1745. err := db.QueryRow("SELECT password FROM users WHERE id = ?", userID).Scan(&dbPass)
  1746. switch {
  1747. case err == sql.ErrNoRows:
  1748. return ErrUserNotFound
  1749. case err != nil:
  1750. log.Error("Couldn't SELECT user password for change: %v", err)
  1751. return err
  1752. }
  1753. if !sudo && !auth.Authenticated(dbPass, []byte(curPass)) {
  1754. return impart.HTTPError{http.StatusUnauthorized, "Incorrect password."}
  1755. }
  1756. _, err = db.Exec("UPDATE users SET password = ? WHERE id = ?", hashedPass, userID)
  1757. if err != nil {
  1758. log.Error("Could not update passphrase: %v", err)
  1759. return err
  1760. }
  1761. return nil
  1762. }
  1763. func (db *datastore) RemoveCollectionRedirect(t *sql.Tx, alias string) error {
  1764. _, err := t.Exec("DELETE FROM collectionredirects WHERE prev_alias = ?", alias)
  1765. if err != nil {
  1766. log.Error("Unable to delete from collectionredirects: %v", err)
  1767. return err
  1768. }
  1769. return nil
  1770. }
  1771. func (db *datastore) GetCollectionRedirect(alias string) (new string) {
  1772. row := db.QueryRow("SELECT new_alias FROM collectionredirects WHERE prev_alias = ?", alias)
  1773. err := row.Scan(&new)
  1774. if err != nil && err != sql.ErrNoRows {
  1775. log.Error("Failed selecting from collectionredirects: %v", err)
  1776. }
  1777. return
  1778. }
  1779. func (db *datastore) DeleteCollection(alias string, userID int64) error {
  1780. c := &Collection{Alias: alias}
  1781. var username string
  1782. row := db.QueryRow("SELECT username FROM users WHERE id = ?", userID)
  1783. err := row.Scan(&username)
  1784. if err != nil {
  1785. return err
  1786. }
  1787. // Ensure user isn't deleting their main blog
  1788. if alias == username {
  1789. return impart.HTTPError{http.StatusForbidden, "You cannot currently delete your primary blog."}
  1790. }
  1791. row = db.QueryRow("SELECT id FROM collections WHERE alias = ? AND owner_id = ?", alias, userID)
  1792. err = row.Scan(&c.ID)
  1793. switch {
  1794. case err == sql.ErrNoRows:
  1795. return impart.HTTPError{http.StatusNotFound, "Collection doesn't exist or you're not allowed to delete it."}
  1796. case err != nil:
  1797. log.Error("Failed selecting from collections: %v", err)
  1798. return ErrInternalGeneral
  1799. }
  1800. t, err := db.Begin()
  1801. if err != nil {
  1802. return err
  1803. }
  1804. // Float all collection's posts
  1805. _, err = t.Exec("UPDATE posts SET collection_id = NULL WHERE collection_id = ? AND owner_id = ?", c.ID, userID)
  1806. if err != nil {
  1807. t.Rollback()
  1808. return err
  1809. }
  1810. // Remove redirects to or from this collection
  1811. _, err = t.Exec("DELETE FROM collectionredirects WHERE prev_alias = ? OR new_alias = ?", alias, alias)
  1812. if err != nil {
  1813. t.Rollback()
  1814. return err
  1815. }
  1816. // Remove any optional collection password
  1817. _, err = t.Exec("DELETE FROM collectionpasswords WHERE collection_id = ?", c.ID)
  1818. if err != nil {
  1819. t.Rollback()
  1820. return err
  1821. }
  1822. // Finally, delete collection itself
  1823. _, err = t.Exec("DELETE FROM collections WHERE id = ?", c.ID)
  1824. if err != nil {
  1825. t.Rollback()
  1826. return err
  1827. }
  1828. err = t.Commit()
  1829. if err != nil {
  1830. t.Rollback()
  1831. return err
  1832. }
  1833. return nil
  1834. }
  1835. func (db *datastore) IsCollectionAttributeOn(id int64, attr string) bool {
  1836. var v string
  1837. err := db.QueryRow("SELECT value FROM collectionattributes WHERE collection_id = ? AND attribute = ?", id, attr).Scan(&v)
  1838. switch {
  1839. case err == sql.ErrNoRows:
  1840. return false
  1841. case err != nil:
  1842. log.Error("Couldn't SELECT value in isCollectionAttributeOn for attribute '%s': %v", attr, err)
  1843. return false
  1844. }
  1845. return v == "1"
  1846. }
  1847. func (db *datastore) CollectionHasAttribute(id int64, attr string) bool {
  1848. var dummy string
  1849. err := db.QueryRow("SELECT value FROM collectionattributes WHERE collection_id = ? AND attribute = ?", id, attr).Scan(&dummy)
  1850. switch {
  1851. case err == sql.ErrNoRows:
  1852. return false
  1853. case err != nil:
  1854. log.Error("Couldn't SELECT value in collectionHasAttribute for attribute '%s': %v", attr, err)
  1855. return false
  1856. }
  1857. return true
  1858. }
  1859. func (db *datastore) DeleteAccount(userID int64) (l *string, err error) {
  1860. debug := ""
  1861. l = &debug
  1862. t, err := db.Begin()
  1863. if err != nil {
  1864. stringLogln(l, "Unable to begin: %v", err)
  1865. return
  1866. }
  1867. // Get all collections
  1868. rows, err := db.Query("SELECT id, alias FROM collections WHERE owner_id = ?", userID)
  1869. if err != nil {
  1870. t.Rollback()
  1871. stringLogln(l, "Unable to get collections: %v", err)
  1872. return
  1873. }
  1874. defer rows.Close()
  1875. colls := []Collection{}
  1876. var c Collection
  1877. for rows.Next() {
  1878. err = rows.Scan(&c.ID, &c.Alias)
  1879. if err != nil {
  1880. t.Rollback()
  1881. stringLogln(l, "Unable to scan collection cols: %v", err)
  1882. return
  1883. }
  1884. colls = append(colls, c)
  1885. }
  1886. var res sql.Result
  1887. for _, c := range colls {
  1888. // TODO: user deleteCollection() func
  1889. // Delete tokens
  1890. res, err = t.Exec("DELETE FROM collectionattributes WHERE collection_id = ?", c.ID)
  1891. if err != nil {
  1892. t.Rollback()
  1893. stringLogln(l, "Unable to delete attributes on %s: %v", c.Alias, err)
  1894. return
  1895. }
  1896. rs, _ := res.RowsAffected()
  1897. stringLogln(l, "Deleted %d for %s from collectionattributes", rs, c.Alias)
  1898. // Remove any optional collection password
  1899. res, err = t.Exec("DELETE FROM collectionpasswords WHERE collection_id = ?", c.ID)
  1900. if err != nil {
  1901. t.Rollback()
  1902. stringLogln(l, "Unable to delete passwords on %s: %v", c.Alias, err)
  1903. return
  1904. }
  1905. rs, _ = res.RowsAffected()
  1906. stringLogln(l, "Deleted %d for %s from collectionpasswords", rs, c.Alias)
  1907. // Remove redirects to this collection
  1908. res, err = t.Exec("DELETE FROM collectionredirects WHERE new_alias = ?", c.Alias)
  1909. if err != nil {
  1910. t.Rollback()
  1911. stringLogln(l, "Unable to delete redirects on %s: %v", c.Alias, err)
  1912. return
  1913. }
  1914. rs, _ = res.RowsAffected()
  1915. stringLogln(l, "Deleted %d for %s from collectionredirects", rs, c.Alias)
  1916. }
  1917. // Delete collections
  1918. res, err = t.Exec("DELETE FROM collections WHERE owner_id = ?", userID)
  1919. if err != nil {
  1920. t.Rollback()
  1921. stringLogln(l, "Unable to delete collections: %v", err)
  1922. return
  1923. }
  1924. rs, _ := res.RowsAffected()
  1925. stringLogln(l, "Deleted %d from collections", rs)
  1926. // Delete tokens
  1927. res, err = t.Exec("DELETE FROM accesstokens WHERE user_id = ?", userID)
  1928. if err != nil {
  1929. t.Rollback()
  1930. stringLogln(l, "Unable to delete access tokens: %v", err)
  1931. return
  1932. }
  1933. rs, _ = res.RowsAffected()
  1934. stringLogln(l, "Deleted %d from accesstokens", rs)
  1935. // Delete posts
  1936. res, err = t.Exec("DELETE FROM posts WHERE owner_id = ?", userID)
  1937. if err != nil {
  1938. t.Rollback()
  1939. stringLogln(l, "Unable to delete posts: %v", err)
  1940. return
  1941. }
  1942. rs, _ = res.RowsAffected()
  1943. stringLogln(l, "Deleted %d from posts", rs)
  1944. res, err = t.Exec("DELETE FROM userattributes WHERE user_id = ?", userID)
  1945. if err != nil {
  1946. t.Rollback()
  1947. stringLogln(l, "Unable to delete attributes: %v", err)
  1948. return
  1949. }
  1950. rs, _ = res.RowsAffected()
  1951. stringLogln(l, "Deleted %d from userattributes", rs)
  1952. res, err = t.Exec("DELETE FROM users WHERE id = ?", userID)
  1953. if err != nil {
  1954. t.Rollback()
  1955. stringLogln(l, "Unable to delete user: %v", err)
  1956. return
  1957. }
  1958. rs, _ = res.RowsAffected()
  1959. stringLogln(l, "Deleted %d from users", rs)
  1960. err = t.Commit()
  1961. if err != nil {
  1962. t.Rollback()
  1963. stringLogln(l, "Unable to commit: %v", err)
  1964. return
  1965. }
  1966. return
  1967. }
  1968. func (db *datastore) GetAPActorKeys(collectionID int64) ([]byte, []byte) {
  1969. var pub, priv []byte
  1970. err := db.QueryRow("SELECT public_key, private_key FROM collectionkeys WHERE collection_id = ?", collectionID).Scan(&pub, &priv)
  1971. switch {
  1972. case err == sql.ErrNoRows:
  1973. // Generate keys
  1974. pub, priv = activitypub.GenerateKeys()
  1975. _, err = db.Exec("INSERT INTO collectionkeys (collection_id, public_key, private_key) VALUES (?, ?, ?)", collectionID, pub, priv)
  1976. if err != nil {
  1977. log.Error("Unable to INSERT new activitypub keypair: %v", err)
  1978. return nil, nil
  1979. }
  1980. case err != nil:
  1981. log.Error("Couldn't SELECT collectionkeys: %v", err)
  1982. return nil, nil
  1983. }
  1984. return pub, priv
  1985. }
  1986. func (db *datastore) CreateUserInvite(id string, userID int64, maxUses int, expires *time.Time) error {
  1987. _, err := db.Exec("INSERT INTO userinvites (id, owner_id, max_uses, created, expires, inactive) VALUES (?, ?, ?, "+db.now()+", ?, 0)", id, userID, maxUses, expires)
  1988. return err
  1989. }
  1990. func (db *datastore) GetUserInvites(userID int64) (*[]Invite, error) {
  1991. rows, err := db.Query("SELECT id, max_uses, created, expires, inactive FROM userinvites WHERE owner_id = ? ORDER BY created DESC", userID)
  1992. if err != nil {
  1993. log.Error("Failed selecting from userinvites: %v", err)
  1994. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve user invites."}
  1995. }
  1996. defer rows.Close()
  1997. is := []Invite{}
  1998. for rows.Next() {
  1999. i := Invite{}
  2000. err = rows.Scan(&i.ID, &i.MaxUses, &i.Created, &i.Expires, &i.Inactive)
  2001. is = append(is, i)
  2002. }
  2003. return &is, nil
  2004. }
  2005. func (db *datastore) GetUserInvite(id string) (*Invite, error) {
  2006. var i Invite
  2007. err := db.QueryRow("SELECT id, max_uses, created, expires, inactive FROM userinvites WHERE id = ?", id).Scan(&i.ID, &i.MaxUses, &i.Created, &i.Expires, &i.Inactive)
  2008. switch {
  2009. case err == sql.ErrNoRows:
  2010. return nil, impart.HTTPError{http.StatusNotFound, "Invite doesn't exist."}
  2011. case err != nil:
  2012. log.Error("Failed selecting invite: %v", err)
  2013. return nil, err
  2014. }
  2015. return &i, nil
  2016. }
  2017. // IsUsersInvite returns true if the user with ID created the invite with code
  2018. // and an error other than sql no rows, if any. Will return false in the event
  2019. // of an error.
  2020. func (db *datastore) IsUsersInvite(code string, userID int64) (bool, error) {
  2021. var id string
  2022. err := db.QueryRow("SELECT id FROM userinvites WHERE id = ? AND owner_id = ?", code, userID).Scan(&id)
  2023. if err != nil && err != sql.ErrNoRows {
  2024. log.Error("Failed selecting invite: %v", err)
  2025. return false, err
  2026. }
  2027. return id != "", nil
  2028. }
  2029. func (db *datastore) GetUsersInvitedCount(id string) int64 {
  2030. var count int64
  2031. err := db.QueryRow("SELECT COUNT(*) FROM usersinvited WHERE invite_id = ?", id).Scan(&count)
  2032. switch {
  2033. case err == sql.ErrNoRows:
  2034. return 0
  2035. case err != nil:
  2036. log.Error("Failed selecting users invited count: %v", err)
  2037. return 0
  2038. }
  2039. return count
  2040. }
  2041. func (db *datastore) CreateInvitedUser(inviteID string, userID int64) error {
  2042. _, err := db.Exec("INSERT INTO usersinvited (invite_id, user_id) VALUES (?, ?)", inviteID, userID)
  2043. return err
  2044. }
  2045. func (db *datastore) GetInstancePages() ([]*instanceContent, error) {
  2046. return db.GetAllDynamicContent("page")
  2047. }
  2048. func (db *datastore) GetAllDynamicContent(t string) ([]*instanceContent, error) {
  2049. where := ""
  2050. params := []interface{}{}
  2051. if t != "" {
  2052. where = " WHERE content_type = ?"
  2053. params = append(params, t)
  2054. }
  2055. rows, err := db.Query("SELECT id, title, content, updated, content_type FROM appcontent"+where, params...)
  2056. if err != nil {
  2057. log.Error("Failed selecting from appcontent: %v", err)
  2058. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve instance pages."}
  2059. }
  2060. defer rows.Close()
  2061. pages := []*instanceContent{}
  2062. for rows.Next() {
  2063. c := &instanceContent{}
  2064. err = rows.Scan(&c.ID, &c.Title, &c.Content, &c.Updated, &c.Type)
  2065. if err != nil {
  2066. log.Error("Failed scanning row: %v", err)
  2067. break
  2068. }
  2069. pages = append(pages, c)
  2070. }
  2071. err = rows.Err()
  2072. if err != nil {
  2073. log.Error("Error after Next() on rows: %v", err)
  2074. }
  2075. return pages, nil
  2076. }
  2077. func (db *datastore) GetDynamicContent(id string) (*instanceContent, error) {
  2078. c := &instanceContent{
  2079. ID: id,
  2080. }
  2081. err := db.QueryRow("SELECT title, content, updated, content_type FROM appcontent WHERE id = ?", id).Scan(&c.Title, &c.Content, &c.Updated, &c.Type)
  2082. switch {
  2083. case err == sql.ErrNoRows:
  2084. return nil, nil
  2085. case err != nil:
  2086. log.Error("Couldn't SELECT FROM appcontent for id '%s': %v", id, err)
  2087. return nil, err
  2088. }
  2089. return c, nil
  2090. }
  2091. func (db *datastore) UpdateDynamicContent(id, title, content, contentType string) error {
  2092. var err error
  2093. if db.driverName == driverSQLite {
  2094. _, err = db.Exec("INSERT OR REPLACE INTO appcontent (id, title, content, updated, content_type) VALUES (?, ?, ?, "+db.now()+", ?)", id, title, content, contentType)
  2095. } else {
  2096. _, err = db.Exec("INSERT INTO appcontent (id, title, content, updated, content_type) VALUES (?, ?, ?, "+db.now()+", ?) "+db.upsert("id")+" title = ?, content = ?, updated = "+db.now(), id, title, content, contentType, title, content)
  2097. }
  2098. if err != nil {
  2099. log.Error("Unable to INSERT appcontent for '%s': %v", id, err)
  2100. }
  2101. return err
  2102. }
  2103. func (db *datastore) GetAllUsers(page uint) (*[]User, error) {
  2104. limitStr := fmt.Sprintf("0, %d", adminUsersPerPage)
  2105. if page > 1 {
  2106. limitStr = fmt.Sprintf("%d, %d", (page-1)*adminUsersPerPage, adminUsersPerPage)
  2107. }
  2108. rows, err := db.Query("SELECT id, username, created, status FROM users ORDER BY created DESC LIMIT " + limitStr)
  2109. if err != nil {
  2110. log.Error("Failed selecting from users: %v", err)
  2111. return nil, impart.HTTPError{http.StatusInternalServerError, "Couldn't retrieve all users."}
  2112. }
  2113. defer rows.Close()
  2114. users := []User{}
  2115. for rows.Next() {
  2116. u := User{}
  2117. err = rows.Scan(&u.ID, &u.Username, &u.Created, &u.Status)
  2118. if err != nil {
  2119. log.Error("Failed scanning GetAllUsers() row: %v", err)
  2120. break
  2121. }
  2122. users = append(users, u)
  2123. }
  2124. return &users, nil
  2125. }
  2126. func (db *datastore) GetAllUsersCount() int64 {
  2127. var count int64
  2128. err := db.QueryRow("SELECT COUNT(*) FROM users").Scan(&count)
  2129. switch {
  2130. case err == sql.ErrNoRows:
  2131. return 0
  2132. case err != nil:
  2133. log.Error("Failed selecting all users count: %v", err)
  2134. return 0
  2135. }
  2136. return count
  2137. }
  2138. func (db *datastore) GetUserLastPostTime(id int64) (*time.Time, error) {
  2139. var t time.Time
  2140. err := db.QueryRow("SELECT created FROM posts WHERE owner_id = ? ORDER BY created DESC LIMIT 1", id).Scan(&t)
  2141. switch {
  2142. case err == sql.ErrNoRows:
  2143. return nil, nil
  2144. case err != nil:
  2145. log.Error("Failed selecting last post time from posts: %v", err)
  2146. return nil, err
  2147. }
  2148. return &t, nil
  2149. }
  2150. // SetUserStatus changes a user's status in the database. see Users.UserStatus
  2151. func (db *datastore) SetUserStatus(id int64, status UserStatus) error {
  2152. _, err := db.Exec("UPDATE users SET status = ? WHERE id = ?", status, id)
  2153. if err != nil {
  2154. return fmt.Errorf("failed to update user status: %v", err)
  2155. }
  2156. return nil
  2157. }
  2158. func (db *datastore) GetCollectionLastPostTime(id int64) (*time.Time, error) {
  2159. var t time.Time
  2160. err := db.QueryRow("SELECT created FROM posts WHERE collection_id = ? ORDER BY created DESC LIMIT 1", id).Scan(&t)
  2161. switch {
  2162. case err == sql.ErrNoRows:
  2163. return nil, nil
  2164. case err != nil:
  2165. log.Error("Failed selecting last post time from posts: %v", err)
  2166. return nil, err
  2167. }
  2168. return &t, nil
  2169. }
  2170. func (db *datastore) GenerateOAuthState(ctx context.Context) (string, error) {
  2171. state, err := randString(24)
  2172. if err != nil {
  2173. return "", err
  2174. }
  2175. _, err = db.ExecContext(ctx, "INSERT INTO oauth_client_state (state, used, created_at) VALUES (?, FALSE, NOW())", state)
  2176. if err != nil {
  2177. return "", fmt.Errorf("unable to record oauth client state: %w", err)
  2178. }
  2179. return state, nil
  2180. }
  2181. func (db *datastore) ValidateOAuthState(ctx context.Context, state string) error {
  2182. res, err := db.ExecContext(ctx, "UPDATE oauth_client_state SET used = TRUE WHERE state = ?", state)
  2183. if err != nil {
  2184. return err
  2185. }
  2186. rowsAffected, err := res.RowsAffected()
  2187. if err != nil {
  2188. return err
  2189. }
  2190. if rowsAffected != 1 {
  2191. return fmt.Errorf("state not found")
  2192. }
  2193. return nil
  2194. }
  2195. func (db *datastore) RecordRemoteUserID(ctx context.Context, localUserID, remoteUserID int64) error {
  2196. var err error
  2197. if db.driverName == driverSQLite {
  2198. _, err = db.ExecContext(ctx, "INSERT OR REPLACE INTO users_oauth (user_id, remote_user_id) VALUES (?, ?)", localUserID, remoteUserID)
  2199. } else {
  2200. _, err = db.ExecContext(ctx, "INSERT INTO users_oauth (user_id, remote_user_id) VALUES (?, ?) "+db.upsert("user_id") + " user_id = ?", localUserID, remoteUserID, localUserID)
  2201. }
  2202. if err != nil {
  2203. log.Error("Unable to INSERT users_oauth for '%d': %v", localUserID, err)
  2204. }
  2205. return err
  2206. }
  2207. // GetIDForRemoteUser returns a user ID associated with a remote user ID.
  2208. func (db *datastore) GetIDForRemoteUser(ctx context.Context, remoteUserID int64) (int64, error) {
  2209. var userID int64 = -1
  2210. err := db.
  2211. QueryRowContext(ctx, "SELECT user_id FROM users_oauth WHERE remote_user_id = ?", remoteUserID).
  2212. Scan(&userID)
  2213. // Not finding a record is OK.
  2214. if err != nil && err != sql.ErrNoRows {
  2215. return -1, err
  2216. }
  2217. return userID, nil
  2218. }
  2219. // DatabaseInitialized returns whether or not the current datastore has been
  2220. // initialized with the correct schema.
  2221. // Currently, it checks to see if the `users` table exists.
  2222. func (db *datastore) DatabaseInitialized() bool {
  2223. var dummy string
  2224. var err error
  2225. if db.driverName == driverSQLite {
  2226. err = db.QueryRow("SELECT name FROM sqlite_master WHERE type = 'table' AND name = 'users'").Scan(&dummy)
  2227. } else {
  2228. err = db.QueryRow("SHOW TABLES LIKE 'users'").Scan(&dummy)
  2229. }
  2230. switch {
  2231. case err == sql.ErrNoRows:
  2232. return false
  2233. case err != nil:
  2234. log.Error("Couldn't SHOW TABLES: %v", err)
  2235. return false
  2236. }
  2237. return true
  2238. }
  2239. func stringLogln(log *string, s string, v ...interface{}) {
  2240. *log += fmt.Sprintf(s+"\n", v...)
  2241. }
  2242. func handleFailedPostInsert(err error) error {
  2243. log.Error("Couldn't insert into posts: %v", err)
  2244. return err
  2245. }
  2246. func randString(length int) (string, error) {
  2247. // every printable character on a US keyboard
  2248. charset := []rune("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789")
  2249. out := make([]rune, length)
  2250. setLen := big.NewInt(int64(len(charset)))
  2251. for idx := 0; idx < length; idx++ {
  2252. offset, err := rand.Int(rand.Reader, setLen)
  2253. if err != nil {
  2254. return "", err
  2255. }
  2256. if !offset.IsUint64() {
  2257. // this should (in theory) never happen
  2258. return "", errors.Errorf("Non-Uint64 offset returned from rand.Int")
  2259. }
  2260. out[idx] = charset[offset.Uint64()]
  2261. }
  2262. return string(out), nil
  2263. }