forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Introduce LLM class for offline inference (#115) #28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
MitchLewis930
wants to merge
1
commit into
request_id_before
Choose a base branch
from
request_id_after
base: request_id_before
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,19 +1,15 @@ | ||
| from cacheflow.entrypoints.llm import LLM | ||
| from cacheflow.outputs import RequestOutput | ||
| from cacheflow.sampling_params import SamplingParams | ||
| from cacheflow.server.arg_utils import ( | ||
| add_server_arguments, | ||
| create_server_configs_from_args, | ||
| initialize_server_from_args, | ||
| ) | ||
| from cacheflow.server.arg_utils import ServerArgs | ||
| from cacheflow.server.llm_server import LLMServer | ||
| from cacheflow.server.ray_utils import initialize_cluster | ||
|
|
||
| __all__ = [ | ||
| "RequestOutput", | ||
| "LLM", | ||
| "SamplingParams", | ||
| "RequestOutput", | ||
| "LLMServer", | ||
| "add_server_arguments", | ||
| "create_server_configs_from_args", | ||
| "initialize_server_from_args", | ||
| "ServerArgs", | ||
| "initialize_cluster", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,62 @@ | ||
| from typing import List, Optional | ||
|
|
||
| from tqdm import tqdm | ||
|
|
||
| from cacheflow.outputs import RequestOutput | ||
| from cacheflow.sampling_params import SamplingParams | ||
| from cacheflow.server.arg_utils import ServerArgs | ||
| from cacheflow.server.llm_server import LLMServer | ||
| from cacheflow.utils import Counter | ||
|
|
||
|
|
||
| class LLM: | ||
|
|
||
| def __init__( | ||
| self, | ||
| model: str, | ||
| tensor_parallel_size: int = 1, | ||
| dtype: str = "default", | ||
| seed: int = 0, | ||
| **kwargs, | ||
| ) -> None: | ||
| if "disable_log_stats" not in kwargs: | ||
| kwargs["disable_log_stats"] = True | ||
| server_args = ServerArgs( | ||
| model=model, | ||
| tensor_parallel_size=tensor_parallel_size, | ||
| dtype=dtype, | ||
| seed=seed, | ||
| **kwargs, | ||
| ) | ||
| self.llm_server = LLMServer.from_server_args(server_args) | ||
| self.request_counter = Counter() | ||
|
|
||
| def generate( | ||
| self, | ||
| prompts: List[str], | ||
| sampling_params: Optional[SamplingParams] = None, | ||
| use_tqdm: bool = True, | ||
| ) -> List[RequestOutput]: | ||
| if sampling_params is None: | ||
| sampling_params = SamplingParams() | ||
| # Initialize tqdm. | ||
| if use_tqdm: | ||
| pbar = tqdm(total=len(prompts), desc="Processed prompts") | ||
|
|
||
| # Add requests to the server. | ||
| for prompt in prompts: | ||
| request_id = str(next(self.request_counter)) | ||
| self.llm_server.add_request(request_id, prompt, sampling_params) | ||
|
|
||
| # Run the server. | ||
| outputs: List[RequestOutput] = [] | ||
| while self.llm_server.has_unfinished_requests(): | ||
| step_outputs = self.llm_server.step() | ||
| for output in step_outputs: | ||
| if output.done: | ||
| outputs.append(output) | ||
| if use_tqdm: | ||
| pbar.update(1) | ||
| if use_tqdm: | ||
| pbar.close() | ||
| return outputs | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,74 +1,117 @@ | ||
| import argparse | ||
| from typing import Tuple | ||
| import dataclasses | ||
| from dataclasses import dataclass | ||
| from typing import Optional, Tuple | ||
|
|
||
| from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig, | ||
| SchedulerConfig) | ||
| from cacheflow.server.llm_server import LLMServer | ||
| from cacheflow.server.ray_utils import initialize_cluster | ||
|
|
||
| _GiB = 1 << 30 | ||
|
|
||
| @dataclass | ||
| class ServerArgs: | ||
| model: str | ||
| download_dir: Optional[str] = None | ||
| use_np_weights: bool = False | ||
| use_dummy_weights: bool = False | ||
| dtype: str = "default" | ||
| seed: int = 0 | ||
| use_ray: bool = False | ||
| pipeline_parallel_size: int = 1 | ||
| tensor_parallel_size: int = 1 | ||
| block_size: int = 16 | ||
| swap_space: int = 4 # GiB | ||
| gpu_memory_utilization: float = 0.95 | ||
| max_num_batched_tokens: int = 2560 | ||
| max_num_seqs: int = 256 | ||
| disable_log_stats: bool = False | ||
|
|
||
| def add_server_arguments(parser: argparse.ArgumentParser): | ||
| """Shared arguments for CacheFlow servers.""" | ||
| def __post_init__(self): | ||
| self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens) | ||
|
|
||
| @staticmethod | ||
| def add_cli_args( | ||
| parser: argparse.ArgumentParser, | ||
| ) -> argparse.ArgumentParser: | ||
| return _add_server_arguments(parser) | ||
|
|
||
| @classmethod | ||
| def from_cli_args(cls, args: argparse.Namespace) -> "ServerArgs": | ||
| # Get the list of attributes of this dataclass. | ||
| attrs = [attr.name for attr in dataclasses.fields(cls)] | ||
| # Set the attributes from the parsed arguments. | ||
| server_args = cls(**{attr: getattr(args, attr) for attr in attrs}) | ||
| return server_args | ||
|
|
||
| def create_server_configs( | ||
| self, | ||
| ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: | ||
| # Initialize the configs. | ||
| model_config = ModelConfig( | ||
| self.model, self.download_dir, self.use_np_weights, | ||
| self.use_dummy_weights, self.dtype, self.seed) | ||
| cache_config = CacheConfig(self.block_size, self.gpu_memory_utilization, | ||
| self.swap_space) | ||
| parallel_config = ParallelConfig(self.pipeline_parallel_size, | ||
| self.tensor_parallel_size, | ||
| self.use_ray) | ||
| scheduler_config = SchedulerConfig(self.max_num_batched_tokens, | ||
| self.max_num_seqs) | ||
| return model_config, cache_config, parallel_config, scheduler_config | ||
|
|
||
|
|
||
| def _add_server_arguments( | ||
| parser: argparse.ArgumentParser, | ||
| )-> argparse.ArgumentParser: | ||
| """Shared CLI arguments for CacheFlow servers.""" | ||
| # Model arguments | ||
| parser.add_argument('--model', type=str, default='facebook/opt-125m', help='model name') | ||
| parser.add_argument('--download-dir', type=str, default=None, | ||
| parser.add_argument('--model', type=str, default='facebook/opt-125m', | ||
| help='name or path of the huggingface model to use') | ||
| parser.add_argument('--download-dir', type=str, | ||
| default=ServerArgs.download_dir, | ||
| help='directory to download and load the weights, ' | ||
| 'default to the default cache dir of huggingface') | ||
| parser.add_argument('--use-np-weights', action='store_true', | ||
| help='save a numpy copy of model weights for faster loading') | ||
| parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights') | ||
| help='save a numpy copy of model weights for faster ' | ||
| 'loading. This can increase the disk usage by up ' | ||
| 'to 2x.') | ||
| parser.add_argument('--use-dummy-weights', action='store_true', | ||
| help='use dummy values for model weights') | ||
| # TODO(woosuk): Support FP32. | ||
| parser.add_argument('--dtype', type=str, default='default', choices=['default', 'half', 'bfloat16'], | ||
| parser.add_argument('--dtype', type=str, default=ServerArgs.dtype, | ||
| choices=['default', 'half', 'bfloat16'], | ||
| help=('data type for model weights and activations. ' | ||
| 'The "default" option will use FP16 precision ' | ||
| 'for FP32 and FP16 models, and BF16 precision ' | ||
| 'for BF16 models.')) | ||
| # Parallel arguments | ||
| parser.add_argument('--use-ray', action='store_true', help='use Ray for distributed serving, will be automatically set when using more than 1 GPU') | ||
| parser.add_argument('--pipeline-parallel-size', '-pp', type=int, default=1, help='number of pipeline stages') | ||
| parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1, help='number of tensor parallel replicas') | ||
| parser.add_argument('--use-ray', action='store_true', | ||
| help='use Ray for distributed serving, will be ' | ||
| 'automatically set when using more than 1 GPU') | ||
| parser.add_argument('--pipeline-parallel-size', '-pp', type=int, | ||
| default=ServerArgs.pipeline_parallel_size, | ||
| help='number of pipeline stages') | ||
| parser.add_argument('--tensor-parallel-size', '-tp', type=int, | ||
| default=ServerArgs.tensor_parallel_size, | ||
| help='number of tensor parallel replicas') | ||
| # KV cache arguments | ||
| parser.add_argument('--block-size', type=int, default=16, choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], help='token block size') | ||
| parser.add_argument('--block-size', type=int, default=ServerArgs.block_size, | ||
| choices=[1, 2, 4, 8, 16, 32, 64, 128, 256], | ||
| help='token block size') | ||
| # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). | ||
| parser.add_argument('--seed', type=int, default=0, help='random seed') | ||
| parser.add_argument('--swap-space', type=int, default=4, help='CPU swap space size (GiB) per GPU') | ||
| parser.add_argument('--gpu-memory-utilization', type=float, default=0.95, help='the percentage of GPU memory to be used for the model executor') | ||
| parser.add_argument('--max-num-batched-tokens', type=int, default=2560, help='maximum number of batched tokens per iteration') | ||
| parser.add_argument('--max-num-seqs', type=int, default=256, help='maximum number of sequences per iteration') | ||
| parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics') | ||
| parser.add_argument('--seed', type=int, default=ServerArgs.seed, | ||
| help='random seed') | ||
| parser.add_argument('--swap-space', type=int, default=ServerArgs.swap_space, | ||
| help='CPU swap space size (GiB) per GPU') | ||
| parser.add_argument('--gpu-memory-utilization', type=float, | ||
| default=ServerArgs.gpu_memory_utilization, | ||
| help='the percentage of GPU memory to be used for the ' | ||
| 'model executor') | ||
| parser.add_argument('--max-num-batched-tokens', type=int, | ||
| default=ServerArgs.max_num_batched_tokens, | ||
| help='maximum number of batched tokens per iteration') | ||
| parser.add_argument('--max-num-seqs', type=int, | ||
| default=ServerArgs.max_num_seqs, | ||
| help='maximum number of sequences per iteration') | ||
| parser.add_argument('--disable-log-stats', action='store_true', | ||
| help='disable logging statistics') | ||
| return parser | ||
|
|
||
|
|
||
| def create_server_configs_from_args( | ||
| args: argparse.Namespace, | ||
| ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: | ||
| # Post-process the parsed arguments. | ||
| args.swap_space = args.swap_space * _GiB | ||
| args.max_num_seqs = min(args.max_num_seqs, args.max_num_batched_tokens) | ||
|
|
||
| # Initialize the configs. | ||
| model_config = ModelConfig( | ||
| args.model, args.download_dir, args.use_np_weights, | ||
| args.use_dummy_weights, args.dtype, args.seed) | ||
| cache_config = CacheConfig(args.block_size, args.gpu_memory_utilization, | ||
| args.swap_space) | ||
| parallel_config = ParallelConfig(args.pipeline_parallel_size, | ||
| args.tensor_parallel_size, args.use_ray) | ||
| scheduler_config = SchedulerConfig(args.max_num_batched_tokens, | ||
| args.max_num_seqs) | ||
| return model_config, cache_config, parallel_config, scheduler_config | ||
|
|
||
|
|
||
| def initialize_server_from_args(args: argparse.Namespace) -> LLMServer: | ||
| server_configs = create_server_configs_from_args(args) | ||
| parallel_config = server_configs[2] | ||
|
|
||
| # Initialize the cluster. | ||
| distributed_init_method, devices = initialize_cluster(parallel_config) | ||
|
|
||
| # Create the LLM server. | ||
| server = LLMServer(*server_configs, distributed_init_method, devices, | ||
| log_stats=not args.disable_log_stats) | ||
| return server |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: Signal65/vllm-code-review
Length of output: 94
🏁 Script executed:
Repository: Signal65/vllm-code-review
Length of output: 1098
🏁 Script executed:
Repository: Signal65/vllm-code-review
Length of output: 1963
🏁 Script executed:
Repository: Signal65/vllm-code-review
Length of output: 256
🏁 Script executed:
Repository: Signal65/vllm-code-review
Length of output: 2635
Add
tqdmto dependencies or use lazy import.tqdmis imported unconditionally at module level but is not declared inrequirements.txtorsetup.py. This breaks module imports unlesstqdmis installed separately, even whenuse_tqdm=False. Either addtqdmto dependencies or defer the import until needed.💡 Optional lazy import approach
🤖 Prompt for AI Agents