Skip to content

[PD disaggregate] optimize splitwise scheduler #2685

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions fastdeploy/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ def _zmq_send_generated_tokens(self):
while self.running:
try:
results = self.scheduler.get_results()
if len(results) == 0:
time.sleep(0.001)
for request_id, contents in results.items():
for result in contents:
self.zmq_server.send_multipart(request_id, result)
Expand Down
35 changes: 24 additions & 11 deletions fastdeploy/scheduler/splitwise_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,12 +260,13 @@ class ResultReader(object):
ResultReader use an async thread to continue get infer result from redis
"""

def __init__(self, client, idx, batch=200, ttl=900):
def __init__(self, client, idx, batch=200, ttl=900, group=""):
self.idx = idx
self.batch = batch
self.client = client
self.data = deque()
self.ttl = ttl
self.group = group

self.reqs = dict()
self.out_buffer = dict()
Expand Down Expand Up @@ -380,15 +381,18 @@ def sync_results(self, keys):
fetch infer results from redis for the give keys
"""
total = 0
if self.group != "":
keys = [self.group]
for key in keys:
#logger.info(f"Sync Results from Redis {key}")
results = self.client.rpop(key, self.batch)
if results is None or len(results) == 0:
continue
#logger.info(f"Rpop {self.idx}: {len(results)}")
#logger.info(f"Rpop {key} {self.idx}: {len(results)}")
total += len(results)
for result in results:
try:
#logger.info(f"Scheduler Get Results: {result}")
# logger.info(f"Scheduler Get Results: {result.request_id}")
data = orjson.loads(result)
result = RequestOutput.from_dict(data)
self.data.appendleft(result)
Expand Down Expand Up @@ -425,8 +429,9 @@ def start(self):
start backup threads
"""
for i in range(self.reader_parallel):
group = f"{self.nodeid}-{i}"
reader = ResultReader(self.client, i, self.reader_batch_size,
self.ttl)
self.ttl, group)
self.readers.append(reader)

self.clear_expired_nodes_thread = threading.Thread(
Expand Down Expand Up @@ -481,15 +486,16 @@ def loop_schedule(self):

reader = self.readers[reader_idx]
reader.add_req(req)
group = self.readers[reader_idx].group
reader_idx = (reader_idx + 1) % len(self.readers)

self.schedule(req, pnodes, dnodes, mnodes)
self.schedule(req, pnodes, dnodes, mnodes, group)
except IndexError:
continue
except Exception as e:
logger.error(f"APIScheduler Schedule req error: {str(e)}")

def schedule(self, req, pnodes, dnodes, mnodes):
def schedule(self, req, pnodes, dnodes, mnodes, group=""):
"""
schedule an req to according redis node queue
"""
Expand All @@ -498,7 +504,9 @@ def schedule(self, req, pnodes, dnodes, mnodes):
pnode = self.select_pd(req, pnodes, "prefill")
if pnode.role == "mixed":
req.disaggregate_info = None
req_str = orjson.dumps(req.to_dict())
req_dict = req.to_dict()
req_dict["group"] = group
req_str = orjson.dumps(req_dict)
pkey = f"ReqQ_{pnode.nodeid}"
#logger.info(f"Schedule Req {req_str} to Mixed")
self.client.lpush(pkey, req_str)
Expand All @@ -518,7 +526,9 @@ def schedule(self, req, pnodes, dnodes, mnodes):
disaggregated["transfer_protocol"] = transfer_protocol[0]
req.disaggregate_info = disaggregated
pkey, dkey = f"ReqQ_{pnode.nodeid}", f"ReqQ_{dnode.nodeid}"
req_str = orjson.dumps(req.to_dict())
req_dict = req.to_dict()
req_dict["group"] = group
req_str = orjson.dumps(req_dict)
#logger.info(f"Schedule Req {req_str}")
self.client.lpush(dkey, req_str)
self.client.lpush(pkey, req_str)
Expand Down Expand Up @@ -634,7 +644,9 @@ def run(self):
size = len(self.data)
if size == 0:
self.cond.wait()
#qsize = size
size = min(size, self.batch)
#logger.info(f"Writer {self.idx} Queue Size: {qsize}, Cur Size: {size}")
groups = dict()
for i in range(size):
key, item = self.data.pop()
Expand Down Expand Up @@ -749,12 +761,13 @@ def select_writer(req):

for req_str in reqs:
req = orjson.loads(req_str)
group = req.get("group", "")
req = Request.from_dict(req)
writer_idx = select_writer(req)
logger.info(
f"Infer Scheduler Get Req: {req.request_id} writer idx {writer_idx}"
)
req.request_id = f"{req.request_id}#{writer_idx}"
req.request_id = f"{req.request_id}#{writer_idx}#{group}"
if self.role == "prefill" or self.role == "mixed":
self.reqs_queue.append(req)
self.node.add_req(req.request_id,
Expand Down Expand Up @@ -813,10 +826,10 @@ def put_results(self, results):

req_ids.add(result.request_id)

req_id, idx = result.request_id.split("#")
req_id, idx, group = result.request_id.split("#")
result.request_id = req_id

key = (req_id, int(idx))
key = (req_id if group == "" else group, int(idx))
if key not in groups:
groups[key] = list()

Expand Down
Loading