Skip to content

Commit

Permalink
Add a /free route to unload models or free all memory.
Browse files Browse the repository at this point in the history
A POST request to /free with: {"unload_models":true}
will unload models from vram.

A POST request to /free with: {"free_memory":true}
will unload models and free all cached data from the last run workflow.
  • Loading branch information
comfyanonymous committed Jan 4, 2024
1 parent 8c64935 commit 6d281b4
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 2 deletions.
20 changes: 19 additions & 1 deletion execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,14 @@ def recursive_output_delete_if_changed(prompt, old_prompt, outputs, current_item

class PromptExecutor:
def __init__(self, server):
self.server = server
self.reset()

def reset(self):
self.outputs = {}
self.object_storage = {}
self.outputs_ui = {}
self.old_prompt = {}
self.server = server

def handle_execution_error(self, prompt_id, prompt, current_outputs, executed, error, ex):
node_id = error["node_id"]
Expand Down Expand Up @@ -706,6 +709,7 @@ def __init__(self, server):
self.queue = []
self.currently_running = {}
self.history = {}
self.flags = {}
server.prompt_queue = self

def put(self, item):
Expand Down Expand Up @@ -792,3 +796,17 @@ def wipe_history(self):
def delete_history_item(self, id_to_delete):
with self.mutex:
self.history.pop(id_to_delete, None)

def set_flag(self, name, data):
with self.mutex:
self.flags[name] = data
self.not_empty.notify()

def get_flags(self, reset=True):
with self.mutex:
if reset:
ret = self.flags
self.flags = {}
return ret
else:
return self.flags.copy()
15 changes: 14 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def prompt_worker(q, server):
gc_collect_interval = 10.0

while True:
timeout = None
timeout = 1000.0
if need_gc:
timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)

Expand All @@ -118,6 +118,19 @@ def prompt_worker(q, server):
execution_time = current_time - execution_start_time
print("Prompt executed in {:.2f} seconds".format(execution_time))

flags = q.get_flags()
free_memory = flags.get("free_memory", False)

if flags.get("unload_models", free_memory):
comfy.model_management.unload_all_models()
need_gc = True
last_gc_collect = 0

if free_memory:
e.reset()
need_gc = True
last_gc_collect = 0

if need_gc:
current_time = time.perf_counter()
if (current_time - last_gc_collect) > gc_collect_interval:
Expand Down
11 changes: 11 additions & 0 deletions server.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,17 @@ async def post_interrupt(request):
nodes.interrupt_processing()
return web.Response(status=200)

@routes.post("/free")
async def post_interrupt(request):
json_data = await request.json()
unload_models = json_data.get("unload_models", False)
free_memory = json_data.get("free_memory", False)
if unload_models:
self.prompt_queue.set_flag("unload_models", unload_models)
if free_memory:
self.prompt_queue.set_flag("free_memory", free_memory)
return web.Response(status=200)

@routes.post("/history")
async def post_history(request):
json_data = await request.json()
Expand Down

0 comments on commit 6d281b4

Please sign in to comment.