diff --git a/wallet/udb/addressdb.go b/wallet/udb/addressdb.go index da323b4b1..0c86f9eee 100644 --- a/wallet/udb/addressdb.go +++ b/wallet/udb/addressdb.go @@ -703,7 +703,7 @@ func fetchAccountByName(ns walletdb.ReadBucket, name string) (uint32, error) { // fetchAccountInfo loads information about the passed account from the // database. -func fetchAccountInfo(ns walletdb.ReadBucket, account uint32, dbVersion uint32) (*dbBIP0044AccountRow, error) { +func fetchAccountInfo(ns walletdb.ReadBucket, account uint32) (*dbAccountRow, error) { bucket := ns.NestedReadBucket(acctBucketName) accountID := uint32ToBytes(account) @@ -717,12 +717,7 @@ func fetchAccountInfo(ns walletdb.ReadBucket, account uint32, dbVersion uint32) return nil, err } - switch row.acctType { - case actBIP0044: - return deserializeBIP0044AccountRow(accountID, row, dbVersion) - } - - return nil, errors.E(errors.IO, errors.Errorf("unknown account type %d", row.acctType)) + return row, nil } // deleteAccountNameIndex deletes the given key from the account name index of the database. diff --git a/wallet/udb/addressmanager.go b/wallet/udb/addressmanager.go index e43a18f7f..b3bcc8117 100644 --- a/wallet/udb/addressmanager.go +++ b/wallet/udb/addressmanager.go @@ -425,13 +425,26 @@ func (m *Manager) GetMasterPubkey(ns walletdb.ReadBucket, account uint32) (strin // The account is either invalid or just wasn't cached, so attempt to // load the information from the database. - row, err := fetchAccountInfo(ns, account, DBVersion) + row, err := fetchAccountInfo(ns, account) if err != nil { return "", err } + var acctRow *dbBIP0044AccountRow + switch row.acctType { + case actBIP0044: + accountID := uint32ToBytes(account) + acctRow, err = deserializeBIP0044AccountRow(accountID, row, DBVersion) + if err != nil { + return "", errors.E(errors.IO, err) + } + + default: + return "", errors.E(errors.IO, errors.Errorf("unknown account type %d", row.acctType)) + } + // Use the crypto public key to decrypt the account public extended key. - serializedKeyPub, err := m.cryptoKeyPub.Decrypt(row.pubKeyEncrypted) + serializedKeyPub, err := m.cryptoKeyPub.Decrypt(acctRow.pubKeyEncrypted) if err != nil { return "", errors.E(errors.IO, err) } @@ -452,7 +465,7 @@ func (m *Manager) loadAccountInfo(ns walletdb.ReadBucket, account uint32) (*acco // The account is either invalid or just wasn't cached, so attempt to // load the information from the database. - row, err := fetchAccountInfo(ns, account, DBVersion) + row, err := fetchAccountInfo(ns, account) if err != nil { if errors.Is(errors.NotExist, err) { return nil, err @@ -460,8 +473,21 @@ func (m *Manager) loadAccountInfo(ns walletdb.ReadBucket, account uint32) (*acco return nil, errors.E(errors.NotExist, errors.Errorf("no account %d", account)) } + var acctRow *dbBIP0044AccountRow + switch row.acctType { + case actBIP0044: + accountID := uint32ToBytes(account) + acctRow, err = deserializeBIP0044AccountRow(accountID, row, DBVersion) + if err != nil { + return nil, errors.E(errors.IO, err) + } + + default: + return nil, errors.E(errors.IO, errors.Errorf("unknown account type %d", row.acctType)) + } + // Use the crypto public key to decrypt the account public extended key. - serializedKeyPub, err := m.cryptoKeyPub.Decrypt(row.pubKeyEncrypted) + serializedKeyPub, err := m.cryptoKeyPub.Decrypt(acctRow.pubKeyEncrypted) if err != nil { return nil, errors.E(errors.Crypto, errors.Errorf("decrypt account %d pubkey: %v", account, err)) } @@ -473,8 +499,8 @@ func (m *Manager) loadAccountInfo(ns walletdb.ReadBucket, account uint32) (*acco // Create the new account info with the known information. The rest // of the fields are filled out below. acctInfo := &accountInfo{ - acctName: row.name, - acctKeyEncrypted: row.privKeyEncrypted, + acctName: acctRow.name, + acctKeyEncrypted: acctRow.privKeyEncrypted, acctKeyPub: acctKeyPub, } @@ -528,14 +554,28 @@ func (m *Manager) AccountProperties(ns walletdb.ReadBucket, account uint32) (*Ac return nil, err } props.AccountName = acctInfo.acctName - row, err := fetchAccountInfo(ns, account, DBVersion) + row, err := fetchAccountInfo(ns, account) if err != nil { return nil, errors.E(errors.IO, err) } - props.LastUsedExternalIndex = row.lastUsedExternalIndex - props.LastUsedInternalIndex = row.lastUsedInternalIndex - props.LastReturnedExternalIndex = row.lastReturnedExternalIndex - props.LastReturnedInternalIndex = row.lastReturnedInternalIndex + + var acctRow *dbBIP0044AccountRow + switch row.acctType { + case actBIP0044: + accountID := uint32ToBytes(account) + acctRow, err = deserializeBIP0044AccountRow(accountID, row, DBVersion) + if err != nil { + return nil, errors.E(errors.IO, err) + } + + default: + return nil, errors.E(errors.IO, errors.Errorf("unknown account type %d", row.acctType)) + } + + props.LastUsedExternalIndex = acctRow.lastUsedExternalIndex + props.LastUsedInternalIndex = acctRow.lastUsedInternalIndex + props.LastReturnedExternalIndex = acctRow.lastReturnedExternalIndex + props.LastReturnedInternalIndex = acctRow.lastReturnedInternalIndex } else { props.AccountName = ImportedAddrAccountName // reserved, nonchangable @@ -1444,12 +1484,27 @@ func (m *Manager) MarkUsed(ns walletdb.ReadWriteBucket, address dcrutil.Address) if !ok { return nil } - row, err := fetchAccountInfo(ns, bip0044Addr.account, DBVersion) + + row, err := fetchAccountInfo(ns, bip0044Addr.account) if err != nil { return errors.E(errors.IO, errors.Errorf("missing account %d", bip0044Addr.account)) } - lastUsedExtIndex := row.lastUsedExternalIndex - lastUsedIntIndex := row.lastUsedInternalIndex + + var acctRow *dbBIP0044AccountRow + switch row.acctType { + case actBIP0044: + accountID := uint32ToBytes(bip0044Addr.account) + acctRow, err = deserializeBIP0044AccountRow(accountID, row, DBVersion) + if err != nil { + return errors.E(errors.IO, err) + } + + default: + return errors.E(errors.IO, errors.Errorf("unknown account type %d", row.acctType)) + } + + lastUsedExtIndex := acctRow.lastUsedExternalIndex + lastUsedIntIndex := acctRow.lastUsedInternalIndex switch bip0044Addr.branch { case ExternalBranch: lastUsedExtIndex = bip0044Addr.index @@ -1459,8 +1514,8 @@ func (m *Manager) MarkUsed(ns walletdb.ReadWriteBucket, address dcrutil.Address) return errors.E(errors.IO, errors.Errorf("invalid account branch %d", bip0044Addr.branch)) } - if lastUsedExtIndex+1 < row.lastUsedExternalIndex+1 || - lastUsedIntIndex+1 < row.lastUsedInternalIndex+1 { + if lastUsedExtIndex+1 < acctRow.lastUsedExternalIndex+1 || + lastUsedIntIndex+1 < acctRow.lastUsedInternalIndex+1 { // More recent addresses have already been marked used, nothing to // update. return nil @@ -1469,25 +1524,39 @@ func (m *Manager) MarkUsed(ns walletdb.ReadWriteBucket, address dcrutil.Address) // The last returned indexes should never be less than the last used. The // weird addition and subtraction makes this calculation work correctly even // when any of of the indexes are ^uint32(0). - lastRetExtIndex := maxUint32(lastUsedExtIndex+1, row.lastReturnedExternalIndex+1) - 1 - lastRetIntIndex := maxUint32(lastUsedIntIndex+1, row.lastReturnedInternalIndex+1) - 1 + lastRetExtIndex := maxUint32(lastUsedExtIndex+1, acctRow.lastReturnedExternalIndex+1) - 1 + lastRetIntIndex := maxUint32(lastUsedIntIndex+1, acctRow.lastReturnedInternalIndex+1) - 1 - row = bip0044AccountInfo(row.pubKeyEncrypted, row.privKeyEncrypted, 0, 0, + acctRow = bip0044AccountInfo(acctRow.pubKeyEncrypted, acctRow.privKeyEncrypted, 0, 0, lastUsedExtIndex, lastUsedIntIndex, lastRetExtIndex, lastRetIntIndex, - row.name, DBVersion) - return putAccountRow(ns, bip0044Addr.account, &row.dbAccountRow) + acctRow.name, DBVersion) + return putAccountRow(ns, bip0044Addr.account, &acctRow.dbAccountRow) } // MarkUsedChildIndex marks a BIP0044 account branch child as used. func (m *Manager) MarkUsedChildIndex(tx walletdb.ReadWriteTx, account, branch, child uint32) error { ns := tx.ReadWriteBucket(waddrmgrBucketKey) - row, err := fetchAccountInfo(ns, account, DBVersion) + row, err := fetchAccountInfo(ns, account) if err != nil { return err } - lastUsedExtIndex := row.lastUsedExternalIndex - lastUsedIntIndex := row.lastUsedInternalIndex + + var acctRow *dbBIP0044AccountRow + switch row.acctType { + case actBIP0044: + accountID := uint32ToBytes(account) + acctRow, err = deserializeBIP0044AccountRow(accountID, row, DBVersion) + if err != nil { + return errors.E(errors.IO, err) + } + + default: + return errors.E(errors.IO, errors.Errorf("unknown account type %d", row.acctType)) + } + + lastUsedExtIndex := acctRow.lastUsedExternalIndex + lastUsedIntIndex := acctRow.lastUsedInternalIndex switch branch { case ExternalBranch: lastUsedExtIndex = child @@ -1497,8 +1566,8 @@ func (m *Manager) MarkUsedChildIndex(tx walletdb.ReadWriteTx, account, branch, c return errors.E(errors.Invalid, errors.Errorf("account branch %d", branch)) } - if lastUsedExtIndex+1 < row.lastUsedExternalIndex+1 || - lastUsedIntIndex+1 < row.lastUsedInternalIndex+1 { + if lastUsedExtIndex+1 < acctRow.lastUsedExternalIndex+1 || + lastUsedIntIndex+1 < acctRow.lastUsedInternalIndex+1 { // More recent addresses have already been marked used, nothing to // update. return nil @@ -1507,13 +1576,13 @@ func (m *Manager) MarkUsedChildIndex(tx walletdb.ReadWriteTx, account, branch, c // The last returned indexes should never be less than the last used. The // weird addition and subtraction makes this calculation work correctly even // when any of of the indexes are ^uint32(0). - lastRetExtIndex := maxUint32(lastUsedExtIndex+1, row.lastReturnedExternalIndex+1) - 1 - lastRetIntIndex := maxUint32(lastUsedIntIndex+1, row.lastReturnedInternalIndex+1) - 1 + lastRetExtIndex := maxUint32(lastUsedExtIndex+1, acctRow.lastReturnedExternalIndex+1) - 1 + lastRetIntIndex := maxUint32(lastUsedIntIndex+1, acctRow.lastReturnedInternalIndex+1) - 1 - row = bip0044AccountInfo(row.pubKeyEncrypted, row.privKeyEncrypted, 0, 0, + acctRow = bip0044AccountInfo(acctRow.pubKeyEncrypted, acctRow.privKeyEncrypted, 0, 0, lastUsedExtIndex, lastUsedIntIndex, lastRetExtIndex, lastRetIntIndex, - row.name, DBVersion) - return putAccountRow(ns, account, &row.dbAccountRow) + acctRow.name, DBVersion) + return putAccountRow(ns, account, &acctRow.dbAccountRow) } // MarkReturnedChildIndex marks a BIP0044 account branch child as returned to a @@ -1523,12 +1592,26 @@ func (m *Manager) MarkUsedChildIndex(tx walletdb.ReadWriteTx, account, branch, c func (m *Manager) MarkReturnedChildIndex(tx walletdb.ReadWriteTx, account, branch, child uint32) error { ns := tx.ReadWriteBucket(waddrmgrBucketKey) - row, err := fetchAccountInfo(ns, account, DBVersion) + row, err := fetchAccountInfo(ns, account) if err != nil { return err } - lastRetExtIndex := row.lastReturnedExternalIndex - lastRetIntIndex := row.lastReturnedInternalIndex + + var acctRow *dbBIP0044AccountRow + switch row.acctType { + case actBIP0044: + accountID := uint32ToBytes(account) + acctRow, err = deserializeBIP0044AccountRow(accountID, row, DBVersion) + if err != nil { + return errors.E(errors.IO, err) + } + + default: + return errors.E(errors.IO, errors.Errorf("unknown account type %d", row.acctType)) + } + + lastRetExtIndex := acctRow.lastReturnedExternalIndex + lastRetIntIndex := acctRow.lastReturnedInternalIndex switch branch { case ExternalBranch: lastRetExtIndex = child @@ -1541,13 +1624,13 @@ func (m *Manager) MarkReturnedChildIndex(tx walletdb.ReadWriteTx, account, branc // The last returned indexes should never be less than the last used. The // weird addition and subtraction makes this calculation work correctly even // when any of of the indexes are ^uint32(0). - lastRetExtIndex = maxUint32(row.lastUsedExternalIndex+1, lastRetExtIndex+1) - 1 - lastRetIntIndex = maxUint32(row.lastUsedInternalIndex+1, lastRetIntIndex+1) - 1 + lastRetExtIndex = maxUint32(acctRow.lastUsedExternalIndex+1, lastRetExtIndex+1) - 1 + lastRetIntIndex = maxUint32(acctRow.lastUsedInternalIndex+1, lastRetIntIndex+1) - 1 - row = bip0044AccountInfo(row.pubKeyEncrypted, row.privKeyEncrypted, 0, 0, - row.lastUsedExternalIndex, row.lastUsedInternalIndex, - lastRetExtIndex, lastRetIntIndex, row.name, DBVersion) - return putAccountRow(ns, account, &row.dbAccountRow) + acctRow = bip0044AccountInfo(acctRow.pubKeyEncrypted, acctRow.privKeyEncrypted, 0, 0, + acctRow.lastUsedExternalIndex, acctRow.lastUsedInternalIndex, + lastRetExtIndex, lastRetIntIndex, acctRow.name, DBVersion) + return putAccountRow(ns, account, &acctRow.dbAccountRow) } // ChainParams returns the chain parameters for this address manager. @@ -1784,24 +1867,37 @@ func (m *Manager) RenameAccount(ns walletdb.ReadWriteBucket, account uint32, nam return err } - row, err := fetchAccountInfo(ns, account, DBVersion) + row, err := fetchAccountInfo(ns, account) if err != nil { return err } + var acctRow *dbBIP0044AccountRow + switch row.acctType { + case actBIP0044: + accountID := uint32ToBytes(account) + acctRow, err = deserializeBIP0044AccountRow(accountID, row, DBVersion) + if err != nil { + return errors.E(errors.IO, err) + } + + default: + return errors.E(errors.IO, errors.Errorf("unknown account type %d", row.acctType)) + } + // Remove the old name key from the accout id index if err = deleteAccountIDIndex(ns, account); err != nil { return err } // Remove the old name key from the account name index - if err = deleteAccountNameIndex(ns, row.name); err != nil { + if err = deleteAccountNameIndex(ns, acctRow.name); err != nil { return err } - row = bip0044AccountInfo(row.pubKeyEncrypted, row.privKeyEncrypted, - 0, 0, row.lastUsedExternalIndex, row.lastUsedInternalIndex, - row.lastReturnedExternalIndex, row.lastReturnedInternalIndex, + acctRow = bip0044AccountInfo(acctRow.pubKeyEncrypted, acctRow.privKeyEncrypted, + 0, 0, acctRow.lastUsedExternalIndex, acctRow.lastUsedInternalIndex, + acctRow.lastReturnedExternalIndex, acctRow.lastReturnedInternalIndex, name, DBVersion) - err = putAccountInfo(ns, account, row) + err = putAccountInfo(ns, account, acctRow) if err != nil { return err } diff --git a/wallet/udb/upgrades.go b/wallet/udb/upgrades.go index 986aa8bda..d6d442e9c 100644 --- a/wallet/udb/upgrades.go +++ b/wallet/udb/upgrades.go @@ -191,14 +191,27 @@ func lastUsedAddressIndexUpgrade(tx walletdb.ReadWriteTx, publicPassphrase []byt // Perform account updates on all BIP0044 accounts created thus far. for account := uint32(0); account <= lastAccount; account++ { // Load the old account info. - row, err := fetchAccountInfo(addrmgrBucket, account, oldVersion) + row, err := fetchAccountInfo(addrmgrBucket, account) if err != nil { return err } + var actRow *dbBIP0044AccountRow + switch row.acctType { + case actBIP0044: + accountID := uint32ToBytes(account) + actRow, err = deserializeBIP0044AccountRow(accountID, row, oldVersion) + if err != nil { + return err + } + + default: + return errors.Errorf("unknown account type %d", row.acctType) + } + // Use the crypto public key to decrypt the account public extended key // and each branch key. - serializedKeyPub, err := cryptoPubKey.Decrypt(row.pubKeyEncrypted) + serializedKeyPub, err := cryptoPubKey.Decrypt(actRow.pubKeyEncrypted) if err != nil { return errors.E(errors.Crypto, errors.Errorf("decrypt extended pubkey: %v", err)) } @@ -261,9 +274,9 @@ func lastUsedAddressIndexUpgrade(tx walletdb.ReadWriteTx, publicPassphrase []byt // Convert account row values to the new serialization format that // replaces the next to use indexes with the last used indexes. - row = bip0044AccountInfo(row.pubKeyEncrypted, row.privKeyEncrypted, - 0, 0, lastUsedExtIndex, lastUsedIntIndex, 0, 0, row.name, newVersion) - err = putAccountInfo(addrmgrBucket, account, row) + actRow = bip0044AccountInfo(actRow.pubKeyEncrypted, actRow.privKeyEncrypted, + 0, 0, lastUsedExtIndex, lastUsedIntIndex, 0, 0, actRow.name, newVersion) + err = putAccountInfo(addrmgrBucket, account, actRow) if err != nil { return err } @@ -383,19 +396,32 @@ func lastReturnedAddressUpgrade(tx walletdb.ReadWriteTx, publicPassphrase []byte upgradeAcct := func(account uint32) error { // Load the old account info. - row, err := fetchAccountInfo(addrmgrBucket, account, oldVersion) + row, err := fetchAccountInfo(addrmgrBucket, account) if err != nil { return err } + var actRow *dbBIP0044AccountRow + switch row.acctType { + case actBIP0044: + accountID := uint32ToBytes(account) + actRow, err = deserializeBIP0044AccountRow(accountID, row, oldVersion) + if err != nil { + return err + } + + default: + return errors.Errorf("unknown account type %d", row.acctType) + } + // Convert account row values to the new serialization format that adds // the last returned indexes. Assume that the last used address is also // the last returned address. - row = bip0044AccountInfo(row.pubKeyEncrypted, row.privKeyEncrypted, - 0, 0, row.lastUsedExternalIndex, row.lastUsedInternalIndex, - row.lastUsedExternalIndex, row.lastUsedInternalIndex, - row.name, newVersion) - return putAccountInfo(addrmgrBucket, account, row) + actRow = bip0044AccountInfo(actRow.pubKeyEncrypted, actRow.privKeyEncrypted, + 0, 0, actRow.lastUsedExternalIndex, actRow.lastUsedInternalIndex, + actRow.lastUsedExternalIndex, actRow.lastUsedInternalIndex, + actRow.name, newVersion) + return putAccountInfo(addrmgrBucket, account, actRow) } // Determine how many BIP0044 accounts have been created. Each of these diff --git a/wallet/udb/upgrades_test.go b/wallet/udb/upgrades_test.go index af3eee6da..b0d94be28 100644 --- a/wallet/udb/upgrades_test.go +++ b/wallet/udb/upgrades_test.go @@ -19,6 +19,7 @@ import ( "github.com/decred/dcrd/chaincfg/v2" "github.com/decred/dcrd/dcrutil/v2" "github.com/decred/dcrd/wire" + "github.com/decred/dcrwallet/errors" _ "github.com/decred/dcrwallet/wallet/v3/drivers/bdb" "github.com/decred/dcrwallet/wallet/v3/walletdb" ) @@ -267,25 +268,39 @@ func verifyV5Upgrade(t *testing.T, db walletdb.DB) { const dbVersion = 5 for _, d := range data { - row, err := fetchAccountInfo(ns, d.acct, dbVersion) + row, err := fetchAccountInfo(ns, d.acct) if err != nil { return err } - if row.lastUsedExternalIndex != d.lastUsedExtChild { + + var actRow *dbBIP0044AccountRow + switch row.acctType { + case actBIP0044: + accountID := uint32ToBytes(d.acct) + actRow, err = deserializeBIP0044AccountRow(accountID, row, DBVersion) + if err != nil { + return err + } + + default: + return errors.Errorf("unknown account type %d", row.acctType) + } + + if actRow.lastUsedExternalIndex != d.lastUsedExtChild { t.Errorf("Account %d last used ext child mismatch %d != %d", - d.acct, row.lastUsedExternalIndex, d.lastUsedExtChild) + d.acct, actRow.lastUsedExternalIndex, d.lastUsedExtChild) } - if row.lastReturnedExternalIndex != d.lastUsedExtChild { + if actRow.lastReturnedExternalIndex != d.lastUsedExtChild { t.Errorf("Account %d last returned ext child mismatch %d != %d", - d.acct, row.lastReturnedExternalIndex, d.lastUsedExtChild) + d.acct, actRow.lastReturnedExternalIndex, d.lastUsedExtChild) } - if row.lastUsedInternalIndex != d.lastUsedIntChild { + if actRow.lastUsedInternalIndex != d.lastUsedIntChild { t.Errorf("Account %d last used int child mismatch %d != %d", - d.acct, row.lastUsedInternalIndex, d.lastUsedIntChild) + d.acct, actRow.lastUsedInternalIndex, d.lastUsedIntChild) } - if row.lastReturnedInternalIndex != d.lastUsedIntChild { + if actRow.lastReturnedInternalIndex != d.lastUsedIntChild { t.Errorf("Account %d last returned int child mismatch %d != %d", - d.acct, row.lastReturnedInternalIndex, d.lastUsedIntChild) + d.acct, actRow.lastReturnedInternalIndex, d.lastUsedIntChild) } }