diff --git a/internal/dependencymanager/dependencyinstaller.go b/internal/dependencymanager/dependencyinstaller.go index 5faebe292..ad5225e87 100644 --- a/internal/dependencymanager/dependencyinstaller.go +++ b/internal/dependencymanager/dependencyinstaller.go @@ -103,12 +103,14 @@ func (f *dependencyManagerFlagsCollection) AddToCommand(cmd *cobra.Command) { } type DependencyInstaller struct { - Gateways map[string]gateway.Gateway - Logger output.Logger - State *flowkit.State - SkipDeployments bool - SkipAlias bool - logs categorizedLogs + Gateways map[string]gateway.Gateway + Logger output.Logger + State *flowkit.State + SkipDeployments bool + SkipAlias bool + logs categorizedLogs + initialContractsState config.Contracts + dependencies map[string]config.Dependency } // NewDependencyInstaller creates a new instance of DependencyInstaller @@ -135,11 +137,13 @@ func NewDependencyInstaller(logger output.Logger, state *flowkit.State, flags de } return &DependencyInstaller{ - Gateways: gateways, - Logger: logger, - State: state, - SkipDeployments: flags.skipDeployments, - SkipAlias: flags.skipAlias, + Gateways: gateways, + Logger: logger, + State: state, + SkipDeployments: flags.skipDeployments, + SkipAlias: flags.skipAlias, + initialContractsState: *state.Contracts(), // Copy at this point in time + dependencies: make(map[string]config.Dependency), }, nil } @@ -152,6 +156,8 @@ func (di *DependencyInstaller) Install() error { } } + di.checkForConflictingContracts() + err := di.State.SaveDefault() if err != nil { return fmt.Errorf("error saving state: %w", err) @@ -188,6 +194,8 @@ func (di *DependencyInstaller) Add(depSource, customName string) error { return fmt.Errorf("error processing dependency: %w", err) } + di.checkForConflictingContracts() + err = di.State.SaveDefault() if err != nil { return fmt.Errorf("error saving state: %w", err) @@ -198,12 +206,48 @@ func (di *DependencyInstaller) Add(depSource, customName string) error { return nil } +func (di *DependencyInstaller) addDependency(dep config.Dependency) error { + if _, exists := di.dependencies[dep.Source.Address.String()]; exists { + return nil + } + + di.dependencies[dep.Source.Address.String()] = dep + + return nil + +} + +// checkForConflictingContracts checks if any of the dependencies conflict with contracts already in the state +func (di *DependencyInstaller) checkForConflictingContracts() { + for _, dependency := range di.dependencies { + _, err := di.initialContractsState.ByName(dependency.Name) + if err != nil { + if !isCoreContract(dependency.Name) { + msg := util.MessageWithEmojiPrefix("❌", fmt.Sprintf("Contract named %s already exists in flow.json", dependency.Name)) + di.logs.issues = append(di.logs.issues, msg) + } + } + } +} + func (di *DependencyInstaller) processDependency(dependency config.Dependency) error { depAddress := flowsdk.HexToAddress(dependency.Source.Address.String()) return di.fetchDependencies(dependency.Source.NetworkName, depAddress, dependency.Name, dependency.Source.ContractName) } func (di *DependencyInstaller) fetchDependencies(networkName string, address flowsdk.Address, assignedName, contractName string) error { + err := di.addDependency(config.Dependency{ + Name: assignedName, + Source: config.Source{ + NetworkName: networkName, + Address: address, + ContractName: contractName, + }, + }) + if err != nil { + return fmt.Errorf("error adding dependency: %w", err) + } + ctx := context.Background() account, err := di.Gateways[networkName].GetAccount(ctx, address) if err != nil { @@ -307,20 +351,6 @@ func isCoreContract(contractName string) bool { return false } -// checkForContractConflicts checks if a contract with the same name already exists in the state and adds a warning -func (di *DependencyInstaller) checkForContractConflicts(contractName string) error { - _, err := di.State.Contracts().ByName(contractName) - if err != nil { - return nil - } else { - if !isCoreContract(contractName) { - msg := util.MessageWithEmojiPrefix("❌", fmt.Sprintf("Contract named %s already exists in flow.json", contractName)) - di.logs.issues = append(di.logs.issues, msg) - } - return nil - } -} - func (di *DependencyInstaller) handleFoundContract(networkName, contractAddr, assignedName, contractName string, program *project.Program) error { hash := sha256.New() hash.Write(program.CodeWithUnprocessedImports()) @@ -348,14 +378,16 @@ func (di *DependencyInstaller) handleFoundContract(networkName, contractAddr, as } } - //// This needs to happen before dependency state is updated - err := di.checkForContractConflicts(assignedName) - if err != nil { - di.Logger.Error(fmt.Sprintf("Error checking for contract conflicts: %v", err)) - return err + // Needs to happen before handleFileSystem + if !di.contractFileExists(contractAddr, contractName) { + err := di.handleAdditionalDependencyTasks(networkName, contractName) + if err != nil { + di.Logger.Error(fmt.Sprintf("Error handling additional dependency tasks: %v", err)) + return err + } } - err = di.handleFileSystem(contractAddr, contractName, contractData, networkName) + err := di.handleFileSystem(contractAddr, contractName, contractData, networkName) if err != nil { return fmt.Errorf("error handling file system: %w", err) } @@ -366,9 +398,13 @@ func (di *DependencyInstaller) handleFoundContract(networkName, contractAddr, as return err } + return nil +} + +func (di *DependencyInstaller) handleAdditionalDependencyTasks(networkName, contractName string) error { // If the contract is not a core contract and the user does not want to skip deployments, then prompt for a deployment if !di.SkipDeployments && !isCoreContract(contractName) { - err = di.updateDependencyDeployment(contractName) + err := di.updateDependencyDeployment(contractName) if err != nil { di.Logger.Error(fmt.Sprintf("Error updating deployment: %v", err)) return err @@ -380,7 +416,7 @@ func (di *DependencyInstaller) handleFoundContract(networkName, contractAddr, as // If the contract is not a core contract and the user does not want to skip aliasing, then prompt for an alias if !di.SkipAlias && !isCoreContract(contractName) { - err = di.updateDependencyAlias(contractName, networkName) + err := di.updateDependencyAlias(contractName, networkName) if err != nil { di.Logger.Error(fmt.Sprintf("Error updating alias: %v", err)) return err diff --git a/internal/dependencymanager/dependencyinstaller_test.go b/internal/dependencymanager/dependencyinstaller_test.go index 92d64427e..57be1be5c 100644 --- a/internal/dependencymanager/dependencyinstaller_test.go +++ b/internal/dependencymanager/dependencyinstaller_test.go @@ -78,6 +78,7 @@ func TestDependencyInstallerInstall(t *testing.T) { State: state, SkipDeployments: true, SkipAlias: true, + dependencies: make(map[string]config.Dependency), } err := di.Install() @@ -122,6 +123,7 @@ func TestDependencyInstallerAdd(t *testing.T) { State: state, SkipDeployments: true, SkipAlias: true, + dependencies: make(map[string]config.Dependency), } sourceStr := fmt.Sprintf("emulator://%s.%s", serviceAddress.String(), tests.ContractHelloString.Name)