Skip to content

Commit

Permalink
refactor: tx_multisign/tx_sign refactor to improved readability and m…
Browse files Browse the repository at this point in the history
…aintainability (cosmos#18451)
  • Loading branch information
pluveto committed Nov 13, 2023
1 parent e657752 commit f0753b4
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 90 deletions.
48 changes: 22 additions & 26 deletions x/auth/client/cli/tx_multisign.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ import (
"github.com/cosmos/cosmos-sdk/client/flags"
"github.com/cosmos/cosmos-sdk/client/tx"
codectypes "github.com/cosmos/cosmos-sdk/codec/types"
"github.com/cosmos/cosmos-sdk/crypto/keyring"
kmultisig "github.com/cosmos/cosmos-sdk/crypto/keys/multisig"
"github.com/cosmos/cosmos-sdk/crypto/types/multisig"
sdk "github.com/cosmos/cosmos-sdk/types"
Expand Down Expand Up @@ -68,11 +67,15 @@ The SIGN_MODE_DIRECT sign mode is not supported.'

func makeMultiSignCmd() func(cmd *cobra.Command, args []string) (err error) {
return func(cmd *cobra.Command, args []string) (err error) {
file := args[0]
name := args[1]
sigsRaw := args[2:]

clientCtx, err := client.GetClientTxContext(cmd)
if err != nil {
return err
}
parsedTx, err := authclient.ReadTxFromFile(clientCtx, args[0])
parsedTx, err := authclient.ReadTxFromFile(clientCtx, file)
if err != nil {
return err
}
Expand All @@ -91,9 +94,9 @@ func makeMultiSignCmd() func(cmd *cobra.Command, args []string) (err error) {
return err
}

k, err := getMultisigRecord(clientCtx, args[1])
k, err := clientCtx.Keyring.Key(name)
if err != nil {
return err
return errorsmod.Wrap(err, "error getting keybase multisig account")
}
pubKey, err := k.GetPubKey()
if err != nil {
Expand All @@ -117,8 +120,8 @@ func makeMultiSignCmd() func(cmd *cobra.Command, args []string) (err error) {
}

// read each signature and add it to the multisig if valid
for i := 2; i < len(args); i++ {
sigs, err := unmarshalSignatureJSON(clientCtx, args[i])
for i := 0; i < len(sigsRaw); i++ {
sigs, err := unmarshalSignatureJSON(clientCtx, sigsRaw[i])
if err != nil {
return err
}
Expand Down Expand Up @@ -176,7 +179,7 @@ func makeMultiSignCmd() func(cmd *cobra.Command, args []string) (err error) {
sigOnly, _ := cmd.Flags().GetBool(flagSigOnly)

var json []byte
json, err = marshalSignatureJSON(txCfg, txBuilder, sigOnly)
json, err = marshalSignatureJSON(txCfg, txBuilder.GetTx(), sigOnly)
if err != nil {
return err
}
Expand Down Expand Up @@ -233,6 +236,9 @@ func makeBatchMultisignCmd() func(cmd *cobra.Command, args []string) error {
return func(cmd *cobra.Command, args []string) (err error) {
var clientCtx client.Context

file, name := args[0], args[1]
sigFiles := args[2:]

clientCtx, err = client.GetClientTxContext(cmd)
if err != nil {
return err
Expand All @@ -248,19 +254,19 @@ func makeBatchMultisignCmd() func(cmd *cobra.Command, args []string) error {
}

// reads tx from args[0]
scanner, err := authclient.ReadTxsFromInput(txCfg, args[0])
scanner, err := authclient.ReadTxsFromInput(txCfg, file)
if err != nil {
return err
}

k, err := getMultisigRecord(clientCtx, args[1])
k, err := clientCtx.Keyring.Key(name)
if err != nil {
return err
return errorsmod.Wrap(err, "error getting keybase multisig account")
}

var signatureBatch [][]signingtypes.SignatureV2
for i := 2; i < len(args); i++ {
sigs, err := readSignaturesFromFile(clientCtx, args[i])
for i := 0; i < len(sigFiles); i++ {
sigs, err := readSignaturesFromFile(clientCtx, sigFiles[i])
if err != nil {
return err
}
Expand Down Expand Up @@ -292,7 +298,7 @@ func makeBatchMultisignCmd() func(cmd *cobra.Command, args []string) error {
clientCtx.WithOutput(cmd.OutOrStdout())

for i := 0; scanner.Scan(); i++ {
txBldr, err := txCfg.WrapTxBuilder(scanner.Tx())
txBuilder, err := txCfg.WrapTxBuilder(scanner.Tx())
if err != nil {
return err
}
Expand All @@ -318,7 +324,7 @@ func makeBatchMultisignCmd() func(cmd *cobra.Command, args []string) error {
},
}

builtTx := txBldr.GetTx()
builtTx := txBuilder.GetTx()
adaptableTx, ok := builtTx.(signing.V2AdaptableTx)
if !ok {
return fmt.Errorf("expected Tx to be signing.V2AdaptableTx, got %T", builtTx)
Expand All @@ -343,14 +349,14 @@ func makeBatchMultisignCmd() func(cmd *cobra.Command, args []string) error {
Sequence: txFactory.Sequence(),
}

err = txBldr.SetSignatures(sigV2)
err = txBuilder.SetSignatures(sigV2)
if err != nil {
return err
}

sigOnly, _ := cmd.Flags().GetBool(flagSigOnly)
var json []byte
json, err = marshalSignatureJSON(txCfg, txBldr, sigOnly)
json, err = marshalSignatureJSON(txCfg, txBuilder.GetTx(), sigOnly)
if err != nil {
return err
}
Expand Down Expand Up @@ -398,13 +404,3 @@ func readSignaturesFromFile(ctx client.Context, filename string) (sigs []signing
}
return sigs, nil
}

func getMultisigRecord(clientCtx client.Context, name string) (*keyring.Record, error) {
kb := clientCtx.Keyring
multisigRecord, err := kb.Key(name)
if err != nil {
return nil, errorsmod.Wrap(err, "error getting keybase multisig account")
}

return multisigRecord, nil
}
117 changes: 53 additions & 64 deletions x/auth/client/cli/tx_sign.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ import (

"github.com/spf13/cobra"

errorsmod "cosmossdk.io/errors"
authclient "cosmossdk.io/x/auth/client"
"cosmossdk.io/x/auth/signing"

"github.com/cosmos/cosmos-sdk/client"
"github.com/cosmos/cosmos-sdk/client/flags"
Expand Down Expand Up @@ -73,9 +75,8 @@ func makeSignBatchCmd() func(cmd *cobra.Command, args []string) error {
return err
}
txCfg := clientCtx.TxConfig
printSignatureOnly, _ := cmd.Flags().GetBool(flagSigOnly)

ms, err := cmd.Flags().GetString(flagMultisig)
multisigKey, err := cmd.Flags().GetString(flagMultisig)
if err != nil {
return err
}
Expand All @@ -94,39 +95,35 @@ func makeSignBatchCmd() func(cmd *cobra.Command, args []string) error {
return err
}

if !clientCtx.Offline {
if ms == "" {
from, err := cmd.Flags().GetString(flags.FlagFrom)
if err != nil {
return err
}

addr, _, _, err := client.GetFromFields(clientCtx, txFactory.Keybase(), from)
if err != nil {
return err
}
if !clientCtx.Offline && multisigKey == "" {
from, err := cmd.Flags().GetString(flags.FlagFrom)
if err != nil {
return err
}

acc, err := txFactory.AccountRetriever().GetAccount(clientCtx, addr)
if err != nil {
return err
}
fromAddr, _, _, err := client.GetFromFields(clientCtx, txFactory.Keybase(), from)
if err != nil {
return err
}

txFactory = txFactory.WithAccountNumber(acc.GetAccountNumber()).WithSequence(acc.GetSequence())
} else {
txFactory = txFactory.WithAccountNumber(0).WithSequence(0)
fromAcc, err := txFactory.AccountRetriever().GetAccount(clientCtx, fromAddr)
if err != nil {
return err
}

txFactory = txFactory.WithAccountNumber(fromAcc.GetAccountNumber()).WithSequence(fromAcc.GetSequence())
}

appendMessagesToSingleTx, _ := cmd.Flags().GetBool(flagAppend)
// Combines all tx msgs and create single signed transaction
if appendMessagesToSingleTx {
txBuilder := clientCtx.TxConfig.NewTxBuilder()
txBuilder := txCfg.NewTxBuilder()
msgs := make([]sdk.Msg, 0)
newGasLimit := uint64(0)

for scanner.Scan() {
unsignedStdTx := scanner.Tx()
fe, err := clientCtx.TxConfig.WrapTxBuilder(unsignedStdTx)
fe, err := txCfg.WrapTxBuilder(unsignedStdTx)
if err != nil {
return err
}
Expand All @@ -151,18 +148,11 @@ func makeSignBatchCmd() func(cmd *cobra.Command, args []string) error {
txBuilder.SetGasLimit(newGasLimit)

// sign the txs
if ms == "" {
from, _ := cmd.Flags().GetString(flags.FlagFrom)
if err := sign(clientCtx, txBuilder, txFactory, from); err != nil {
return err
}
} else {
if err := multisigSign(clientCtx, txBuilder, txFactory, ms); err != nil {
return err
}
}
from, _ := cmd.Flags().GetString(flags.FlagFrom)
sigTxOrMultisig(clientCtx, txBuilder, txFactory, from, multisigKey)

json, err := marshalSignatureJSON(txCfg, txBuilder, printSignatureOnly)
sigOnly, _ := cmd.Flags().GetBool(flagSigOnly)
json, err := marshalSignatureJSON(txCfg, txBuilder.GetTx(), sigOnly)
if err != nil {
return err
}
Expand All @@ -179,33 +169,31 @@ func makeSignBatchCmd() func(cmd *cobra.Command, args []string) error {
}

// sign the txs
if ms == "" {
from, _ := cmd.Flags().GetString(flags.FlagFrom)
if err := sign(clientCtx, txBuilder, txFactory, from); err != nil {
return err
}
} else {
if err := multisigSign(clientCtx, txBuilder, txFactory, ms); err != nil {
return err
}
}
from, _ := cmd.Flags().GetString(flags.FlagFrom)
sigTxOrMultisig(clientCtx, txBuilder, txFactory, from, multisigKey)

json, err := marshalSignatureJSON(txCfg, txBuilder, printSignatureOnly)
printSigOnly, _ := cmd.Flags().GetBool(flagSigOnly)
json, err := marshalSignatureJSON(txCfg, txBuilder.GetTx(), printSigOnly)
if err != nil {
return err
}
cmd.Printf("%s\n", json)
}
}

if err := scanner.UnmarshalErr(); err != nil {
return err
}

return scanner.UnmarshalErr()
}
}

func sigTxOrMultisig(clientCtx client.Context, txBuilder client.TxBuilder, txFactory tx.Factory, from string, multisigKey string) (err error) {
if multisigKey == "" {
err = sign(clientCtx, txBuilder, txFactory, from)
} else {
err = multisigSign(clientCtx, txBuilder, txFactory, multisigKey)
}
return err
}

func sign(clientCtx client.Context, txBuilder client.TxBuilder, txFactory tx.Factory, from string) error {
_, fromName, _, err := client.GetFromFields(clientCtx, txFactory.Keybase(), from)
if err != nil {
Expand Down Expand Up @@ -322,20 +310,20 @@ func makeSignCmd() func(cmd *cobra.Command, args []string) error {
}
}

func signTx(cmd *cobra.Command, clientCtx client.Context, txF tx.Factory, newTx sdk.Tx) error {
func signTx(cmd *cobra.Command, clientCtx client.Context, txFactory tx.Factory, newTx sdk.Tx) error {
f := cmd.Flags()
txCfg := clientCtx.TxConfig
txBuilder, err := txCfg.WrapTxBuilder(newTx)
if err != nil {
return err
}

printSignatureOnly, err := cmd.Flags().GetBool(flagSigOnly)
sigOnly, err := cmd.Flags().GetBool(flagSigOnly)
if err != nil {
return err
}

multisig, err := cmd.Flags().GetString(flagMultisig)
multisigKey, err := cmd.Flags().GetString(flagMultisig)
if err != nil {
return err
}
Expand All @@ -345,7 +333,7 @@ func signTx(cmd *cobra.Command, clientCtx client.Context, txF tx.Factory, newTx
return err
}

_, fromName, _, err := client.GetFromFields(clientCtx, txF.Keybase(), from)
_, fromName, _, err := client.GetFromFields(clientCtx, txFactory.Keybase(), from)
if err != nil {
return fmt.Errorf("error getting account from keybase: %w", err)
}
Expand All @@ -355,15 +343,18 @@ func signTx(cmd *cobra.Command, clientCtx client.Context, txF tx.Factory, newTx
return err
}

if multisig != "" {
if multisigKey != "" {
sigOnly = true

// get the multisig key by name
// Bech32 decode error, maybe it's a name, we try to fetch from keyring
multisigAddr, multisigName, _, err := client.GetFromFields(clientCtx, txF.Keybase(), multisig)
multisigAddr, multisigName, _, err := client.GetFromFields(clientCtx, txFactory.Keybase(), multisigKey)
if err != nil {
return fmt.Errorf("error getting account from keybase: %w", err)
}
multisigkey, err := getMultisigRecord(clientCtx, multisigName)
multisigkey, err := clientCtx.Keyring.Key(multisigName)
if err != nil {
return err
return errorsmod.Wrap(err, "error getting keybase multisig account")
}
multisigPubKey, err := multisigkey.GetPubKey()
if err != nil {
Expand All @@ -390,13 +381,12 @@ func signTx(cmd *cobra.Command, clientCtx client.Context, txF tx.Factory, newTx
return fmt.Errorf("signing key is not a part of multisig key")
}
err = authclient.SignTxWithSignerAddress(
txF, clientCtx, multisigAddr, fromName, txBuilder, clientCtx.Offline, overwrite)
txFactory, clientCtx, multisigAddr, fromName, txBuilder, clientCtx.Offline, overwrite)
if err != nil {
return err
}
printSignatureOnly = true
} else {
err = authclient.SignTx(txF, clientCtx, clientCtx.FromName, txBuilder, clientCtx.Offline, overwrite)
err = authclient.SignTx(txFactory, clientCtx, clientCtx.FromName, txBuilder, clientCtx.Offline, overwrite)
}
if err != nil {
return err
Expand All @@ -412,7 +402,7 @@ func signTx(cmd *cobra.Command, clientCtx client.Context, txF tx.Factory, newTx
clientCtx.WithOutput(cmd.OutOrStdout())

var json []byte
json, err = marshalSignatureJSON(txCfg, txBuilder, printSignatureOnly)
json, err = marshalSignatureJSON(txCfg, txBuilder.GetTx(), sigOnly)
if err != nil {
return err
}
Expand All @@ -422,15 +412,14 @@ func signTx(cmd *cobra.Command, clientCtx client.Context, txF tx.Factory, newTx
return err
}

func marshalSignatureJSON(txConfig client.TxConfig, txBldr client.TxBuilder, signatureOnly bool) ([]byte, error) {
parsedTx := txBldr.GetTx()
func marshalSignatureJSON(txConfig client.TxConfig, tx signing.Tx, signatureOnly bool) ([]byte, error) {
if signatureOnly {
sigs, err := parsedTx.GetSignaturesV2()
sigs, err := tx.GetSignaturesV2()
if err != nil {
return nil, err
}
return txConfig.MarshalSignatureJSON(sigs)
}

return txConfig.TxJSONEncoder()(parsedTx)
return txConfig.TxJSONEncoder()(tx)
}

0 comments on commit f0753b4

Please sign in to comment.