Skip to content

Commit e835cf4

Browse files
authored
Add more parameters to vLLM tutorial (#53)
* Remove unused UserData class * Add more vLLM parameters * Fix params, add others
1 parent 5d70fe7 commit e835cf4

2 files changed

Lines changed: 11 additions & 6 deletions

File tree

Quick_Deploy/vLLM/client.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,6 @@
3636
from tritonclient.utils import *
3737

3838

39-
class UserData:
40-
def __init__(self):
41-
self._completed_requests = queue.Queue()
42-
4339

4440
def create_request(prompt, stream, request_id, sampling_parameters, model_name, send_parameters_as_tensor=True):
4541
inputs = []

Quick_Deploy/vLLM/model_repository/vllm/1/model.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,12 +127,21 @@ def get_sampling_params_dict(self, params_json):
127127
params_dict = json.loads(params_json)
128128

129129
# Special parsing for the supported sampling parameters
130-
# TODO: Add more parameters if needed
131-
float_keys = ["temperature", "top_p"]
130+
bool_keys = ["ignore_eos", "skip_special_tokens", "use_beam_search"]
131+
for k in bool_keys:
132+
if k in params_dict:
133+
params_dict[k] = bool(params_dict[k])
134+
135+
float_keys = ["frequency_penalty", "length_penalty", "presence_penalty", "temperature", "top_p"]
132136
for k in float_keys:
133137
if k in params_dict:
134138
params_dict[k] = float(params_dict[k])
135139

140+
int_keys = ["best_of", "max_tokens", "n", "top_k"]
141+
for k in int_keys:
142+
if k in params_dict:
143+
params_dict[k] = int(params_dict[k])
144+
136145
return params_dict
137146

138147
def create_response(self, vllm_output):

0 commit comments

Comments
 (0)