package encrypter
import (
"crypto/aes"
"crypto/cipher"
"crypto/ecdh"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"errors"
"math/big"
"github.com/siherrmann/encrypter/model"
)
// Encrypter holds ECDH and AES-GCM keys
type Encrypter struct {
PrivateKey *ecdh.PrivateKey
PublicKey *ecdh.PublicKey
AESKey []byte
GCM cipher.AEAD
}
// NewEncrypter creates a new Encrypter with ECDH keys
func NewEncrypter() (*Encrypter, error) {
curve := ecdh.P256()
priv, err := curve.GenerateKey(rand.Reader)
if err != nil {
return nil, err
}
return &Encrypter{
PrivateKey: priv,
PublicKey: priv.PublicKey(),
}, nil
}
func (e *Encrypter) GetECCHandshake() model.Handshake {
bytes := e.PublicKey.Bytes()
// Extract X and Y from uncompressed point (0x04 || X || Y)
x := new(big.Int).SetBytes(bytes[1:33])
y := new(big.Int).SetBytes(bytes[33:65])
return model.Handshake{
PublicKeyX: base64.StdEncoding.EncodeToString(x.Bytes()),
PublicKeyY: base64.StdEncoding.EncodeToString(y.Bytes()),
}
}
func (e *Encrypter) SetPeerPublicKey(h model.Handshake) (*ecdh.PublicKey, error) {
bx, err := base64.StdEncoding.DecodeString(h.PublicKeyX)
if err != nil {
return nil, err
}
by, err := base64.StdEncoding.DecodeString(h.PublicKeyY)
if err != nil {
return nil, err
}
// Reconstruct uncompressed point: 0x04 || X || Y
bytes := make([]byte, 1+32+32)
bytes[0] = 0x04
copy(bytes[1+32-len(bx):33], bx)
copy(bytes[33+32-len(by):65], by)
curve := ecdh.P256()
return curve.NewPublicKey(bytes)
}
// EncryptECC encrypts data using ECIES
// Steps:
// 1. Generate ephemeral private key r
// 2. Compute R = r·G (ephemeral public key)
// 3. Compute shared secret S = r·Q (recipient's public key)
// 4. Derive AES key from S using SHA-256
// 5. Encrypt plaintext with AES-GCM
func (e *Encrypter) EncryptECC(recipientPub *ecdh.PublicKey, plaintext []byte) (*model.EncryptedMessage, error) {
// Generate ephemeral ECDH private key r
curve := ecdh.P256()
ephemeralPriv, err := curve.GenerateKey(rand.Reader)
if err != nil {
return nil, err
}
// R = r·G (ephemeral public key)
ephemeralPub := ephemeralPriv.PublicKey()
// Compute shared secret S = r·Q using ECDH
sharedSecret, err := ephemeralPriv.ECDH(recipientPub)
if err != nil {
return nil, err
}
// Derive AES key from shared secret using SHA-256
aesKey := sha256.Sum256(sharedSecret)
// Encrypt with AES-GCM
block, err := aes.NewCipher(aesKey[:])
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
nonce := make([]byte, gcm.NonceSize())
if _, err := rand.Read(nonce); err != nil {
return nil, err
}
ciphertext := gcm.Seal(nonce, nonce, plaintext, nil)
// Extract X and Y from ephemeral public key
pubBytes := ephemeralPub.Bytes()
Rx := new(big.Int).SetBytes(pubBytes[1:33])
Ry := new(big.Int).SetBytes(pubBytes[33:65])
return &model.EncryptedMessage{
Rx: base64.StdEncoding.EncodeToString(Rx.Bytes()),
Ry: base64.StdEncoding.EncodeToString(Ry.Bytes()),
Ciphertext: base64.StdEncoding.EncodeToString(ciphertext),
}, nil
}
// DecryptECC decrypts data using ECIES
// Steps:
// 1. Extract ephemeral public key R from message
// 2. Compute shared secret S = d·R (using private key d)
// 3. Derive AES key from S using SHA-256
// 4. Decrypt ciphertext with AES-GCM
func (e *Encrypter) DecryptECC(encMsg *model.EncryptedMessage) ([]byte, error) {
// Reconstruct ephemeral public key R
rxBytes, err := base64.StdEncoding.DecodeString(encMsg.Rx)
if err != nil {
return nil, err
}
ryBytes, err := base64.StdEncoding.DecodeString(encMsg.Ry)
if err != nil {
return nil, err
}
// Reconstruct uncompressed point: 0x04 || X || Y
pubBytes := make([]byte, 1+32+32)
pubBytes[0] = 0x04
copy(pubBytes[1+32-len(rxBytes):33], rxBytes)
copy(pubBytes[33+32-len(ryBytes):65], ryBytes)
curve := ecdh.P256()
ephemeralPub, err := curve.NewPublicKey(pubBytes)
if err != nil {
return nil, err
}
// Compute shared secret S = d·R using ECDH
sharedSecret, err := e.PrivateKey.ECDH(ephemeralPub)
if err != nil {
return nil, err
}
// Derive AES key from shared secret using SHA-256
aesKey := sha256.Sum256(sharedSecret)
// Decrypt with AES-GCM
block, err := aes.NewCipher(aesKey[:])
if err != nil {
return nil, err
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
ciphertextBytes, err := base64.StdEncoding.DecodeString(encMsg.Ciphertext)
if err != nil {
return nil, err
}
nonceSize := gcm.NonceSize()
if len(ciphertextBytes) < nonceSize {
return nil, errors.New("ciphertext too short")
}
nonce := ciphertextBytes[:nonceSize]
ciphertext := ciphertextBytes[nonceSize:]
return gcm.Open(nil, nonce, ciphertext, nil)
}
package encrypter
import (
"encoding/json"
"fmt"
"net/http"
"github.com/siherrmann/encrypter/model"
)
type EncrypterClient struct {
Encrypter *Encrypter
serverURL string
}
func NewEncrypterClient(serverUrl string) (*EncrypterClient, error) {
e, err := NewEncrypter()
if err != nil {
return nil, err
}
return &EncrypterClient{Encrypter: e, serverURL: serverUrl}, nil
}
func (c *EncrypterClient) RequestData(urlPath string) ([]byte, error) {
// Step 1: Get client's public key handshake
clientHandshake := c.Encrypter.GetECCHandshake()
// Step 2: Request encrypted data from server with public key in headers
req, err := http.NewRequest("POST", c.serverURL+urlPath, nil)
if err != nil {
return nil, err
}
req.Header.Set("X-Encryption-Public-Key-X", clientHandshake.PublicKeyX)
req.Header.Set("X-Encryption-Public-Key-Y", clientHandshake.PublicKeyY)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode >= 300 {
return nil, fmt.Errorf("server returned error status: %s", resp.Status)
}
// Step 3: Parse encrypted response
encMsg := &model.EncryptedMessage{}
err = json.NewDecoder(resp.Body).Decode(encMsg)
if err != nil {
return nil, err
}
// Step 4: Decrypt data using client's private key (S = d·R)
data, err := c.Encrypter.DecryptECC(encMsg)
if err != nil {
return nil, err
}
return data, nil
}
package encrypter
import (
"encoding/json"
"net/http"
"github.com/siherrmann/encrypter/model"
)
// EncryptionMiddleware wraps an http.Handler to encrypt responses
// Client must send their public key in X-Encryption-Public-Key-X and X-Encryption-Public-Key-Y headers
// Response will be encrypted using ECIES and returned as EncryptedMessage JSON
func EncryptionMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Check if encryption is requested
pubKeyX := r.Header.Get("X-Encryption-Public-Key-X")
pubKeyY := r.Header.Get("X-Encryption-Public-Key-Y")
if pubKeyX == "" || pubKeyY == "" {
http.Error(w, "encryption headers missing", http.StatusBadRequest)
return
}
// Parse client's public key
encrypter, err := NewEncrypter()
if err != nil {
http.Error(w, "encryption setup failed", http.StatusInternalServerError)
return
}
handshake := model.Handshake{
PublicKeyX: pubKeyX,
PublicKeyY: pubKeyY,
}
recipientPub, err := encrypter.SetPeerPublicKey(handshake)
if err != nil {
http.Error(w, "invalid public key", http.StatusBadRequest)
return
}
// Call the next handler and capture the response
recorder := model.NewResponseRecorder()
next.ServeHTTP(recorder, r)
// Encrypt the response body
encMsg, err := encrypter.EncryptECC(recipientPub, recorder.BodyBytes())
if err != nil {
http.Error(w, "encryption failed", http.StatusInternalServerError)
return
}
payload, err := json.Marshal(encMsg)
if err != nil {
http.Error(w, "response encoding failed", http.StatusInternalServerError)
return
}
// Clear any headers that were set and write only encrypted response
for k := range w.Header() {
w.Header().Del(k)
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
if _, err := w.Write(payload); err != nil {
return
}
})
}
package model
import (
"bytes"
"net/http"
)
// ResponseRecorder captures the response body and headers
type ResponseRecorder struct {
header http.Header
body *bytes.Buffer
statusCode int
wroteHeader bool
}
func NewResponseRecorder() *ResponseRecorder {
return &ResponseRecorder{
header: http.Header{},
body: &bytes.Buffer{},
statusCode: http.StatusOK,
}
}
func (r *ResponseRecorder) Header() http.Header {
return r.header
}
func (r *ResponseRecorder) Write(b []byte) (int, error) {
if !r.wroteHeader {
r.WriteHeader(http.StatusOK)
}
return r.body.Write(b)
}
func (r *ResponseRecorder) WriteHeader(statusCode int) {
if !r.wroteHeader {
r.statusCode = statusCode
r.wroteHeader = true
}
}
func (r *ResponseRecorder) BodyBytes() []byte {
return r.body.Bytes()
}