diff --git a/core/cli/api/p2p.go b/core/cli/api/p2p.go new file mode 100644 index 00000000000..a2ecfe3febd --- /dev/null +++ b/core/cli/api/p2p.go @@ -0,0 +1,80 @@ +package cli_api + +import ( + "context" + "fmt" + "net" + "os" + "strings" + + "github.com/mudler/LocalAI/core/p2p" + "github.com/mudler/edgevpn/pkg/node" + + "github.com/rs/zerolog/log" +) + +func StartP2PStack(ctx context.Context, address, token, networkID string, federated bool) error { + var n *node.Node + // Here we are avoiding creating multiple nodes: + // - if the federated mode is enabled, we create a federated node and expose a service + // - exposing a service creates a node with specific options, and we don't want to create another node + + // If the federated mode is enabled, we expose a service to the local instance running + // at r.Address + if federated { + _, port, err := net.SplitHostPort(address) + if err != nil { + return err + } + + // Here a new node is created and started + // and a service is exposed by the node + node, err := p2p.ExposeService(ctx, "localhost", port, token, p2p.NetworkID(networkID, p2p.FederatedID)) + if err != nil { + return err + } + + if err := p2p.ServiceDiscoverer(ctx, node, token, p2p.NetworkID(networkID, p2p.FederatedID), nil, false); err != nil { + return err + } + + n = node + } + + // If the p2p mode is enabled, we start the service discovery + if token != "" { + // If a node wasn't created previously, create it + if n == nil { + node, err := p2p.NewNode(token) + if err != nil { + return err + } + err = node.Start(ctx) + if err != nil { + return fmt.Errorf("starting new node: %w", err) + } + n = node + } + + // Attach a ServiceDiscoverer to the p2p node + log.Info().Msg("Starting P2P server discovery...") + if err := p2p.ServiceDiscoverer(ctx, n, token, p2p.NetworkID(networkID, p2p.WorkerID), func(serviceID string, node p2p.NodeData) { + var tunnelAddresses []string + for _, v := range p2p.GetAvailableNodes(p2p.NetworkID(networkID, p2p.WorkerID)) { + if v.IsOnline() { + tunnelAddresses = append(tunnelAddresses, v.TunnelAddress) + } else { + log.Info().Msgf("Node %s is offline", v.ID) + } + } + tunnelEnvVar := strings.Join(tunnelAddresses, ",") + + os.Setenv("LLAMACPP_GRPC_SERVERS", tunnelEnvVar) + log.Debug().Msgf("setting LLAMACPP_GRPC_SERVERS to %s", tunnelEnvVar) + }, true); err != nil { + return err + } + } + + return nil +} diff --git a/core/cli/run.go b/core/cli/run.go index 4fbcd73c93a..4f24182e8c3 100644 --- a/core/cli/run.go +++ b/core/cli/run.go @@ -3,11 +3,10 @@ package cli import ( "context" "fmt" - "net" - "os" "strings" "time" + cli_api "github.com/mudler/LocalAI/core/cli/api" cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http" @@ -115,52 +114,12 @@ func (r *RunCMD) Run(ctx *cliContext.Context) error { fmt.Printf("export TOKEN=\"%s\"\nlocal-ai worker p2p-llama-cpp-rpc\n", token) } opts = append(opts, config.WithP2PToken(token)) - - node, err := p2p.NewNode(token) - if err != nil { - return err - } - nodeContext := context.Background() - - err = node.Start(nodeContext) - if err != nil { - return fmt.Errorf("starting new node: %w", err) - } - - log.Info().Msg("Starting P2P server discovery...") - if err := p2p.ServiceDiscoverer(nodeContext, node, token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID), func(serviceID string, node p2p.NodeData) { - var tunnelAddresses []string - for _, v := range p2p.GetAvailableNodes(p2p.NetworkID(r.Peer2PeerNetworkID, p2p.WorkerID)) { - if v.IsOnline() { - tunnelAddresses = append(tunnelAddresses, v.TunnelAddress) - } else { - log.Info().Msgf("Node %s is offline", v.ID) - } - } - tunnelEnvVar := strings.Join(tunnelAddresses, ",") - - os.Setenv("LLAMACPP_GRPC_SERVERS", tunnelEnvVar) - log.Debug().Msgf("setting LLAMACPP_GRPC_SERVERS to %s", tunnelEnvVar) - }, true); err != nil { - return err - } } - if r.Federated { - _, port, err := net.SplitHostPort(r.Address) - if err != nil { - return err - } - fedCtx := context.Background() - - node, err := p2p.ExposeService(fedCtx, "localhost", port, token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.FederatedID)) - if err != nil { - return err - } + backgroundCtx := context.Background() - if err := p2p.ServiceDiscoverer(fedCtx, node, token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.FederatedID), nil, false); err != nil { - return err - } + if err := cli_api.StartP2PStack(backgroundCtx, r.Address, token, r.Peer2PeerNetworkID, r.Federated); err != nil { + return err } idleWatchDog := r.EnableWatchdogIdle