package cryptography import ( "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) var ( key16 = []byte("1234567890123456") key24 = []byte("123456789012345678901234") key32 = []byte("12345678901234567890123456789012") ) // ---- AES-GCM ---- func TestAESGCMEncryptDecryptRoundTrip(t *testing.T) { for _, key := range [][]byte{key16, key24, key32} { plaintext := []byte("hello, world") ciphertext, err := AESGCMEncrypt(plaintext, key) require.NoError(t, err) require.NotEmpty(t, ciphertext) got, err := AESGCMDecrypt(ciphertext, key) require.NoError(t, err) assert.Equal(t, plaintext, got) } } func TestAESGCMEncryptProducesUniqueOutput(t *testing.T) { plaintext := []byte("same input") c1, _ := AESGCMEncrypt(plaintext, key32) c2, _ := AESGCMEncrypt(plaintext, key32) // nonces differ, so ciphertexts must differ assert.NotEqual(t, c1, c2) } func TestAESGCMEncryptBadKeyLength(t *testing.T) { _, err := AESGCMEncrypt([]byte("data"), []byte("shortkey")) require.Error(t, err) assert.Contains(t, err.Error(), "AES key length") } func TestAESGCMDecryptBadKeyLength(t *testing.T) { _, err := AESGCMDecrypt("somedata", []byte("bad")) require.Error(t, err) } func TestAESGCMDecryptInvalidBase64(t *testing.T) { _, err := AESGCMDecrypt("!!!not-base64!!!", key32) require.Error(t, err) } func TestAESGCMDecryptTooShort(t *testing.T) { // A valid base64 string that is too short to contain nonce _, err := AESGCMDecrypt("YQ", key32) // "a" – only 1 byte require.Error(t, err) } // ---- AES-CBC ---- func TestAESCBCEncryptDecryptRoundTrip(t *testing.T) { for _, key := range [][]byte{key16, key24, key32} { plaintext := []byte("CBC round-trip test") ciphertext, err := AESCBCEncrypt(plaintext, key) require.NoError(t, err) require.NotEmpty(t, ciphertext) got, err := AESCBCDecrypt(ciphertext, key) require.NoError(t, err) assert.Equal(t, plaintext, got) } } func TestAESCBCEncryptEmptyPlaintext(t *testing.T) { ciphertext, err := AESCBCEncrypt([]byte{}, key32) require.NoError(t, err) got, err := AESCBCDecrypt(ciphertext, key32) require.NoError(t, err) assert.Equal(t, []byte{}, got) } func TestAESCBCEncryptBadKeyLength(t *testing.T) { _, err := AESCBCEncrypt([]byte("data"), []byte("bad-key")) require.Error(t, err) } func TestAESCBCDecryptInvalidBase64(t *testing.T) { _, err := AESCBCDecrypt("!!!!", key32) require.Error(t, err) } func TestAESCBCDecryptTooShort(t *testing.T) { // base64 of a single byte – shorter than block size _, err := AESCBCDecrypt("YQ", key32) require.Error(t, err) } // ---- AES-CFB ---- func TestAESCFBEncryptDecryptRoundTrip(t *testing.T) { for _, key := range [][]byte{key16, key24, key32} { original := []byte("CFB mode test data") // AESCFBEncrypt modifies plaintext in-place; keep original copy for assertion plaintext := append([]byte(nil), original...) ciphertext, err := AESCFBEncrypt(plaintext, key) require.NoError(t, err) require.NotEmpty(t, ciphertext) got, err := AESCFBDecrypt(ciphertext, key) require.NoError(t, err) assert.Equal(t, original, got) } } func TestAESCFBEncryptBadKeyLength(t *testing.T) { _, err := AESCFBEncrypt([]byte("data"), []byte("x")) require.Error(t, err) } func TestAESCFBDecryptTooShort(t *testing.T) { _, err := AESCFBDecrypt("YQ", key32) require.Error(t, err) } // ---- PKCS7 padding ---- func TestPkcs7PadUnpad(t *testing.T) { data := []byte("hello") padded := pkcs7Pad(data, 16) assert.Equal(t, 16, len(padded)) unpadded, err := pkcs7Unpad(padded) require.NoError(t, err) assert.Equal(t, data, unpadded) } func TestPkcs7UnpadEmpty(t *testing.T) { _, err := pkcs7Unpad([]byte{}) require.Error(t, err) } func TestPkcs7UnpadInvalidPadding(t *testing.T) { // last byte claims padding of 0 – invalid _, err := pkcs7Unpad([]byte{0x01, 0x02, 0x00}) require.Error(t, err) } // ---- normalizeKey ---- func TestNormalizeKeyValidLengths(t *testing.T) { for _, k := range [][]byte{key16, key24, key32} { got, err := normalizeKey(k) require.NoError(t, err) assert.Equal(t, k, got) } } func TestNormalizeKeyInvalidLength(t *testing.T) { _, err := normalizeKey([]byte(strings.Repeat("x", 10))) require.Error(t, err) }