diff --git a/wallet/udb/addressdb.go b/wallet/udb/addressdb.go index da323b4b1..c2f1523e6 100644 --- a/wallet/udb/addressdb.go +++ b/wallet/udb/addressdb.go @@ -57,16 +57,29 @@ const ( actBIP0044 accountType = 0 // not iota as they need to be stable for db ) +type accountRow interface { + actType() accountType + actData() []byte +} + // dbAccountRow houses information stored about an account in the database. type dbAccountRow struct { acctType accountType rawData []byte // Varies based on account type field. } +func (row *dbAccountRow) actType() accountType { + return row.acctType +} + +func (row *dbAccountRow) actData() []byte { + return row.rawData +} + // dbBIP0044AccountRow houses additional information stored about a BIP0044 // account in the database. type dbBIP0044AccountRow struct { - dbAccountRow + accountRow pubKeyEncrypted []byte privKeyEncrypted []byte nextExternalIndex uint32 // Removed by version 2 @@ -78,6 +91,21 @@ type dbBIP0044AccountRow struct { name string } +// parseAccountRow returns an account row from the provided serialized bytes. +func parseAccountRow(accountID []byte, serializedRow []byte, dbVersion uint32) (accountRow, error) { + row := dbAccountRow{ + acctType: accountType(serializedRow[0]), + rawData: serializedRow, + } + + switch row.actType() { + case actBIP0044: + return deserializeBIP0044AccountRow(accountID, serializedRow, dbVersion) + default: + return nil, errors.E(errors.IO, errors.Errorf("invalid account type %d", row.actType())) + } +} + // dbAddressRow houses common information stored about an address in the // database. type dbAddressRow struct { @@ -465,66 +493,37 @@ func putWatchingOnly(ns walletdb.ReadWriteBucket, watchingOnly bool) error { return nil } -// deserializeAccountRow deserializes the passed serialized account information. -// This is used as a common base for the various account types to deserialize -// the common parts. -func deserializeAccountRow(accountID []byte, serializedAccount []byte) (*dbAccountRow, error) { - // The serialized account format is: - // - // - // 1 byte acctType + 4 bytes raw data length + raw data - - // Given the above, the length of the entry must be at a minimum - // the constant value sizes. - if len(serializedAccount) < 5 { - return nil, errors.E(errors.IO, errors.Errorf("bad account len %d", len(serializedAccount))) - } - - row := dbAccountRow{} - row.acctType = accountType(serializedAccount[0]) - rdlen := binary.LittleEndian.Uint32(serializedAccount[1:5]) - row.rawData = make([]byte, rdlen) - copy(row.rawData, serializedAccount[5:5+rdlen]) - - return &row, nil -} - -// serializeAccountRow returns the serialization of the passed account row. -func serializeAccountRow(row *dbAccountRow) []byte { - // The serialized account format is: - // - // - // 1 byte acctType + 4 bytes raw data length + raw data - rdlen := len(row.rawData) - buf := make([]byte, 5+rdlen) - buf[0] = byte(row.acctType) - binary.LittleEndian.PutUint32(buf[1:5], uint32(rdlen)) - copy(buf[5:5+rdlen], row.rawData) - return buf -} - // deserializeBIP0044AccountRow deserializes the raw data from the passed // account row as a BIP0044 account. -func deserializeBIP0044AccountRow(accountID []byte, row *dbAccountRow, dbVersion uint32) (*dbBIP0044AccountRow, error) { +func deserializeBIP0044AccountRow(accountID []byte, account []byte, dbVersion uint32) (*dbBIP0044AccountRow, error) { // The serialized BIP0044 account raw data format is: - // - // + // + // + // // - // 4 bytes encrypted pubkey len + encrypted pubkey + 4 bytes encrypted - // privkey len + encrypted privkey + 4 bytes last used external index + - // 4 bytes last used internal index + 4 bytes last returned external + + // 1 byte acctType + 4 bytes raw data length + 4 bytes encrypted + // pubkey len + encrypted pubkey + 4 bytes encrypted privkey len + + // encrypted privkey + 4 bytes last used external index + 4 bytes + // last used internal index + 4 bytes last returned external + // 4 bytes last returned internal + 4 bytes name len + name // Given the above, the length of the entry must be at a minimum // the constant value sizes. + dataLen := len(account) switch { - case dbVersion < 5 && len(row.rawData) < 20, - dbVersion >= 5 && len(row.rawData) < 28: - return nil, errors.E(errors.IO, errors.Errorf("bip0044 account %x bad len %d", accountID, len(row.rawData))) + case dbVersion < 5 && dataLen < 25, + dbVersion >= 5 && dataLen < 33: + return nil, errors.E(errors.IO, errors.Errorf("bip0044 account %x bad len %d", accountID, dataLen)) } + row := &dbAccountRow{} + row.acctType = accountType(account[0]) + rdlen := binary.LittleEndian.Uint32(account[1:5]) + row.rawData = make([]byte, rdlen) + copy(row.rawData, account[5:5+rdlen]) + retRow := dbBIP0044AccountRow{ - dbAccountRow: *row, + accountRow: row, } pubLen := binary.LittleEndian.Uint32(row.rawData[0:4]) @@ -564,25 +563,33 @@ func deserializeBIP0044AccountRow(accountID []byte, row *dbAccountRow, dbVersion // for a BIP0044 account. func serializeBIP0044AccountRow(row *dbBIP0044AccountRow, dbVersion uint32) []byte { // The serialized BIP0044 account raw data format is: - // - // + // + // + // // - // 4 bytes encrypted pubkey len + encrypted pubkey + 4 bytes encrypted - // privkey len + encrypted privkey + 4 bytes last used external index + - // 4 bytes last used internal index + 4 bytes last returned external + + // 1 byte acctType + 4 bytes raw data length + 4 bytes encrypted + // pubkey len + encrypted pubkey + 4 bytes encrypted privkey len + + // encrypted privkey + 4 bytes last used external index + 4 bytes + // last used internal index + 4 bytes last returned external + // 4 bytes last returned internal + 4 bytes name len + name pubLen := uint32(len(row.pubKeyEncrypted)) privLen := uint32(len(row.privKeyEncrypted)) nameLen := uint32(len(row.name)) - rowSize := 28 + pubLen + privLen + nameLen + rdLen := 28 + pubLen + privLen + nameLen switch { case dbVersion < 5: - rowSize -= 8 + rdLen -= 8 } + rowSize := 1 + 4 + rdLen rawData := make([]byte, rowSize) - binary.LittleEndian.PutUint32(rawData[0:4], pubLen) - copy(rawData[4:4+pubLen], row.pubKeyEncrypted) - offset := 4 + pubLen + rawData[0] = byte(actBIP0044) + offset := uint32(1) + binary.LittleEndian.PutUint32(rawData[offset:offset+4], rdLen) + offset += 4 + binary.LittleEndian.PutUint32(rawData[offset:offset+4], pubLen) + offset += 4 + copy(rawData[offset:offset+pubLen], row.pubKeyEncrypted) + offset += pubLen binary.LittleEndian.PutUint32(rawData[offset:offset+4], privLen) offset += 4 copy(rawData[offset:offset+privLen], row.privKeyEncrypted) @@ -615,10 +622,6 @@ func bip0044AccountInfo(pubKeyEnc, privKeyEnc []byte, nextExtIndex, nextIntIndex name string, dbVersion uint32) *dbBIP0044AccountRow { row := &dbBIP0044AccountRow{ - dbAccountRow: dbAccountRow{ - acctType: actBIP0044, - rawData: nil, - }, pubKeyEncrypted: pubKeyEnc, privKeyEncrypted: privKeyEnc, nextExternalIndex: 0, @@ -642,7 +645,10 @@ func bip0044AccountInfo(pubKeyEnc, privKeyEnc []byte, nextExtIndex, nextIntIndex row.lastReturnedExternalIndex = lastRetExtIndex row.lastReturnedInternalIndex = lastRetIntIndex } - row.rawData = serializeBIP0044AccountRow(row, dbVersion) + row.accountRow = &dbAccountRow{ + acctType: actBIP0044, + rawData: serializeBIP0044AccountRow(row, dbVersion), + } return row } @@ -703,7 +709,7 @@ func fetchAccountByName(ns walletdb.ReadBucket, name string) (uint32, error) { // fetchAccountInfo loads information about the passed account from the // database. -func fetchAccountInfo(ns walletdb.ReadBucket, account uint32, dbVersion uint32) (*dbBIP0044AccountRow, error) { +func fetchAccountInfo(ns walletdb.ReadBucket, account uint32, dbVersion uint32) (accountRow, error) { bucket := ns.NestedReadBucket(acctBucketName) accountID := uint32ToBytes(account) @@ -712,17 +718,7 @@ func fetchAccountInfo(ns walletdb.ReadBucket, account uint32, dbVersion uint32) return nil, errors.E(errors.NotExist, errors.Errorf("no account %d", account)) } - row, err := deserializeAccountRow(accountID, serializedRow) - if err != nil { - return nil, err - } - - switch row.acctType { - case actBIP0044: - return deserializeBIP0044AccountRow(accountID, row, dbVersion) - } - - return nil, errors.E(errors.IO, errors.Errorf("unknown account type %d", row.acctType)) + return parseAccountRow(accountID, serializedRow, dbVersion) } // deleteAccountNameIndex deletes the given key from the account name index of the database. @@ -797,28 +793,39 @@ func putAddrAccountIndex(ns walletdb.ReadWriteBucket, account uint32, addrHash [ // putAccountRow stores the provided account information to the database. This // is used a common base for storing the various account types. -func putAccountRow(ns walletdb.ReadWriteBucket, account uint32, row *dbAccountRow) error { +func putAccountRow(ns walletdb.ReadWriteBucket, account uint32, row accountRow) error { bucket := ns.NestedReadWriteBucket(acctBucketName) - // Write the serialized value keyed by the account number. - err := bucket.Put(uint32ToBytes(account), serializeAccountRow(row)) - if err != nil { - return errors.E(errors.IO, err) + switch acctRow := row.(type) { + case *dbBIP0044AccountRow: + // Write the serialized value keyed by the account number. + err := bucket.Put(uint32ToBytes(account), serializeBIP0044AccountRow(acctRow, DBVersion)) + if err != nil { + return errors.E(errors.IO, err) + } + default: + return errors.E(errors.Invalid, errors.Errorf("invalid account type %d", row.actType())) } return nil } // putAccountInfo stores the provided account information to the database. -func putAccountInfo(ns walletdb.ReadWriteBucket, account uint32, row *dbBIP0044AccountRow) error { - if err := putAccountRow(ns, account, &row.dbAccountRow); err != nil { +func putAccountInfo(ns walletdb.ReadWriteBucket, account uint32, row accountRow) error { + if err := putAccountRow(ns, account, row); err != nil { return err } - // Update account id index - if err := putAccountIDIndex(ns, account, row.name); err != nil { - return err + + switch acctRow := row.(type) { + case *dbBIP0044AccountRow: + // Update account id index + if err := putAccountIDIndex(ns, account, acctRow.name); err != nil { + return err + } + // Update account name index + return putAccountNameIndex(ns, account, acctRow.name) + default: + return errors.E(errors.Invalid, errors.Errorf("invalid account type %d", acctRow.actType())) } - // Update account name index - return putAccountNameIndex(ns, account, row.name) } // putLastAccount stores the provided metadata - last account - to the database. @@ -1232,7 +1239,7 @@ func deletePrivateKeys(ns walletdb.ReadWriteBucket, dbVersion uint32) error { return errors.E(errors.IO, err) } - BIP0044Set := map[string]*dbAccountRow{} + BIP0044Set := map[string]*dbBIP0044AccountRow{} // Fetch all BIP0044 accounts. bucket = ns.NestedReadWriteBucket(acctBucketName) @@ -1243,35 +1250,29 @@ func deletePrivateKeys(ns walletdb.ReadWriteBucket, dbVersion uint32) error { continue } - // Deserialize the account row first to determine the type. - row, err := deserializeAccountRow(k, v) + row, err := parseAccountRow(k, v, dbVersion) if err != nil { c.Close() return err } - switch row.acctType { - case actBIP0044: - BIP0044Set[string(k)] = row + switch acctRow := row.(type) { + case *dbBIP0044AccountRow: + BIP0044Set[string(k)] = acctRow } } c.Close() // Delete the account extended private key for all BIP0044 accounts. for k, row := range BIP0044Set { - arow, err := deserializeBIP0044AccountRow([]byte(k), row, dbVersion) - if err != nil { - return err - } - // Reserialize the account without the private key and // store it. - row := bip0044AccountInfo(arow.pubKeyEncrypted, nil, - arow.nextExternalIndex, arow.nextInternalIndex, - arow.lastUsedExternalIndex, arow.lastUsedInternalIndex, - arow.lastReturnedExternalIndex, arow.lastReturnedInternalIndex, - arow.name, dbVersion) - err = bucket.Put([]byte(k), serializeAccountRow(&row.dbAccountRow)) + row := bip0044AccountInfo(row.pubKeyEncrypted, nil, + row.nextExternalIndex, row.nextInternalIndex, + row.lastUsedExternalIndex, row.lastUsedInternalIndex, + row.lastReturnedExternalIndex, row.lastReturnedInternalIndex, + row.name, dbVersion) + err := bucket.Put([]byte(k), serializeBIP0044AccountRow(row, dbVersion)) if err != nil { return errors.E(errors.IO, err) } diff --git a/wallet/udb/addressmanager.go b/wallet/udb/addressmanager.go index e43a18f7f..5c04cf0ab 100644 --- a/wallet/udb/addressmanager.go +++ b/wallet/udb/addressmanager.go @@ -425,18 +425,23 @@ func (m *Manager) GetMasterPubkey(ns walletdb.ReadBucket, account uint32) (strin // The account is either invalid or just wasn't cached, so attempt to // load the information from the database. - row, err := fetchAccountInfo(ns, account, DBVersion) + acctRow, err := fetchAccountInfo(ns, account, DBVersion) if err != nil { return "", err } - // Use the crypto public key to decrypt the account public extended key. - serializedKeyPub, err := m.cryptoKeyPub.Decrypt(row.pubKeyEncrypted) - if err != nil { - return "", errors.E(errors.IO, err) - } + switch row := acctRow.(type) { + case *dbBIP0044AccountRow: + // Use the crypto public key to decrypt the account public extended key. + serializedKeyPub, err := m.cryptoKeyPub.Decrypt(row.pubKeyEncrypted) + if err != nil { + return "", errors.E(errors.IO, err) + } - return string(serializedKeyPub), nil + return string(serializedKeyPub), nil + default: + return "", errors.E(errors.IO, errors.Errorf("invalid account type %d", row.actType())) + } } // loadAccountInfo attempts to load and cache information about the given @@ -452,7 +457,7 @@ func (m *Manager) loadAccountInfo(ns walletdb.ReadBucket, account uint32) (*acco // The account is either invalid or just wasn't cached, so attempt to // load the information from the database. - row, err := fetchAccountInfo(ns, account, DBVersion) + acctRow, err := fetchAccountInfo(ns, account, DBVersion) if err != nil { if errors.Is(errors.NotExist, err) { return nil, err @@ -460,42 +465,48 @@ func (m *Manager) loadAccountInfo(ns walletdb.ReadBucket, account uint32) (*acco return nil, errors.E(errors.NotExist, errors.Errorf("no account %d", account)) } - // Use the crypto public key to decrypt the account public extended key. - serializedKeyPub, err := m.cryptoKeyPub.Decrypt(row.pubKeyEncrypted) - if err != nil { - return nil, errors.E(errors.Crypto, errors.Errorf("decrypt account %d pubkey: %v", account, err)) - } - acctKeyPub, err := hdkeychain.NewKeyFromString(string(serializedKeyPub), m.chainParams) - if err != nil { - return nil, errors.E(errors.IO, err) - } - - // Create the new account info with the known information. The rest - // of the fields are filled out below. - acctInfo := &accountInfo{ - acctName: row.name, - acctKeyEncrypted: row.privKeyEncrypted, - acctKeyPub: acctKeyPub, - } - - if !m.locked { - // Use the crypto private key to decrypt the account private - // extended keys. - decrypted, err := m.cryptoKeyPriv.Decrypt(acctInfo.acctKeyEncrypted) + switch row := acctRow.(type) { + case *dbBIP0044AccountRow: + // Use the crypto public key to decrypt the account public extended key. + serializedKeyPub, err := m.cryptoKeyPub.Decrypt(row.pubKeyEncrypted) if err != nil { - return nil, errors.E(errors.Crypto, errors.Errorf("decrypt account %d privkey: %v", account, err)) + return nil, errors.E(errors.Crypto, errors.Errorf("decrypt account %d pubkey: %v", account, err)) } - - acctKeyPriv, err := hdkeychain.NewKeyFromString(string(decrypted), m.chainParams) + acctKeyPub, err := hdkeychain.NewKeyFromString(string(serializedKeyPub), m.chainParams) if err != nil { return nil, errors.E(errors.IO, err) } - acctInfo.acctKeyPriv = acctKeyPriv - } - // Add it to the cache and return it when everything is successful. - m.acctInfo[account] = acctInfo - return acctInfo, nil + // Create the new account info with the known information. The rest + // of the fields are filled out below. + acctInfo := &accountInfo{ + acctName: row.name, + acctKeyEncrypted: row.privKeyEncrypted, + acctKeyPub: acctKeyPub, + } + + if !m.locked { + // Use the crypto private key to decrypt the account private + // extended keys. + decrypted, err := m.cryptoKeyPriv.Decrypt(acctInfo.acctKeyEncrypted) + if err != nil { + return nil, errors.E(errors.Crypto, errors.Errorf("decrypt account %d privkey: %v", account, err)) + } + + acctKeyPriv, err := hdkeychain.NewKeyFromString(string(decrypted), m.chainParams) + if err != nil { + return nil, errors.E(errors.IO, err) + } + acctInfo.acctKeyPriv = acctKeyPriv + } + + // Add it to the cache and return it when everything is successful. + m.acctInfo[account] = acctInfo + return acctInfo, nil + + default: + return nil, errors.E(errors.IO, errors.Errorf("invalid account type %d", row.actType())) + } } // AccountProperties returns properties associated with the account, such as the @@ -528,14 +539,21 @@ func (m *Manager) AccountProperties(ns walletdb.ReadBucket, account uint32) (*Ac return nil, err } props.AccountName = acctInfo.acctName - row, err := fetchAccountInfo(ns, account, DBVersion) + acctRow, err := fetchAccountInfo(ns, account, DBVersion) if err != nil { return nil, errors.E(errors.IO, err) } - props.LastUsedExternalIndex = row.lastUsedExternalIndex - props.LastUsedInternalIndex = row.lastUsedInternalIndex - props.LastReturnedExternalIndex = row.lastReturnedExternalIndex - props.LastReturnedInternalIndex = row.lastReturnedInternalIndex + + switch row := acctRow.(type) { + case *dbBIP0044AccountRow: + props.LastUsedExternalIndex = row.lastUsedExternalIndex + props.LastUsedInternalIndex = row.lastUsedInternalIndex + props.LastReturnedExternalIndex = row.lastReturnedExternalIndex + props.LastReturnedInternalIndex = row.lastReturnedInternalIndex + + default: + return nil, errors.E(errors.IO, errors.Errorf("invalid account type %d", row.actType())) + } } else { props.AccountName = ImportedAddrAccountName // reserved, nonchangable @@ -733,21 +751,23 @@ func (m *Manager) UpgradeToSLIP0044CoinType(dbtx walletdb.ReadWriteTx) error { return errors.E(errors.IO, "missing SLIP0044 coin type account row") } accountID := uint32ToBytes(0) - row, err := deserializeAccountRow(accountID, serializedRow) + acctRow, err := parseAccountRow(accountID, serializedRow, initialVersion) if err != nil { return err } - if row.acctType != actBIP0044 { - return errors.E(errors.IO, errors.Errorf("invalid SLIP0044 account 0 row type %d", row.acctType)) - } - bip0044Row, err := deserializeBIP0044AccountRow(accountID, row, initialVersion) - if err != nil { - return err + + row, ok := acctRow.(*dbBIP0044AccountRow) + if !ok { + return errors.E(errors.IO, errors.Errorf("invalid SLIP0044 account 0 row type %d", row.actType())) } + // Keep previous name of account 0 - bip0044Row.name = acctProps.AccountName - bip0044Row.rawData = serializeBIP0044AccountRow(bip0044Row, DBVersion) - err = putAccountRow(ns, 0, &bip0044Row.dbAccountRow) + row.name = acctProps.AccountName + row.accountRow = &dbAccountRow{ + acctType: actBIP0044, + rawData: serializeBIP0044AccountRow(row, DBVersion), + } + err = putAccountRow(ns, 0, row) if err != nil { return err } @@ -770,7 +790,7 @@ func (m *Manager) UpgradeToSLIP0044CoinType(dbtx walletdb.ReadWriteTx) error { // Decrypt the SLIP0044 coin type account xpub so the in memory account // information can be updated. - acctExtPubKeyStr, err := m.cryptoKeyPub.Decrypt(bip0044Row.pubKeyEncrypted) + acctExtPubKeyStr, err := m.cryptoKeyPub.Decrypt(row.pubKeyEncrypted) if err != nil { return errors.E(errors.Crypto, errors.Errorf("decrypt SLIP0044 account 0 xpub: %v", err)) } @@ -782,7 +802,7 @@ func (m *Manager) UpgradeToSLIP0044CoinType(dbtx walletdb.ReadWriteTx) error { // When unlocked, decrypt the SLIP0044 coin type account xpriv as well. var acctExtPrivKey *hdkeychain.ExtendedKey if !m.locked { - acctExtPrivKeyStr, err := m.cryptoKeyPriv.Decrypt(bip0044Row.privKeyEncrypted) + acctExtPrivKeyStr, err := m.cryptoKeyPriv.Decrypt(row.privKeyEncrypted) if err != nil { return errors.E(errors.Crypto, errors.Errorf("decrypt SLIP0044 account 0 xpriv: %v", err)) } @@ -792,7 +812,7 @@ func (m *Manager) UpgradeToSLIP0044CoinType(dbtx walletdb.ReadWriteTx) error { } } - acctInfo.acctKeyEncrypted = bip0044Row.privKeyEncrypted + acctInfo.acctKeyEncrypted = row.privKeyEncrypted acctInfo.acctKeyPriv = acctExtPrivKey acctInfo.acctKeyPub = acctExtPubKey @@ -1444,76 +1464,91 @@ func (m *Manager) MarkUsed(ns walletdb.ReadWriteBucket, address dcrutil.Address) if !ok { return nil } - row, err := fetchAccountInfo(ns, bip0044Addr.account, DBVersion) + acctRow, err := fetchAccountInfo(ns, bip0044Addr.account, DBVersion) if err != nil { return errors.E(errors.IO, errors.Errorf("missing account %d", bip0044Addr.account)) } - lastUsedExtIndex := row.lastUsedExternalIndex - lastUsedIntIndex := row.lastUsedInternalIndex - switch bip0044Addr.branch { - case ExternalBranch: - lastUsedExtIndex = bip0044Addr.index - case InternalBranch: - lastUsedIntIndex = bip0044Addr.index - default: - return errors.E(errors.IO, errors.Errorf("invalid account branch %d", bip0044Addr.branch)) - } - if lastUsedExtIndex+1 < row.lastUsedExternalIndex+1 || - lastUsedIntIndex+1 < row.lastUsedInternalIndex+1 { - // More recent addresses have already been marked used, nothing to - // update. - return nil - } + switch row := acctRow.(type) { + case *dbBIP0044AccountRow: + lastUsedExtIndex := row.lastUsedExternalIndex + lastUsedIntIndex := row.lastUsedInternalIndex + switch bip0044Addr.branch { + case ExternalBranch: + lastUsedExtIndex = bip0044Addr.index + case InternalBranch: + lastUsedIntIndex = bip0044Addr.index + default: + return errors.E(errors.IO, errors.Errorf("invalid account branch %d", bip0044Addr.branch)) + } + + if lastUsedExtIndex+1 < row.lastUsedExternalIndex+1 || + lastUsedIntIndex+1 < row.lastUsedInternalIndex+1 { + // More recent addresses have already been marked used, nothing to + // update. + return nil + } + + // The last returned indexes should never be less than the last used. The + // weird addition and subtraction makes this calculation work correctly even + // when any of of the indexes are ^uint32(0). + lastRetExtIndex := maxUint32(lastUsedExtIndex+1, row.lastReturnedExternalIndex+1) - 1 + lastRetIntIndex := maxUint32(lastUsedIntIndex+1, row.lastReturnedInternalIndex+1) - 1 - // The last returned indexes should never be less than the last used. The - // weird addition and subtraction makes this calculation work correctly even - // when any of of the indexes are ^uint32(0). - lastRetExtIndex := maxUint32(lastUsedExtIndex+1, row.lastReturnedExternalIndex+1) - 1 - lastRetIntIndex := maxUint32(lastUsedIntIndex+1, row.lastReturnedInternalIndex+1) - 1 + row = bip0044AccountInfo(row.pubKeyEncrypted, row.privKeyEncrypted, 0, 0, + lastUsedExtIndex, lastUsedIntIndex, lastRetExtIndex, lastRetIntIndex, + row.name, DBVersion) + return putAccountRow(ns, bip0044Addr.account, row) + + default: + return errors.E(errors.IO, errors.Errorf("invalid account type %d", row.actType())) + } - row = bip0044AccountInfo(row.pubKeyEncrypted, row.privKeyEncrypted, 0, 0, - lastUsedExtIndex, lastUsedIntIndex, lastRetExtIndex, lastRetIntIndex, - row.name, DBVersion) - return putAccountRow(ns, bip0044Addr.account, &row.dbAccountRow) } // MarkUsedChildIndex marks a BIP0044 account branch child as used. func (m *Manager) MarkUsedChildIndex(tx walletdb.ReadWriteTx, account, branch, child uint32) error { ns := tx.ReadWriteBucket(waddrmgrBucketKey) - row, err := fetchAccountInfo(ns, account, DBVersion) + acctRow, err := fetchAccountInfo(ns, account, DBVersion) if err != nil { return err } - lastUsedExtIndex := row.lastUsedExternalIndex - lastUsedIntIndex := row.lastUsedInternalIndex - switch branch { - case ExternalBranch: - lastUsedExtIndex = child - case InternalBranch: - lastUsedIntIndex = child - default: - return errors.E(errors.Invalid, errors.Errorf("account branch %d", branch)) - } - if lastUsedExtIndex+1 < row.lastUsedExternalIndex+1 || - lastUsedIntIndex+1 < row.lastUsedInternalIndex+1 { - // More recent addresses have already been marked used, nothing to - // update. - return nil - } + switch row := acctRow.(type) { + case *dbBIP0044AccountRow: + lastUsedExtIndex := row.lastUsedExternalIndex + lastUsedIntIndex := row.lastUsedInternalIndex + switch branch { + case ExternalBranch: + lastUsedExtIndex = child + case InternalBranch: + lastUsedIntIndex = child + default: + return errors.E(errors.Invalid, errors.Errorf("account branch %d", branch)) + } - // The last returned indexes should never be less than the last used. The - // weird addition and subtraction makes this calculation work correctly even - // when any of of the indexes are ^uint32(0). - lastRetExtIndex := maxUint32(lastUsedExtIndex+1, row.lastReturnedExternalIndex+1) - 1 - lastRetIntIndex := maxUint32(lastUsedIntIndex+1, row.lastReturnedInternalIndex+1) - 1 + if lastUsedExtIndex+1 < row.lastUsedExternalIndex+1 || + lastUsedIntIndex+1 < row.lastUsedInternalIndex+1 { + // More recent addresses have already been marked used, nothing to + // update. + return nil + } - row = bip0044AccountInfo(row.pubKeyEncrypted, row.privKeyEncrypted, 0, 0, - lastUsedExtIndex, lastUsedIntIndex, lastRetExtIndex, lastRetIntIndex, - row.name, DBVersion) - return putAccountRow(ns, account, &row.dbAccountRow) + // The last returned indexes should never be less than the last used. The + // weird addition and subtraction makes this calculation work correctly even + // when any of of the indexes are ^uint32(0). + lastRetExtIndex := maxUint32(lastUsedExtIndex+1, row.lastReturnedExternalIndex+1) - 1 + lastRetIntIndex := maxUint32(lastUsedIntIndex+1, row.lastReturnedInternalIndex+1) - 1 + + row = bip0044AccountInfo(row.pubKeyEncrypted, row.privKeyEncrypted, 0, 0, + lastUsedExtIndex, lastUsedIntIndex, lastRetExtIndex, lastRetIntIndex, + row.name, DBVersion) + return putAccountRow(ns, account, row) + + default: + return errors.E(errors.IO, errors.Errorf("invalid account type %d", row.actType())) + } } // MarkReturnedChildIndex marks a BIP0044 account branch child as returned to a @@ -1523,31 +1558,38 @@ func (m *Manager) MarkUsedChildIndex(tx walletdb.ReadWriteTx, account, branch, c func (m *Manager) MarkReturnedChildIndex(tx walletdb.ReadWriteTx, account, branch, child uint32) error { ns := tx.ReadWriteBucket(waddrmgrBucketKey) - row, err := fetchAccountInfo(ns, account, DBVersion) + acctRow, err := fetchAccountInfo(ns, account, DBVersion) if err != nil { return err } - lastRetExtIndex := row.lastReturnedExternalIndex - lastRetIntIndex := row.lastReturnedInternalIndex - switch branch { - case ExternalBranch: - lastRetExtIndex = child - case InternalBranch: - lastRetIntIndex = child - default: - return errors.E(errors.Invalid, errors.Errorf("account branch %d", branch)) - } - // The last returned indexes should never be less than the last used. The - // weird addition and subtraction makes this calculation work correctly even - // when any of of the indexes are ^uint32(0). - lastRetExtIndex = maxUint32(row.lastUsedExternalIndex+1, lastRetExtIndex+1) - 1 - lastRetIntIndex = maxUint32(row.lastUsedInternalIndex+1, lastRetIntIndex+1) - 1 + switch row := acctRow.(type) { + case *dbBIP0044AccountRow: + lastRetExtIndex := row.lastReturnedExternalIndex + lastRetIntIndex := row.lastReturnedInternalIndex + switch branch { + case ExternalBranch: + lastRetExtIndex = child + case InternalBranch: + lastRetIntIndex = child + default: + return errors.E(errors.Invalid, errors.Errorf("account branch %d", branch)) + } + + // The last returned indexes should never be less than the last used. The + // weird addition and subtraction makes this calculation work correctly even + // when any of of the indexes are ^uint32(0). + lastRetExtIndex = maxUint32(row.lastUsedExternalIndex+1, lastRetExtIndex+1) - 1 + lastRetIntIndex = maxUint32(row.lastUsedInternalIndex+1, lastRetIntIndex+1) - 1 + + row = bip0044AccountInfo(row.pubKeyEncrypted, row.privKeyEncrypted, 0, 0, + row.lastUsedExternalIndex, row.lastUsedInternalIndex, + lastRetExtIndex, lastRetIntIndex, row.name, DBVersion) + return putAccountRow(ns, account, row) - row = bip0044AccountInfo(row.pubKeyEncrypted, row.privKeyEncrypted, 0, 0, - row.lastUsedExternalIndex, row.lastUsedInternalIndex, - lastRetExtIndex, lastRetIntIndex, row.name, DBVersion) - return putAccountRow(ns, account, &row.dbAccountRow) + default: + return errors.E(errors.IO, errors.Errorf("invalid account type %d", row.actType())) + } } // ChainParams returns the chain parameters for this address manager. @@ -1784,34 +1826,40 @@ func (m *Manager) RenameAccount(ns walletdb.ReadWriteBucket, account uint32, nam return err } - row, err := fetchAccountInfo(ns, account, DBVersion) + acctRow, err := fetchAccountInfo(ns, account, DBVersion) if err != nil { return err } - // Remove the old name key from the accout id index - if err = deleteAccountIDIndex(ns, account); err != nil { - return err - } - // Remove the old name key from the account name index - if err = deleteAccountNameIndex(ns, row.name); err != nil { - return err - } - row = bip0044AccountInfo(row.pubKeyEncrypted, row.privKeyEncrypted, - 0, 0, row.lastUsedExternalIndex, row.lastUsedInternalIndex, - row.lastReturnedExternalIndex, row.lastReturnedInternalIndex, - name, DBVersion) - err = putAccountInfo(ns, account, row) - if err != nil { - return err - } + switch row := acctRow.(type) { + case *dbBIP0044AccountRow: + // Remove the old name key from the accout id index + if err = deleteAccountIDIndex(ns, account); err != nil { + return err + } + // Remove the old name key from the account name index + if err = deleteAccountNameIndex(ns, row.name); err != nil { + return err + } + row = bip0044AccountInfo(row.pubKeyEncrypted, row.privKeyEncrypted, + 0, 0, row.lastUsedExternalIndex, row.lastUsedInternalIndex, + row.lastReturnedExternalIndex, row.lastReturnedInternalIndex, + name, DBVersion) + err = putAccountInfo(ns, account, row) + if err != nil { + return err + } - // Update in-memory account info with new name if cached and the db - // write was successful. - if acctInfo, ok := m.acctInfo[account]; ok { - acctInfo.acctName = name + // Update in-memory account info with new name if cached and the db + // write was successful. + if acctInfo, ok := m.acctInfo[account]; ok { + acctInfo.acctName = name + } + return nil + + default: + return errors.E(errors.IO, errors.Errorf("invalid account type %d", row.actType())) } - return nil } // AccountName returns the account name for the given account number @@ -2535,7 +2583,7 @@ func createAddressManager(ns walletdb.ReadWriteBucket, seed, pubPassphrase, priv slip0044Account0Row := bip0044AccountInfo(acctPubSLIP0044Enc, acctPrivSLIP0044Enc, 0, 0, 0, 0, 0, 0, defaultAccountName, initialVersion) mainBucket := ns.NestedReadWriteBucket(mainBucketName) - err = mainBucket.Put(slip0044Account0RowName, serializeAccountRow(&slip0044Account0Row.dbAccountRow)) + err = mainBucket.Put(slip0044Account0RowName, serializeBIP0044AccountRow(slip0044Account0Row, initialVersion)) if err != nil { return errors.E(errors.IO, err) } diff --git a/wallet/udb/upgrades.go b/wallet/udb/upgrades.go index 986aa8bda..7c9680ed3 100644 --- a/wallet/udb/upgrades.go +++ b/wallet/udb/upgrades.go @@ -191,11 +191,16 @@ func lastUsedAddressIndexUpgrade(tx walletdb.ReadWriteTx, publicPassphrase []byt // Perform account updates on all BIP0044 accounts created thus far. for account := uint32(0); account <= lastAccount; account++ { // Load the old account info. - row, err := fetchAccountInfo(addrmgrBucket, account, oldVersion) + acctRow, err := fetchAccountInfo(addrmgrBucket, account, oldVersion) if err != nil { return err } + row, ok := acctRow.(*dbBIP0044AccountRow) + if !ok { + return errors.E(errors.IO, errors.Errorf("unexpected account type %d", row.actType())) + } + // Use the crypto public key to decrypt the account public extended key // and each branch key. serializedKeyPub, err := cryptoPubKey.Decrypt(row.pubKeyEncrypted) @@ -383,11 +388,16 @@ func lastReturnedAddressUpgrade(tx walletdb.ReadWriteTx, publicPassphrase []byte upgradeAcct := func(account uint32) error { // Load the old account info. - row, err := fetchAccountInfo(addrmgrBucket, account, oldVersion) + acctRow, err := fetchAccountInfo(addrmgrBucket, account, oldVersion) if err != nil { return err } + row, ok := acctRow.(*dbBIP0044AccountRow) + if !ok { + return errors.E(errors.IO, errors.Errorf("unexpected account type %d", row.actType())) + } + // Convert account row values to the new serialization format that adds // the last returned indexes. Assume that the last used address is also // the last returned address. diff --git a/wallet/udb/upgrades_test.go b/wallet/udb/upgrades_test.go index af3eee6da..3468ae2e0 100644 --- a/wallet/udb/upgrades_test.go +++ b/wallet/udb/upgrades_test.go @@ -267,10 +267,16 @@ func verifyV5Upgrade(t *testing.T, db walletdb.DB) { const dbVersion = 5 for _, d := range data { - row, err := fetchAccountInfo(ns, d.acct, dbVersion) + acctRow, err := fetchAccountInfo(ns, d.acct, dbVersion) if err != nil { return err } + + row, ok := acctRow.(*dbBIP0044AccountRow) + if !ok { + t.Errorf("unexpected account type %d", row.actType()) + } + if row.lastUsedExternalIndex != d.lastUsedExtChild { t.Errorf("Account %d last used ext child mismatch %d != %d", d.acct, row.lastUsedExternalIndex, d.lastUsedExtChild)