Skip to content

Commit

Permalink
Merge pull request #1705 from onflow/supun/support-type-requirements
Browse files Browse the repository at this point in the history
Extract type requirements from old code for staged contracts
  • Loading branch information
SupunS authored Aug 28, 2024
2 parents 763e01c + bd87b9f commit 5dd8fdd
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 54 deletions.
147 changes: 94 additions & 53 deletions internal/migrate/staging_validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ import (
"fmt"
"strings"

"github.com/rs/zerolog"
"golang.org/x/exp/slices"

"github.com/onflow/cadence"
"github.com/onflow/cadence/runtime"
"github.com/onflow/cadence/runtime/ast"
Expand All @@ -35,12 +38,14 @@ import (
"github.com/onflow/cadence/runtime/pretty"
"github.com/onflow/cadence/runtime/sema"
"github.com/onflow/cadence/runtime/stdlib"
"github.com/onflow/contract-updater/lib/go/templates"
flowsdk "github.com/onflow/flow-go-sdk"

"github.com/onflow/flow-go/cmd/util/ledger/migrations"
"github.com/onflow/flow-go/cmd/util/ledger/reporters"
"github.com/onflow/flow-go/model/flow"

"github.com/onflow/contract-updater/lib/go/templates"
flowsdk "github.com/onflow/flow-go-sdk"
"github.com/onflow/flowkit/v2"
"golang.org/x/exp/slices"

"github.com/onflow/flow-cli/internal/util"
)
Expand All @@ -58,9 +63,13 @@ type stagingValidatorImpl struct {

// Cache for account contract names so we don't have to fetch them multiple times
accountContractNames map[common.Address][]string

// All resolved contract code
contracts map[common.Location][]byte

// Contract codes that are not updated/staged
oldCodes map[common.Location][]byte

// Dependency graph for staged contracts
// This root level map holds all nodes
graph map[common.Location]node
Expand Down Expand Up @@ -172,6 +181,7 @@ func newStagingValidator(flow flowkit.Services) *stagingValidatorImpl {
checkingCache: make(map[common.Location]*cachedCheckingResult),
accountContractNames: make(map[common.Address][]string),
graph: make(map[common.Location]node),
oldCodes: make(map[common.Location][]byte),
}
}

Expand All @@ -186,10 +196,12 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate)

v.stagedContracts = make(map[common.AddressLocation]stagedContractUpdate)
for _, stagedContract := range stagedContracts {
v.stagedContracts[stagedContract.DeployLocation] = stagedContract
stagedContractLocation := stagedContract.DeployLocation

v.stagedContracts[stagedContractLocation] = stagedContract

// Add the contract code to the contracts map for pretty printing
v.contracts[stagedContract.SourceLocation] = stagedContract.Code
v.contracts[stagedContractLocation] = stagedContract.Code
}

// Load system contracts
Expand All @@ -198,24 +210,72 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate)
// Parse and check all staged contracts
errs := v.checkAllStaged()

typeRequirements := &migrations.LegacyTypeRequirements{}

// Extract type requirements from the old codes for all staged contracts.
for _, contract := range v.stagedContracts {
location := contract.DeployLocation

// Don't validate contracts with existing errors
if errs[location] != nil {
continue
}

// Get the account for the contract
address := flowsdk.Address(location.Address)

var account *flowsdk.Account
var err error

err = withRetry(func() error {
account, err = v.flow.GetAccount(context.Background(), address)
return err
})
if err != nil {
return fmt.Errorf("failed to get account: %w", err)
}

// Get the target contract old code
contractName := location.Name
oldCode, ok := account.Contracts[contractName]
if !ok {
return fmt.Errorf("old contract code not found for contract: %s", contractName)
}
v.oldCodes[location] = oldCode

migrations.ExtractTypeRequirements(
migrations.AddressContract{
Location: location,
Code: oldCode,
},
zerolog.Nop(),
reporters.ReportNilWriter{},
typeRequirements,
)
}

// Validate all contract updates
for _, contract := range v.stagedContracts {
location := contract.DeployLocation

// Don't validate contracts with existing errors
if errs[contract.SourceLocation] != nil {
if errs[location] != nil {
continue
}

// Validate the contract update
checker := v.checkingCache[contract.SourceLocation].checker
err := v.validateContractUpdate(contract, checker)
checker := v.checkingCache[location].checker
err := v.validateContractUpdate(contract, checker, typeRequirements)
if err != nil {
errs[contract.SourceLocation] = err
errs[location] = err
}
}

// Check for any upstream contract update failures
for _, contract := range v.stagedContracts {
err := errs[contract.SourceLocation]
location := contract.DeployLocation

err := errs[location]

// We will override any errors other than those related
// to missing dependencies, since they are more specific
Expand All @@ -233,19 +293,14 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate)

badDeps := make([]common.Location, 0)
v.forEachDependency(contract, func(dependency common.Location) {
strLocation, ok := dependency.(common.StringLocation)
if !ok {
return
}

if errs[strLocation] != nil {
if errs[dependency] != nil {
badDeps = append(badDeps, dependency)
}
})

if len(badDeps) > 0 {
errs[contract.SourceLocation] = &upstreamValidationError{
Location: contract.SourceLocation,
errs[location] = &upstreamValidationError{
Location: location,
BadDependencies: badDeps,
}
}
Expand All @@ -256,7 +311,7 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate)
// Map errors to address locations
errsByAddress := make(map[common.AddressLocation]error)
for _, contract := range v.stagedContracts {
err := errs[contract.SourceLocation]
err := errs[contract.DeployLocation]
if err != nil {
errsByAddress[contract.DeployLocation] = err
}
Expand All @@ -266,12 +321,13 @@ func (v *stagingValidatorImpl) Validate(stagedContracts []stagedContractUpdate)
return nil
}

func (v *stagingValidatorImpl) checkAllStaged() map[common.StringLocation]error {
errs := make(map[common.StringLocation]error)
func (v *stagingValidatorImpl) checkAllStaged() map[common.Location]error {
errs := make(map[common.Location]error)
for _, contract := range v.stagedContracts {
_, err := v.checkContract(contract.SourceLocation)
location := contract.DeployLocation
_, err := v.checkContract(location)
if err != nil {
errs[contract.SourceLocation] = err
errs[location] = err
}
}

Expand All @@ -280,6 +336,8 @@ func (v *stagingValidatorImpl) checkAllStaged() map[common.StringLocation]error
// Note: nodes are not visited more than once so cyclic imports are not an issue
// They will be reported, however, by the checker, if they do exist
for _, contract := range v.stagedContracts {
location := contract.DeployLocation

// Create a set of all dependencies
missingDependencies := make([]common.AddressLocation, 0)
v.forEachDependency(contract, func(dependency common.Location) {
Expand All @@ -293,15 +351,15 @@ func (v *stagingValidatorImpl) checkAllStaged() map[common.StringLocation]error

if len(missingDependencies) > 0 {
// If an error exists, only overwrite if it is a checking error
existingErr, ok := errs[contract.SourceLocation]
existingErr, ok := errs[location]
if ok {
var existingCheckingErr *sema.CheckerError
if !errors.As(existingErr, &existingCheckingErr) {
continue
}
}

errs[contract.SourceLocation] = &missingDependenciesError{
errs[location] = &missingDependenciesError{
MissingContracts: missingDependencies,
}
}
Expand All @@ -310,29 +368,23 @@ func (v *stagingValidatorImpl) checkAllStaged() map[common.StringLocation]error
return errs
}

func (v *stagingValidatorImpl) validateContractUpdate(contract stagedContractUpdate, checker *sema.Checker) (err error) {
func (v *stagingValidatorImpl) validateContractUpdate(
contract stagedContractUpdate,
checker *sema.Checker,
typeRequirements *migrations.LegacyTypeRequirements,
) (err error) {
// Gracefully recover from panics
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic during contract update validation: %v", r)
}
}()

// Get the account for the contract
address := flowsdk.Address(contract.DeployLocation.Address)

var account *flowsdk.Account
err = withRetry(func() error {
account, err = v.flow.GetAccount(context.Background(), address)
return err
})
if err != nil {
return fmt.Errorf("failed to get account: %w", err)
}
location := contract.DeployLocation
contractName := location.Name

// Get the target contract old code
contractName := contract.DeployLocation.Name
contractCode, ok := account.Contracts[contractName]
contractCode, ok := v.oldCodes[location]
if !ok {
return fmt.Errorf("old contract code not found for contract: %s", contractName)
}
Expand All @@ -348,7 +400,7 @@ func (v *stagingValidatorImpl) validateContractUpdate(contract stagedContractUpd

// Check if contract code is valid according to Cadence V1 Update Checker
validator := stdlib.NewCadenceV042ToV1ContractUpdateValidator(
contract.SourceLocation,
location,
contractName,
&accountContractNamesProviderImpl{
resolverFunc: func(address common.Address) ([]string, error) {
Expand All @@ -371,9 +423,6 @@ func (v *stagingValidatorImpl) validateContractUpdate(contract stagedContractUpd
return fmt.Errorf("unsupported network: %s", v.flow.Network().Name)
}

// TODO: extract type requirements from the old contracts
typeRequirements := &migrations.LegacyTypeRequirements{}

validator.WithUserDefinedTypeChangeChecker(
migrations.NewUserDefinedTypeChangeCheckerFunc(chainId, typeRequirements),
)
Expand Down Expand Up @@ -583,19 +632,11 @@ func (v *stagingValidatorImpl) resolveLocation(
for i := range resolvedLocations {
identifier := identifiers[i]

var resolvedLocation common.Location
resovledAddrLocation := common.AddressLocation{
resolvedLocation := common.AddressLocation{
Address: addressLocation.Address,
Name: identifier.Identifier,
}

// If the contract one of our staged contract updates, use the source location
if stagedUpdate, ok := v.stagedContracts[resovledAddrLocation]; ok {
resolvedLocation = stagedUpdate.SourceLocation
} else {
resolvedLocation = resovledAddrLocation
}

resolvedLocations[i] = runtime.ResolvedLocation{
Location: resolvedLocation,
Identifiers: []runtime.Identifier{identifier},
Expand Down Expand Up @@ -765,7 +806,7 @@ func (v *stagingValidatorImpl) forEachDependency(
}
}
}
traverse(contract.SourceLocation)
traverse(contract.DeployLocation)
}

// Helper for pretty printing errors
Expand Down
Loading

0 comments on commit 5dd8fdd

Please sign in to comment.