Skip to content

Commit

Permalink
lnurl tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bbengfort committed Aug 23, 2023
1 parent 0b7c755 commit ab9e0b6
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 23 deletions.
35 changes: 15 additions & 20 deletions pkg/openvasp/lnurl/bech32.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@ import (
"strings"
)

// Set of characters used in the data of bech32 strings. Note that this string is
// ordered, such that for a given charset[i], i is the binary value of the character.
const charset = "qpzry9x8gf2tvdw0s3jn54khce6mua7l"

// gen encodes the generator polynomial for the bech32 BCH checksum.
var gen = []int{0x3b6a57b2, 0x26508e6d, 0x1ea119fa, 0x3d4233dd, 0x2a1462b3}

// decode decodes a bech32 encoded string, returning the human-readable
Expand All @@ -15,17 +18,15 @@ func decode(bech string) (string, []byte, error) {
// Only ASCII characters between 33 and 126 are allowed.
for i := 0; i < len(bech); i++ {
if bech[i] < 33 || bech[i] > 126 {
return "", nil, fmt.Errorf("invalid character in "+
"string: '%c'", bech[i])
return "", nil, ErrInvalidCharacter(bech[i])
}
}

// The characters must be either all lowercase or all uppercase.
lower := strings.ToLower(bech)
upper := strings.ToUpper(bech)
if bech != lower && bech != upper {
return "", nil, fmt.Errorf("string not all lowercase or all " +
"uppercase")
return "", nil, ErrMixedCase
}

// We'll work with the lowercase string from now on.
Expand All @@ -37,7 +38,7 @@ func decode(bech string) (string, []byte, error) {
// or if the string is more than 90 characters in total.
one := strings.LastIndexByte(bech, '1')
if one < 1 || one+7 > len(bech) {
return "", nil, fmt.Errorf("invalid index of 1")
return "", nil, ErrInvalidSeparatorIndex(1)
}

// The human-readable part is everything before the last '1'.
Expand All @@ -48,20 +49,16 @@ func decode(bech string) (string, []byte, error) {
// 'charset'.
decoded, err := toBytes(data)
if err != nil {
return "", nil, fmt.Errorf("failed converting data to bytes: "+
"%v", err)
return "", nil, fmt.Errorf("failed converting data to bytes: %w", err)
}

if !bech32VerifyChecksum(hrp, decoded) {
moreInfo := ""
checksum := bech[len(bech)-6:]
expected, err := toChars(bech32Checksum(hrp,
decoded[:len(decoded)-6]))
expected, err := toChars(bech32Checksum(hrp, decoded[:len(decoded)-6]))
if err == nil {
moreInfo = fmt.Sprintf("Expected %v, got %v.",
expected, checksum)
err = ErrInvalidChecksum{expected, checksum}
}
return "", nil, fmt.Errorf("checksum failed. " + moreInfo)
return "", nil, fmt.Errorf("checksum failed: %w", err)
}

// We exclude the last 6 bytes, which is the checksum.
Expand All @@ -81,8 +78,7 @@ func encode(hrp string, data []byte) (string, error) {
// represented using the specified charset.
dataChars, err := toChars(combined)
if err != nil {
return "", fmt.Errorf("unable to convert data bytes to chars: "+
"%v", err)
return "", fmt.Errorf("unable to convert data bytes to chars: %w", err)
}
return hrp + "1" + dataChars, nil
}
Expand All @@ -94,8 +90,7 @@ func toBytes(chars string) ([]byte, error) {
for i := 0; i < len(chars); i++ {
index := strings.IndexByte(charset, chars[i])
if index < 0 {
return nil, fmt.Errorf("invalid character not part of "+
"charset: %v", chars[i])
return nil, ErrNonCharsetChar(chars[i])
}
decoded = append(decoded, byte(index))
}
Expand All @@ -108,7 +103,7 @@ func toChars(data []byte) (string, error) {
result := make([]byte, 0, len(data))
for _, b := range data {
if int(b) >= len(charset) {
return "", fmt.Errorf("invalid data byte: %v", b)
return "", ErrInvalidDataByte(b)
}
result = append(result, charset[b])
}
Expand All @@ -119,7 +114,7 @@ func toChars(data []byte) (string, error) {
// to a byte slice where each byte is encoding toBits bits.
func convertBits(data []byte, fromBits, toBits uint8, pad bool) ([]byte, error) {
if fromBits < 1 || fromBits > 8 || toBits < 1 || toBits > 8 {
return nil, fmt.Errorf("only bit groups between 1 and 8 allowed")
return nil, ErrInvalidBitGroups
}

// The final bytes, each byte encoding toBits bits.
Expand Down Expand Up @@ -178,7 +173,7 @@ func convertBits(data []byte, fromBits, toBits uint8, pad bool) ([]byte, error)

// Any incomplete group must be <= 4 bits, and all zeroes.
if filledBits > 0 && (filledBits > 4 || nextByte != 0) {
return nil, fmt.Errorf("invalid incomplete group")
return nil, ErrInvalidIncompleteGroup
}

return regrouped, nil
Expand Down
56 changes: 56 additions & 0 deletions pkg/openvasp/lnurl/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package lnurl

import (
"errors"
"fmt"
)

var (
ErrUnhandledScheme = errors.New("unhandled lnurl scheme")
ErrMixedCase = errors.New("string is not all lowercase or all uppercase")
ErrInvalidBitGroups = errors.New("only bit groups between 1 and 8 allowed")
ErrInvalidIncompleteGroup = errors.New("invalid incomplete group")
)

// ErrNonCharsetChar is returned when a character outside of the specific
// bech32 charset is used in the string.
type ErrNonCharsetChar rune

func (e ErrNonCharsetChar) Error() string {
return fmt.Sprintf("invalid character not part of charset: %v", int(e))
}

// ErrInvalidDataByte is returned when a byte outside the range required for
// conversion into a string was found.
type ErrInvalidDataByte byte

func (e ErrInvalidDataByte) Error() string {
return fmt.Sprintf("invalid data byte: %v", byte(e))
}

// ErrInvalidChecksum is returned when the extracted checksum of the string
// is different than what was expected.
type ErrInvalidChecksum struct {
Expected string
Actual string
}

func (e ErrInvalidChecksum) Error() string {
return fmt.Sprintf("expected %v, got %v", e.Expected, e.Actual)
}

// ErrInvalidCharacter is returned when the bech32 string has a character
// outside the range of the supported charset.
type ErrInvalidCharacter rune

func (e ErrInvalidCharacter) Error() string {
return fmt.Sprintf("invalid character in string: '%c'", rune(e))
}

// ErrInvalidSeparatorIndex is returned when the separator character '1' is
// in an invalid position in the bech32 string.
type ErrInvalidSeparatorIndex int

func (e ErrInvalidSeparatorIndex) Error() string {
return fmt.Sprintf("invalid separator index %d", int(e))
}
4 changes: 3 additions & 1 deletion pkg/openvasp/lnurl/lnurl.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"strings"
)

// Encode a plain-text https URL into a bech32-encoded uppercased lnurl string.
func Encode(url string) (lnurl string, err error) {
var converted []byte
if converted, err = convertBits([]byte(url), 8, 5, true); err != nil {
Expand All @@ -29,11 +30,12 @@ func Encode(url string) (lnurl string, err error) {
return strings.ToUpper(lnurl), nil
}

// Decode a bech32 encoded lnurl string and returns a plain-text https URL.
func Decode(lnurl string) (url string, err error) {
lnurl = strings.ToLower(lnurl)

if !strings.HasPrefix(lnurl, "lnurl1") {
return "", errors.New("unhandled lnurl scheme")
return "", ErrUnhandledScheme
}

// bech32
Expand Down
48 changes: 46 additions & 2 deletions pkg/openvasp/lnurl/lnurl_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package lnurl_test

import (
"strings"
"testing"

"github.com/stretchr/testify/require"
Expand All @@ -16,15 +17,58 @@ func TestLNURL(t *testing.T) {
"LNURL1DP68GURN8GHJ7MMSV4H8VCTNWQH8GETNWSKKUET59E5K7TE3XGEN7ARPVU7KJMN3W45HY7GF5KZ53",
"https://openvasp.test-net.io/123?tag=inquiry",
},
{
"lnurl1dp68gurn8ghj7cn9dejkv6trd9shy7fwvdhk6tm5wfcr7arpvu7hgunpwejkcun4d3jkjmn3w45hy7gmsy37e",
"https://beneficiary.com/trp?tag=travelruleinquiry",
},
}

for i, tc := range testCases {
actual, err := lnurl.Decode(tc.lnurl)
require.NoError(t, err, "test case %d decode failed", i)
require.Equal(t, tc.wburl, actual, "test case %d docde equality failed", i)
require.Equal(t, tc.wburl, actual, "test case %d decode equality failed", i)

actual, err = lnurl.Encode(tc.wburl)
require.NoError(t, err, "test case %d encode failed", i)
require.Equal(t, tc.lnurl, actual, "test case %d encode equality failed", i)
require.Equal(t, strings.ToUpper(tc.lnurl), actual, "test case %d encode equality failed", i)
}
}

func TestLNURLErrors(t *testing.T) {
testCases := []struct {
input string
err error
}{
{
"https://DP68GURN8GHJ7MMSV4H8VCTNWQH8GETNWSKKUET59E5K7TE3XGEN7ARPVU7KJMN3W45HY7GF5KZ53",
lnurl.ErrUnhandledScheme,
},
{
"lnurl1split1checkupstagehandshakeupstreamerranterredcaperred2y9e2w",
lnurl.ErrInvalidChecksum{"6gr7g4", "2y9e2w"},
},
{
"lnurl1s lit1checkupstagehandshakeupstreamerranterredcaperredp8hs2p",
lnurl.ErrInvalidCharacter(' '),
},
{
"lnurl1spl\x7Ft1checkupstagehandshakeupstreamerranterredcaperred2y9e3w",
lnurl.ErrInvalidCharacter(127),
},
{
"lnurl1split1cheo2y9e2w",
lnurl.ErrNonCharsetChar('o'),
},
{
"lnurl1split1a2y9w",
lnurl.ErrInvalidSeparatorIndex(1),
},
}

for i, tc := range testCases {
actual, err := lnurl.Decode(tc.input)
require.Error(t, err, "test case %d did not error", i)
require.ErrorIs(t, err, tc.err, "test case %d error did not match", i)
require.Empty(t, actual, "test case %d returned data", i)
}
}

0 comments on commit ab9e0b6

Please sign in to comment.