Skip to content

Commit

Permalink
Merge pull request #430 from multiversx/shard-id-for-system-account
Browse files Browse the repository at this point in the history
Shard id parameter for system account
  • Loading branch information
miiu96 authored Mar 12, 2024
2 parents 62c486f + cc5505a commit 34c7d95
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 33 deletions.
31 changes: 18 additions & 13 deletions api/groups/baseAccountsGroup.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func NewAccountsGroup(facadeHandler data.FacadeHandler) (*accountsGroup, error)
func (group *accountsGroup) respondWithAccount(c *gin.Context, transform func(*data.AccountModel) gin.H) {
address := c.Param("address")

options, err := parseAccountQueryOptions(c)
options, err := parseAccountQueryOptions(c, address)
if err != nil {
shared.RespondWithValidationError(c, errors.ErrBadUrlParams, err)
return
Expand Down Expand Up @@ -113,7 +113,7 @@ func (group *accountsGroup) getNonce(c *gin.Context) {
// getCodeHash returns the code hash for the address parameter
func (group *accountsGroup) getCodeHash(c *gin.Context) {
address := c.Param("address")
options, err := parseAccountQueryOptions(c)
options, err := parseAccountQueryOptions(c, address)
if err != nil {
shared.RespondWithValidationError(c, errors.ErrBadUrlParams, err)
return
Expand All @@ -137,7 +137,12 @@ func (group *accountsGroup) getAccounts(c *gin.Context) {
return
}

options, err := parseAccountQueryOptions(c)
addr := ""
if len(addresses) > 0 {
addr = addresses[0]
}

options, err := parseAccountQueryOptions(c, addr)
if err != nil {
shared.RespondWithValidationError(c, errors.ErrInvalidFields, err)
return
Expand Down Expand Up @@ -171,7 +176,7 @@ func (group *accountsGroup) getKeyValuePairs(c *gin.Context) {
return
}

options, err := parseAccountQueryOptions(c)
options, err := parseAccountQueryOptions(c, addr)
if err != nil {
shared.RespondWithValidationError(c, errors.ErrGetKeyValuePairs, err)
return
Expand All @@ -194,7 +199,7 @@ func (group *accountsGroup) getValueForKey(c *gin.Context) {
return
}

options, err := parseAccountQueryOptions(c)
options, err := parseAccountQueryOptions(c, addr)
if err != nil {
shared.RespondWithValidationError(c, errors.ErrGetValueForKey, err)
return
Expand Down Expand Up @@ -246,7 +251,7 @@ func (group *accountsGroup) getESDTTokenData(c *gin.Context) {
return
}

options, err := parseAccountQueryOptions(c)
options, err := parseAccountQueryOptions(c, addr)
if err != nil {
shared.RespondWithValidationError(c, errors.ErrGetESDTTokenData, err)
return
Expand Down Expand Up @@ -274,7 +279,7 @@ func (group *accountsGroup) getESDTsRoles(c *gin.Context) {
return
}

options, err := parseAccountQueryOptions(c)
options, err := parseAccountQueryOptions(c, addr)
if err != nil {
shared.RespondWithValidationError(c, errors.ErrGetRolesForAccount, err)
return
Expand All @@ -297,7 +302,7 @@ func (group *accountsGroup) getESDTsWithRole(c *gin.Context) {
return
}

options, err := parseAccountQueryOptions(c)
options, err := parseAccountQueryOptions(c, addr)
if err != nil {
shared.RespondWithValidationError(c, errors.ErrGetESDTsWithRole, err)
return
Expand Down Expand Up @@ -326,7 +331,7 @@ func (group *accountsGroup) getRegisteredNFTs(c *gin.Context) {
return
}

options, err := parseAccountQueryOptions(c)
options, err := parseAccountQueryOptions(c, addr)
if err != nil {
shared.RespondWithValidationError(c, errors.ErrGetNFTTokenIDsRegisteredByAddress, err)
return
Expand All @@ -349,7 +354,7 @@ func (group *accountsGroup) getESDTNftTokenData(c *gin.Context) {
return
}

options, err := parseAccountQueryOptions(c)
options, err := parseAccountQueryOptions(c, addr)
if err != nil {
shared.RespondWithValidationError(c, errors.ErrGetESDTTokenData, err)
return
Expand Down Expand Up @@ -383,7 +388,7 @@ func (group *accountsGroup) getGuardianData(c *gin.Context) {
return
}

options, err := parseAccountQueryOptions(c)
options, err := parseAccountQueryOptions(c, addr)
if err != nil {
shared.RespondWithValidationError(c, errors.ErrGetGuardianData, err)
return
Expand All @@ -406,7 +411,7 @@ func (group *accountsGroup) getESDTTokens(c *gin.Context) {
return
}

options, err := parseAccountQueryOptions(c)
options, err := parseAccountQueryOptions(c, addr)
if err != nil {
shared.RespondWithValidationError(c, errors.ErrGetESDTTokenData, err)
return
Expand All @@ -427,7 +432,7 @@ func (group *accountsGroup) isDataTrieMigrated(c *gin.Context) {
return
}

options, err := parseAccountQueryOptions(c)
options, err := parseAccountQueryOptions(c, addr)
if err != nil {
shared.RespondWithValidationError(c, errors.ErrIsDataTrieMigrated, err)
return
Expand Down
3 changes: 3 additions & 0 deletions api/groups/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ var ErrHandlerDoesNotExist = errors.New("handler does not exist")

// ErrWrongTypeAssertion signals that a wrong type assertion issue was found during the execution
var ErrWrongTypeAssertion = errors.New("wrong type assertion")

// ErrForcedShardIDCannotBeProvided signals that the forced shard id cannot be provided for a different address other than the system account address
var ErrForcedShardIDCannotBeProvided = errors.New("forced shard id parameter can only be provided for system accounts")
15 changes: 14 additions & 1 deletion api/groups/urlParams.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import (
"github.com/multiversx/mx-chain-proxy-go/common"
)

// SystemAccountAddressBech is the const for the system account address
const SystemAccountAddressBech = "erd1lllllllllllllllllllllllllllllllllllllllllllllllllllsckry7t"

func parseBlockQueryOptions(c *gin.Context) (common.BlockQueryOptions, error) {
withTxs, err := parseBoolUrlParam(c, common.UrlParameterWithTransactions)
if err != nil {
Expand Down Expand Up @@ -56,7 +59,7 @@ func parseHyperblockQueryOptions(c *gin.Context) (common.HyperblockQueryOptions,
}, nil
}

func parseAccountQueryOptions(c *gin.Context) (common.AccountQueryOptions, error) {
func parseAccountQueryOptions(c *gin.Context, address string) (common.AccountQueryOptions, error) {
onFinalBlock, err := parseBoolUrlParam(c, common.UrlParameterOnFinalBlock)
if err != nil {
return common.AccountQueryOptions{}, err
Expand Down Expand Up @@ -87,13 +90,23 @@ func parseAccountQueryOptions(c *gin.Context) (common.AccountQueryOptions, error
return common.AccountQueryOptions{}, err
}

shardID, err := parseUint32UrlParam(c, common.UrlParameterForcedShardID)
if err != nil {
return common.AccountQueryOptions{}, err
}

if shardID.HasValue && address != SystemAccountAddressBech {
return common.AccountQueryOptions{}, ErrForcedShardIDCannotBeProvided
}

options := common.AccountQueryOptions{
OnFinalBlock: onFinalBlock,
OnStartOfEpoch: onStartOfEpoch,
BlockNonce: blockNonce,
BlockHash: blockHash,
BlockRootHash: blockRootHash,
HintEpoch: hintEpoch,
ForcedShardID: shardID,
}

return options, nil
Expand Down
6 changes: 3 additions & 3 deletions api/groups/urlParams_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,15 @@ func TestParseHyperblockQueryOptions(t *testing.T) {
}

func TestParseAccountQueryOptions(t *testing.T) {
options, err := parseAccountQueryOptions(createDummyGinContextWithQuery("onFinalBlock=true"))
options, err := parseAccountQueryOptions(createDummyGinContextWithQuery("onFinalBlock=true"), "")
require.Nil(t, err)
require.Equal(t, common.AccountQueryOptions{OnFinalBlock: true}, options)

options, err = parseAccountQueryOptions(createDummyGinContextWithQuery(""))
options, err = parseAccountQueryOptions(createDummyGinContextWithQuery(""), "")
require.Nil(t, err)
require.Empty(t, options)

options, err = parseAccountQueryOptions(createDummyGinContextWithQuery("onFinalBlock=foobar"))
options, err = parseAccountQueryOptions(createDummyGinContextWithQuery("onFinalBlock=foobar"), "")
require.NotNil(t, err)
require.Empty(t, options)
}
Expand Down
3 changes: 3 additions & 0 deletions common/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ const (
UrlParameterWithResults = "withResults"
// UrlParameterShardID represents the name of an URL parameter
UrlParameterShardID = "shard-id"
// UrlParameterForcedShardID represents the name of an URL parameter
UrlParameterForcedShardID = "forced-shard-id"
// UrlParameterSender represents the name of an URL parameter
UrlParameterSender = "by-sender"
// UrlParameterFields represents the name of an URL parameter
Expand Down Expand Up @@ -105,6 +107,7 @@ func BuildUrlWithBlockQueryOptions(path string, options BlockQueryOptions) strin
type AccountQueryOptions struct {
OnFinalBlock bool
OnStartOfEpoch core.OptionalUint32
ForcedShardID core.OptionalUint32
BlockNonce core.OptionalUint64
BlockHash []byte
BlockRootHash []byte
Expand Down
31 changes: 15 additions & 16 deletions process/accountProcessor.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (ap *AccountProcessor) GetShardIDForAddress(address string) (uint32, error)
// GetAccount resolves the request by sending the request to the right observer and returns the response
func (ap *AccountProcessor) GetAccount(address string, options common.AccountQueryOptions) (*data.AccountModel, error) {
availability := ap.availabilityProvider.AvailabilityForAccountQueryOptions(options)
observers, err := ap.getObserversForAddress(address, availability)
observers, err := ap.getObserversForAddress(address, availability, options.ForcedShardID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -161,7 +161,7 @@ func (ap *AccountProcessor) getAccountsInShard(addresses []string, shardID uint3
// GetValueForKey returns the value for the given address and key
func (ap *AccountProcessor) GetValueForKey(address string, key string, options common.AccountQueryOptions) (string, error) {
availability := ap.availabilityProvider.AvailabilityForAccountQueryOptions(options)
observers, err := ap.getObserversForAddress(address, availability)
observers, err := ap.getObserversForAddress(address, availability, options.ForcedShardID)
if err != nil {
return "", err
}
Expand Down Expand Up @@ -193,7 +193,7 @@ func (ap *AccountProcessor) GetValueForKey(address string, key string, options c
// GetESDTTokenData returns the token data for a token with the given name
func (ap *AccountProcessor) GetESDTTokenData(address string, key string, options common.AccountQueryOptions) (*data.GenericAPIResponse, error) {
availability := ap.availabilityProvider.AvailabilityForAccountQueryOptions(options)
observers, err := ap.getObserversForAddress(address, availability)
observers, err := ap.getObserversForAddress(address, availability, options.ForcedShardID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -325,7 +325,7 @@ func (ap *AccountProcessor) GetNFTTokenIDsRegisteredByAddress(address string, op
// GetESDTNftTokenData returns the nft token data for a token with the given identifier and nonce
func (ap *AccountProcessor) GetESDTNftTokenData(address string, key string, nonce uint64, options common.AccountQueryOptions) (*data.GenericAPIResponse, error) {
availability := ap.availabilityProvider.AvailabilityForAccountQueryOptions(options)
observers, err := ap.getObserversForAddress(address, availability)
observers, err := ap.getObserversForAddress(address, availability, options.ForcedShardID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -359,7 +359,7 @@ func (ap *AccountProcessor) GetESDTNftTokenData(address string, key string, nonc
// GetAllESDTTokens returns all the tokens for a given address
func (ap *AccountProcessor) GetAllESDTTokens(address string, options common.AccountQueryOptions) (*data.GenericAPIResponse, error) {
availability := ap.availabilityProvider.AvailabilityForAccountQueryOptions(options)
observers, err := ap.getObserversForAddress(address, availability)
observers, err := ap.getObserversForAddress(address, availability, options.ForcedShardID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -391,7 +391,7 @@ func (ap *AccountProcessor) GetAllESDTTokens(address string, options common.Acco
// GetKeyValuePairs returns all the key-value pairs for a given address
func (ap *AccountProcessor) GetKeyValuePairs(address string, options common.AccountQueryOptions) (*data.GenericAPIResponse, error) {
availability := ap.availabilityProvider.AvailabilityForAccountQueryOptions(options)
observers, err := ap.getObserversForAddress(address, availability)
observers, err := ap.getObserversForAddress(address, availability, options.ForcedShardID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -423,7 +423,7 @@ func (ap *AccountProcessor) GetKeyValuePairs(address string, options common.Acco
// GetGuardianData returns the guardian data for the given address
func (ap *AccountProcessor) GetGuardianData(address string, options common.AccountQueryOptions) (*data.GenericAPIResponse, error) {
availability := ap.availabilityProvider.AvailabilityForAccountQueryOptions(options)
observers, err := ap.getObserversForAddress(address, availability)
observers, err := ap.getObserversForAddress(address, availability, options.ForcedShardID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -464,7 +464,7 @@ func (ap *AccountProcessor) GetTransactions(address string) ([]data.DatabaseTran
// GetCodeHash returns the code hash for a given address
func (ap *AccountProcessor) GetCodeHash(address string, options common.AccountQueryOptions) (*data.GenericAPIResponse, error) {
availability := ap.availabilityProvider.AvailabilityForAccountQueryOptions(options)
observers, err := ap.getObserversForAddress(address, availability)
observers, err := ap.getObserversForAddress(address, availability, options.ForcedShardID)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -502,23 +502,22 @@ func (ap *AccountProcessor) getShardIfOdAddress(address string) (uint32, error)
return ap.proc.ComputeShardId(addressBytes)
}

func (ap *AccountProcessor) getObserversForAddress(address string, availability data.ObserverDataAvailabilityType) ([]*data.NodeData, error) {
addressBytes, err := ap.pubKeyConverter.Decode(address)
if err != nil {
return nil, err
func (ap *AccountProcessor) getObserversForAddress(address string, availability data.ObserverDataAvailabilityType, forcedShardID core.OptionalUint32) ([]*data.NodeData, error) {
if forcedShardID.HasValue {
return ap.proc.GetObservers(forcedShardID.Value, availability)
}

shardID, err := ap.proc.ComputeShardId(addressBytes)
addressBytes, err := ap.pubKeyConverter.Decode(address)
if err != nil {
return nil, err
}

observers, err := ap.proc.GetObservers(shardID, availability)
shardID, err := ap.proc.ComputeShardId(addressBytes)
if err != nil {
return nil, err
}

return observers, nil
return ap.proc.GetObservers(shardID, availability)
}

// GetBaseProcessor returns the base processor
Expand All @@ -528,7 +527,7 @@ func (ap *AccountProcessor) GetBaseProcessor() Processor {

// IsDataTrieMigrated returns true if the data trie for the given address is migrated
func (ap *AccountProcessor) IsDataTrieMigrated(address string, options common.AccountQueryOptions) (*data.GenericAPIResponse, error) {
observers, err := ap.getObserversForAddress(address, data.AvailabilityRecent)
observers, err := ap.getObserversForAddress(address, data.AvailabilityRecent, options.ForcedShardID)
if err != nil {
return nil, err
}
Expand Down

0 comments on commit 34c7d95

Please sign in to comment.