Skip to content

Commit

Permalink
Update parallel for topdown
Browse files Browse the repository at this point in the history
  • Loading branch information
XKTZ committed Sep 27, 2024
1 parent 406fca7 commit 9e7f846
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions src/rank_llm/rerank/listwise/reorder/top_down_reorder_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def perform(self):
# base
base = indices[: min(window_size, len(indices))]
request = ReorderRequest(self._pad(base), None)
yield request
yield [request]
base = [base[i] for i in self._unpad(base, request.result)]

if len(base) < window_size:
Expand All @@ -95,11 +95,20 @@ def perform(self):
for i in range(pivot - 1):
result.append(base[i])

requests = []
req_inds = []

# then sort others
for i in range(window_size, len(indices), window_size - 1):
request_indices = indices[i : i + window_size - 1] + [piv_item]
req_inds.append(request_indices)
request = ReorderRequest(self._pad(request_indices), None)
yield request
requests.append(request)

yield requests

for request, request_indices, i \
in zip(requests, req_inds, range(window_size, len(indices), window_size - 1)):
request_indices = [
request_indices[i]
for i in self._unpad(request_indices, request.result)
Expand All @@ -121,7 +130,7 @@ def perform(self):
# here len(indices) == top_k
request_indices = indices
request = ReorderRequest(self._pad(request_indices), None)
yield request
yield [request]
indices = [
request_indices[i] for i in self._unpad(request_indices, request.result)
]
Expand Down Expand Up @@ -152,8 +161,8 @@ def multiple_sort(
finish_requests = []
for idx in left_not_sorted:
try:
req = next(progress[idx])
perm_request.append((idx, req))
reqs = next(progress[idx])
perm_request.extend([(idx, req) for req in reqs])
except StopIteration as e:
result[idx] = e.value
finish_requests.append(idx)
Expand Down

0 comments on commit 9e7f846

Please sign in to comment.