|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +import argparse |
| 3 | +import asyncio |
| 4 | +import contextlib |
| 5 | +import gc |
| 6 | +import os |
| 7 | +import random |
| 8 | +import time |
| 9 | +from collections.abc import Iterable |
| 10 | +from dataclasses import dataclass |
| 11 | +from functools import lru_cache |
| 12 | +from typing import Dict, List |
| 13 | + |
| 14 | +import tqdm |
| 15 | +import uvloop |
| 16 | +from transformers import AutoTokenizer, PreTrainedTokenizerBase |
| 17 | + |
| 18 | +from vllm import SamplingParams |
| 19 | +from vllm.core.block_manager import SelfAttnBlockSpaceManager |
| 20 | +from vllm.engine.arg_utils import AsyncEngineArgs |
| 21 | +from vllm.engine.async_llm_engine import AsyncLLMEngine |
| 22 | +from vllm.engine.metrics import PerfMetricCSVLogger |
| 23 | +from vllm.engine.metrics_types import Stats |
| 24 | +from vllm.entrypoints.openai.api_server import ( |
| 25 | + build_async_engine_client_from_engine_args) |
| 26 | +from vllm.logger import init_logger |
| 27 | +from vllm.model_executor.layers.sampler import SamplerOutput |
| 28 | +from vllm.platforms.nvml_power_monitor import measure_power |
| 29 | +from vllm.platforms.nvml_utils import nvml_get_available_freq, nvml_lock_freq |
| 30 | +from vllm.sequence import (ExecuteModelRequest, SequenceData, |
| 31 | + SequenceGroupMetadata) |
| 32 | +from vllm.utils import FlexibleArgumentParser, cdiv, random_uuid |
| 33 | + |
| 34 | +logger = init_logger(__name__) |
| 35 | + |
| 36 | + |
| 37 | +@contextlib.contextmanager |
| 38 | +def disable_python_gc(): |
| 39 | + was_enabled = gc.isenabled() |
| 40 | + gc.disable() |
| 41 | + try: |
| 42 | + yield |
| 43 | + finally: |
| 44 | + if was_enabled: |
| 45 | + gc.enable() |
| 46 | + gc.collect() |
| 47 | + |
| 48 | + |
| 49 | +@contextlib.contextmanager |
| 50 | +def log_perf_metric(filename: str): |
| 51 | + perf_logger = None |
| 52 | + try: |
| 53 | + perf_logger = PerfMetricCSVLogger( |
| 54 | + filename=filename, disable_periodic_persist_to_disk=True) |
| 55 | + yield perf_logger |
| 56 | + finally: |
| 57 | + if perf_logger: |
| 58 | + perf_logger.persist_to_disk() |
| 59 | + |
| 60 | + |
| 61 | +def cyclic_generator(lst: Iterable): |
| 62 | + while True: |
| 63 | + yield from lst |
| 64 | + |
| 65 | + |
| 66 | +@dataclass |
| 67 | +class BenchmarkBatchParam: |
| 68 | + prefill_input_lens: List[int] |
| 69 | + decode_input_lens: List[int] |
| 70 | + log_dir: str |
| 71 | + gpu_freq_mhz: int |
| 72 | + delay_time_s: float = 0.0 # Delay before issuing each batch. |
| 73 | + |
| 74 | + # Run terminates when both reaches |
| 75 | + min_num_iters: int = 7 |
| 76 | + min_seconds: int = 1 |
| 77 | + |
| 78 | + |
| 79 | +async def benchmark_batch(vllm_args: argparse.Namespace, |
| 80 | + params: Iterable[BenchmarkBatchParam], |
| 81 | + latencies: List): |
| 82 | + """ |
| 83 | + Feed executor with ExecuteModelRequest similar to how it's done in |
| 84 | + `AsyncLLMEngine` |
| 85 | + """ |
| 86 | + random.seed(vllm_args.seed) |
| 87 | + |
| 88 | + engine_args = AsyncEngineArgs.from_cli_args(vllm_args) |
| 89 | + disable_frontend_multiprocessing = True |
| 90 | + assert disable_frontend_multiprocessing, \ |
| 91 | + ''' |
| 92 | + setting disable_frontend_multiprocessing=True will use |
| 93 | + MQLLMEngineClient instead of AsyncLLMEngine, which is not supported |
| 94 | + for now' |
| 95 | + ''' |
| 96 | + |
| 97 | + tokenizer = AutoTokenizer.from_pretrained( |
| 98 | + vllm_args.model, trust_remote_code=vllm_args.trust_remote_code) |
| 99 | + |
| 100 | + async with build_async_engine_client_from_engine_args( |
| 101 | + engine_args, disable_frontend_multiprocessing) as llm: |
| 102 | + assert isinstance(llm, AsyncLLMEngine) |
| 103 | + |
| 104 | + executor = llm.engine.model_executor |
| 105 | + pipeline_parallel_size \ |
| 106 | + = llm.engine.parallel_config.pipeline_parallel_size |
| 107 | + |
| 108 | + # Keep `pipeline_parallel_size` instances of `execute_model_async()` |
| 109 | + # running concurrently |
| 110 | + for param in tqdm.tqdm(params): |
| 111 | + # Construct requests eagarly so request creation does not block the |
| 112 | + # critical path. Create more than `param.min_num_iters` requests to |
| 113 | + # prevent wrap around and send same request multiple times and |
| 114 | + # affecting the cache hit rate |
| 115 | + requests = [ |
| 116 | + build_dummy_execute_model_request(llm, tokenizer, param) |
| 117 | + for _ in range(param.min_num_iters * 2) |
| 118 | + ] |
| 119 | + request_gen = cyclic_generator(requests) |
| 120 | + |
| 121 | + initial_requests = [ |
| 122 | + next(request_gen) for ve in range(pipeline_parallel_size) |
| 123 | + ] |
| 124 | + requests_in_progress = [ |
| 125 | + asyncio.create_task(executor.execute_model_async(req)) |
| 126 | + for req in initial_requests |
| 127 | + ] |
| 128 | + |
| 129 | + # The `PerfMetricCSVLogger` of `LLMEngine` will not be invoked when |
| 130 | + # we directly call the executor, so we create another logger |
| 131 | + # outside of it |
| 132 | + energy_log = os.path.join(param.log_dir, 'power_log.csv') |
| 133 | + perf_log = os.path.join(param.log_dir, 'perf_metric.csv') |
| 134 | + with disable_python_gc(), \ |
| 135 | + measure_power(energy_log), \ |
| 136 | + log_perf_metric(perf_log) as perf_metric_logger, \ |
| 137 | + nvml_lock_freq(param.gpu_freq_mhz): |
| 138 | + time_start = time.perf_counter() |
| 139 | + iter = 0 |
| 140 | + sample_latencies = [] |
| 141 | + while True: |
| 142 | + done, _ = await asyncio.wait( |
| 143 | + requests_in_progress, |
| 144 | + return_when=asyncio.FIRST_COMPLETED) |
| 145 | + for _ in range(pipeline_parallel_size): |
| 146 | + await asyncio.sleep(0) |
| 147 | + for task in done: |
| 148 | + output = task.result() |
| 149 | + time_ranges = get_stats( |
| 150 | + llm, |
| 151 | + output).batch_execute_timing_iter.time_ranges[0] |
| 152 | + sample_latencies.append(time_ranges.end - |
| 153 | + time_ranges.start) |
| 154 | + perf_metric_logger.log(get_stats(llm, output)) |
| 155 | + |
| 156 | + # Insert new req |
| 157 | + virtual_engine = requests_in_progress.index(task) |
| 158 | + req = next(request_gen) |
| 159 | + if param.delay_time_s > 0: |
| 160 | + await asyncio.sleep(param.delay_time_s) |
| 161 | + requests_in_progress[ |
| 162 | + virtual_engine] = asyncio.create_task( |
| 163 | + executor.execute_model_async(req)) |
| 164 | + |
| 165 | + iter += 1 |
| 166 | + if (iter >= param.min_num_iters |
| 167 | + and time.perf_counter() - time_start |
| 168 | + > param.min_seconds): |
| 169 | + logger.info( |
| 170 | + 'Run terminated on %d iters and %d seconds', |
| 171 | + param.min_num_iters, param.min_seconds) |
| 172 | + break |
| 173 | + |
| 174 | + # Cleanup |
| 175 | + _ = await asyncio.wait(requests_in_progress, |
| 176 | + return_when=asyncio.ALL_COMPLETED) |
| 177 | + latencies.append(sample_latencies) |
| 178 | + |
| 179 | + |
| 180 | +def build_dummy_execute_model_request( |
| 181 | + llm: AsyncLLMEngine, tokenizer: PreTrainedTokenizerBase, |
| 182 | + benchmark_batch_param: BenchmarkBatchParam): |
| 183 | + seq_group_metadata_list: List[SequenceGroupMetadata] = [] |
| 184 | + for input_len in benchmark_batch_param.prefill_input_lens: |
| 185 | + seq_group_metadata_list.append( |
| 186 | + build_dummy_seq_group_metadata(llm, |
| 187 | + tokenizer, |
| 188 | + input_len, |
| 189 | + is_prompt=True)) |
| 190 | + for input_len in benchmark_batch_param.decode_input_lens: |
| 191 | + seq_group_metadata_list.append( |
| 192 | + build_dummy_seq_group_metadata(llm, |
| 193 | + tokenizer, |
| 194 | + input_len, |
| 195 | + is_prompt=False)) |
| 196 | + return ExecuteModelRequest(seq_group_metadata_list=seq_group_metadata_list, |
| 197 | + # All the rest stay as default |
| 198 | + ) |
| 199 | + |
| 200 | + |
| 201 | +@lru_cache(maxsize=16384) |
| 202 | +def build_dummy_seq_group_metadata( |
| 203 | + llm: AsyncLLMEngine, |
| 204 | + tokenizer: PreTrainedTokenizerBase, |
| 205 | + input_len: int, |
| 206 | + is_prompt: bool, |
| 207 | +) -> SequenceGroupMetadata: |
| 208 | + """ |
| 209 | + Send requests as new every time (no `SequenceGroupMetadataDelta`). |
| 210 | + """ |
| 211 | + seq = SequenceData.from_seqs([ |
| 212 | + random.randint(0, tokenizer.vocab_size - 1) for _ in range(input_len) |
| 213 | + ]) |
| 214 | + if not is_prompt: |
| 215 | + seq.update_num_computed_tokens(input_len - 1) |
| 216 | + |
| 217 | + seq_data: Dict[int, SequenceData] = {0: seq} |
| 218 | + |
| 219 | + # Same as in `benchmark_throughput.py` |
| 220 | + sampling_params = SamplingParams( |
| 221 | + n=1, |
| 222 | + temperature=1.0, |
| 223 | + top_p=1.0, |
| 224 | + ignore_eos=True, |
| 225 | + max_tokens=2048, # TODO: remove this hardcoded value |
| 226 | + ) |
| 227 | + |
| 228 | + # Build a random block mapping |
| 229 | + # TODO: try sequential block tables |
| 230 | + block_manager = llm.engine.scheduler[0].block_manager |
| 231 | + assert isinstance(block_manager, SelfAttnBlockSpaceManager) |
| 232 | + block_size = block_manager.block_size |
| 233 | + num_required_blocks = cdiv(input_len, block_size) |
| 234 | + block_tables: Dict[int, List[int]] = { |
| 235 | + 0: [ |
| 236 | + random.randint(0, block_manager.num_total_gpu_blocks) |
| 237 | + for _ in range(num_required_blocks) |
| 238 | + ] |
| 239 | + } |
| 240 | + |
| 241 | + # For simplicity, assume all prefill and decode requires sampling. In |
| 242 | + # practice, if prefill is chunked, only the last chunk requires sampling |
| 243 | + do_sample = True |
| 244 | + |
| 245 | + ret = SequenceGroupMetadata( |
| 246 | + request_id=random_uuid(), |
| 247 | + is_prompt=is_prompt, |
| 248 | + seq_data=seq_data, |
| 249 | + sampling_params=sampling_params, |
| 250 | + block_tables=block_tables, |
| 251 | + do_sample=do_sample, |
| 252 | + # Assume the rest doesn't matter and uses defaults |
| 253 | + ) |
| 254 | + return ret |
| 255 | + |
| 256 | + |
| 257 | +def get_stats(llm: AsyncLLMEngine, model_output: List[SamplerOutput]) -> Stats: |
| 258 | + return llm.engine._get_stats( |
| 259 | + scheduler_outputs=None, |
| 260 | + model_output=model_output, |
| 261 | + finished_before=None, |
| 262 | + skip=None, |
| 263 | + ) |
| 264 | + |
| 265 | + |
| 266 | +if __name__ == '__main__': |
| 267 | + parser = FlexibleArgumentParser(description="Benchmark per-batch.") |
| 268 | + parser = AsyncEngineArgs.add_cli_args(parser) |
| 269 | + vllm_args = ("--model meta-llama/Llama-3.1-8B-Instruct " |
| 270 | + f"-tp {1} " |
| 271 | + f"-pp {1} " |
| 272 | + "--collect-detailed-traces worker").split() |
| 273 | + vllm_args = parser.parse_args(vllm_args) |
| 274 | + vllm_args.max_model_len = 10000 |
| 275 | + |
| 276 | + benchmark_batch_param = BenchmarkBatchParam( |
| 277 | + prefill_input_lens=[1024, 1024], |
| 278 | + decode_input_lens=[128 for _ in range(512)], |
| 279 | + log_dir='./logs', |
| 280 | + gpu_freq_mhz=nvml_get_available_freq()[0], |
| 281 | + delay_time_s=0.0, |
| 282 | + ) |
| 283 | + |
| 284 | + uvloop.run(benchmark_batch(vllm_args, [benchmark_batch_param])) |
0 commit comments