A clean, Markdown-based publishing platform made for writers. Write together, and build a community. https://writefreely.org
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

1004 lines
26 KiB

  1. /*
  2. * Copyright © 2018-2021 Musing Studio LLC.
  3. *
  4. * This file is part of WriteFreely.
  5. *
  6. * WriteFreely is free software: you can redistribute it and/or modify
  7. * it under the terms of the GNU Affero General Public License, included
  8. * in the LICENSE file in this source code package.
  9. */
  10. package writefreely
  11. import (
  12. "crypto/tls"
  13. "database/sql"
  14. _ "embed"
  15. "fmt"
  16. "html/template"
  17. "net"
  18. "net/http"
  19. "net/url"
  20. "os"
  21. "os/signal"
  22. "path/filepath"
  23. "regexp"
  24. "strings"
  25. "syscall"
  26. "time"
  27. "github.com/gorilla/mux"
  28. "github.com/gorilla/schema"
  29. "github.com/gorilla/sessions"
  30. "github.com/manifoldco/promptui"
  31. stripmd "github.com/writeas/go-strip-markdown/v2"
  32. "github.com/writeas/impart"
  33. "github.com/writeas/web-core/auth"
  34. "github.com/writeas/web-core/converter"
  35. "github.com/writeas/web-core/log"
  36. "golang.org/x/crypto/acme/autocert"
  37. "github.com/writefreely/writefreely/author"
  38. "github.com/writefreely/writefreely/config"
  39. "github.com/writefreely/writefreely/key"
  40. "github.com/writefreely/writefreely/migrations"
  41. "github.com/writefreely/writefreely/page"
  42. )
  43. const (
  44. staticDir = "static"
  45. assumedTitleLen = 80
  46. postsPerPage = 10
  47. serverSoftware = "WriteFreely"
  48. softwareURL = "https://writefreely.org"
  49. )
  50. var (
  51. debugging bool
  52. // Software version can be set from git env using -ldflags
  53. softwareVer = "0.15.0"
  54. // DEPRECATED VARS
  55. isSingleUser bool
  56. )
  57. // App holds data and configuration for an individual WriteFreely instance.
  58. type App struct {
  59. router *mux.Router
  60. shttp *http.ServeMux
  61. db *datastore
  62. cfg *config.Config
  63. cfgFile string
  64. keys *key.Keychain
  65. sessionStore sessions.Store
  66. formDecoder *schema.Decoder
  67. updates *updatesCache
  68. timeline *localTimeline
  69. }
  70. // DB returns the App's datastore
  71. func (app *App) DB() *datastore {
  72. return app.db
  73. }
  74. // Router returns the App's router
  75. func (app *App) Router() *mux.Router {
  76. return app.router
  77. }
  78. // Config returns the App's current configuration.
  79. func (app *App) Config() *config.Config {
  80. return app.cfg
  81. }
  82. // SetConfig updates the App's Config to the given value.
  83. func (app *App) SetConfig(cfg *config.Config) {
  84. app.cfg = cfg
  85. }
  86. // SetKeys updates the App's Keychain to the given value.
  87. func (app *App) SetKeys(k *key.Keychain) {
  88. app.keys = k
  89. }
  90. func (app *App) SessionStore() sessions.Store {
  91. return app.sessionStore
  92. }
  93. func (app *App) SetSessionStore(s sessions.Store) {
  94. app.sessionStore = s
  95. }
  96. // Apper is the interface for getting data into and out of a WriteFreely
  97. // instance (or "App").
  98. //
  99. // App returns the App for the current instance.
  100. //
  101. // LoadConfig reads an app configuration into the App, returning any error
  102. // encountered.
  103. //
  104. // SaveConfig persists the current App configuration.
  105. //
  106. // LoadKeys reads the App's encryption keys and loads them into its
  107. // key.Keychain.
  108. type Apper interface {
  109. App() *App
  110. LoadConfig() error
  111. SaveConfig(*config.Config) error
  112. LoadKeys() error
  113. ReqLog(r *http.Request, status int, timeSince time.Duration) string
  114. }
  115. // App returns the App
  116. func (app *App) App() *App {
  117. return app
  118. }
  119. // LoadConfig loads and parses a config file.
  120. func (app *App) LoadConfig() error {
  121. log.Info("Loading %s configuration...", app.cfgFile)
  122. cfg, err := config.Load(app.cfgFile)
  123. if err != nil {
  124. log.Error("Unable to load configuration: %v", err)
  125. os.Exit(1)
  126. return err
  127. }
  128. app.cfg = cfg
  129. return nil
  130. }
  131. // SaveConfig saves the given Config to disk -- namely, to the App's cfgFile.
  132. func (app *App) SaveConfig(c *config.Config) error {
  133. return config.Save(c, app.cfgFile)
  134. }
  135. // LoadKeys reads all needed keys from disk into the App. In order to use the
  136. // configured `Server.KeysParentDir`, you must call initKeyPaths(App) before
  137. // this.
  138. func (app *App) LoadKeys() error {
  139. var err error
  140. app.keys = &key.Keychain{}
  141. if debugging {
  142. log.Info(" %s", emailKeyPath)
  143. }
  144. executable, err := os.Executable()
  145. if err != nil {
  146. executable = "writefreely"
  147. } else {
  148. executable = filepath.Base(executable)
  149. }
  150. app.keys.EmailKey, err = os.ReadFile(emailKeyPath)
  151. if err != nil {
  152. return err
  153. }
  154. if debugging {
  155. log.Info(" %s", cookieAuthKeyPath)
  156. }
  157. app.keys.CookieAuthKey, err = os.ReadFile(cookieAuthKeyPath)
  158. if err != nil {
  159. return err
  160. }
  161. if debugging {
  162. log.Info(" %s", cookieKeyPath)
  163. }
  164. app.keys.CookieKey, err = os.ReadFile(cookieKeyPath)
  165. if err != nil {
  166. return err
  167. }
  168. if debugging {
  169. log.Info(" %s", csrfKeyPath)
  170. }
  171. app.keys.CSRFKey, err = os.ReadFile(csrfKeyPath)
  172. if err != nil {
  173. if os.IsNotExist(err) {
  174. log.Error(`Missing key: %s.
  175. Run this command to generate missing keys:
  176. %s keys generate
  177. `, csrfKeyPath, executable)
  178. }
  179. return err
  180. }
  181. return nil
  182. }
  183. func (app *App) ReqLog(r *http.Request, status int, timeSince time.Duration) string {
  184. return fmt.Sprintf("\"%s %s\" %d %s \"%s\"", r.Method, r.RequestURI, status, timeSince, r.UserAgent())
  185. }
  186. // handleViewHome shows page at root path. It checks the configuration and
  187. // authentication state to show the correct page.
  188. func handleViewHome(app *App, w http.ResponseWriter, r *http.Request) error {
  189. if app.cfg.App.SingleUser {
  190. // Render blog index
  191. return handleViewCollection(app, w, r)
  192. }
  193. // Multi-user instance
  194. forceLanding := r.FormValue("landing") == "1"
  195. if !forceLanding {
  196. // Show correct page based on user auth status and configured landing path
  197. u := getUserSession(app, r)
  198. if app.cfg.App.Chorus {
  199. // This instance is focused on reading, so show Reader on home route if not
  200. // private or a private-instance user is logged in.
  201. if !app.cfg.App.Private || u != nil {
  202. return viewLocalTimeline(app, w, r)
  203. }
  204. }
  205. if u != nil {
  206. // User is logged in, so show the Pad
  207. return handleViewPad(app, w, r)
  208. }
  209. if app.cfg.App.Private {
  210. return viewLogin(app, w, r)
  211. }
  212. if land := app.cfg.App.LandingPath(); land != "/" {
  213. return impart.HTTPError{http.StatusFound, land}
  214. }
  215. }
  216. return handleViewLanding(app, w, r)
  217. }
  218. func handleViewLanding(app *App, w http.ResponseWriter, r *http.Request) error {
  219. forceLanding := r.FormValue("landing") == "1"
  220. p := struct {
  221. page.StaticPage
  222. *OAuthButtons
  223. Flashes []template.HTML
  224. Banner template.HTML
  225. Content template.HTML
  226. ForcedLanding bool
  227. }{
  228. StaticPage: pageForReq(app, r),
  229. OAuthButtons: NewOAuthButtons(app.Config()),
  230. ForcedLanding: forceLanding,
  231. }
  232. banner, err := getLandingBanner(app)
  233. if err != nil {
  234. log.Error("unable to get landing banner: %v", err)
  235. return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not get banner: %v", err)}
  236. }
  237. p.Banner = template.HTML(applyMarkdown([]byte(banner.Content), "", app.cfg))
  238. content, err := getLandingBody(app)
  239. if err != nil {
  240. log.Error("unable to get landing content: %v", err)
  241. return impart.HTTPError{http.StatusInternalServerError, fmt.Sprintf("Could not get content: %v", err)}
  242. }
  243. p.Content = template.HTML(applyMarkdown([]byte(content.Content), "", app.cfg))
  244. // Get error messages
  245. session, err := app.sessionStore.Get(r, cookieName)
  246. if err != nil {
  247. // Ignore this
  248. log.Error("Unable to get session in handleViewHome; ignoring: %v", err)
  249. }
  250. flashes, _ := getSessionFlashes(app, w, r, session)
  251. for _, flash := range flashes {
  252. p.Flashes = append(p.Flashes, template.HTML(flash))
  253. }
  254. // Show landing page
  255. return renderPage(w, "landing.tmpl", p)
  256. }
  257. func handleTemplatedPage(app *App, w http.ResponseWriter, r *http.Request, t *template.Template) error {
  258. p := struct {
  259. page.StaticPage
  260. ContentTitle string
  261. Content template.HTML
  262. PlainContent string
  263. Updated string
  264. AboutStats *InstanceStats
  265. }{
  266. StaticPage: pageForReq(app, r),
  267. }
  268. if r.URL.Path == "/about" || r.URL.Path == "/contact" || r.URL.Path == "/privacy" {
  269. var c *instanceContent
  270. var err error
  271. if r.URL.Path == "/about" {
  272. c, err = getAboutPage(app)
  273. // Fetch stats
  274. p.AboutStats = &InstanceStats{}
  275. p.AboutStats.NumPosts, _ = app.db.GetTotalPosts()
  276. p.AboutStats.NumBlogs, _ = app.db.GetTotalCollections()
  277. } else if r.URL.Path == "/contact" {
  278. c, err = getContactPage(app)
  279. if c.Updated.IsZero() {
  280. // Page was never set up, so return 404
  281. return ErrPostNotFound
  282. }
  283. } else {
  284. c, err = getPrivacyPage(app)
  285. }
  286. if err != nil {
  287. return err
  288. }
  289. p.ContentTitle = c.Title.String
  290. p.Content = template.HTML(applyMarkdown([]byte(c.Content), "", app.cfg))
  291. p.PlainContent = shortPostDescription(stripmd.Strip(c.Content))
  292. if !c.Updated.IsZero() {
  293. p.Updated = c.Updated.Format("January 2, 2006")
  294. }
  295. }
  296. // Serve templated page
  297. err := t.ExecuteTemplate(w, "base", p)
  298. if err != nil {
  299. log.Error("Unable to render page: %v", err)
  300. }
  301. return nil
  302. }
  303. func pageForReq(app *App, r *http.Request) page.StaticPage {
  304. p := page.StaticPage{
  305. AppCfg: app.cfg.App,
  306. Path: r.URL.Path,
  307. Version: "v" + softwareVer,
  308. }
  309. // Use custom style, if file exists
  310. if _, err := os.Stat(filepath.Join(app.cfg.Server.StaticParentDir, staticDir, "local", "custom.css")); err == nil {
  311. p.CustomCSS = true
  312. }
  313. // Add user information, if given
  314. var u *User
  315. accessToken := r.FormValue("t")
  316. if accessToken != "" {
  317. userID := app.db.GetUserID(accessToken)
  318. if userID != -1 {
  319. var err error
  320. u, err = app.db.GetUserByID(userID)
  321. if err == nil {
  322. p.Username = u.Username
  323. }
  324. }
  325. } else {
  326. u = getUserSession(app, r)
  327. if u != nil {
  328. p.Username = u.Username
  329. p.IsAdmin = u != nil && u.IsAdmin()
  330. p.CanInvite = canUserInvite(app.cfg, p.IsAdmin)
  331. }
  332. }
  333. p.CanViewReader = !app.cfg.App.Private || u != nil
  334. return p
  335. }
  336. var fileRegex = regexp.MustCompile("/([^/]*\\.[^/]*)$")
  337. // Initialize loads the app configuration and initializes templates, keys,
  338. // session, route handlers, and the database connection.
  339. func Initialize(apper Apper, debug bool) (*App, error) {
  340. debugging = debug
  341. apper.LoadConfig()
  342. // Load templates
  343. err := InitTemplates(apper.App().Config())
  344. if err != nil {
  345. return nil, fmt.Errorf("load templates: %s", err)
  346. }
  347. // Load keys and set up session
  348. initKeyPaths(apper.App()) // TODO: find a better way to do this, since it's unneeded in all Apper implementations
  349. err = InitKeys(apper)
  350. if err != nil {
  351. return nil, fmt.Errorf("init keys: %s", err)
  352. }
  353. apper.App().InitUpdates()
  354. apper.App().InitSession()
  355. apper.App().InitDecoder()
  356. err = ConnectToDatabase(apper.App())
  357. if err != nil {
  358. return nil, fmt.Errorf("connect to DB: %s", err)
  359. }
  360. initActivityPub(apper.App())
  361. if apper.App().cfg.Email.Domain != "" || apper.App().cfg.Email.MailgunPrivate != "" {
  362. if apper.App().cfg.Email.Domain == "" {
  363. log.Error("[FAILED] Starting publish jobs queue: no [letters]domain config value set.")
  364. } else if apper.App().cfg.Email.MailgunPrivate == "" {
  365. log.Error("[FAILED] Starting publish jobs queue: no [letters]mailgun_private config value set.")
  366. } else {
  367. log.Info("Starting publish jobs queue...")
  368. go startPublishJobsQueue(apper.App())
  369. }
  370. }
  371. // Handle local timeline, if enabled
  372. if apper.App().cfg.App.LocalTimeline {
  373. log.Info("Initializing local timeline...")
  374. initLocalTimeline(apper.App())
  375. }
  376. return apper.App(), nil
  377. }
  378. func Serve(app *App, r *mux.Router) {
  379. log.Info("Going to serve...")
  380. isSingleUser = app.cfg.App.SingleUser
  381. app.cfg.Server.Dev = debugging
  382. // Handle shutdown
  383. c := make(chan os.Signal, 2)
  384. signal.Notify(c, os.Interrupt, syscall.SIGTERM)
  385. go func() {
  386. <-c
  387. log.Info("Shutting down...")
  388. shutdown(app)
  389. log.Info("Done.")
  390. os.Exit(0)
  391. }()
  392. // Start gopher server
  393. if app.cfg.Server.GopherPort > 0 && !app.cfg.App.Private {
  394. go initGopher(app)
  395. }
  396. // Start web application server
  397. var bindAddress = app.cfg.Server.Bind
  398. if bindAddress == "" {
  399. bindAddress = "localhost"
  400. }
  401. var err error
  402. if app.cfg.IsSecureStandalone() {
  403. if app.cfg.Server.Autocert {
  404. m := &autocert.Manager{
  405. Prompt: autocert.AcceptTOS,
  406. Cache: autocert.DirCache(app.cfg.Server.TLSCertPath),
  407. }
  408. host, err := url.Parse(app.cfg.App.Host)
  409. if err != nil {
  410. log.Error("[WARNING] Unable to parse configured host! %s", err)
  411. log.Error(`[WARNING] ALL hosts are allowed, which can open you to an attack where
  412. clients connect to a server by IP address and pretend to be asking for an
  413. incorrect host name, and cause you to reach the CA's rate limit for certificate
  414. requests. We recommend supplying a valid host name.`)
  415. log.Info("Using autocert on ANY host")
  416. } else {
  417. log.Info("Using autocert on host %s", host.Host)
  418. m.HostPolicy = autocert.HostWhitelist(host.Host)
  419. }
  420. s := &http.Server{
  421. Addr: ":https",
  422. Handler: r,
  423. TLSConfig: &tls.Config{
  424. GetCertificate: m.GetCertificate,
  425. },
  426. }
  427. s.SetKeepAlivesEnabled(false)
  428. go func() {
  429. log.Info("Serving redirects on http://%s:80", bindAddress)
  430. err = http.ListenAndServe(":80", m.HTTPHandler(nil))
  431. log.Error("Unable to start redirect server: %v", err)
  432. }()
  433. log.Info("Serving on https://%s:443", bindAddress)
  434. log.Info("---")
  435. err = s.ListenAndServeTLS("", "")
  436. } else {
  437. go func() {
  438. log.Info("Serving redirects on http://%s:80", bindAddress)
  439. err = http.ListenAndServe(fmt.Sprintf("%s:80", bindAddress), http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  440. http.Redirect(w, r, app.cfg.App.Host, http.StatusMovedPermanently)
  441. }))
  442. log.Error("Unable to start redirect server: %v", err)
  443. }()
  444. log.Info("Serving on https://%s:443", bindAddress)
  445. log.Info("Using manual certificates")
  446. log.Info("---")
  447. err = http.ListenAndServeTLS(fmt.Sprintf("%s:443", bindAddress), app.cfg.Server.TLSCertPath, app.cfg.Server.TLSKeyPath, r)
  448. }
  449. } else {
  450. network := "tcp"
  451. protocol := "http"
  452. if strings.HasPrefix(bindAddress, "/") {
  453. network = "unix"
  454. protocol = "http+unix"
  455. // old sockets will remain after server closes;
  456. // we need to delete them in order to open new ones
  457. err = os.Remove(bindAddress)
  458. if err != nil && !os.IsNotExist(err) {
  459. log.Error("%s already exists but could not be removed: %v", bindAddress, err)
  460. os.Exit(1)
  461. }
  462. } else {
  463. bindAddress = fmt.Sprintf("%s:%d", bindAddress, app.cfg.Server.Port)
  464. }
  465. log.Info("Serving on %s://%s", protocol, bindAddress)
  466. log.Info("---")
  467. listener, err := net.Listen(network, bindAddress)
  468. if err != nil {
  469. log.Error("Could not bind to address: %v", err)
  470. os.Exit(1)
  471. }
  472. if network == "unix" {
  473. err = os.Chmod(bindAddress, 0o666)
  474. if err != nil {
  475. log.Error("Could not update socket permissions: %v", err)
  476. os.Exit(1)
  477. }
  478. }
  479. defer listener.Close()
  480. err = http.Serve(listener, r)
  481. }
  482. if err != nil {
  483. log.Error("Unable to start: %v", err)
  484. os.Exit(1)
  485. }
  486. }
  487. func (app *App) InitDecoder() {
  488. // TODO: do this at the package level, instead of the App level
  489. // Initialize modules
  490. app.formDecoder = schema.NewDecoder()
  491. app.formDecoder.RegisterConverter(converter.NullJSONString{}, converter.ConvertJSONNullString)
  492. app.formDecoder.RegisterConverter(converter.NullJSONBool{}, converter.ConvertJSONNullBool)
  493. app.formDecoder.RegisterConverter(sql.NullString{}, converter.ConvertSQLNullString)
  494. app.formDecoder.RegisterConverter(sql.NullBool{}, converter.ConvertSQLNullBool)
  495. app.formDecoder.RegisterConverter(sql.NullInt64{}, converter.ConvertSQLNullInt64)
  496. app.formDecoder.RegisterConverter(sql.NullFloat64{}, converter.ConvertSQLNullFloat64)
  497. }
  498. // ConnectToDatabase validates and connects to the configured database, then
  499. // tests the connection.
  500. func ConnectToDatabase(app *App) error {
  501. // Check database configuration
  502. if app.cfg.Database.Type == driverMySQL && app.cfg.Database.User == "" {
  503. return fmt.Errorf("Database user not set.")
  504. }
  505. if app.cfg.Database.Host == "" {
  506. app.cfg.Database.Host = "localhost"
  507. }
  508. if app.cfg.Database.Database == "" {
  509. app.cfg.Database.Database = "writefreely"
  510. }
  511. // TODO: check err
  512. connectToDatabase(app)
  513. // Test database connection
  514. err := app.db.Ping()
  515. if err != nil {
  516. return fmt.Errorf("Database ping failed: %s", err)
  517. }
  518. return nil
  519. }
  520. // FormatVersion constructs the version string for the application
  521. func FormatVersion() string {
  522. return serverSoftware + " " + softwareVer
  523. }
  524. // OutputVersion prints out the version of the application.
  525. func OutputVersion() {
  526. fmt.Println(FormatVersion())
  527. }
  528. // NewApp creates a new app instance.
  529. func NewApp(cfgFile string) *App {
  530. return &App{
  531. cfgFile: cfgFile,
  532. }
  533. }
  534. // CreateConfig creates a default configuration and saves it to the app's cfgFile.
  535. func CreateConfig(app *App) error {
  536. log.Info("Creating configuration...")
  537. c := config.New()
  538. log.Info("Saving configuration %s...", app.cfgFile)
  539. err := config.Save(c, app.cfgFile)
  540. if err != nil {
  541. return fmt.Errorf("Unable to save configuration: %v", err)
  542. }
  543. return nil
  544. }
  545. // DoConfig runs the interactive configuration process.
  546. func DoConfig(app *App, configSections string) {
  547. if configSections == "" {
  548. configSections = "server db app"
  549. }
  550. // let's check there aren't any garbage in the list
  551. configSectionsArray := strings.Split(configSections, " ")
  552. for _, element := range configSectionsArray {
  553. if element != "server" && element != "db" && element != "app" {
  554. log.Error("Invalid argument to --sections. Valid arguments are only \"server\", \"db\" and \"app\"")
  555. os.Exit(1)
  556. }
  557. }
  558. d, err := config.Configure(app.cfgFile, configSections)
  559. if err != nil {
  560. log.Error("Unable to configure: %v", err)
  561. os.Exit(1)
  562. }
  563. app.cfg = d.Config
  564. connectToDatabase(app)
  565. defer shutdown(app)
  566. if !app.db.DatabaseInitialized() {
  567. err = adminInitDatabase(app)
  568. if err != nil {
  569. log.Error(err.Error())
  570. os.Exit(1)
  571. }
  572. } else {
  573. log.Info("Database already initialized.")
  574. }
  575. if d.User != nil {
  576. u := &User{
  577. Username: d.User.Username,
  578. HashedPass: d.User.HashedPass,
  579. Created: time.Now().Truncate(time.Second).UTC(),
  580. }
  581. // Create blog
  582. log.Info("Creating user %s...\n", u.Username)
  583. err = app.db.CreateUser(app.cfg, u, app.cfg.App.SiteName, "")
  584. if err != nil {
  585. log.Error("Unable to create user: %s", err)
  586. os.Exit(1)
  587. }
  588. log.Info("Done!")
  589. }
  590. os.Exit(0)
  591. }
  592. // GenerateKeyFiles creates app encryption keys and saves them into the configured KeysParentDir.
  593. func GenerateKeyFiles(app *App) error {
  594. // Read keys path from config
  595. app.LoadConfig()
  596. // Create keys dir if it doesn't exist yet
  597. fullKeysDir := filepath.Join(app.cfg.Server.KeysParentDir, keysDir)
  598. if _, err := os.Stat(fullKeysDir); os.IsNotExist(err) {
  599. err = os.Mkdir(fullKeysDir, 0700)
  600. if err != nil {
  601. return err
  602. }
  603. }
  604. // Generate keys
  605. initKeyPaths(app)
  606. // TODO: use something like https://github.com/hashicorp/go-multierror to return errors
  607. var keyErrs error
  608. err := generateKey(emailKeyPath)
  609. if err != nil {
  610. keyErrs = err
  611. }
  612. err = generateKey(cookieAuthKeyPath)
  613. if err != nil {
  614. keyErrs = err
  615. }
  616. err = generateKey(cookieKeyPath)
  617. if err != nil {
  618. keyErrs = err
  619. }
  620. err = generateKey(csrfKeyPath)
  621. if err != nil {
  622. keyErrs = err
  623. }
  624. return keyErrs
  625. }
  626. // CreateSchema creates all database tables needed for the application.
  627. func CreateSchema(apper Apper) error {
  628. apper.LoadConfig()
  629. connectToDatabase(apper.App())
  630. defer shutdown(apper.App())
  631. err := adminInitDatabase(apper.App())
  632. if err != nil {
  633. return err
  634. }
  635. return nil
  636. }
  637. // Migrate runs all necessary database migrations.
  638. func Migrate(apper Apper) error {
  639. apper.LoadConfig()
  640. connectToDatabase(apper.App())
  641. defer shutdown(apper.App())
  642. err := migrations.Migrate(migrations.NewDatastore(apper.App().db.DB, apper.App().db.driverName))
  643. if err != nil {
  644. return fmt.Errorf("migrate: %s", err)
  645. }
  646. return nil
  647. }
  648. // ResetPassword runs the interactive password reset process.
  649. func ResetPassword(apper Apper, username string) error {
  650. // Connect to the database
  651. apper.LoadConfig()
  652. connectToDatabase(apper.App())
  653. defer shutdown(apper.App())
  654. // Fetch user
  655. u, err := apper.App().db.GetUserForAuth(username)
  656. if err != nil {
  657. log.Error("Get user: %s", err)
  658. os.Exit(1)
  659. }
  660. // Prompt for new password
  661. prompt := promptui.Prompt{
  662. Templates: &promptui.PromptTemplates{
  663. Success: "{{ . | bold | faint }}: ",
  664. },
  665. Label: "New password",
  666. Mask: '*',
  667. }
  668. newPass, err := prompt.Run()
  669. if err != nil {
  670. log.Error("%s", err)
  671. os.Exit(1)
  672. }
  673. // Do the update
  674. log.Info("Updating...")
  675. err = adminResetPassword(apper.App(), u, newPass)
  676. if err != nil {
  677. log.Error("%s", err)
  678. os.Exit(1)
  679. }
  680. log.Info("Success.")
  681. return nil
  682. }
  683. // DoDeleteAccount runs the confirmation and account delete process.
  684. func DoDeleteAccount(apper Apper, username string) error {
  685. // Connect to the database
  686. apper.LoadConfig()
  687. connectToDatabase(apper.App())
  688. defer shutdown(apper.App())
  689. // check user exists
  690. u, err := apper.App().db.GetUserForAuth(username)
  691. if err != nil {
  692. log.Error("%s", err)
  693. os.Exit(1)
  694. }
  695. userID := u.ID
  696. // do not delete the admin account
  697. // TODO: check for other admins and skip?
  698. if u.IsAdmin() {
  699. log.Error("Can not delete admin account")
  700. os.Exit(1)
  701. }
  702. // confirm deletion, w/ w/out posts
  703. prompt := promptui.Prompt{
  704. Templates: &promptui.PromptTemplates{
  705. Success: "{{ . | bold | faint }}: ",
  706. },
  707. Label: fmt.Sprintf("Really delete user : %s", username),
  708. IsConfirm: true,
  709. }
  710. _, err = prompt.Run()
  711. if err != nil {
  712. log.Info("Aborted...")
  713. os.Exit(0)
  714. }
  715. log.Info("Deleting...")
  716. err = apper.App().db.DeleteAccount(userID)
  717. if err != nil {
  718. log.Error("%s", err)
  719. os.Exit(1)
  720. }
  721. log.Info("Success.")
  722. return nil
  723. }
  724. func connectToDatabase(app *App) {
  725. log.Info("Connecting to %s database...", app.cfg.Database.Type)
  726. var db *sql.DB
  727. var err error
  728. if app.cfg.Database.Type == driverMySQL {
  729. db, err = sql.Open(app.cfg.Database.Type, fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=%s&tls=%t", app.cfg.Database.User, app.cfg.Database.Password, app.cfg.Database.Host, app.cfg.Database.Port, app.cfg.Database.Database, url.QueryEscape(time.Local.String()), app.cfg.Database.TLS))
  730. db.SetMaxOpenConns(50)
  731. } else if app.cfg.Database.Type == driverSQLite {
  732. if !SQLiteEnabled {
  733. log.Error("Invalid database type '%s'. Binary wasn't compiled with SQLite3 support.", app.cfg.Database.Type)
  734. os.Exit(1)
  735. }
  736. if app.cfg.Database.FileName == "" {
  737. log.Error("SQLite database filename value in config.ini is empty.")
  738. os.Exit(1)
  739. }
  740. db, err = sql.Open("sqlite3_with_regex", app.cfg.Database.FileName+"?parseTime=true&cached=shared")
  741. db.SetMaxOpenConns(2)
  742. } else {
  743. log.Error("Invalid database type '%s'. Only 'mysql' and 'sqlite3' are supported right now.", app.cfg.Database.Type)
  744. os.Exit(1)
  745. }
  746. if err != nil {
  747. log.Error("%s", err)
  748. os.Exit(1)
  749. }
  750. app.db = &datastore{db, app.cfg.Database.Type}
  751. }
  752. func shutdown(app *App) {
  753. log.Info("Closing database connection...")
  754. app.db.Close()
  755. if strings.HasPrefix(app.cfg.Server.Bind, "/") {
  756. // Clean up socket
  757. log.Info("Removing socket file...")
  758. err := os.Remove(app.cfg.Server.Bind)
  759. if err != nil {
  760. log.Error("Unable to remove socket: %s", err)
  761. os.Exit(1)
  762. }
  763. log.Info("Success.")
  764. }
  765. }
  766. // CreateUser creates a new admin or normal user from the given credentials.
  767. func CreateUser(apper Apper, username, password string, isAdmin bool) error {
  768. // Create an admin user with --create-admin
  769. apper.LoadConfig()
  770. connectToDatabase(apper.App())
  771. defer shutdown(apper.App())
  772. // Ensure an admin / first user doesn't already exist
  773. firstUser, _ := apper.App().db.GetUserByID(1)
  774. if isAdmin {
  775. // Abort if trying to create admin user, but one already exists
  776. if firstUser != nil {
  777. return fmt.Errorf("Admin user already exists (%s). Create a regular user with: writefreely --create-user", firstUser.Username)
  778. }
  779. } else {
  780. // Abort if trying to create regular user, but no admin exists yet
  781. if firstUser == nil {
  782. return fmt.Errorf("No admin user exists yet. Create an admin first with: writefreely --create-admin")
  783. }
  784. }
  785. // Create the user
  786. // Normalize and validate username
  787. desiredUsername := username
  788. username = getSlug(username, "")
  789. usernameDesc := username
  790. if username != desiredUsername {
  791. usernameDesc += " (originally: " + desiredUsername + ")"
  792. }
  793. if !author.IsValidUsername(apper.App().cfg, username) {
  794. return fmt.Errorf("Username %s is invalid, reserved, or shorter than configured minimum length (%d characters).", usernameDesc, apper.App().cfg.App.MinUsernameLen)
  795. }
  796. // Hash the password
  797. hashedPass, err := auth.HashPass([]byte(password))
  798. if err != nil {
  799. return fmt.Errorf("Unable to hash password: %v", err)
  800. }
  801. u := &User{
  802. Username: username,
  803. HashedPass: hashedPass,
  804. Created: time.Now().Truncate(time.Second).UTC(),
  805. }
  806. userType := "user"
  807. if isAdmin {
  808. userType = "admin"
  809. }
  810. log.Info("Creating %s %s...", userType, usernameDesc)
  811. err = apper.App().db.CreateUser(apper.App().Config(), u, desiredUsername, "")
  812. if err != nil {
  813. return fmt.Errorf("Unable to create user: %s", err)
  814. }
  815. log.Info("Done!")
  816. return nil
  817. }
  818. //go:embed schema.sql
  819. var schemaSql string
  820. //go:embed sqlite.sql
  821. var sqliteSql string
  822. func adminInitDatabase(app *App) error {
  823. var schema string
  824. if app.cfg.Database.Type == driverSQLite {
  825. schema = sqliteSql
  826. } else {
  827. schema = schemaSql
  828. }
  829. tblReg := regexp.MustCompile("CREATE TABLE (IF NOT EXISTS )?`([a-z_]+)`")
  830. queries := strings.Split(string(schema), ";\n")
  831. for _, q := range queries {
  832. if strings.TrimSpace(q) == "" {
  833. continue
  834. }
  835. parts := tblReg.FindStringSubmatch(q)
  836. if len(parts) >= 3 {
  837. log.Info("Creating table %s...", parts[2])
  838. } else {
  839. log.Info("Creating table ??? (Weird query) No match in: %v", parts)
  840. }
  841. _, err := app.db.Exec(q)
  842. if err != nil {
  843. log.Error("%s", err)
  844. } else {
  845. log.Info("Created.")
  846. }
  847. }
  848. // Set up migrations table
  849. log.Info("Initializing appmigrations table...")
  850. err := migrations.SetInitialMigrations(migrations.NewDatastore(app.db.DB, app.db.driverName))
  851. if err != nil {
  852. return fmt.Errorf("Unable to set initial migrations: %v", err)
  853. }
  854. log.Info("Running migrations...")
  855. err = migrations.Migrate(migrations.NewDatastore(app.db.DB, app.db.driverName))
  856. if err != nil {
  857. return fmt.Errorf("migrate: %s", err)
  858. }
  859. log.Info("Done.")
  860. return nil
  861. }
  862. // ServerUserAgent returns a User-Agent string to use in external requests. The
  863. // hostName parameter may be left empty.
  864. func ServerUserAgent(hostName string) string {
  865. hostUAStr := ""
  866. if hostName != "" {
  867. hostUAStr = "; +" + hostName
  868. }
  869. return "Go (" + serverSoftware + "/" + softwareVer + hostUAStr + ")"
  870. }