@@ -0,0 +1,83 @@ | |||||
// Package data provides utilities for interacting with database data | |||||
// throughout Write.as. | |||||
package data | |||||
import ( | |||||
"crypto/aes" | |||||
"crypto/cipher" | |||||
"crypto/rand" | |||||
"errors" | |||||
"fmt" | |||||
) | |||||
// Encryption parameters | |||||
const ( | |||||
keyLen = 32 | |||||
delimiter = '%' | |||||
) | |||||
// Encrypt AES-encrypts given text with the given key k. | |||||
// This is used for encrypting sensitive information in the database, such as | |||||
// oAuth tokens and email addresses. | |||||
func Encrypt(k []byte, text string) ([]byte, error) { | |||||
// Validate parameters | |||||
if len(k) != keyLen { | |||||
return nil, errors.New(fmt.Sprintf("Invalid key length (must be %d bytes).", keyLen)) | |||||
} | |||||
// Encrypt plaintext with AES-GCM | |||||
block, err := aes.NewCipher(k) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
gcm, err := cipher.NewGCM(block) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
// Generate nonce | |||||
ns := gcm.NonceSize() | |||||
nonce := make([]byte, ns) | |||||
if _, err := rand.Read(nonce); err != nil { | |||||
return nil, err | |||||
} | |||||
ciphertext := gcm.Seal(nil, nonce, []byte(text), nil) | |||||
// Build text output in the format: | |||||
// NonceCiphertext | |||||
outtext := append(nonce, ciphertext...) | |||||
return outtext, nil | |||||
} | |||||
// Decrypt decrypts the given ciphertext with the given key k. | |||||
func Decrypt(k, ciphertext []byte) ([]byte, error) { | |||||
// Decrypt ciphertext | |||||
block, err := aes.NewCipher(k) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
gcm, err := cipher.NewGCM(block) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
ns := gcm.NonceSize() | |||||
// Validate data | |||||
if len(ciphertext) < ns { | |||||
return nil, errors.New("Ciphertext is too short") | |||||
} | |||||
nonce := ciphertext[:ns] | |||||
ciphertext = ciphertext[ns:] | |||||
plaintext, err := gcm.Open(nil, nonce, ciphertext, nil) | |||||
if err != nil { | |||||
return nil, err | |||||
} | |||||
return plaintext, nil | |||||
} |
@@ -0,0 +1,70 @@ | |||||
// Package data provides utilities for interacting with database data | |||||
// throughout Write.as. | |||||
package data | |||||
import ( | |||||
"bytes" | |||||
"crypto/rand" | |||||
"strings" | |||||
"testing" | |||||
) | |||||
func TestEncDec(t *testing.T) { | |||||
// Generate a random key with a valid length | |||||
k := make([]byte, keyLen) | |||||
_, err := rand.Read(k) | |||||
if err != nil { | |||||
t.Fatal(err) | |||||
} | |||||
runEncDec(t, k, "this is my secret message™. 😄", nil) | |||||
runEncDec(t, k, "mygreatemailaddress@gmail.com", nil) | |||||
} | |||||
func TestAuthentication(t *testing.T) { | |||||
// Generate a random key with a valid length | |||||
k := make([]byte, keyLen) | |||||
_, err := rand.Read(k) | |||||
if err != nil { | |||||
t.Fatal(err) | |||||
} | |||||
runEncDec(t, k, "mygreatemailaddress@gmail.com", func(c []byte) []byte { | |||||
c[0] = 'a' | |||||
t.Logf("Modified: %s\n", c) | |||||
return c | |||||
}) | |||||
} | |||||
func runEncDec(t *testing.T, k []byte, plaintext string, transform func([]byte) []byte) { | |||||
t.Logf("Plaintext: %s\n", plaintext) | |||||
// Encrypt the data | |||||
ciphertext, err := Encrypt(k, plaintext) | |||||
if err != nil { | |||||
t.Fatal(err) | |||||
} | |||||
t.Logf("Ciphertext: %s\n", ciphertext) | |||||
if transform != nil { | |||||
ciphertext = transform(ciphertext) | |||||
} | |||||
// Decrypt the data | |||||
decryptedText, err := Decrypt(k, ciphertext) | |||||
if err != nil { | |||||
if transform != nil && strings.Contains(err.Error(), "message authentication failed") { | |||||
// We modified the ciphertext; make sure we're getting the right error | |||||
t.Logf("%v\n", err) | |||||
return | |||||
} | |||||
t.Fatal(err) | |||||
} | |||||
t.Logf("Decrypted: %s\n", string(decryptedText)) | |||||
if !bytes.Equal([]byte(plaintext), decryptedText) { | |||||
t.Errorf("Plaintext mismatch: got %x vs %x", plaintext, decryptedText) | |||||
} | |||||
} |