Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add more methods for converting to/from sender IDs #412

Merged
merged 5 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions spec/senderid.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,19 @@ type CreateSenderID func(ctx context.Context, userID UserID, roomID RoomID, room
// StoreSenderIDFromPublicID is a function to store the mxid_mapping after receiving a join event over federation.
type StoreSenderIDFromPublicID func(ctx context.Context, senderID SenderID, userID string, id RoomID) error

// Create a new sender ID from a private room key
func SenderIDFromPseudoIDKey(key ed25519.PrivateKey) SenderID {
return SenderID(Base64Bytes(key.Public().(ed25519.PublicKey)).Encode())
}

// Create a new sender ID from a user ID
func SenderIDFromUserID(user UserID) SenderID {
return SenderID(user.String())
}

// Decodes this sender ID as base64, i.e. returns the raw bytes of the
// pseudo ID used to create this SenderID, assuming this SenderID was made
// using a pseudo ID.
func (s SenderID) RawBytes() (res Base64Bytes, err error) {
err = res.Decode(string(s))
if err != nil {
Expand All @@ -43,13 +52,41 @@ func (s SenderID) RawBytes() (res Base64Bytes, err error) {
return res, nil
}

// Returns true if this SenderID was made using a user ID
func (s SenderID) IsUserID() bool {
// Key is base64, @ is not a valid base64 char
// So if string starts with @, then this sender ID must
// be a user ID
return string(s)[0] == '@'
}

// Returns true if this SenderID was made using a pseudo ID
func (s SenderID) IsPseudoID() bool {
return !s.IsUserID()
}

// Returns the non-nil UserID used to create this SenderID, or nil
// if this SenderID was not created using a UserID
func (s SenderID) ToUserID() *UserID {
if s.IsUserID() {
uID, _ := NewUserID(string(s), true)
return uID
}

return nil
}

// Returns the non-nil room public key (pseudo ID) used to create this
// SenderID, or nil if this SenderID was not created using a pseudo ID
func (s SenderID) ToPseudoID() *ed25519.PublicKey {
if s.IsPseudoID() {
decoded, err := s.RawBytes()
if err != nil {
return nil
}
key := ed25519.PublicKey([]byte(decoded))
return &key
}

return nil
}
77 changes: 77 additions & 0 deletions spec/senderid_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package spec

import (
"crypto/ed25519"
"reflect"
"testing"
)

func TestUserIDSenderIDs(t *testing.T) {
tests := map[string]UserID{
"basic": NewUserIDOrPanic("@localpart:domain", false),
"extensive_local": NewUserIDOrPanic("@abcdefghijklmnopqrstuvwxyz0123456789._=-/:domain", false),
"extensive_local_historic": NewUserIDOrPanic("@!\"#$%&'()*+,-./0123456789;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~:domain", true),
"domain_with_port": NewUserIDOrPanic("@localpart:domain.org:80", false),
"minimum_id": NewUserIDOrPanic("@a:1", false),
}

for name, userID := range tests {
t.Run(name, func(t *testing.T) {
senderID := SenderIDFromUserID(userID)

if string(senderID) != userID.String() {
t.Fatalf("Created sender ID did not match user ID string: senderID %s for user ID %s", string(senderID), userID.String())
}
if !senderID.IsUserID() {
t.Fatalf("IsUserID returned false for user ID: %s", userID.String())
}
if senderID.IsPseudoID() {
t.Fatalf("IsPseudoID returned true for user ID: %s", userID.String())
}
returnedUserID := senderID.ToUserID()
if returnedUserID == nil {
t.Fatalf("ToUserID returned nil value")
}
if !reflect.DeepEqual(userID, *returnedUserID) {
t.Fatalf("ToUserID returned different user ID than one used to created sender ID\ncreated with %s\nreturned %s", userID, *returnedUserID)
}
roomKey := senderID.ToPseudoID()
if roomKey != nil {
t.Fatalf("ToPseudoID returned non-nil value for user ID: %s, returned %s", userID.String(), roomKey)
}
})
}
}

func TestPseudoIDSenderIDs(t *testing.T) {
// Generate key from all zeroes seed
testKeySeed := make([]byte, 32)
testKey := ed25519.NewKeyFromSeed(testKeySeed)

t.Run("test pseudo ID", func(t *testing.T) {
senderID := SenderIDFromPseudoIDKey(testKey)
testPubkey := testKey.Public()
expectedSenderIDString := Base64Bytes(testPubkey.(ed25519.PublicKey)).Encode()

if string(senderID) != expectedSenderIDString {
t.Fatalf("Created sender ID did not match provided key: created sender ID %s, expected: %s", string(senderID), expectedSenderIDString)
}
if !senderID.IsPseudoID() {
t.Fatalf("IsPseudoID returned false for pseudo ID sender ID")
}
if senderID.IsUserID() {
t.Fatalf("IsUserID returned true for pseudo ID sender ID")
}
returnedKey := senderID.ToPseudoID()
if returnedKey == nil {
t.Fatal("ToPseudoID returned nil")
}
if !reflect.DeepEqual(testPubkey, *returnedKey) {
t.Fatalf("ToPseudoID returned different key to the one used to create the sender ID:\ncreated with %v\nreturned %v", testPubkey, *returnedKey)
}
userID := senderID.ToUserID()
if userID != nil {
t.Fatalf("ToUserID returned non-nil value %v", userID.String())
}
})
}
10 changes: 10 additions & 0 deletions spec/userid.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,20 @@ type UserID struct {
domain string
}

// Creates a new UserID, returning an error if invalid
func NewUserID(id string, allowHistoricalIDs bool) (*UserID, error) {
return parseAndValidateUserID(id, allowHistoricalIDs)
}

// Creates a new UserID, panicing if invalid
func NewUserIDOrPanic(id string, allowHistoricalIDs bool) UserID {
userID, err := parseAndValidateUserID(id, allowHistoricalIDs)
if err != nil {
panic(fmt.Sprintf("NewUserIDOrPanic failed: invalid user ID %s: %s", id, err.Error()))
}
return *userID
}

// Returns the full userID string including leading sigil
func (user *UserID) String() string {
return user.raw
Expand Down