Skip to content

Commit

Permalink
Do a garbage collect after the interval even if nothing is running.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Nov 30, 2023
1 parent 7f46920 commit 6b769bc
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 18 deletions.
6 changes: 4 additions & 2 deletions execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,10 +700,12 @@ def put(self, item):
self.server.queue_updated()
self.not_empty.notify()

def get(self):
def get(self, timeout=None):
with self.not_empty:
while len(self.queue) == 0:
self.not_empty.wait()
self.not_empty.wait(timeout=timeout)
if timeout is not None and len(self.queue) == 0:
return None
item = heapq.heappop(self.queue)
i = self.task_counter
self.currently_running[i] = copy.deepcopy(item)
Expand Down
45 changes: 29 additions & 16 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,23 +89,36 @@ def cuda_malloc_warning():
def prompt_worker(q, server):
e = execution.PromptExecutor(server)
last_gc_collect = 0
need_gc = False
gc_collect_interval = 10.0

while True:
item, item_id = q.get()
execution_start_time = time.perf_counter()
prompt_id = item[1]
e.execute(item[2], prompt_id, item[3], item[4])
q.task_done(item_id, e.outputs_ui)
if server.client_id is not None:
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)

current_time = time.perf_counter()
execution_time = current_time - execution_start_time
print("Prompt executed in {:.2f} seconds".format(execution_time))
if (current_time - last_gc_collect) > 10.0:
gc.collect()
comfy.model_management.soft_empty_cache()
last_gc_collect = current_time
print("gc collect")
timeout = None
if need_gc:
timeout = max(gc_collect_interval - (current_time - last_gc_collect), 0.0)

queue_item = q.get(timeout=timeout)
if queue_item is not None:
item, item_id = queue_item
execution_start_time = time.perf_counter()
prompt_id = item[1]
e.execute(item[2], prompt_id, item[3], item[4])
need_gc = True
q.task_done(item_id, e.outputs_ui)
if server.client_id is not None:
server.send_sync("executing", { "node": None, "prompt_id": prompt_id }, server.client_id)

current_time = time.perf_counter()
execution_time = current_time - execution_start_time
print("Prompt executed in {:.2f} seconds".format(execution_time))

if need_gc:
current_time = time.perf_counter()
if (current_time - last_gc_collect) > gc_collect_interval:
gc.collect()
comfy.model_management.soft_empty_cache()
last_gc_collect = current_time
need_gc = False

async def run(server, address='', port=8188, verbose=True, call_on_start=None):
await asyncio.gather(server.start(address, port, verbose, call_on_start), server.publish_loop())
Expand Down

0 comments on commit 6b769bc

Please sign in to comment.