Skip to content

Commit

Permalink
Add more methods for converting to/from sender IDs (#412)
Browse files Browse the repository at this point in the history
Including some tests for sender IDs

Signed-off-by: `Sam Wedgwood <[email protected]>`
  • Loading branch information
swedgwood authored Sep 8, 2023
1 parent 740f742 commit 47bceff
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 0 deletions.
37 changes: 37 additions & 0 deletions spec/senderid.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,19 @@ type CreateSenderID func(ctx context.Context, userID UserID, roomID RoomID, room
// StoreSenderIDFromPublicID is a function to store the mxid_mapping after receiving a join event over federation.
type StoreSenderIDFromPublicID func(ctx context.Context, senderID SenderID, userID string, id RoomID) error

// Create a new sender ID from a private room key
func SenderIDFromPseudoIDKey(key ed25519.PrivateKey) SenderID {
return SenderID(Base64Bytes(key.Public().(ed25519.PublicKey)).Encode())
}

// Create a new sender ID from a user ID
func SenderIDFromUserID(user UserID) SenderID {
return SenderID(user.String())
}

// Decodes this sender ID as base64, i.e. returns the raw bytes of the
// pseudo ID used to create this SenderID, assuming this SenderID was made
// using a pseudo ID.
func (s SenderID) RawBytes() (res Base64Bytes, err error) {
err = res.Decode(string(s))
if err != nil {
Expand All @@ -43,13 +52,41 @@ func (s SenderID) RawBytes() (res Base64Bytes, err error) {
return res, nil
}

// Returns true if this SenderID was made using a user ID
func (s SenderID) IsUserID() bool {
// Key is base64, @ is not a valid base64 char
// So if string starts with @, then this sender ID must
// be a user ID
return string(s)[0] == '@'
}

// Returns true if this SenderID was made using a pseudo ID
func (s SenderID) IsPseudoID() bool {
return !s.IsUserID()
}

// Returns the non-nil UserID used to create this SenderID, or nil
// if this SenderID was not created using a UserID
func (s SenderID) ToUserID() *UserID {
if s.IsUserID() {
uID, _ := NewUserID(string(s), true)
return uID
}

return nil
}

// Returns the non-nil room public key (pseudo ID) used to create this
// SenderID, or nil if this SenderID was not created using a pseudo ID
func (s SenderID) ToPseudoID() *ed25519.PublicKey {
if s.IsPseudoID() {
decoded, err := s.RawBytes()
if err != nil {
return nil
}
key := ed25519.PublicKey([]byte(decoded))
return &key
}

return nil
}
77 changes: 77 additions & 0 deletions spec/senderid_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package spec

import (
"crypto/ed25519"
"reflect"
"testing"
)

func TestUserIDSenderIDs(t *testing.T) {
tests := map[string]UserID{
"basic": NewUserIDOrPanic("@localpart:domain", false),
"extensive_local": NewUserIDOrPanic("@abcdefghijklmnopqrstuvwxyz0123456789._=-/:domain", false),
"extensive_local_historic": NewUserIDOrPanic("@!\"#$%&'()*+,-./0123456789;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~:domain", true),
"domain_with_port": NewUserIDOrPanic("@localpart:domain.org:80", false),
"minimum_id": NewUserIDOrPanic("@a:1", false),
}

for name, userID := range tests {
t.Run(name, func(t *testing.T) {
senderID := SenderIDFromUserID(userID)

if string(senderID) != userID.String() {
t.Fatalf("Created sender ID did not match user ID string: senderID %s for user ID %s", string(senderID), userID.String())
}
if !senderID.IsUserID() {
t.Fatalf("IsUserID returned false for user ID: %s", userID.String())
}
if senderID.IsPseudoID() {
t.Fatalf("IsPseudoID returned true for user ID: %s", userID.String())
}
returnedUserID := senderID.ToUserID()
if returnedUserID == nil {
t.Fatalf("ToUserID returned nil value")
}
if !reflect.DeepEqual(userID, *returnedUserID) {
t.Fatalf("ToUserID returned different user ID than one used to created sender ID\ncreated with %s\nreturned %s", userID, *returnedUserID)
}
roomKey := senderID.ToPseudoID()
if roomKey != nil {
t.Fatalf("ToPseudoID returned non-nil value for user ID: %s, returned %s", userID.String(), roomKey)
}
})
}
}

func TestPseudoIDSenderIDs(t *testing.T) {
// Generate key from all zeroes seed
testKeySeed := make([]byte, 32)
testKey := ed25519.NewKeyFromSeed(testKeySeed)

t.Run("test pseudo ID", func(t *testing.T) {
senderID := SenderIDFromPseudoIDKey(testKey)
testPubkey := testKey.Public()
expectedSenderIDString := Base64Bytes(testPubkey.(ed25519.PublicKey)).Encode()

if string(senderID) != expectedSenderIDString {
t.Fatalf("Created sender ID did not match provided key: created sender ID %s, expected: %s", string(senderID), expectedSenderIDString)
}
if !senderID.IsPseudoID() {
t.Fatalf("IsPseudoID returned false for pseudo ID sender ID")
}
if senderID.IsUserID() {
t.Fatalf("IsUserID returned true for pseudo ID sender ID")
}
returnedKey := senderID.ToPseudoID()
if returnedKey == nil {
t.Fatal("ToPseudoID returned nil")
}
if !reflect.DeepEqual(testPubkey, *returnedKey) {
t.Fatalf("ToPseudoID returned different key to the one used to create the sender ID:\ncreated with %v\nreturned %v", testPubkey, *returnedKey)
}
userID := senderID.ToUserID()
if userID != nil {
t.Fatalf("ToUserID returned non-nil value %v", userID.String())
}
})
}
10 changes: 10 additions & 0 deletions spec/userid.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,20 @@ type UserID struct {
domain string
}

// Creates a new UserID, returning an error if invalid
func NewUserID(id string, allowHistoricalIDs bool) (*UserID, error) {
return parseAndValidateUserID(id, allowHistoricalIDs)
}

// Creates a new UserID, panicing if invalid
func NewUserIDOrPanic(id string, allowHistoricalIDs bool) UserID {
userID, err := parseAndValidateUserID(id, allowHistoricalIDs)
if err != nil {
panic(fmt.Sprintf("NewUserIDOrPanic failed: invalid user ID %s: %s", id, err.Error()))
}
return *userID
}

// Returns the full userID string including leading sigil
func (user *UserID) String() string {
return user.raw
Expand Down

0 comments on commit 47bceff

Please sign in to comment.