diff --git a/common/arg.cpp b/common/arg.cpp index 26c790c7e0b..9c0e6fbe789 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -2255,7 +2255,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex std::vector split_arg{ it, {} }; if (split_arg.size() >= llama_max_devices()) { throw std::invalid_argument( - string_format("got %d input configs, but system only has %d devices", (int)split_arg.size(), (int)llama_max_devices()) + string_format("got %zu input configs, but system only has %zu devices", split_arg.size(), llama_max_devices()) ); } for (size_t i = 0; i < llama_max_devices(); ++i) { @@ -2295,10 +2295,28 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_env("LLAMA_ARG_FIT")); add_opt(common_arg( - { "-fitt", "--fit-target" }, "MiB", - string_format("target margin per device for --fit option, default: %zu", params.fit_params_target/(1024*1024)), - [](common_params & params, int value) { - params.fit_params_target = value * size_t(1024*1024); + { "-fitt", "--fit-target" }, "MiB0,MiB1,MiB2,...", + string_format("target margin per device for --fit, comma-separated list of values, " + "single value is broadcast across all devices, default: %zu", params.fit_params_target[0]/(1024*1024)), + [](common_params & params, const std::string & value) { + std::string arg_next = value; + + // split string by , and / + const std::regex regex{ R"([,/]+)" }; + std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 }; + std::vector split_arg{ it, {} }; + if (split_arg.size() >= llama_max_devices()) { + throw std::invalid_argument( + string_format("got %zu input configs, but system only has %zu devices", split_arg.size(), llama_max_devices()) + ); + } + if (split_arg.size() == 1) { + std::fill(params.fit_params_target.begin(), params.fit_params_target.end(), std::stoul(split_arg[0]) * 1024*1024); + return; + } + for (size_t i = 0; i < split_arg.size(); i++) { + params.fit_params_target[i] = std::stoul(split_arg[i]) * 1024*1024; + } } ).set_env("LLAMA_ARG_FIT_TARGET")); add_opt(common_arg( diff --git a/common/arg.h b/common/arg.h index a1b6a14e675..55782a158d7 100644 --- a/common/arg.h +++ b/common/arg.h @@ -129,11 +129,3 @@ void common_params_add_preset_options(std::vector & args); // initialize argument parser context - used by test-arg-parser and preset common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); - -struct common_remote_params { - std::vector headers; - long timeout = 0; // CURLOPT_TIMEOUT, in seconds ; 0 means no timeout - long max_size = 0; // max size of the response ; unlimited if 0 ; max is 2GB -}; -// get remote file content, returns -std::pair> common_remote_get_content(const std::string & url, const common_remote_params & params); diff --git a/common/common.cpp b/common/common.cpp index 34fa3b5a422..744f0b4eeb4 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1097,7 +1097,7 @@ common_init_result::common_init_result(common_params & params) : if (params.fit_params) { LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__); llama_params_fit(params.model.path.c_str(), &mparams, &cparams, - params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx, + params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx, params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR); } diff --git a/common/common.h b/common/common.h index d55a6b71fb7..7794c0268bd 100644 --- a/common/common.h +++ b/common/common.h @@ -332,12 +332,14 @@ struct common_params { // offload params std::vector devices; // devices to use for offloading - int32_t n_gpu_layers = -1; // number of layers to store in VRAM, -1 is auto, <= -2 is all - int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors - float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs - bool fit_params = true; // whether to fit unset model/context parameters to free device memory - size_t fit_params_target = 1024 * 1024*1024; // margin per device in bytes for fitting parameters to free memory - int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use + int32_t n_gpu_layers = -1; // number of layers to store in VRAM, -1 is auto, <= -2 is all + int32_t main_gpu = 0; // the GPU that is used for scratch and small tensors + float tensor_split[128] = {0}; // how split tensors should be distributed across GPUs + bool fit_params = true; // whether to fit unset model/context parameters to free device memory + int32_t fit_params_min_ctx = 4096; // minimum context size to set when trying to reduce memory use + + // margin per device in bytes for fitting parameters to free memory: + std::vector fit_params_target = std::vector(llama_max_devices(), 1024 * 1024*1024); enum llama_split_mode split_mode = LLAMA_SPLIT_MODE_LAYER; // how to split the model across GPUs diff --git a/common/download.cpp b/common/download.cpp index ef874725607..6f56b5518f5 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -308,7 +308,8 @@ static bool common_download_head(CURL * curl, // download one single file from remote URL to local path static bool common_download_file_single_online(const std::string & url, const std::string & path, - const std::string & bearer_token) { + const std::string & bearer_token, + const common_header_list & custom_headers) { static const int max_attempts = 3; static const int retry_delay_seconds = 2; for (int i = 0; i < max_attempts; ++i) { @@ -330,6 +331,11 @@ static bool common_download_file_single_online(const std::string & url, common_load_model_from_url_headers headers; curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers); curl_slist_ptr http_headers; + + for (const auto & h : custom_headers) { + std::string s = h.first + ": " + h.second; + http_headers.ptr = curl_slist_append(http_headers.ptr, s.c_str()); + } const bool was_perform_successful = common_download_head(curl.get(), http_headers, url, bearer_token); if (!was_perform_successful) { head_request_ok = false; @@ -454,8 +460,10 @@ std::pair> common_remote_get_content(const std::string & curl_easy_setopt(curl.get(), CURLOPT_MAXFILESIZE, params.max_size); } http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); + for (const auto & header : params.headers) { - http_headers.ptr = curl_slist_append(http_headers.ptr, header.c_str()); + std::string header_ = header.first + ": " + header.second; + http_headers.ptr = curl_slist_append(http_headers.ptr, header_.c_str()); } curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); @@ -619,7 +627,8 @@ static bool common_pull_file(httplib::Client & cli, // download one single file from remote URL to local path static bool common_download_file_single_online(const std::string & url, const std::string & path, - const std::string & bearer_token) { + const std::string & bearer_token, + const common_header_list & custom_headers) { static const int max_attempts = 3; static const int retry_delay_seconds = 2; @@ -629,6 +638,9 @@ static bool common_download_file_single_online(const std::string & url, if (!bearer_token.empty()) { default_headers.insert({"Authorization", "Bearer " + bearer_token}); } + for (const auto & h : custom_headers) { + default_headers.emplace(h.first, h.second); + } cli.set_default_headers(default_headers); const bool file_exists = std::filesystem::exists(path); @@ -734,13 +746,9 @@ std::pair> common_remote_get_content(const std::string auto [cli, parts] = common_http_client(url); httplib::Headers headers = {{"User-Agent", "llama-cpp"}}; + for (const auto & header : params.headers) { - size_t pos = header.find(':'); - if (pos != std::string::npos) { - headers.emplace(header.substr(0, pos), header.substr(pos + 1)); - } else { - headers.emplace(header, ""); - } + headers.emplace(header.first, header.second); } if (params.timeout > 0) { @@ -772,9 +780,10 @@ std::pair> common_remote_get_content(const std::string static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token, - bool offline) { + bool offline, + const common_header_list & headers) { if (!offline) { - return common_download_file_single_online(url, path, bearer_token); + return common_download_file_single_online(url, path, bearer_token, headers); } if (!std::filesystem::exists(path)) { @@ -788,13 +797,24 @@ static bool common_download_file_single(const std::string & url, // download multiple files from remote URLs to local paths // the input is a vector of pairs -static bool common_download_file_multiple(const std::vector> & urls, const std::string & bearer_token, bool offline) { +static bool common_download_file_multiple(const std::vector> & urls, + const std::string & bearer_token, + bool offline, + const common_header_list & headers) { // Prepare download in parallel std::vector> futures_download; + futures_download.reserve(urls.size()); + for (auto const & item : urls) { - futures_download.push_back(std::async(std::launch::async, [bearer_token, offline](const std::pair & it) -> bool { - return common_download_file_single(it.first, it.second, bearer_token, offline); - }, item)); + futures_download.push_back( + std::async( + std::launch::async, + [&bearer_token, offline, &headers](const std::pair & it) -> bool { + return common_download_file_single(it.first, it.second, bearer_token, offline, headers); + }, + item + ) + ); } // Wait for all downloads to complete @@ -807,17 +827,17 @@ static bool common_download_file_multiple(const std::vector(hf_repo_with_tag, ':'); std::string tag = parts.size() > 1 ? parts.back() : "latest"; std::string hf_repo = parts[0]; @@ -893,10 +916,10 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, cons std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag; // headers - std::vector headers; - headers.push_back("Accept: application/json"); + common_header_list headers = custom_headers; + headers.push_back({"Accept", "application/json"}); if (!bearer_token.empty()) { - headers.push_back("Authorization: Bearer " + bearer_token); + headers.push_back({"Authorization", "Bearer " + bearer_token}); } // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response // User-Agent header is already set in common_remote_get_content, no need to set it here @@ -1031,9 +1054,10 @@ std::string common_docker_resolve_model(const std::string & docker) { const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo; std::string manifest_url = url_prefix + "/manifests/" + tag; common_remote_params manifest_params; - manifest_params.headers.push_back("Authorization: Bearer " + token); - manifest_params.headers.push_back( - "Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json"); + manifest_params.headers.push_back({"Authorization", "Bearer " + token}); + manifest_params.headers.push_back({"Accept", + "application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json" + }); auto manifest_res = common_remote_get_content(manifest_url, manifest_params); if (manifest_res.first != 200) { throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first)); @@ -1070,7 +1094,7 @@ std::string common_docker_resolve_model(const std::string & docker) { std::string local_path = fs_get_cache_file(model_filename); const std::string blob_url = url_prefix + "/blobs/" + gguf_digest; - if (!common_download_file_single(blob_url, local_path, token, false)) { + if (!common_download_file_single(blob_url, local_path, token, false, {})) { throw std::runtime_error("Failed to download Docker Model"); } @@ -1084,11 +1108,11 @@ std::string common_docker_resolve_model(const std::string & docker) { #else -common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool) { +common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool, const common_header_list &) { throw std::runtime_error("download functionality is not enabled in this build"); } -bool common_download_model(const common_params_model &, const std::string &, bool) { +bool common_download_model(const common_params_model &, const std::string &, bool, const common_header_list &) { throw std::runtime_error("download functionality is not enabled in this build"); } diff --git a/common/download.h b/common/download.h index d1321e6e90e..9ea20939390 100644 --- a/common/download.h +++ b/common/download.h @@ -1,12 +1,21 @@ #pragma once #include +#include struct common_params_model; -// -// download functionalities -// +using common_header = std::pair; +using common_header_list = std::vector; + +struct common_remote_params { + common_header_list headers; + long timeout = 0; // in seconds, 0 means no timeout + long max_size = 0; // unlimited if 0 +}; + +// get remote file content, returns +std::pair> common_remote_get_content(const std::string & url, const common_remote_params & params); struct common_cached_model_info { std::string manifest_path; @@ -41,13 +50,17 @@ struct common_hf_file_res { common_hf_file_res common_get_hf_file( const std::string & hf_repo_with_tag, const std::string & bearer_token, - bool offline); + bool offline, + const common_header_list & headers = {} +); // returns true if download succeeded bool common_download_model( const common_params_model & model, const std::string & bearer_token, - bool offline); + bool offline, + const common_header_list & headers = {} +); // returns list of cached models std::vector common_list_cached_models(); diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 162d238ae44..d7a93848df8 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2541,27 +2541,6 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) { return buft->iface.get_name == ggml_backend_cann_buffer_type_name; } -/** - * @brief Determines if a tensor operation should be offloaded to the CANN - * backend. - * - * This function checks if a given tensor operation should be offloaded to the - * CANN backend based on the operation type and the size of the tensor. It - * returns true if the second dimension (ne[1]) of the tensor is greater than or - * equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS. - * - * @param backend Pointer to the CANN backend. - * @param op Pointer to the tensor operation to check. - * @return bool Returns true if the operation should be offloaded, otherwise - * false. - */ -static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { - const int min_batch_size = 32; - GGML_UNUSED(dev); - - return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS; -} - /** * @brief Records an event on the CANN backend stream. * @@ -2637,6 +2616,7 @@ struct ggml_backend_cann_device_context { int device; std::string name; std::string description; + int op_offload_min_batch_size; }; static const char * ggml_backend_cann_device_get_name(ggml_backend_dev_t dev) { @@ -2713,6 +2693,26 @@ static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type( return ggml_backend_cann_host_buffer_type(); } +/** + * @brief Determines if a tensor operation should be offloaded to the CANN + * backend. + * + * This function checks if a given tensor operation should be offloaded to the + * CANN backend based on the operation type and the size of the tensor. It + * returns true if the second dimension (ne[1]) of the tensor is greater than or + * equal to the minimum batch size and the operation is not GGML_OP_GET_ROWS. + * + * @param backend Pointer to the CANN backend. + * @param op Pointer to the tensor operation to check. + * @return bool Returns true if the operation should be offloaded, otherwise + * false. + */ +static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context; + + return op->ne[1] >= dev_ctx->op_offload_min_batch_size && op->op != GGML_OP_GET_ROWS; +} + /** * @brief Creates a new event for the CANN backend device. * @@ -2829,12 +2829,14 @@ ggml_backend_reg_t ggml_backend_cann_reg() { if (!initialized) { aclInit(nullptr); ggml_backend_cann_reg_context * ctx = new ggml_backend_cann_reg_context; + const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; for (int i = 0; i < ggml_cann_info().device_count; i++) { ggml_backend_cann_device_context * dev_ctx = new ggml_backend_cann_device_context(); dev_ctx->description = aclrtGetSocName(); dev_ctx->device = i; dev_ctx->name = GGML_CANN_NAME + std::to_string(i); + dev_ctx->op_offload_min_batch_size = min_batch_size; ggml_cann_set_device(i); ggml_backend_dev_t dev = new ggml_backend_device{ /* .iface = */ ggml_backend_cann_device_interface, /* .reg = */ ®, diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index bac69cdd1c8..f021de1d745 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4122,6 +4122,7 @@ struct ggml_backend_cuda_device_context { std::string name; std::string description; std::string pci_bus_id; + int op_offload_min_batch_size; }; static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { @@ -4676,11 +4677,9 @@ static int64_t get_op_batch_size(const ggml_tensor * op) { } static bool ggml_backend_cuda_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { - const int min_batch_size = 32; - - return get_op_batch_size(op) >= min_batch_size; + ggml_backend_cuda_device_context * dev_ctx = (ggml_backend_cuda_device_context *) dev->context; - GGML_UNUSED(dev); + return get_op_batch_size(op) >= dev_ctx->op_offload_min_batch_size; } static ggml_backend_event_t ggml_backend_cuda_device_event_new(ggml_backend_dev_t dev) { @@ -4848,6 +4847,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { std::lock_guard lock(mutex); if (!initialized) { ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context; + const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; for (int i = 0; i < ggml_cuda_info().device_count; i++) { ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context; @@ -4861,6 +4861,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { char pci_bus_id[16] = {}; snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID); dev_ctx->pci_bus_id = pci_bus_id; + dev_ctx->op_offload_min_batch_size = min_batch_size; ggml_backend_dev_t dev = new ggml_backend_device { /* .iface = */ ggml_backend_cuda_device_interface, diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index d983b666ca2..9c3b0014878 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -219,6 +219,8 @@ struct ggml_metal_device_props { bool use_shared_buffers; bool supports_gpu_family_apple7; + + int op_offload_min_batch_size; }; ggml_metal_device_t ggml_metal_device_init(void); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 59badd00431..ff899a81709 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -782,6 +782,8 @@ ggml_metal_device_t ggml_metal_device_init(void) { dev->props.supports_gpu_family_apple7 = [dev->mtl_device supportsFamily:MTLGPUFamilyApple7]; + dev->props.op_offload_min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; + dev->props.max_buffer_size = dev->mtl_device.maxBufferLength; dev->props.max_working_set_size = dev->mtl_device.recommendedMaxWorkingSetSize; dev->props.max_theadgroup_memory_size = dev->mtl_device.maxThreadgroupMemoryLength; diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index 70bf6f3d981..56b59f0afdf 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -625,14 +625,11 @@ static int64_t get_op_batch_size(const ggml_tensor * op) { } static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { - const int min_batch_size = 32; + ggml_metal_device_t ctx_dev = (ggml_metal_device_t)dev->context; return (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_MUL_MAT_ID) && - get_op_batch_size(op) >= min_batch_size; - - GGML_UNUSED(dev); - GGML_UNUSED(op); + get_op_batch_size(op) >= ggml_metal_device_get_props(ctx_dev)->op_offload_min_batch_size; } static ggml_backend_device_i ggml_backend_metal_device_i = { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 67b30e0d93c..16d17d26af8 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -9148,6 +9148,7 @@ typedef decltype(kernel_mul_mm_id_map0<1>) kernel_mul_mm_id_map0_t; template [[host_name("kernel_mul_mm_id_map0_ne20_1" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<1>; template [[host_name("kernel_mul_mm_id_map0_ne20_2" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<2>; template [[host_name("kernel_mul_mm_id_map0_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>; +template [[host_name("kernel_mul_mm_id_map0_ne20_5" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<5>; template [[host_name("kernel_mul_mm_id_map0_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>; template [[host_name("kernel_mul_mm_id_map0_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>; template [[host_name("kernel_mul_mm_id_map0_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>; diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index e996d98be8c..8f8176b678a 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4286,6 +4286,7 @@ struct ggml_backend_sycl_device_context { int device; std::string name; std::string description; + int op_offload_min_batch_size; }; static const char * ggml_backend_sycl_device_get_name(ggml_backend_dev_t dev) { @@ -4674,9 +4675,8 @@ static int64_t get_op_batch_size(const ggml_tensor * op) { } static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { - const int min_batch_size = 32; - return get_op_batch_size(op) >= min_batch_size; - GGML_UNUSED(dev); + ggml_backend_sycl_device_context * sycl_ctx = (ggml_backend_sycl_device_context *)dev->context; + return get_op_batch_size(op) >= sycl_ctx->op_offload_min_batch_size; } static ggml_backend_event_t @@ -4799,6 +4799,7 @@ ggml_backend_reg_t ggml_backend_sycl_reg() { std::lock_guard lock(mutex); if (!initialized) { ggml_backend_sycl_reg_context * ctx = new ggml_backend_sycl_reg_context; + const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; for (int i = 0; i < ggml_sycl_info().device_count; i++) { ggml_backend_sycl_device_context * dev_ctx = new ggml_backend_sycl_device_context; @@ -4812,6 +4813,7 @@ ggml_backend_reg_t ggml_backend_sycl_reg() { prop, dpct::dev_mgr::instance().get_device(i)))); dev_ctx->description = prop.get_name(); + dev_ctx->op_offload_min_batch_size = min_batch_size; ggml_backend_dev_t dev = new ggml_backend_device { /* .iface = */ ggml_backend_sycl_device_interface, diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index d68735a040a..b1a51a43658 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -570,6 +570,7 @@ struct vk_device_struct { bool uma; bool prefer_host_memory; bool float_controls_rte_fp16; + bool subgroup_basic; bool subgroup_arithmetic; bool subgroup_shuffle; bool subgroup_ballot; @@ -1504,6 +1505,11 @@ template <> void init_pushconst_fastdiv(vk_op_sum_rows_push_constants &p) { init_fastdiv_values(p.ne01, p.ne0_1mp, p.ne0_1L); } +struct vk_quantize_q8_1_push_constants { + uint32_t ne; + uint32_t num_blocks; +}; + // Allow pre-recording command buffers struct vk_staging_memcpy { vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} @@ -3340,12 +3346,12 @@ static void ggml_vk_load_shaders(vk_device& device) { GGML_ASSERT(device->subgroup_ballot); - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); #if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) if (device->coopmat_bf16_support) { - CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id); } #endif @@ -3453,9 +3459,9 @@ static void ggml_vk_load_shaders(vk_device& device) { #endif if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) { - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); @@ -3497,9 +3503,9 @@ static void ggml_vk_load_shaders(vk_device& device) { } #endif } else { - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0); - CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); @@ -3614,9 +3620,9 @@ static void ggml_vk_load_shaders(vk_device& device) { #endif if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) { - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); - CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); - CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size_16); CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); @@ -3640,9 +3646,9 @@ static void ggml_vk_load_shaders(vk_device& device) { CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, mul_mat_subgroup_size); } else { - CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0); - CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0); - CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); @@ -3840,22 +3846,22 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size; const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_q8_1_f32", arr_dmmv_id_q4_0_q8_1_f32_len[reduc], arr_dmmv_id_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_q8_1_f32", arr_dmmv_id_q4_1_q8_1_f32_len[reduc], arr_dmmv_id_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_q8_1_f32", arr_dmmv_id_q5_0_q8_1_f32_len[reduc], arr_dmmv_id_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_q8_1_f32", arr_dmmv_id_q5_1_q8_1_f32_len[reduc], arr_dmmv_id_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_q8_1_f32", arr_dmmv_id_q8_0_q8_1_f32_len[reduc], arr_dmmv_id_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_q8_1_f32", arr_dmmv_id_q4_0_q8_1_f32_len[reduc], arr_dmmv_id_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_q8_1_f32", arr_dmmv_id_q4_1_q8_1_f32_len[reduc], arr_dmmv_id_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_q8_1_f32", arr_dmmv_id_q5_0_q8_1_f32_len[reduc], arr_dmmv_id_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_q8_1_f32", arr_dmmv_id_q5_1_q8_1_f32_len[reduc], arr_dmmv_id_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_q8_1_f32", arr_dmmv_id_q8_0_q8_1_f32_len[reduc], arr_dmmv_id_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_q8_1_f32", arr_dmmv_id_mxfp4_q8_1_f32_len[reduc], arr_dmmv_id_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_q8_1_f32", arr_dmmv_id_mxfp4_q8_1_f32_len[reduc], arr_dmmv_id_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_q8_1_f32", arr_dmmv_id_q2_k_q8_1_f32_len[reduc], arr_dmmv_id_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_q8_1_f32", arr_dmmv_id_q3_k_q8_1_f32_len[reduc], arr_dmmv_id_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_q8_1_f32", arr_dmmv_id_q2_k_q8_1_f32_len[reduc], arr_dmmv_id_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_q8_1_f32", arr_dmmv_id_q3_k_q8_1_f32_len[reduc], arr_dmmv_id_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_q8_1_f32", arr_dmmv_id_iq1_s_q8_1_f32_len[reduc], arr_dmmv_id_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int); - ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_q8_1_f32", arr_dmmv_id_iq1_m_q8_1_f32_len[reduc], arr_dmmv_id_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_q8_1_f32", arr_dmmv_id_iq1_s_q8_1_f32_len[reduc], arr_dmmv_id_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_q8_1_f32", arr_dmmv_id_iq1_m_q8_1_f32_len[reduc], arr_dmmv_id_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int); } #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT } @@ -3943,9 +3949,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); if (device->subgroup_clustered && device->subgroup_require_full_support) { - ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true); + ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true); } else { - ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_len, quantize_q8_1_x4_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_len, quantize_q8_1_x4_data, "main", 2, sizeof(vk_quantize_q8_1_push_constants), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); } for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) { @@ -4153,9 +4159,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_add1_f16_f32, "add1_f16_f32", add1_f16_f32_len, add1_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_add1_f32_f32, "add1_f32_f32", add1_f32_f32_len, add1_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); #define CREATE_GLU(name) \ if (device->float_controls_rte_fp16) { \ @@ -4301,8 +4307,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); if (device->subgroup_arithmetic && device->subgroup_require_full_support) { - ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true); - ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true); + ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true); + ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true); } else { ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true); ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true); @@ -4638,6 +4644,8 @@ static vk_device ggml_vk_get_device(size_t idx) { } device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16; + device->subgroup_basic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBasic); device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic); #ifdef __APPLE__ @@ -6097,6 +6105,7 @@ static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size()); GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT); GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size()); + GGML_ASSERT(pipeline->push_constant_size == push_constant_size(push_constants)); vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++]; vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() }; @@ -6879,7 +6888,12 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub const uint64_t max_elements = std::min(uint64_t{ctx->device->properties.limits.maxComputeWorkGroupCount[0]} * pipeline->wg_denoms[0], std::numeric_limits::max()); const uint32_t elements = std::min(ne, static_cast(max_elements)); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array{ ne, num_blocks }, { elements, 1, 1 }); + const vk_quantize_q8_1_push_constants pc = { + ne, + num_blocks, + }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, { elements, 1, 1 }); ggml_vk_sync_buffers(ctx, subctx); } @@ -9870,8 +9884,9 @@ static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, std::array elements; - const int splitH = 16; - const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, splitH); + const uint32_t d_state = src0->ne[0]; + uint32_t num_subgroups = d_state / ctx->device->subgroup_size; + const uint32_t num_workgroups_x = CEIL_DIV(n_head * head_dim, num_subgroups); const uint32_t num_workgroups_y = n_seq; elements = { num_workgroups_x, num_workgroups_y, 1 }; @@ -14249,6 +14264,7 @@ struct ggml_backend_vk_device_context { std::string description; bool is_integrated_gpu; std::string pci_bus_id; + int op_offload_min_batch_size; }; static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { @@ -14776,11 +14792,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return false; } - const uint32_t SPLIT_H = 16; + size_t shmem_size = d_state * sizeof(float); - size_t stateC_size = SPLIT_H * d_state * sizeof(float); + if (shmem_size > device->properties.limits.maxComputeSharedMemorySize) { + return false; + } - if (stateC_size > device->properties.limits.maxComputeSharedMemorySize) { + if (!device->subgroup_basic) { return false; } @@ -14820,12 +14838,10 @@ static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_ba } static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { - const int min_batch_size = 32; + ggml_backend_vk_device_context * dev_ctx = (ggml_backend_vk_device_context *)dev->context; - return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) || - (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID); - - UNUSED(dev); + return (op->ne[1] >= dev_ctx->op_offload_min_batch_size && op->op != GGML_OP_GET_ROWS) || + (op->ne[2] >= dev_ctx->op_offload_min_batch_size && op->op == GGML_OP_MUL_MAT_ID); } static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t dev) { @@ -14951,6 +14967,7 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, static std::mutex mutex; std::lock_guard lock(mutex); if (!initialized) { + const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) { ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context; char desc[256]; @@ -14960,6 +14977,7 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, ctx->description = desc; ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i); + ctx->op_offload_min_batch_size = min_batch_size; devices.push_back(new ggml_backend_device { /* .iface = */ ggml_backend_vk_device_i, /* .reg = */ reg, diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp index 8f67be97995..c7416206dbd 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp @@ -1,6 +1,7 @@ #version 450 #extension GL_EXT_control_flow_attributes : require +#extension GL_KHR_shader_subgroup_basic : enable #if USE_SUBGROUP_ADD #extension GL_KHR_shader_subgroup_arithmetic : enable #endif @@ -9,7 +10,8 @@ layout(constant_id = 0) const uint D_STATE = 128; layout(constant_id = 1) const uint SUBGROUP_SIZE = 32; -layout(constant_id = 2) const uint SPLIT_H = 16; + +const uint32_t c_factor = D_STATE / SUBGROUP_SIZE; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; @@ -41,22 +43,28 @@ float softplus(float x) { } } -shared float stateC[SPLIT_H * D_STATE]; +#if !USE_SUBGROUP_ADD +shared float temp[D_STATE]; +#endif void main() { - const uint tid = gl_LocalInvocationID.x; - const uint head_idx = (gl_WorkGroupID.x * SPLIT_H) / d_head; - const uint head_off = ((gl_WorkGroupID.x * SPLIT_H) % d_head) * 4; - const uint seq_idx = gl_WorkGroupID.y; + const uint subgroup = gl_SubgroupID; + const uint lane = gl_SubgroupInvocationID; + const uint tid = gl_SubgroupID * SUBGROUP_SIZE + lane; + const uint subgroup_idx = gl_WorkGroupID.x * c_factor + subgroup; + + const uint head_idx = subgroup_idx / d_head; + const uint head_off = (subgroup_idx % d_head) * 4; + const uint seq_idx = gl_WorkGroupID.y; const uint group_off = (head_idx / (n_head / n_group)) * D_STATE * 4; const uint s0_base_idx = (uint(ids[seq_idx]) * nb03 + head_idx * nb02 + head_off * D_STATE) / 4; - const uint x_base_idx = (seq_idx * nb13 + gl_WorkGroupID.x * SPLIT_H * 4) / 4; + const uint x_base_idx = (seq_idx * nb13 + subgroup_idx * 4) / 4; const uint dt_base_idx = (seq_idx * nb22 + head_idx * 4) / 4; const uint A_base_idx = (head_idx * nb31) / 4; const uint B_base_idx = (seq_idx * nb43 + group_off) / 4; const uint C_base_idx = (seq_idx * nb53 + group_off) / 4; - const uint y_base_idx = seq_idx * n_tok * n_head * d_head + gl_WorkGroupID.x * SPLIT_H; + const uint y_base_idx = seq_idx * n_tok * n_head * d_head + subgroup_idx; const uint s_base_idx = (s_off + seq_idx * nb03 + head_idx * nb02 + head_off * D_STATE) / 4; const uint stride_x = nb12 / 4; @@ -65,76 +73,52 @@ void main() { const uint stride_C = nb52 / 4; const uint stride_y = n_head * d_head; - float state[SPLIT_H]; - [[unroll]] for (uint j = 0; j < SPLIT_H; j++) { - state[j] = s0[s0_base_idx + j * D_STATE + tid]; - } + float state[c_factor]; - for (uint i = 0; i < n_tok; i++) { - const float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]); + [[unroll]] for (uint j = 0; j < c_factor; j++) { + state[j] = s0[s0_base_idx + SUBGROUP_SIZE * j + lane]; + } - const float dA = exp(dt_soft_plus * A[A_base_idx]); + float a = A[A_base_idx]; - const float B_val = B[B_base_idx + i * stride_B + tid]; - const float C_val = C[C_base_idx + i * stride_C + tid]; + for (uint i = 0; i < n_tok; i++) { + float dt_soft_plus = softplus(dt[dt_base_idx + i * stride_dt]); - [[unroll]] for (uint j = 0; j < SPLIT_H; j++) { - const float x_dt = x[x_base_idx + i * stride_x + j] * dt_soft_plus; + float state_sum = 0.0f; + const float dA = exp(dt_soft_plus * a); + const float x_dt = x[x_base_idx + i * stride_x] * dt_soft_plus; + [[unroll]] for (uint j = 0; j < c_factor; j++) { + float B_val = B[B_base_idx + i * stride_B + SUBGROUP_SIZE * j + lane]; + float C_val = C[C_base_idx + i * stride_C + SUBGROUP_SIZE * j + lane]; state[j] = (state[j] * dA) + (B_val * x_dt); - - stateC[j * D_STATE + tid] = state[j] * C_val; + state_sum += state[j] * C_val; } +#if USE_SUBGROUP_ADD + state_sum = subgroupAdd(state_sum); +#else + temp[tid] = state_sum; barrier(); - [[unroll]] - for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) { - [[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) { - const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w); - if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) { - stateC[k] += stateC[k + w]; - } + [[unroll]] for (uint s = SUBGROUP_SIZE / 2; s > 0; s >>= 1) { + if (lane < s) { + temp[tid] += temp[tid + s]; } barrier(); } - - [[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) { - const uint idx = (tid % SUBGROUP_SIZE) + - D_STATE * (tid / SUBGROUP_SIZE) + - j * D_STATE * (D_STATE / SUBGROUP_SIZE); - const uint max_idx = SUBGROUP_SIZE - 1 + - D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) + - j * D_STATE * (D_STATE / SUBGROUP_SIZE); - - if (idx < SPLIT_H * D_STATE || - max_idx < SPLIT_H * D_STATE) { - float sc; -#if USE_SUBGROUP_ADD - sc = stateC[idx]; - sc = subgroupAdd(sc); -#else - [[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) { - if (idx + offset < SPLIT_H * D_STATE) { - stateC[idx] += stateC[idx + offset]; - } - barrier(); - } - if (tid % SUBGROUP_SIZE == 0) { - sc = stateC[idx]; - } + // get the value from lane 0 + state_sum = temp[subgroup * SUBGROUP_SIZE]; + barrier(); #endif - if (tid % SUBGROUP_SIZE == 0) { - const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE); - d[y_base_idx + i * stride_y + k] = sc; - } - } + if (lane == 0) { + d[y_base_idx + i * stride_y] = state_sum; } - - barrier(); } - [[unroll]] for (uint j = 0; j < SPLIT_H; j++) { - d[s_base_idx + j * D_STATE + tid] = state[j]; + // write back the state + [[unroll]] + for (int j = 0; j < c_factor; j++) { + d[s_base_idx + SUBGROUP_SIZE * j + lane] = state[j]; } } diff --git a/include/llama.h b/include/llama.h index edc4c871a14..12e4e57d0e5 100644 --- a/include/llama.h +++ b/include/llama.h @@ -495,7 +495,7 @@ extern "C" { struct llama_context_params * cparams, float * tensor_split, // writable buffer for tensor split, needs at least llama_max_devices elements struct llama_model_tensor_buft_override * tensor_buft_overrides, // writable buffer for overrides, needs at least llama_max_tensor_buft_overrides elements - size_t margin, // margin of memory to leave per device in bytes + size_t * margins, // margins of memory to leave per device in bytes uint32_t n_ctx_min, // minimum context size to set when trying to reduce memory use enum ggml_log_level log_level); // minimum log level to print during fitting, lower levels go to debug log diff --git a/scripts/pr2wt.sh b/scripts/pr2wt.sh index 36ccde2f34e..7970bec371d 100755 --- a/scripts/pr2wt.sh +++ b/scripts/pr2wt.sh @@ -9,6 +9,7 @@ # sample usage: # ./scripts/pr2wt.sh 12345 # ./scripts/pr2wt.sh 12345 opencode +# ./scripts/pr2wt.sh 12345 "cmake -B build && cmake --build build" function usage() { echo "usage: $0 [cmd]" @@ -46,7 +47,7 @@ head_ref=$(echo "$meta" | jq -r '.head.ref') echo "url: $url_remote" echo "head_ref: $head_ref" -git remote rm pr/${PR} +git remote rm pr/${PR} 2> /dev/null git remote add pr/${PR} $url_remote git fetch pr/${PR} $head_ref @@ -62,5 +63,5 @@ echo "git worktree created in $wt_path" # if a command was provided, execute it if [[ $# -eq 2 ]]; then cd ../$dir-pr-$PR - exec $2 + eval "$2" fi diff --git a/scripts/sync_vendor.py b/scripts/sync_vendor.py index 637f4cdc186..ed6bf1bf4e2 100755 --- a/scripts/sync_vendor.py +++ b/scripts/sync_vendor.py @@ -16,7 +16,7 @@ # "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.23/miniaudio.h": "vendor/miniaudio/miniaudio.h", "https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h", - "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.28.0/httplib.h": "vendor/cpp-httplib/httplib.h", + "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.30.0/httplib.h": "vendor/cpp-httplib/httplib.h", "https://raw.githubusercontent.com/sheredom/subprocess.h/b49c56e9fe214488493021017bf3954b91c7c1f5/subprocess.h": "vendor/sheredom/subprocess.h", } diff --git a/src/llama.cpp b/src/llama.cpp index dfefb3d2b50..33f51a23890 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -147,9 +147,8 @@ class llama_params_fit_exception : public std::runtime_error { static void llama_params_fit_impl( const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams, float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides, - size_t margin_s, uint32_t n_ctx_min, enum ggml_log_level log_level) { + size_t * margins_s, uint32_t n_ctx_min, enum ggml_log_level log_level) { constexpr int64_t MiB = 1024*1024; - const int64_t margin = margin_s; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits typedef std::vector dmds_t; const llama_model_params default_mparams = llama_model_default_params(); @@ -168,6 +167,12 @@ static void llama_params_fit_impl( return; } + std::vector margins; // this function uses int64_t rather than size_t for memory sizes to more conveniently handle deficits + margins.reserve(nd); + for (size_t id = 0; id < nd; id++) { + margins.push_back(margins_s[id]); + } + std::vector dev_names; { dev_names.reserve(nd); @@ -187,9 +192,10 @@ static void llama_params_fit_impl( int64_t sum_free = 0; int64_t sum_projected_free = 0; - int64_t min_projected_free = INT64_MAX; int64_t sum_projected_used = 0; int64_t sum_projected_model = 0; + std::vector projected_free_per_device; + projected_free_per_device.reserve(nd); if (nd > 1) { LLAMA_LOG_INFO("%s: projected memory use with initial parameters [MiB]:\n", __func__); @@ -199,45 +205,63 @@ static void llama_params_fit_impl( const int64_t projected_used = dmd.mb.total(); const int64_t projected_free = dmd.free - projected_used; + projected_free_per_device.push_back(projected_free); sum_free += dmd.free; sum_projected_used += projected_used; sum_projected_free += projected_free; - min_projected_free = std::min(min_projected_free, projected_free); sum_projected_model += dmd.mb.model; if (nd > 1) { - LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " %s\n", - __func__, dev_names[id].c_str(), dmd.total/MiB, projected_used/MiB, std::abs(projected_free)/MiB, - projected_free >= 0 ? "surplus" : "deficit"); + LLAMA_LOG_INFO("%s: - %s: %6" PRId64 " total, %6" PRId64 " used, %6" PRId64 " free vs. target of %6" PRId64 "\n", + __func__, dev_names[id].c_str(), dmd.total/MiB, projected_used/MiB, projected_free/MiB, margins[id]/MiB); } } assert(sum_free >= 0 && sum_projected_used >= 0); LLAMA_LOG_INFO("%s: projected to use %" PRId64 " MiB of device memory vs. %" PRId64 " MiB of free device memory\n", __func__, sum_projected_used/MiB, sum_free/MiB); - if (min_projected_free >= margin) { - if (nd == 1) { + if (nd == 1) { + if (projected_free_per_device[0] >= margins[0]) { LLAMA_LOG_INFO("%s: will leave %" PRId64 " >= %" PRId64 " MiB of free device memory, no changes needed\n", - __func__, min_projected_free/MiB, margin/MiB); + __func__, projected_free_per_device[0]/MiB, margins[0]/MiB); + return; + } + } else { + bool changes_needed = false; + for (size_t id = 0; id < nd; id++) { + if (projected_free_per_device[id] < margins[id]) { + changes_needed = true; + break; + } + } + if (!changes_needed) { + LLAMA_LOG_INFO("%s: targets for free memory can be met on all devices, no changes needed\n", __func__); return; } - LLAMA_LOG_INFO("%s: will leave at least %" PRId64 " >= %" PRId64 " MiB of free memory on all devices, no changes needed\n", - __func__, min_projected_free/MiB, margin/MiB); - return; } // step 2: try reducing memory use by reducing the context size { - int64_t global_surplus = sum_projected_free - int64_t(nd)*margin; + int64_t global_surplus = sum_projected_free; + for (size_t id = 0; id < nd; id++) { + global_surplus -= margins[id]; + } if (global_surplus < 0) { - LLAMA_LOG_INFO(nd == 1 ? - "%s: cannot fulfill margin of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n" : - "%s: cannot fulfill margin of %" PRId64 " MiB on all devices, need to use %" PRId64 " MiB less in total\n", - __func__, margin/MiB, -global_surplus/MiB); + if (nd == 1) { + LLAMA_LOG_INFO("%s: cannot meet free memory target of %" PRId64 " MiB, need to reduce device memory by %" PRId64 " MiB\n", + __func__, margins[0]/MiB, -global_surplus/MiB); + } else { + LLAMA_LOG_INFO( + "%s: cannot meet free memory targets on all devices, need to use %" PRId64 " MiB less in total\n", + __func__, -global_surplus/MiB); + } if (cparams->n_ctx == 0) { if (hp_nct > n_ctx_min) { - int64_t sum_used_target = sum_free - nd*margin_s; + int64_t sum_used_target = sum_free; + for (size_t id = 0; id < nd; id++) { + sum_used_target -= margins[id]; + } if (nd > 1) { // for multiple devices we need to be more conservative in terms of how much context we think can fit: // - for dense models only whole layers can be assigned to devices @@ -448,9 +472,9 @@ static void llama_params_fit_impl( const dmds_t dmds_cpu_moe = llama_get_device_memory_data( path_model, mparams, cparams, devs, hp_ngl, hp_nct, hp_nex, log_level); - for (const llama_device_memory_data & dmd : dmds_cpu_moe) { - global_surplus_cpu_moe += dmd.free; - global_surplus_cpu_moe -= int64_t(dmd.mb.total()) + margin; + for (size_t id = 0; id < nd; id++) { + global_surplus_cpu_moe += dmds_cpu_moe[id].free; + global_surplus_cpu_moe -= int64_t(dmds_cpu_moe[id].mb.total()) + margins[id]; } if (global_surplus_cpu_moe > 0) { @@ -469,7 +493,7 @@ static void llama_params_fit_impl( std::vector targets; // maximum acceptable memory use per device targets.reserve(nd); for (size_t id = 0; id < nd; id++) { - targets.push_back(dmds_full[id].free - margin); + targets.push_back(dmds_full[id].free - margins[id]); LLAMA_LOG_DEBUG("%s: id=%zu, target=%" PRId64 " MiB\n", __func__, id, targets[id]/MiB); } @@ -701,11 +725,11 @@ static void llama_params_fit_impl( enum llama_params_fit_status llama_params_fit( const char * path_model, struct llama_model_params * mparams, struct llama_context_params * cparams, float * tensor_split, struct llama_model_tensor_buft_override * tensor_buft_overrides, - size_t margin_s, uint32_t n_ctx_min, enum ggml_log_level log_level) { + size_t * margins, uint32_t n_ctx_min, enum ggml_log_level log_level) { const int64_t t0_us = llama_time_us(); llama_params_fit_status status = LLAMA_PARAMS_FIT_STATUS_SUCCESS; try { - llama_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margin_s, n_ctx_min, log_level); + llama_params_fit_impl(path_model, mparams, cparams, tensor_split, tensor_buft_overrides, margins, n_ctx_min, log_level); LLAMA_LOG_INFO("%s: successfully fit params to free device memory\n", __func__); } catch (const llama_params_fit_exception & e) { LLAMA_LOG_WARN("%s: failed to fit params to free device memory: %s\n", __func__, e.what()); diff --git a/tests/test-arg-parser.cpp b/tests/test-arg-parser.cpp index e995974a2e7..c7be0021beb 100644 --- a/tests/test-arg-parser.cpp +++ b/tests/test-arg-parser.cpp @@ -1,5 +1,6 @@ #include "arg.h" #include "common.h" +#include "download.h" #include #include diff --git a/tools/fit-params/fit-params.cpp b/tools/fit-params/fit-params.cpp index c7e7748ca93..f9d9cb34c7d 100644 --- a/tools/fit-params/fit-params.cpp +++ b/tools/fit-params/fit-params.cpp @@ -27,7 +27,7 @@ int main(int argc, char ** argv) { auto mparams = common_model_params_to_llama(params); auto cparams = common_context_params_to_llama(params); const llama_params_fit_status status = llama_params_fit(params.model.path.c_str(), &mparams, &cparams, - params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target, params.fit_params_min_ctx, + params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx, params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR); if (status != LLAMA_PARAMS_FIT_STATUS_SUCCESS) { LOG_ERR("%s: failed to fit CLI arguments to free memory, exiting...\n", __func__); diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index e4a0be44ccb..16b0db2983e 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -1,10 +1,10 @@ #include "common.h" +#include "download.h" #include "log.h" #include "llama.h" #include "mtmd.h" #include "mtmd-helper.h" #include "chat.h" -#include "arg.h" // for common_remote_get_content; TODO: use download.h only #include "base64.hpp" #include "server-common.h" @@ -779,7 +779,7 @@ static void handle_media( // download remote image // TODO @ngxson : maybe make these params configurable common_remote_params params; - params.headers.push_back("User-Agent: llama.cpp/" + build_info); + params.headers.push_back({"User-Agent", "llama.cpp/" + build_info}); params.max_size = 1024 * 1024 * 10; // 10MB params.timeout = 10; // seconds SRV_INF("downloading image from '%s'\n", url.c_str()); diff --git a/vendor/cpp-httplib/httplib.cpp b/vendor/cpp-httplib/httplib.cpp index b86e6a2310f..a437a36ed7d 100644 --- a/vendor/cpp-httplib/httplib.cpp +++ b/vendor/cpp-httplib/httplib.cpp @@ -9,7 +9,7 @@ namespace httplib { namespace detail { bool is_hex(char c, int &v) { - if (0x20 <= c && isdigit(c)) { + if (isdigit(c)) { v = c - '0'; return true; } else if ('A' <= c && c <= 'F') { @@ -49,6 +49,90 @@ std::string from_i_to_hex(size_t n) { return ret; } +std::string compute_etag(const FileStat &fs) { + if (!fs.is_file()) { return std::string(); } + + // If mtime cannot be determined (negative value indicates an error + // or sentinel), do not generate an ETag. Returning a neutral / fixed + // value like 0 could collide with a real file that legitimately has + // mtime == 0 (epoch) and lead to misleading validators. + auto mtime_raw = fs.mtime(); + if (mtime_raw < 0) { return std::string(); } + + auto mtime = static_cast(mtime_raw); + auto size = fs.size(); + + return std::string("W/\"") + from_i_to_hex(mtime) + "-" + + from_i_to_hex(size) + "\""; +} + +// Format time_t as HTTP-date (RFC 9110 Section 5.6.7): "Sun, 06 Nov 1994 +// 08:49:37 GMT" This implementation is defensive: it validates `mtime`, checks +// return values from `gmtime_r`/`gmtime_s`, and ensures `strftime` succeeds. +std::string file_mtime_to_http_date(time_t mtime) { + if (mtime < 0) { return std::string(); } + + struct tm tm_buf; +#ifdef _WIN32 + if (gmtime_s(&tm_buf, &mtime) != 0) { return std::string(); } +#else + if (gmtime_r(&mtime, &tm_buf) == nullptr) { return std::string(); } +#endif + char buf[64]; + if (strftime(buf, sizeof(buf), "%a, %d %b %Y %H:%M:%S GMT", &tm_buf) == 0) { + return std::string(); + } + + return std::string(buf); +} + +// Parse HTTP-date (RFC 9110 Section 5.6.7) to time_t. Returns -1 on failure. +time_t parse_http_date(const std::string &date_str) { + struct tm tm_buf; + + // Create a classic locale object once for all parsing attempts + const std::locale classic_locale = std::locale::classic(); + + // Try to parse using std::get_time (C++11, cross-platform) + auto try_parse = [&](const char *fmt) -> bool { + std::istringstream ss(date_str); + ss.imbue(classic_locale); + + memset(&tm_buf, 0, sizeof(tm_buf)); + ss >> std::get_time(&tm_buf, fmt); + + return !ss.fail(); + }; + + // RFC 9110 preferred format (HTTP-date): "Sun, 06 Nov 1994 08:49:37 GMT" + if (!try_parse("%a, %d %b %Y %H:%M:%S")) { + // RFC 850 format: "Sunday, 06-Nov-94 08:49:37 GMT" + if (!try_parse("%A, %d-%b-%y %H:%M:%S")) { + // asctime format: "Sun Nov 6 08:49:37 1994" + if (!try_parse("%a %b %d %H:%M:%S %Y")) { + return static_cast(-1); + } + } + } + +#ifdef _WIN32 + return _mkgmtime(&tm_buf); +#else + return timegm(&tm_buf); +#endif +} + +bool is_weak_etag(const std::string &s) { + // Check if the string is a weak ETag (starts with 'W/"') + return s.size() > 3 && s[0] == 'W' && s[1] == '/' && s[2] == '"'; +} + +bool is_strong_etag(const std::string &s) { + // Check if the string is a strong ETag (starts and ends with '"', at least 2 + // chars) + return s.size() >= 2 && s[0] == '"' && s.back() == '"'; +} + size_t to_utf8(int code, char *buff) { if (code < 0x0080) { buff[0] = static_cast(code & 0x7F); @@ -168,6 +252,15 @@ bool FileStat::is_dir() const { return ret_ >= 0 && S_ISDIR(st_.st_mode); } +time_t FileStat::mtime() const { + return ret_ >= 0 ? static_cast(st_.st_mtime) + : static_cast(-1); +} + +size_t FileStat::size() const { + return ret_ >= 0 ? static_cast(st_.st_size) : 0; +} + std::string encode_path(const std::string &s) { std::string result; result.reserve(s.size()); @@ -209,6 +302,149 @@ std::string file_extension(const std::string &path) { bool is_space_or_tab(char c) { return c == ' ' || c == '\t'; } +template +bool parse_header(const char *beg, const char *end, T fn); + +template +bool parse_header(const char *beg, const char *end, T fn) { + // Skip trailing spaces and tabs. + while (beg < end && is_space_or_tab(end[-1])) { + end--; + } + + auto p = beg; + while (p < end && *p != ':') { + p++; + } + + auto name = std::string(beg, p); + if (!detail::fields::is_field_name(name)) { return false; } + + if (p == end) { return false; } + + auto key_end = p; + + if (*p++ != ':') { return false; } + + while (p < end && is_space_or_tab(*p)) { + p++; + } + + if (p <= end) { + auto key_len = key_end - beg; + if (!key_len) { return false; } + + auto key = std::string(beg, key_end); + auto val = std::string(p, end); + + if (!detail::fields::is_field_value(val)) { return false; } + + if (case_ignore::equal(key, "Location") || + case_ignore::equal(key, "Referer")) { + fn(key, val); + } else { + fn(key, decode_path_component(val)); + } + + return true; + } + + return false; +} + +bool parse_trailers(stream_line_reader &line_reader, Headers &dest, + const Headers &src_headers) { + // NOTE: In RFC 9112, '7.1 Chunked Transfer Coding' mentions "The chunked + // transfer coding is complete when a chunk with a chunk-size of zero is + // received, possibly followed by a trailer section, and finally terminated by + // an empty line". https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1 + // + // In '7.1.3. Decoding Chunked', however, the pseudo-code in the section + // doesn't care for the existence of the final CRLF. In other words, it seems + // to be ok whether the final CRLF exists or not in the chunked data. + // https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1.3 + // + // According to the reference code in RFC 9112, cpp-httplib now allows + // chunked transfer coding data without the final CRLF. + + // RFC 7230 Section 4.1.2 - Headers prohibited in trailers + thread_local case_ignore::unordered_set prohibited_trailers = { + "transfer-encoding", + "content-length", + "host", + "authorization", + "www-authenticate", + "proxy-authenticate", + "proxy-authorization", + "cookie", + "set-cookie", + "cache-control", + "expect", + "max-forwards", + "pragma", + "range", + "te", + "age", + "expires", + "date", + "location", + "retry-after", + "vary", + "warning", + "content-encoding", + "content-type", + "content-range", + "trailer"}; + + case_ignore::unordered_set declared_trailers; + auto trailer_header = get_header_value(src_headers, "Trailer", "", 0); + if (trailer_header && std::strlen(trailer_header)) { + auto len = std::strlen(trailer_header); + split(trailer_header, trailer_header + len, ',', + [&](const char *b, const char *e) { + const char *kbeg = b; + const char *kend = e; + while (kbeg < kend && (*kbeg == ' ' || *kbeg == '\t')) { + ++kbeg; + } + while (kend > kbeg && (kend[-1] == ' ' || kend[-1] == '\t')) { + --kend; + } + std::string key(kbeg, static_cast(kend - kbeg)); + if (!key.empty() && + prohibited_trailers.find(key) == prohibited_trailers.end()) { + declared_trailers.insert(key); + } + }); + } + + size_t trailer_header_count = 0; + while (strcmp(line_reader.ptr(), "\r\n") != 0) { + if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { return false; } + if (trailer_header_count >= CPPHTTPLIB_HEADER_MAX_COUNT) { return false; } + + constexpr auto line_terminator_len = 2; + auto line_beg = line_reader.ptr(); + auto line_end = + line_reader.ptr() + line_reader.size() - line_terminator_len; + + if (!parse_header(line_beg, line_end, + [&](const std::string &key, const std::string &val) { + if (declared_trailers.find(key) != + declared_trailers.end()) { + dest.emplace(key, val); + trailer_header_count++; + } + })) { + return false; + } + + if (!line_reader.getline()) { return false; } + } + + return true; +} + std::pair trim(const char *b, const char *e, size_t left, size_t right) { while (b + left < e && is_space_or_tab(b[left])) { @@ -280,6 +516,42 @@ void split(const char *b, const char *e, char d, size_t m, } } +bool split_find(const char *b, const char *e, char d, size_t m, + std::function fn) { + size_t i = 0; + size_t beg = 0; + size_t count = 1; + + while (e ? (b + i < e) : (b[i] != '\0')) { + if (b[i] == d && count < m) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { + auto found = fn(&b[r.first], &b[r.second]); + if (found) { return true; } + } + beg = i + 1; + count++; + } + i++; + } + + if (i) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { + auto found = fn(&b[r.first], &b[r.second]); + if (found) { return true; } + } + } + + return false; +} + +bool split_find(const char *b, const char *e, char d, + std::function fn) { + return split_find(b, e, d, (std::numeric_limits::max)(), + std::move(fn)); +} + stream_line_reader::stream_line_reader(Stream &strm, char *fixed_buffer, size_t fixed_buffer_size) : strm_(strm), fixed_buffer_(fixed_buffer), @@ -1892,6 +2164,27 @@ bool zstd_decompressor::decompress(const char *data, size_t data_length, } #endif +std::unique_ptr +create_decompressor(const std::string &encoding) { + std::unique_ptr decompressor; + + if (encoding == "gzip" || encoding == "deflate") { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + decompressor = detail::make_unique(); +#endif + } else if (encoding.find("br") != std::string::npos) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + decompressor = detail::make_unique(); +#endif + } else if (encoding == "zstd" || encoding.find("zstd") != std::string::npos) { +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + decompressor = detail::make_unique(); +#endif + } + + return decompressor; +} + bool is_prohibited_header_name(const std::string &name) { using udl::operator""_t; @@ -1928,53 +2221,6 @@ const char *get_header_value(const Headers &headers, return def; } -template -bool parse_header(const char *beg, const char *end, T fn) { - // Skip trailing spaces and tabs. - while (beg < end && is_space_or_tab(end[-1])) { - end--; - } - - auto p = beg; - while (p < end && *p != ':') { - p++; - } - - auto name = std::string(beg, p); - if (!detail::fields::is_field_name(name)) { return false; } - - if (p == end) { return false; } - - auto key_end = p; - - if (*p++ != ':') { return false; } - - while (p < end && is_space_or_tab(*p)) { - p++; - } - - if (p <= end) { - auto key_len = key_end - beg; - if (!key_len) { return false; } - - auto key = std::string(beg, key_end); - auto val = std::string(p, end); - - if (!detail::fields::is_field_value(val)) { return false; } - - if (case_ignore::equal(key, "Location") || - case_ignore::equal(key, "Referer")) { - fn(key, val); - } else { - fn(key, decode_path_component(val)); - } - - return true; - } - - return false; -} - bool read_headers(Stream &strm, Headers &headers) { const auto bufsiz = 2048; char buf[bufsiz]; @@ -2026,10 +2272,18 @@ bool read_content_with_length(Stream &strm, size_t len, ContentReceiverWithProgress out) { char buf[CPPHTTPLIB_RECV_BUFSIZ]; + detail::BodyReader br; + br.stream = &strm; + br.content_length = len; + br.chunked = false; + br.bytes_read = 0; + br.last_error = Error::Success; + size_t r = 0; while (r < len) { auto read_len = static_cast(len - r); - auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + auto to_read = (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ); + auto n = detail::read_body_content(&strm, br, buf, to_read); if (n <= 0) { return false; } if (!out(buf, static_cast(n), r, len)) { return false; } @@ -2089,125 +2343,35 @@ template ReadContentResult read_content_chunked(Stream &strm, T &x, size_t payload_max_length, ContentReceiverWithProgress out) { - const auto bufsiz = 16; - char buf[bufsiz]; - - stream_line_reader line_reader(strm, buf, bufsiz); - - if (!line_reader.getline()) { return ReadContentResult::Error; } + detail::ChunkedDecoder dec(strm); - unsigned long chunk_len; + char buf[CPPHTTPLIB_RECV_BUFSIZ]; size_t total_len = 0; - while (true) { - char *end_ptr; - - chunk_len = std::strtoul(line_reader.ptr(), &end_ptr, 16); - if (end_ptr == line_reader.ptr()) { return ReadContentResult::Error; } - if (chunk_len == ULONG_MAX) { return ReadContentResult::Error; } + for (;;) { + size_t chunk_offset = 0; + size_t chunk_total = 0; + auto n = dec.read_payload(buf, sizeof(buf), chunk_offset, chunk_total); + if (n < 0) { return ReadContentResult::Error; } - if (chunk_len == 0) { break; } + if (n == 0) { + if (!dec.parse_trailers_into(x.trailers, x.headers)) { + return ReadContentResult::Error; + } + return ReadContentResult::Success; + } - // Check if adding this chunk would exceed the payload limit if (total_len > payload_max_length || - payload_max_length - total_len < chunk_len) { + payload_max_length - total_len < static_cast(n)) { return ReadContentResult::PayloadTooLarge; } - total_len += chunk_len; - - if (!read_content_with_length(strm, chunk_len, nullptr, out)) { + if (!out(buf, static_cast(n), chunk_offset, chunk_total)) { return ReadContentResult::Error; } - if (!line_reader.getline()) { return ReadContentResult::Error; } - - if (strcmp(line_reader.ptr(), "\r\n") != 0) { - return ReadContentResult::Error; - } - - if (!line_reader.getline()) { return ReadContentResult::Error; } + total_len += static_cast(n); } - - assert(chunk_len == 0); - - // NOTE: In RFC 9112, '7.1 Chunked Transfer Coding' mentions "The chunked - // transfer coding is complete when a chunk with a chunk-size of zero is - // received, possibly followed by a trailer section, and finally terminated by - // an empty line". https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1 - // - // In '7.1.3. Decoding Chunked', however, the pseudo-code in the section - // does't care for the existence of the final CRLF. In other words, it seems - // to be ok whether the final CRLF exists or not in the chunked data. - // https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1.3 - // - // According to the reference code in RFC 9112, cpp-httplib now allows - // chunked transfer coding data without the final CRLF. - if (!line_reader.getline()) { return ReadContentResult::Success; } - - // RFC 7230 Section 4.1.2 - Headers prohibited in trailers - thread_local case_ignore::unordered_set prohibited_trailers = { - // Message framing - "transfer-encoding", "content-length", - - // Routing - "host", - - // Authentication - "authorization", "www-authenticate", "proxy-authenticate", - "proxy-authorization", "cookie", "set-cookie", - - // Request modifiers - "cache-control", "expect", "max-forwards", "pragma", "range", "te", - - // Response control - "age", "expires", "date", "location", "retry-after", "vary", "warning", - - // Payload processing - "content-encoding", "content-type", "content-range", "trailer"}; - - // Parse declared trailer headers once for performance - case_ignore::unordered_set declared_trailers; - if (has_header(x.headers, "Trailer")) { - auto trailer_header = get_header_value(x.headers, "Trailer", "", 0); - auto len = std::strlen(trailer_header); - - split(trailer_header, trailer_header + len, ',', - [&](const char *b, const char *e) { - std::string key(b, e); - if (prohibited_trailers.find(key) == prohibited_trailers.end()) { - declared_trailers.insert(key); - } - }); - } - - size_t trailer_header_count = 0; - while (strcmp(line_reader.ptr(), "\r\n") != 0) { - if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { - return ReadContentResult::Error; - } - - // Check trailer header count limit - if (trailer_header_count >= CPPHTTPLIB_HEADER_MAX_COUNT) { - return ReadContentResult::Error; - } - - // Exclude line terminator - constexpr auto line_terminator_len = 2; - auto end = line_reader.ptr() + line_reader.size() - line_terminator_len; - - parse_header(line_reader.ptr(), end, - [&](const std::string &key, const std::string &val) { - if (declared_trailers.find(key) != declared_trailers.end()) { - x.trailers.emplace(key, val); - trailer_header_count++; - } - }); - - if (!line_reader.getline()) { return ReadContentResult::Error; } - } - - return ReadContentResult::Success; } bool is_chunked_transfer_encoding(const Headers &headers) { @@ -2223,27 +2387,13 @@ bool prepare_content_receiver(T &x, int &status, std::string encoding = x.get_header_value("Content-Encoding"); std::unique_ptr decompressor; - if (encoding == "gzip" || encoding == "deflate") { -#ifdef CPPHTTPLIB_ZLIB_SUPPORT - decompressor = detail::make_unique(); -#else - status = StatusCode::UnsupportedMediaType_415; - return false; -#endif - } else if (encoding.find("br") != std::string::npos) { -#ifdef CPPHTTPLIB_BROTLI_SUPPORT - decompressor = detail::make_unique(); -#else - status = StatusCode::UnsupportedMediaType_415; - return false; -#endif - } else if (encoding == "zstd") { -#ifdef CPPHTTPLIB_ZSTD_SUPPORT - decompressor = detail::make_unique(); -#else - status = StatusCode::UnsupportedMediaType_415; - return false; -#endif + if (!encoding.empty()) { + decompressor = detail::create_decompressor(encoding); + if (!decompressor) { + // Unsupported encoding or no support compiled in + status = StatusCode::UnsupportedMediaType_415; + return false; + } } if (decompressor) { @@ -2329,7 +2479,7 @@ bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, ssize_t write_request_line(Stream &strm, const std::string &method, const std::string &path) { std::string s = method; - s += " "; + s += ' '; s += path; s += " HTTP/1.1\r\n"; return strm.write(s.data(), s.size()); @@ -2338,7 +2488,7 @@ ssize_t write_request_line(Stream &strm, const std::string &method, ssize_t write_response_line(Stream &strm, int status) { std::string s = "HTTP/1.1 "; s += std::to_string(status); - s += " "; + s += ' '; s += httplib::status_message(status); s += "\r\n"; return strm.write(s.data(), s.size()); @@ -2601,8 +2751,8 @@ bool redirect(T &cli, Request &req, Response &res, auto ret = cli.send(new_req, new_res, error); if (ret) { - req = new_req; - res = new_res; + req = std::move(new_req); + res = std::move(new_res); if (res.location.empty()) { res.location = location; } } @@ -2613,9 +2763,9 @@ std::string params_to_query_str(const Params ¶ms) { std::string query; for (auto it = params.begin(); it != params.end(); ++it) { - if (it != params.begin()) { query += "&"; } + if (it != params.begin()) { query += '&'; } query += encode_query_component(it->first); - query += "="; + query += '='; query += encode_query_component(it->second); } return query; @@ -2648,6 +2798,38 @@ void parse_query_text(const std::string &s, Params ¶ms) { parse_query_text(s.data(), s.size(), params); } +// Normalize a query string by decoding and re-encoding each key/value pair +// while preserving the original parameter order. This avoids double-encoding +// and ensures consistent encoding without reordering (unlike Params which +// uses std::multimap and sorts keys). +std::string normalize_query_string(const std::string &query) { + std::string result; + split(query.data(), query.data() + query.size(), '&', + [&](const char *b, const char *e) { + std::string key; + std::string val; + divide(b, static_cast(e - b), '=', + [&](const char *lhs_data, std::size_t lhs_size, + const char *rhs_data, std::size_t rhs_size) { + key.assign(lhs_data, lhs_size); + val.assign(rhs_data, rhs_size); + }); + + if (!key.empty()) { + auto dec_key = decode_query_component(key); + auto dec_val = decode_query_component(val); + + if (!result.empty()) { result += '&'; } + result += encode_query_component(dec_key); + if (!val.empty() || std::find(b, e, '=') != e) { + result += '='; + result += encode_query_component(dec_val); + } + } + }); + return result; +} + bool parse_multipart_boundary(const std::string &content_type, std::string &boundary) { auto boundary_keyword = "boundary="; @@ -2840,7 +3022,7 @@ bool parse_accept_header(const std::string &s, return; } - entries.push_back(accept_entry); + entries.push_back(std::move(accept_entry)); }); // Return false if any invalid entry was found @@ -2857,8 +3039,8 @@ bool parse_accept_header(const std::string &s, // Extract sorted media types content_types.reserve(entries.size()); - for (const auto &entry : entries) { - content_types.push_back(entry.media_type); + for (auto &entry : entries) { + content_types.push_back(std::move(entry.media_type)); } return true; @@ -2869,7 +3051,7 @@ class FormDataParser { FormDataParser() = default; void set_boundary(std::string &&boundary) { - boundary_ = boundary; + boundary_ = std::move(boundary); dash_boundary_crlf_ = dash_ + boundary_ + crlf_; crlf_dash_boundary_ = crlf_ + dash_ + boundary_; } @@ -3342,9 +3524,9 @@ std::string make_content_range_header_field( std::string field = "bytes "; field += std::to_string(st); - field += "-"; + field += '-'; field += std::to_string(ed); - field += "/"; + field += '/'; field += std::to_string(content_length); return field; } @@ -3721,7 +3903,7 @@ bool parse_www_authenticate(const Response &res, static_cast(m.length(2))) : s.substr(static_cast(m.position(3)), static_cast(m.length(3))); - auth[key] = val; + auth[std::move(key)] = std::move(val); } return true; } @@ -3734,7 +3916,7 @@ class ContentProviderAdapter { public: explicit ContentProviderAdapter( ContentProviderWithoutLength &&content_provider) - : content_provider_(content_provider) {} + : content_provider_(std::move(content_provider)) {} bool operator()(size_t offset, size_t, DataSink &sink) { return content_provider_(offset, sink); @@ -3744,8 +3926,189 @@ class ContentProviderAdapter { ContentProviderWithoutLength content_provider_; }; +// NOTE: https://www.rfc-editor.org/rfc/rfc9110#section-5 +namespace fields { + +bool is_token_char(char c) { + return std::isalnum(c) || c == '!' || c == '#' || c == '$' || c == '%' || + c == '&' || c == '\'' || c == '*' || c == '+' || c == '-' || + c == '.' || c == '^' || c == '_' || c == '`' || c == '|' || c == '~'; +} + +bool is_token(const std::string &s) { + if (s.empty()) { return false; } + for (auto c : s) { + if (!is_token_char(c)) { return false; } + } + return true; +} + +bool is_field_name(const std::string &s) { return is_token(s); } + +bool is_vchar(char c) { return c >= 33 && c <= 126; } + +bool is_obs_text(char c) { return 128 <= static_cast(c); } + +bool is_field_vchar(char c) { return is_vchar(c) || is_obs_text(c); } + +bool is_field_content(const std::string &s) { + if (s.empty()) { return true; } + + if (s.size() == 1) { + return is_field_vchar(s[0]); + } else if (s.size() == 2) { + return is_field_vchar(s[0]) && is_field_vchar(s[1]); + } else { + size_t i = 0; + + if (!is_field_vchar(s[i])) { return false; } + i++; + + while (i < s.size() - 1) { + auto c = s[i++]; + if (c == ' ' || c == '\t' || is_field_vchar(c)) { + } else { + return false; + } + } + + return is_field_vchar(s[i]); + } +} + +bool is_field_value(const std::string &s) { return is_field_content(s); } + +} // namespace fields + } // namespace detail +const char *status_message(int status) { + switch (status) { + case StatusCode::Continue_100: return "Continue"; + case StatusCode::SwitchingProtocol_101: return "Switching Protocol"; + case StatusCode::Processing_102: return "Processing"; + case StatusCode::EarlyHints_103: return "Early Hints"; + case StatusCode::OK_200: return "OK"; + case StatusCode::Created_201: return "Created"; + case StatusCode::Accepted_202: return "Accepted"; + case StatusCode::NonAuthoritativeInformation_203: + return "Non-Authoritative Information"; + case StatusCode::NoContent_204: return "No Content"; + case StatusCode::ResetContent_205: return "Reset Content"; + case StatusCode::PartialContent_206: return "Partial Content"; + case StatusCode::MultiStatus_207: return "Multi-Status"; + case StatusCode::AlreadyReported_208: return "Already Reported"; + case StatusCode::IMUsed_226: return "IM Used"; + case StatusCode::MultipleChoices_300: return "Multiple Choices"; + case StatusCode::MovedPermanently_301: return "Moved Permanently"; + case StatusCode::Found_302: return "Found"; + case StatusCode::SeeOther_303: return "See Other"; + case StatusCode::NotModified_304: return "Not Modified"; + case StatusCode::UseProxy_305: return "Use Proxy"; + case StatusCode::unused_306: return "unused"; + case StatusCode::TemporaryRedirect_307: return "Temporary Redirect"; + case StatusCode::PermanentRedirect_308: return "Permanent Redirect"; + case StatusCode::BadRequest_400: return "Bad Request"; + case StatusCode::Unauthorized_401: return "Unauthorized"; + case StatusCode::PaymentRequired_402: return "Payment Required"; + case StatusCode::Forbidden_403: return "Forbidden"; + case StatusCode::NotFound_404: return "Not Found"; + case StatusCode::MethodNotAllowed_405: return "Method Not Allowed"; + case StatusCode::NotAcceptable_406: return "Not Acceptable"; + case StatusCode::ProxyAuthenticationRequired_407: + return "Proxy Authentication Required"; + case StatusCode::RequestTimeout_408: return "Request Timeout"; + case StatusCode::Conflict_409: return "Conflict"; + case StatusCode::Gone_410: return "Gone"; + case StatusCode::LengthRequired_411: return "Length Required"; + case StatusCode::PreconditionFailed_412: return "Precondition Failed"; + case StatusCode::PayloadTooLarge_413: return "Payload Too Large"; + case StatusCode::UriTooLong_414: return "URI Too Long"; + case StatusCode::UnsupportedMediaType_415: return "Unsupported Media Type"; + case StatusCode::RangeNotSatisfiable_416: return "Range Not Satisfiable"; + case StatusCode::ExpectationFailed_417: return "Expectation Failed"; + case StatusCode::ImATeapot_418: return "I'm a teapot"; + case StatusCode::MisdirectedRequest_421: return "Misdirected Request"; + case StatusCode::UnprocessableContent_422: return "Unprocessable Content"; + case StatusCode::Locked_423: return "Locked"; + case StatusCode::FailedDependency_424: return "Failed Dependency"; + case StatusCode::TooEarly_425: return "Too Early"; + case StatusCode::UpgradeRequired_426: return "Upgrade Required"; + case StatusCode::PreconditionRequired_428: return "Precondition Required"; + case StatusCode::TooManyRequests_429: return "Too Many Requests"; + case StatusCode::RequestHeaderFieldsTooLarge_431: + return "Request Header Fields Too Large"; + case StatusCode::UnavailableForLegalReasons_451: + return "Unavailable For Legal Reasons"; + case StatusCode::NotImplemented_501: return "Not Implemented"; + case StatusCode::BadGateway_502: return "Bad Gateway"; + case StatusCode::ServiceUnavailable_503: return "Service Unavailable"; + case StatusCode::GatewayTimeout_504: return "Gateway Timeout"; + case StatusCode::HttpVersionNotSupported_505: + return "HTTP Version Not Supported"; + case StatusCode::VariantAlsoNegotiates_506: return "Variant Also Negotiates"; + case StatusCode::InsufficientStorage_507: return "Insufficient Storage"; + case StatusCode::LoopDetected_508: return "Loop Detected"; + case StatusCode::NotExtended_510: return "Not Extended"; + case StatusCode::NetworkAuthenticationRequired_511: + return "Network Authentication Required"; + + default: + case StatusCode::InternalServerError_500: return "Internal Server Error"; + } +} + +std::string to_string(const Error error) { + switch (error) { + case Error::Success: return "Success (no error)"; + case Error::Unknown: return "Unknown"; + case Error::Connection: return "Could not establish connection"; + case Error::BindIPAddress: return "Failed to bind IP address"; + case Error::Read: return "Failed to read connection"; + case Error::Write: return "Failed to write connection"; + case Error::ExceedRedirectCount: return "Maximum redirect count exceeded"; + case Error::Canceled: return "Connection handling canceled"; + case Error::SSLConnection: return "SSL connection failed"; + case Error::SSLLoadingCerts: return "SSL certificate loading failed"; + case Error::SSLServerVerification: return "SSL server verification failed"; + case Error::SSLServerHostnameVerification: + return "SSL server hostname verification failed"; + case Error::UnsupportedMultipartBoundaryChars: + return "Unsupported HTTP multipart boundary characters"; + case Error::Compression: return "Compression failed"; + case Error::ConnectionTimeout: return "Connection timed out"; + case Error::ProxyConnection: return "Proxy connection failed"; + case Error::ConnectionClosed: return "Connection closed by server"; + case Error::Timeout: return "Read timeout"; + case Error::ResourceExhaustion: return "Resource exhaustion"; + case Error::TooManyFormDataFiles: return "Too many form data files"; + case Error::ExceedMaxPayloadSize: return "Exceeded maximum payload size"; + case Error::ExceedUriMaxLength: return "Exceeded maximum URI length"; + case Error::ExceedMaxSocketDescriptorCount: + return "Exceeded maximum socket descriptor count"; + case Error::InvalidRequestLine: return "Invalid request line"; + case Error::InvalidHTTPMethod: return "Invalid HTTP method"; + case Error::InvalidHTTPVersion: return "Invalid HTTP version"; + case Error::InvalidHeaders: return "Invalid headers"; + case Error::MultipartParsing: return "Multipart parsing failed"; + case Error::OpenFile: return "Failed to open file"; + case Error::Listen: return "Failed to listen on socket"; + case Error::GetSockName: return "Failed to get socket name"; + case Error::UnsupportedAddressFamily: return "Unsupported address family"; + case Error::HTTPParsing: return "HTTP parsing failed"; + case Error::InvalidRangeHeader: return "Invalid Range header"; + default: break; + } + + return "Invalid"; +} + +std::ostream &operator<<(std::ostream &os, const Error &obj) { + os << to_string(obj); + os << " (" << static_cast::type>(obj) << ')'; + return os; +} + std::string hosted_at(const std::string &hostname) { std::vector addrs; hosted_at(hostname, addrs); @@ -3779,7 +4142,7 @@ void hosted_at(const std::string &hostname, auto dummy = -1; if (detail::get_ip_and_port(addr, sizeof(struct sockaddr_storage), ip, dummy)) { - addrs.push_back(ip); + addrs.emplace_back(std::move(ip)); } } } @@ -4319,6 +4682,67 @@ ssize_t Stream::write(const std::string &s) { return write(s.data(), s.size()); } +// BodyReader implementation +ssize_t detail::BodyReader::read(char *buf, size_t len) { + if (!stream) { + last_error = Error::Connection; + return -1; + } + if (eof) { return 0; } + + if (!chunked) { + // Content-Length based reading + if (bytes_read >= content_length) { + eof = true; + return 0; + } + + auto remaining = content_length - bytes_read; + auto to_read = (std::min)(len, remaining); + auto n = stream->read(buf, to_read); + + if (n < 0) { + last_error = stream->get_error(); + if (last_error == Error::Success) { last_error = Error::Read; } + eof = true; + return n; + } + if (n == 0) { + // Unexpected EOF before content_length + last_error = stream->get_error(); + if (last_error == Error::Success) { last_error = Error::Read; } + eof = true; + return 0; + } + + bytes_read += static_cast(n); + if (bytes_read >= content_length) { eof = true; } + return n; + } + + // Chunked transfer encoding: delegate to shared decoder instance. + if (!chunked_decoder) { chunked_decoder.reset(new ChunkedDecoder(*stream)); } + + size_t chunk_offset = 0; + size_t chunk_total = 0; + auto n = chunked_decoder->read_payload(buf, len, chunk_offset, chunk_total); + if (n < 0) { + last_error = stream->get_error(); + if (last_error == Error::Success) { last_error = Error::Read; } + eof = true; + return n; + } + + if (n == 0) { + // Final chunk observed. Leave trailer parsing to the caller (StreamHandle). + eof = true; + return 0; + } + + bytes_read += static_cast(n); + return n; +} + namespace detail { void calc_actual_timeout(time_t max_timeout_msec, time_t duration_msec, @@ -4395,7 +4819,10 @@ ssize_t SocketStream::read(char *ptr, size_t size) { } } - if (!wait_readable()) { return -1; } + if (!wait_readable()) { + error_ = Error::Timeout; + return -1; + } read_buff_off_ = 0; read_buff_content_size_ = 0; @@ -4404,6 +4831,11 @@ ssize_t SocketStream::read(char *ptr, size_t size) { auto n = read_socket(sock_, read_buff_.data(), read_buff_size_, CPPHTTPLIB_RECV_FLAGS); if (n <= 0) { + if (n == 0) { + error_ = Error::ConnectionClosed; + } else { + error_ = Error::Read; + } return n; } else if (n <= static_cast(size)) { memcpy(ptr, read_buff_.data(), static_cast(n)); @@ -4415,7 +4847,15 @@ ssize_t SocketStream::read(char *ptr, size_t size) { return static_cast(size); } } else { - return read_socket(sock_, ptr, size, CPPHTTPLIB_RECV_FLAGS); + auto n = read_socket(sock_, ptr, size, CPPHTTPLIB_RECV_FLAGS); + if (n <= 0) { + if (n == 0) { + error_ = Error::ConnectionClosed; + } else { + error_ = Error::Read; + } + } + return n; } } @@ -4579,19 +5019,22 @@ bool RegexMatcher::match(Request &request) const { return std::regex_match(request.path, request.matches, regex_); } -std::string make_host_and_port_string(const std::string &host, int port, - bool is_ssl) { - std::string result; - +// Enclose IPv6 address in brackets if needed +std::string prepare_host_string(const std::string &host) { // Enclose IPv6 address in brackets (but not if already enclosed) if (host.find(':') == std::string::npos || (!host.empty() && host[0] == '[')) { // IPv4, hostname, or already bracketed IPv6 - result = host; + return host; } else { // IPv6 address without brackets - result = "[" + host + "]"; + return "[" + host + "]"; } +} + +std::string make_host_and_port_string(const std::string &host, int port, + bool is_ssl) { + auto result = prepare_host_string(host); // Append port if not default if ((!is_ssl && port == 80) || (is_ssl && port == 443)) { @@ -4603,6 +5046,29 @@ std::string make_host_and_port_string(const std::string &host, int port, return result; } +// Create "host:port" string always including port number (for CONNECT method) +std::string +make_host_and_port_string_always_port(const std::string &host, int port) { + return prepare_host_string(host) + ":" + std::to_string(port); +} + +template +bool check_and_write_headers(Stream &strm, Headers &headers, + T header_writer, Error &error) { + for (const auto &h : headers) { + if (!detail::fields::is_field_name(h.first) || + !detail::fields::is_field_value(h.second)) { + error = Error::InvalidHeaders; + return false; + } + } + if (header_writer(strm, headers) <= 0) { + error = Error::Write; + return false; + } + return true; +} + } // namespace detail // HTTP server implementation @@ -4694,7 +5160,7 @@ bool Server::set_mount_point(const std::string &mount_point, if (stat.is_dir()) { std::string mnt = !mount_point.empty() ? mount_point : "/"; if (!mnt.empty() && mnt[0] == '/') { - base_dirs_.push_back({mnt, dir, std::move(headers)}); + base_dirs_.push_back({std::move(mnt), dir, std::move(headers)}); return true; } } @@ -5010,7 +5476,7 @@ bool Server::write_response_core(Stream &strm, bool close_connection, { detail::BufferStream bstrm; if (!detail::write_response_line(bstrm, res.status)) { return false; } - if (!header_writer_(bstrm, res.headers)) { return false; } + if (header_writer_(bstrm, res.headers) <= 0) { return false; } // Flush buffer auto &data = bstrm.get_buffer(); @@ -5103,7 +5569,16 @@ bool Server::read_content(Stream &strm, Request &req, Response &res) { strm, req, res, // Regular [&](const char *buf, size_t n) { - if (req.body.size() + n > req.body.max_size()) { return false; } + // Prevent arithmetic overflow when checking sizes. + // Avoid computing (req.body.size() + n) directly because + // adding two unsigned `size_t` values can wrap around and + // produce a small result instead of indicating overflow. + // Instead, check using subtraction: ensure `n` does not + // exceed the remaining capacity `max_size() - size()`. + if (req.body.size() >= req.body.max_size() || + n > req.body.max_size() - req.body.size()) { + return false; + } req.body.append(buf, n); return true; }, @@ -5182,14 +5657,43 @@ bool Server::read_content_core( out = [receiver](const char *buf, size_t n, size_t /*off*/, size_t /*len*/) { return receiver(buf, n); }; } - - // RFC 7230 Section 3.3.3: If this is a request message and none of the above - // are true (no Transfer-Encoding and no Content-Length), then the message - // body length is zero (no message body is present). + + // RFC 7230 Section 3.3.3: If this is a request message and none of the above + // are true (no Transfer-Encoding and no Content-Length), then the message + // body length is zero (no message body is present). + // + // For non-SSL builds, peek into the socket to detect clients that send a + // body without a Content-Length header (raw HTTP over TCP). If there is + // pending data that exceeds the configured payload limit, treat this as an + // oversized request and fail early (causing connection close). For SSL + // builds we cannot reliably peek the decrypted application bytes, so keep + // the original behaviour. +#if !defined(CPPHTTPLIB_OPENSSL_SUPPORT) && !defined(_WIN32) + if (!req.has_header("Content-Length") && + !detail::is_chunked_transfer_encoding(req.headers)) { + socket_t s = strm.socket(); + if (s != INVALID_SOCKET) { + // Peek up to payload_max_length_ + 1 bytes. If more than + // payload_max_length_ bytes are pending, reject the request. + size_t to_peek = + (payload_max_length_ > 0) + ? (std::min)(payload_max_length_ + 1, static_cast(4096)) + : 1; + std::vector peekbuf(to_peek); + ssize_t n = ::recv(s, peekbuf.data(), to_peek, MSG_PEEK); + if (n > 0 && static_cast(n) > payload_max_length_) { + // Indicate failure so connection will be closed. + return false; + } + } + return true; + } +#else if (!req.has_header("Content-Length") && !detail::is_chunked_transfer_encoding(req.headers)) { return true; } +#endif if (!detail::read_content(strm, req, payload_max_length_, res.status, nullptr, out, true)) { @@ -5207,7 +5711,7 @@ bool Server::read_content_core( return true; } -bool Server::handle_file_request(const Request &req, Response &res) { +bool Server::handle_file_request(Request &req, Response &res) { for (const auto &entry : base_dirs_) { // Prefix match if (!req.path.compare(0, entry.mount_point.size(), entry.mount_point)) { @@ -5228,6 +5732,20 @@ bool Server::handle_file_request(const Request &req, Response &res) { res.set_header(kv.first, kv.second); } + auto etag = detail::compute_etag(stat); + if (!etag.empty()) { res.set_header("ETag", etag); } + + auto mtime = stat.mtime(); + + auto last_modified = detail::file_mtime_to_http_date(mtime); + if (!last_modified.empty()) { + res.set_header("Last-Modified", last_modified); + } + + if (check_if_not_modified(req, res, etag, mtime)) { return true; } + + check_if_range(req, etag, mtime); + auto mm = std::make_shared(path.c_str()); if (!mm->is_open()) { output_error_log(Error::OpenFile, &req); @@ -5257,6 +5775,79 @@ bool Server::handle_file_request(const Request &req, Response &res) { return false; } +bool Server::check_if_not_modified(const Request &req, Response &res, + const std::string &etag, + time_t mtime) const { + // Handle conditional GET: + // 1. If-None-Match takes precedence (RFC 9110 Section 13.1.2) + // 2. If-Modified-Since is checked only when If-None-Match is absent + if (req.has_header("If-None-Match")) { + if (!etag.empty()) { + auto val = req.get_header_value("If-None-Match"); + + // NOTE: We use exact string matching here. This works correctly + // because our server always generates weak ETags (W/"..."), and + // clients typically send back the same ETag they received. + // RFC 9110 Section 8.8.3.2 allows weak comparison for + // If-None-Match, where W/"x" and "x" would match, but this + // simplified implementation requires exact matches. + auto ret = detail::split_find(val.data(), val.data() + val.size(), ',', + [&](const char *b, const char *e) { + return std::equal(b, e, "*") || + std::equal(b, e, etag.begin()); + }); + + if (ret) { + res.status = StatusCode::NotModified_304; + return true; + } + } + } else if (req.has_header("If-Modified-Since")) { + auto val = req.get_header_value("If-Modified-Since"); + auto t = detail::parse_http_date(val); + + if (t != static_cast(-1) && mtime <= t) { + res.status = StatusCode::NotModified_304; + return true; + } + } + return false; +} + +bool Server::check_if_range(Request &req, const std::string &etag, + time_t mtime) const { + // Handle If-Range for partial content requests (RFC 9110 + // Section 13.1.5). If-Range is only evaluated when Range header is + // present. If the validator matches, serve partial content; otherwise + // serve full content. + if (!req.ranges.empty() && req.has_header("If-Range")) { + auto val = req.get_header_value("If-Range"); + + auto is_valid_range = [&]() { + if (detail::is_strong_etag(val)) { + // RFC 9110 Section 13.1.5: If-Range requires strong ETag + // comparison. + return (!etag.empty() && val == etag); + } else if (detail::is_weak_etag(val)) { + // Weak ETags are not valid for If-Range (RFC 9110 Section 13.1.5) + return false; + } else { + // HTTP-date comparison + auto t = detail::parse_http_date(val); + return (t != static_cast(-1) && mtime <= t); + } + }; + + if (!is_valid_range()) { + // Validator doesn't match: ignore Range and serve full content + req.ranges.clear(); + return false; + } + } + + return true; +} + socket_t Server::create_server_socket(const std::string &host, int port, int socket_flags, @@ -5524,10 +6115,13 @@ void Server::apply_ranges(const Request &req, Response &res, res.set_header("Transfer-Encoding", "chunked"); if (type == detail::EncodingType::Gzip) { res.set_header("Content-Encoding", "gzip"); + res.set_header("Vary", "Accept-Encoding"); } else if (type == detail::EncodingType::Brotli) { res.set_header("Content-Encoding", "br"); + res.set_header("Vary", "Accept-Encoding"); } else if (type == detail::EncodingType::Zstd) { res.set_header("Content-Encoding", "zstd"); + res.set_header("Vary", "Accept-Encoding"); } } } @@ -5586,6 +6180,7 @@ void Server::apply_ranges(const Request &req, Response &res, })) { res.body.swap(compressed); res.set_header("Content-Encoding", content_encoding); + res.set_header("Vary", "Accept-Encoding"); } } } @@ -5663,6 +6258,10 @@ Server::process_request(Stream &strm, const std::string &remote_addr, Request req; req.start_time_ = std::chrono::steady_clock::now(); + req.remote_addr = remote_addr; + req.remote_port = remote_port; + req.local_addr = local_addr; + req.local_port = local_port; Response res; res.version = "HTTP/1.1"; @@ -5908,7 +6507,6 @@ ClientImpl::ClientImpl(const std::string &host, int port, const std::string &client_cert_path, const std::string &client_key_path) : host_(detail::escape_abstract_namespace_unix_domain(host)), port_(port), - host_and_port_(detail::make_host_and_port_string(host_, port, is_ssl())), client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} ClientImpl::~ClientImpl() { @@ -6007,6 +6605,26 @@ bool ClientImpl::create_and_connect_socket(Socket &socket, return true; } +bool ClientImpl::ensure_socket_connection(Socket &socket, Error &error) { + return create_and_connect_socket(socket, error); +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +bool SSLClient::ensure_socket_connection(Socket &socket, Error &error) { + if (!ClientImpl::ensure_socket_connection(socket, error)) { return false; } + + if (!proxy_host_.empty() && proxy_port_ != -1) { return true; } + + if (!initialize_ssl(socket, error)) { + shutdown_socket(socket); + close_socket(socket); + return false; + } + + return true; +} +#endif + void ClientImpl::shutdown_ssl(Socket & /*socket*/, bool /*shutdown_gracefully*/) { // If there are any requests in flight from threads other than us, then it's @@ -6119,7 +6737,7 @@ bool ClientImpl::send_(Request &req, Response &res, Error &error) { } if (!is_alive) { - if (!create_and_connect_socket(socket_, error)) { + if (!ensure_socket_connection(socket_, error)) { output_error_log(error, &req); return false; } @@ -6137,9 +6755,11 @@ bool ClientImpl::send_(Request &req, Response &res, Error &error) { } } - if (!scli.initialize_ssl(socket_, error)) { - output_error_log(error, &req); - return false; + if (!proxy_host_.empty() && proxy_port_ != -1) { + if (!scli.initialize_ssl(socket_, error)) { + output_error_log(error, &req); + return false; + } } } #endif @@ -6212,6 +6832,343 @@ Result ClientImpl::send_(Request &&req) { #endif } +void ClientImpl::prepare_default_headers(Request &r, bool for_stream, + const std::string &ct) { + (void)for_stream; + for (const auto &header : default_headers_) { + if (!r.has_header(header.first)) { r.headers.insert(header); } + } + + if (!r.has_header("Host")) { + if (address_family_ == AF_UNIX) { + r.headers.emplace("Host", "localhost"); + } else { + r.headers.emplace( + "Host", detail::make_host_and_port_string(host_, port_, is_ssl())); + } + } + + if (!r.has_header("Accept")) { r.headers.emplace("Accept", "*/*"); } + + if (!r.content_receiver) { + if (!r.has_header("Accept-Encoding")) { + std::string accept_encoding; +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + accept_encoding = "br"; +#endif +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (!accept_encoding.empty()) { accept_encoding += ", "; } + accept_encoding += "gzip, deflate"; +#endif +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + if (!accept_encoding.empty()) { accept_encoding += ", "; } + accept_encoding += "zstd"; +#endif + r.set_header("Accept-Encoding", accept_encoding); + } + +#ifndef CPPHTTPLIB_NO_DEFAULT_USER_AGENT + if (!r.has_header("User-Agent")) { + auto agent = std::string("cpp-httplib/") + CPPHTTPLIB_VERSION; + r.set_header("User-Agent", agent); + } +#endif + } + + if (!r.body.empty()) { + if (!ct.empty() && !r.has_header("Content-Type")) { + r.headers.emplace("Content-Type", ct); + } + if (!r.has_header("Content-Length")) { + r.headers.emplace("Content-Length", std::to_string(r.body.size())); + } + } +} + +ClientImpl::StreamHandle +ClientImpl::open_stream(const std::string &method, const std::string &path, + const Params ¶ms, const Headers &headers, + const std::string &body, + const std::string &content_type) { + StreamHandle handle; + handle.response = detail::make_unique(); + handle.error = Error::Success; + + auto query_path = params.empty() ? path : append_query_params(path, params); + handle.connection_ = detail::make_unique(); + + { + std::lock_guard guard(socket_mutex_); + + auto is_alive = false; + if (socket_.is_open()) { + is_alive = detail::is_socket_alive(socket_.sock); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_alive && is_ssl()) { + if (detail::is_ssl_peer_could_be_closed(socket_.ssl, socket_.sock)) { + is_alive = false; + } + } +#endif + if (!is_alive) { + shutdown_ssl(socket_, false); + shutdown_socket(socket_); + close_socket(socket_); + } + } + + if (!is_alive) { + if (!ensure_socket_connection(socket_, handle.error)) { + handle.response.reset(); + return handle; + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_ssl()) { + auto &scli = static_cast(*this); + if (!proxy_host_.empty() && proxy_port_ != -1) { + if (!scli.initialize_ssl(socket_, handle.error)) { + handle.response.reset(); + return handle; + } + } + } +#endif + } + + transfer_socket_ownership_to_handle(handle); + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_ssl() && handle.connection_->ssl) { + handle.socket_stream_ = detail::make_unique( + handle.connection_->sock, handle.connection_->ssl, read_timeout_sec_, + read_timeout_usec_, write_timeout_sec_, write_timeout_usec_); + } else { + handle.socket_stream_ = detail::make_unique( + handle.connection_->sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_); + } +#else + handle.socket_stream_ = detail::make_unique( + handle.connection_->sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_); +#endif + handle.stream_ = handle.socket_stream_.get(); + + Request req; + req.method = method; + req.path = query_path; + req.headers = headers; + req.body = body; + + prepare_default_headers(req, true, content_type); + + auto &strm = *handle.stream_; + if (detail::write_request_line(strm, req.method, req.path) < 0) { + handle.error = Error::Write; + handle.response.reset(); + return handle; + } + + if (!detail::check_and_write_headers(strm, req.headers, header_writer_, + handle.error)) { + handle.response.reset(); + return handle; + } + + if (!body.empty()) { + if (strm.write(body.data(), body.size()) < 0) { + handle.error = Error::Write; + handle.response.reset(); + return handle; + } + } + + if (!read_response_line(strm, req, *handle.response) || + !detail::read_headers(strm, handle.response->headers)) { + handle.error = Error::Read; + handle.response.reset(); + return handle; + } + + handle.body_reader_.stream = handle.stream_; + + auto content_length_str = handle.response->get_header_value("Content-Length"); + if (!content_length_str.empty()) { + handle.body_reader_.content_length = + static_cast(std::stoull(content_length_str)); + } + + auto transfer_encoding = + handle.response->get_header_value("Transfer-Encoding"); + handle.body_reader_.chunked = (transfer_encoding == "chunked"); + + auto content_encoding = handle.response->get_header_value("Content-Encoding"); + if (!content_encoding.empty()) { + handle.decompressor_ = detail::create_decompressor(content_encoding); + } + + return handle; +} + +ssize_t ClientImpl::StreamHandle::read(char *buf, size_t len) { + if (!is_valid() || !response) { return -1; } + + if (decompressor_) { return read_with_decompression(buf, len); } + auto n = detail::read_body_content(stream_, body_reader_, buf, len); + + if (n <= 0 && body_reader_.chunked && !trailers_parsed_ && stream_) { + trailers_parsed_ = true; + if (body_reader_.chunked_decoder) { + if (!body_reader_.chunked_decoder->parse_trailers_into( + response->trailers, response->headers)) { + return n; + } + } else { + detail::ChunkedDecoder dec(*stream_); + if (!dec.parse_trailers_into(response->trailers, response->headers)) { + return n; + } + } + } + + return n; +} + +ssize_t ClientImpl::StreamHandle::read_with_decompression(char *buf, + size_t len) { + if (decompress_offset_ < decompress_buffer_.size()) { + auto available = decompress_buffer_.size() - decompress_offset_; + auto to_copy = (std::min)(len, available); + std::memcpy(buf, decompress_buffer_.data() + decompress_offset_, to_copy); + decompress_offset_ += to_copy; + return static_cast(to_copy); + } + + decompress_buffer_.clear(); + decompress_offset_ = 0; + + constexpr size_t kDecompressionBufferSize = 8192; + char compressed_buf[kDecompressionBufferSize]; + + while (true) { + auto n = detail::read_body_content(stream_, body_reader_, compressed_buf, + sizeof(compressed_buf)); + + if (n <= 0) { return n; } + + bool decompress_ok = + decompressor_->decompress(compressed_buf, static_cast(n), + [this](const char *data, size_t data_len) { + decompress_buffer_.append(data, data_len); + return true; + }); + + if (!decompress_ok) { + body_reader_.last_error = Error::Read; + return -1; + } + + if (!decompress_buffer_.empty()) { break; } + } + + auto to_copy = (std::min)(len, decompress_buffer_.size()); + std::memcpy(buf, decompress_buffer_.data(), to_copy); + decompress_offset_ = to_copy; + return static_cast(to_copy); +} + +void ClientImpl::StreamHandle::parse_trailers_if_needed() { + if (!response || !stream_ || !body_reader_.chunked || trailers_parsed_) { + return; + } + + trailers_parsed_ = true; + + const auto bufsiz = 128; + char line_buf[bufsiz]; + detail::stream_line_reader line_reader(*stream_, line_buf, bufsiz); + + if (!line_reader.getline()) { return; } + + if (!detail::parse_trailers(line_reader, response->trailers, + response->headers)) { + return; + } +} + +// Inline method implementations for `ChunkedDecoder`. +namespace detail { + +ChunkedDecoder::ChunkedDecoder(Stream &s) : strm(s) {} + +ssize_t ChunkedDecoder::read_payload(char *buf, size_t len, + size_t &out_chunk_offset, + size_t &out_chunk_total) { + if (finished) { return 0; } + + if (chunk_remaining == 0) { + stream_line_reader lr(strm, line_buf, sizeof(line_buf)); + if (!lr.getline()) { return -1; } + + char *endptr = nullptr; + unsigned long chunk_len = std::strtoul(lr.ptr(), &endptr, 16); + if (endptr == lr.ptr()) { return -1; } + if (chunk_len == ULONG_MAX) { return -1; } + + if (chunk_len == 0) { + chunk_remaining = 0; + finished = true; + out_chunk_offset = 0; + out_chunk_total = 0; + return 0; + } + + chunk_remaining = static_cast(chunk_len); + last_chunk_total = chunk_remaining; + last_chunk_offset = 0; + } + + auto to_read = (std::min)(chunk_remaining, len); + auto n = strm.read(buf, to_read); + if (n <= 0) { return -1; } + + auto offset_before = last_chunk_offset; + last_chunk_offset += static_cast(n); + chunk_remaining -= static_cast(n); + + out_chunk_offset = offset_before; + out_chunk_total = last_chunk_total; + + if (chunk_remaining == 0) { + stream_line_reader lr(strm, line_buf, sizeof(line_buf)); + if (!lr.getline()) { return -1; } + if (std::strcmp(lr.ptr(), "\r\n") != 0) { return -1; } + } + + return n; +} + +bool ChunkedDecoder::parse_trailers_into(Headers &dest, + const Headers &src_headers) { + stream_line_reader lr(strm, line_buf, sizeof(line_buf)); + if (!lr.getline()) { return false; } + return parse_trailers(lr, dest, src_headers); +} + +} // namespace detail + +void +ClientImpl::transfer_socket_ownership_to_handle(StreamHandle &handle) { + handle.connection_->sock = socket_.sock; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + handle.connection_->ssl = socket_.ssl; + socket_.ssl = nullptr; +#endif + socket_.sock = INVALID_SOCKET; +} + bool ClientImpl::handle_request(Stream &strm, Request &req, Response &res, bool close_connection, Error &error) { @@ -6227,9 +7184,11 @@ bool ClientImpl::handle_request(Stream &strm, Request &req, if (!is_ssl() && !proxy_host_.empty() && proxy_port_ != -1) { auto req2 = req; - req2.path = "http://" + host_and_port_ + req.path; + req2.path = "http://" + + detail::make_host_and_port_string(host_, port_, false) + + req.path; ret = process_request(strm, req2, res, close_connection, error); - req = req2; + req = std::move(req2); req.path = req_save.path; } else { ret = process_request(strm, req, res, close_connection, error); @@ -6253,7 +7212,7 @@ bool ClientImpl::handle_request(Stream &strm, Request &req, } if (300 < res.status && res.status < 400 && follow_location_) { - req = req_save; + req = std::move(req_save); ret = redirect(req, res, error); } @@ -6281,7 +7240,7 @@ bool ClientImpl::handle_request(Stream &strm, Request &req, Response new_res; ret = send(new_req, new_res, error); - if (ret) { res = new_res; } + if (ret) { res = std::move(new_res); } } } } @@ -6514,42 +7473,11 @@ bool ClientImpl::write_request(Stream &strm, Request &req, } } - if (!req.has_header("Host")) { - // For Unix socket connections, use "localhost" as Host header (similar to - // curl behavior) - if (address_family_ == AF_UNIX) { - req.set_header("Host", "localhost"); - } else { - req.set_header("Host", host_and_port_); - } + std::string ct_for_defaults; + if (!req.has_header("Content-Type") && !req.body.empty()) { + ct_for_defaults = "text/plain"; } - - if (!req.has_header("Accept")) { req.set_header("Accept", "*/*"); } - - if (!req.content_receiver) { - if (!req.has_header("Accept-Encoding")) { - std::string accept_encoding; -#ifdef CPPHTTPLIB_BROTLI_SUPPORT - accept_encoding = "br"; -#endif -#ifdef CPPHTTPLIB_ZLIB_SUPPORT - if (!accept_encoding.empty()) { accept_encoding += ", "; } - accept_encoding += "gzip, deflate"; -#endif -#ifdef CPPHTTPLIB_ZSTD_SUPPORT - if (!accept_encoding.empty()) { accept_encoding += ", "; } - accept_encoding += "zstd"; -#endif - req.set_header("Accept-Encoding", accept_encoding); - } - -#ifndef CPPHTTPLIB_NO_DEFAULT_USER_AGENT - if (!req.has_header("User-Agent")) { - auto agent = std::string("cpp-httplib/") + CPPHTTPLIB_VERSION; - req.set_header("User-Agent", agent); - } -#endif - }; + prepare_default_headers(req, false, ct_for_defaults); if (req.body.empty()) { if (req.content_provider_) { @@ -6565,15 +7493,6 @@ bool ClientImpl::write_request(Stream &strm, Request &req, req.set_header("Content-Length", "0"); } } - } else { - if (!req.has_header("Content-Type")) { - req.set_header("Content-Type", "text/plain"); - } - - if (!req.has_header("Content-Length")) { - auto length = std::to_string(req.body.size()); - req.set_header("Content-Length", length); - } } if (!basic_auth_password_.empty() || !basic_auth_username_.empty()) { @@ -6620,18 +7539,41 @@ bool ClientImpl::write_request(Stream &strm, Request &req, query_part = ""; } - // Encode path and query + // Encode path part. If the original `req.path` already contained a + // query component, preserve its raw query string (including parameter + // order) instead of reparsing and reassembling it which may reorder + // parameters due to container ordering (e.g. `Params` uses + // `std::multimap`). When there is no query in `req.path`, fall back to + // building a query from `req.params` so existing callers that pass + // `Params` continue to work. auto path_with_query = path_encode_ ? detail::encode_path(path_part) : path_part; - detail::parse_query_text(query_part, req.params); - if (!req.params.empty()) { - path_with_query = append_query_params(path_with_query, req.params); + if (!query_part.empty()) { + // Normalize the query string (decode then re-encode) while preserving + // the original parameter order. + auto normalized = detail::normalize_query_string(query_part); + if (!normalized.empty()) { path_with_query += '?' + normalized; } + + // Still populate req.params for handlers/users who read them. + detail::parse_query_text(query_part, req.params); + } else { + // No query in path; parse any query_part (empty) and append params + // from `req.params` when present (preserves prior behavior for + // callers who provide Params separately). + detail::parse_query_text(query_part, req.params); + if (!req.params.empty()) { + path_with_query = append_query_params(path_with_query, req.params); + } } // Write request line and headers detail::write_request_line(bstrm, req.method, path_with_query); - header_writer_(bstrm, req.headers); + if (!detail::check_and_write_headers(bstrm, req.headers, header_writer_, + error)) { + output_error_log(error, &req); + return false; + } // Flush buffer auto &data = bstrm.get_buffer(); @@ -8096,7 +9038,9 @@ bool SSLSocketStream::wait_writable() const { ssize_t SSLSocketStream::read(char *ptr, size_t size) { if (SSL_pending(ssl_) > 0) { - return SSL_read(ssl_, ptr, static_cast(size)); + auto ret = SSL_read(ssl_, ptr, static_cast(size)); + if (ret == 0) { error_ = Error::ConnectionClosed; } + return ret; } else if (wait_readable()) { auto ret = SSL_read(ssl_, ptr, static_cast(size)); if (ret < 0) { @@ -8121,9 +9065,12 @@ ssize_t SSLSocketStream::read(char *ptr, size_t size) { } } assert(ret < 0); + } else if (ret == 0) { + error_ = Error::ConnectionClosed; } return ret; } else { + error_ = Error::Timeout; return -1; } } @@ -8499,7 +9446,8 @@ bool SSLClient::connect_with_proxy( start_time, [&](Stream &strm) { Request req2; req2.method = "CONNECT"; - req2.path = host_and_port_; + req2.path = + detail::make_host_and_port_string_always_port(host_, port_); if (max_timeout_msec_ > 0) { req2.start_time_ = std::chrono::steady_clock::now(); } @@ -8526,7 +9474,7 @@ bool SSLClient::connect_with_proxy( close_socket(socket); // Create a new socket for the authenticated CONNECT request - if (!create_and_connect_socket(socket, error)) { + if (!ensure_socket_connection(socket, error)) { success = false; output_error_log(error, nullptr); return false; @@ -8539,7 +9487,8 @@ bool SSLClient::connect_with_proxy( start_time, [&](Stream &strm) { Request req3; req3.method = "CONNECT"; - req3.path = host_and_port_; + req3.path = detail::make_host_and_port_string_always_port( + host_, port_); req3.headers.insert(detail::make_digest_authentication_header( req3, auth, 1, detail::random_string(10), proxy_digest_auth_username_, proxy_digest_auth_password_, @@ -9424,6 +10373,13 @@ Result Client::Options(const std::string &path, const Headers &headers) { return cli_->Options(path, headers); } +ClientImpl::StreamHandle +Client::open_stream(const std::string &method, const std::string &path, + const Params ¶ms, const Headers &headers, + const std::string &body, const std::string &content_type) { + return cli_->open_stream(method, path, params, headers, body, content_type); +} + bool Client::send(Request &req, Response &res, Error &error) { return cli_->send(req, res, error); } diff --git a/vendor/cpp-httplib/httplib.h b/vendor/cpp-httplib/httplib.h index c9bd9fd86bf..43cdbc58326 100644 --- a/vendor/cpp-httplib/httplib.h +++ b/vendor/cpp-httplib/httplib.h @@ -1,15 +1,15 @@ // // httplib.h // -// Copyright (c) 2025 Yuji Hirose. All rights reserved. +// Copyright (c) 2026 Yuji Hirose. All rights reserved. // MIT License // #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.28.0" -#define CPPHTTPLIB_VERSION_NUM "0x001C00" +#define CPPHTTPLIB_VERSION "0.30.0" +#define CPPHTTPLIB_VERSION_NUM "0x001E00" /* * Platform compatibility check @@ -838,6 +838,50 @@ struct Response { std::string file_content_content_type_; }; +enum class Error { + Success = 0, + Unknown, + Connection, + BindIPAddress, + Read, + Write, + ExceedRedirectCount, + Canceled, + SSLConnection, + SSLLoadingCerts, + SSLServerVerification, + SSLServerHostnameVerification, + UnsupportedMultipartBoundaryChars, + Compression, + ConnectionTimeout, + ProxyConnection, + ConnectionClosed, + Timeout, + ResourceExhaustion, + TooManyFormDataFiles, + ExceedMaxPayloadSize, + ExceedUriMaxLength, + ExceedMaxSocketDescriptorCount, + InvalidRequestLine, + InvalidHTTPMethod, + InvalidHTTPVersion, + InvalidHeaders, + MultipartParsing, + OpenFile, + Listen, + GetSockName, + UnsupportedAddressFamily, + HTTPParsing, + InvalidRangeHeader, + + // For internal use only + SSLPeerCouldBeClosed_, +}; + +std::string to_string(Error error); + +std::ostream &operator<<(std::ostream &os, const Error &obj); + class Stream { public: virtual ~Stream() = default; @@ -856,6 +900,11 @@ class Stream { ssize_t write(const char *ptr); ssize_t write(const std::string &s); + + Error get_error() const { return error_; } + +protected: + Error error_ = Error::Success; }; class TaskQueue { @@ -873,6 +922,7 @@ class ThreadPool final : public TaskQueue { public: explicit ThreadPool(size_t n, size_t mqr = 0) : shutdown_(false), max_queued_requests_(mqr) { + threads_.reserve(n); while (n) { threads_.emplace_back(worker(*this)); n--; @@ -961,27 +1011,21 @@ using ErrorLogger = std::function; using SocketOptions = std::function; -namespace detail { - -bool set_socket_opt_impl(socket_t sock, int level, int optname, - const void *optval, socklen_t optlen); -bool set_socket_opt(socket_t sock, int level, int optname, int opt); -bool set_socket_opt_time(socket_t sock, int level, int optname, time_t sec, - time_t usec); - -} // namespace detail - void default_socket_options(socket_t sock); const char *status_message(int status); +std::string to_string(Error error); + +std::ostream &operator<<(std::ostream &os, const Error &obj); + std::string get_bearer_token_auth(const Request &req); namespace detail { class MatcherBase { public: - MatcherBase(std::string pattern) : pattern_(pattern) {} + MatcherBase(std::string pattern) : pattern_(std::move(pattern)) {} virtual ~MatcherBase() = default; const std::string &pattern() const { return pattern_; } @@ -1051,10 +1095,9 @@ class RegexMatcher final : public MatcherBase { std::regex regex_; }; -ssize_t write_headers(Stream &strm, const Headers &headers); +int close_socket(socket_t sock); -std::string make_host_and_port_string(const std::string &host, int port, - bool is_ssl); +ssize_t write_headers(Stream &strm, const Headers &headers); } // namespace detail @@ -1206,7 +1249,11 @@ class Server { bool listen_internal(); bool routing(Request &req, Response &res, Stream &strm); - bool handle_file_request(const Request &req, Response &res); + bool handle_file_request(Request &req, Response &res); + bool check_if_not_modified(const Request &req, Response &res, + const std::string &etag, time_t mtime) const; + bool check_if_range(Request &req, const std::string &etag, + time_t mtime) const; bool dispatch_request(Request &req, Response &res, const Handlers &handlers) const; bool dispatch_request_for_content_reader( @@ -1290,48 +1337,6 @@ class Server { detail::write_headers; }; -enum class Error { - Success = 0, - Unknown, - Connection, - BindIPAddress, - Read, - Write, - ExceedRedirectCount, - Canceled, - SSLConnection, - SSLLoadingCerts, - SSLServerVerification, - SSLServerHostnameVerification, - UnsupportedMultipartBoundaryChars, - Compression, - ConnectionTimeout, - ProxyConnection, - ResourceExhaustion, - TooManyFormDataFiles, - ExceedMaxPayloadSize, - ExceedUriMaxLength, - ExceedMaxSocketDescriptorCount, - InvalidRequestLine, - InvalidHTTPMethod, - InvalidHTTPVersion, - InvalidHeaders, - MultipartParsing, - OpenFile, - Listen, - GetSockName, - UnsupportedAddressFamily, - HTTPParsing, - InvalidRangeHeader, - - // For internal use only - SSLPeerCouldBeClosed_, -}; - -std::string to_string(Error error); - -std::ostream &operator<<(std::ostream &os, const Error &obj); - class Result { public: Result() = default; @@ -1390,6 +1395,87 @@ class Result { #endif }; +struct ClientConnection { + socket_t sock = INVALID_SOCKET; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSL *ssl = nullptr; +#endif + + bool is_open() const { return sock != INVALID_SOCKET; } + + ClientConnection() = default; + + ~ClientConnection() { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (ssl) { + SSL_free(ssl); + ssl = nullptr; + } +#endif + if (sock != INVALID_SOCKET) { + detail::close_socket(sock); + sock = INVALID_SOCKET; + } + } + + ClientConnection(const ClientConnection &) = delete; + ClientConnection &operator=(const ClientConnection &) = delete; + + ClientConnection(ClientConnection &&other) noexcept + : sock(other.sock) +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + , + ssl(other.ssl) +#endif + { + other.sock = INVALID_SOCKET; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + other.ssl = nullptr; +#endif + } + + ClientConnection &operator=(ClientConnection &&other) noexcept { + if (this != &other) { + sock = other.sock; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + ssl = other.ssl; +#endif + other.sock = INVALID_SOCKET; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + other.ssl = nullptr; +#endif + } + return *this; + } +}; + +namespace detail { + +struct ChunkedDecoder; + +struct BodyReader { + Stream *stream = nullptr; + size_t content_length = 0; + size_t bytes_read = 0; + bool chunked = false; + bool eof = false; + std::unique_ptr chunked_decoder; + Error last_error = Error::Success; + + ssize_t read(char *buf, size_t len); + bool has_error() const { return last_error != Error::Success; } +}; + +inline ssize_t read_body_content(Stream *stream, BodyReader &br, char *buf, + size_t len) { + (void)stream; + return br.read(buf, len); +} + +class decompressor; + +} // namespace detail + class ClientImpl { public: explicit ClientImpl(const std::string &host); @@ -1404,6 +1490,43 @@ class ClientImpl { virtual bool is_valid() const; + struct StreamHandle { + std::unique_ptr response; + Error error = Error::Success; + + StreamHandle() = default; + StreamHandle(const StreamHandle &) = delete; + StreamHandle &operator=(const StreamHandle &) = delete; + StreamHandle(StreamHandle &&) = default; + StreamHandle &operator=(StreamHandle &&) = default; + ~StreamHandle() = default; + + bool is_valid() const { + return response != nullptr && error == Error::Success; + } + + ssize_t read(char *buf, size_t len); + void parse_trailers_if_needed(); + Error get_read_error() const { return body_reader_.last_error; } + bool has_read_error() const { return body_reader_.has_error(); } + + bool trailers_parsed_ = false; + + private: + friend class ClientImpl; + + ssize_t read_with_decompression(char *buf, size_t len); + + std::unique_ptr connection_; + std::unique_ptr socket_stream_; + Stream *stream_ = nullptr; + detail::BodyReader body_reader_; + + std::unique_ptr decompressor_; + std::string decompress_buffer_; + size_t decompress_offset_ = 0; + }; + // clang-format off Result Get(const std::string &path, DownloadProgress progress = nullptr); Result Get(const std::string &path, ContentReceiver content_receiver, DownloadProgress progress = nullptr); @@ -1497,6 +1620,15 @@ class ClientImpl { Result Options(const std::string &path, const Headers &headers); // clang-format on + // Streaming API: Open a stream for reading response body incrementally + // Socket ownership is transferred to StreamHandle for true streaming + // Supports all HTTP methods (GET, POST, PUT, PATCH, DELETE, etc.) + StreamHandle open_stream(const std::string &method, const std::string &path, + const Params ¶ms = {}, + const Headers &headers = {}, + const std::string &body = {}, + const std::string &content_type = {}); + bool send(Request &req, Response &res, Error &error); Result send(const Request &req); @@ -1592,6 +1724,7 @@ class ClientImpl { }; virtual bool create_and_connect_socket(Socket &socket, Error &error); + virtual bool ensure_socket_connection(Socket &socket, Error &error); // All of: // shutdown_ssl @@ -1618,7 +1751,6 @@ class ClientImpl { // Socket endpoint information const std::string host_; const int port_; - const std::string host_and_port_; // Current open socket Socket socket_; @@ -1717,6 +1849,8 @@ class ClientImpl { Response &res) const; bool write_request(Stream &strm, Request &req, bool close_connection, Error &error); + void prepare_default_headers(Request &r, bool for_stream, + const std::string &ct); bool redirect(Request &req, Response &res, Error &error); bool create_redirect_client(const std::string &scheme, const std::string &host, int port, Request &req, @@ -1747,6 +1881,8 @@ class ClientImpl { std::chrono::time_point start_time, std::function callback); virtual bool is_ssl() const; + + void transfer_socket_ownership_to_handle(StreamHandle &handle); }; class Client { @@ -1865,6 +2001,16 @@ class Client { Result Options(const std::string &path, const Headers &headers); // clang-format on + // Streaming API: Open a stream for reading response body incrementally + // Socket ownership is transferred to StreamHandle for true streaming + // Supports all HTTP methods (GET, POST, PUT, PATCH, DELETE, etc.) + ClientImpl::StreamHandle open_stream(const std::string &method, + const std::string &path, + const Params ¶ms = {}, + const Headers &headers = {}, + const std::string &body = {}, + const std::string &content_type = {}); + bool send(Request &req, Response &res, Error &error); Result send(const Request &req); @@ -2027,6 +2173,7 @@ class SSLClient final : public ClientImpl { private: bool create_and_connect_socket(Socket &socket, Error &error) override; + bool ensure_socket_connection(Socket &socket, Error &error) override; void shutdown_ssl(Socket &socket, bool shutdown_gracefully) override; void shutdown_ssl_impl(Socket &socket, bool shutdown_gracefully); @@ -2163,82 +2310,6 @@ inline void default_socket_options(socket_t sock) { 1); } -inline const char *status_message(int status) { - switch (status) { - case StatusCode::Continue_100: return "Continue"; - case StatusCode::SwitchingProtocol_101: return "Switching Protocol"; - case StatusCode::Processing_102: return "Processing"; - case StatusCode::EarlyHints_103: return "Early Hints"; - case StatusCode::OK_200: return "OK"; - case StatusCode::Created_201: return "Created"; - case StatusCode::Accepted_202: return "Accepted"; - case StatusCode::NonAuthoritativeInformation_203: - return "Non-Authoritative Information"; - case StatusCode::NoContent_204: return "No Content"; - case StatusCode::ResetContent_205: return "Reset Content"; - case StatusCode::PartialContent_206: return "Partial Content"; - case StatusCode::MultiStatus_207: return "Multi-Status"; - case StatusCode::AlreadyReported_208: return "Already Reported"; - case StatusCode::IMUsed_226: return "IM Used"; - case StatusCode::MultipleChoices_300: return "Multiple Choices"; - case StatusCode::MovedPermanently_301: return "Moved Permanently"; - case StatusCode::Found_302: return "Found"; - case StatusCode::SeeOther_303: return "See Other"; - case StatusCode::NotModified_304: return "Not Modified"; - case StatusCode::UseProxy_305: return "Use Proxy"; - case StatusCode::unused_306: return "unused"; - case StatusCode::TemporaryRedirect_307: return "Temporary Redirect"; - case StatusCode::PermanentRedirect_308: return "Permanent Redirect"; - case StatusCode::BadRequest_400: return "Bad Request"; - case StatusCode::Unauthorized_401: return "Unauthorized"; - case StatusCode::PaymentRequired_402: return "Payment Required"; - case StatusCode::Forbidden_403: return "Forbidden"; - case StatusCode::NotFound_404: return "Not Found"; - case StatusCode::MethodNotAllowed_405: return "Method Not Allowed"; - case StatusCode::NotAcceptable_406: return "Not Acceptable"; - case StatusCode::ProxyAuthenticationRequired_407: - return "Proxy Authentication Required"; - case StatusCode::RequestTimeout_408: return "Request Timeout"; - case StatusCode::Conflict_409: return "Conflict"; - case StatusCode::Gone_410: return "Gone"; - case StatusCode::LengthRequired_411: return "Length Required"; - case StatusCode::PreconditionFailed_412: return "Precondition Failed"; - case StatusCode::PayloadTooLarge_413: return "Payload Too Large"; - case StatusCode::UriTooLong_414: return "URI Too Long"; - case StatusCode::UnsupportedMediaType_415: return "Unsupported Media Type"; - case StatusCode::RangeNotSatisfiable_416: return "Range Not Satisfiable"; - case StatusCode::ExpectationFailed_417: return "Expectation Failed"; - case StatusCode::ImATeapot_418: return "I'm a teapot"; - case StatusCode::MisdirectedRequest_421: return "Misdirected Request"; - case StatusCode::UnprocessableContent_422: return "Unprocessable Content"; - case StatusCode::Locked_423: return "Locked"; - case StatusCode::FailedDependency_424: return "Failed Dependency"; - case StatusCode::TooEarly_425: return "Too Early"; - case StatusCode::UpgradeRequired_426: return "Upgrade Required"; - case StatusCode::PreconditionRequired_428: return "Precondition Required"; - case StatusCode::TooManyRequests_429: return "Too Many Requests"; - case StatusCode::RequestHeaderFieldsTooLarge_431: - return "Request Header Fields Too Large"; - case StatusCode::UnavailableForLegalReasons_451: - return "Unavailable For Legal Reasons"; - case StatusCode::NotImplemented_501: return "Not Implemented"; - case StatusCode::BadGateway_502: return "Bad Gateway"; - case StatusCode::ServiceUnavailable_503: return "Service Unavailable"; - case StatusCode::GatewayTimeout_504: return "Gateway Timeout"; - case StatusCode::HttpVersionNotSupported_505: - return "HTTP Version Not Supported"; - case StatusCode::VariantAlsoNegotiates_506: return "Variant Also Negotiates"; - case StatusCode::InsufficientStorage_507: return "Insufficient Storage"; - case StatusCode::LoopDetected_508: return "Loop Detected"; - case StatusCode::NotExtended_510: return "Not Extended"; - case StatusCode::NetworkAuthenticationRequired_511: - return "Network Authentication Required"; - - default: - case StatusCode::InternalServerError_500: return "Internal Server Error"; - } -} - inline std::string get_bearer_token_auth(const Request &req) { if (req.has_header("Authorization")) { constexpr auto bearer_header_prefix_len = detail::str_len("Bearer "); @@ -2272,55 +2343,6 @@ Server::set_idle_interval(const std::chrono::duration &duration) { return *this; } -inline std::string to_string(const Error error) { - switch (error) { - case Error::Success: return "Success (no error)"; - case Error::Unknown: return "Unknown"; - case Error::Connection: return "Could not establish connection"; - case Error::BindIPAddress: return "Failed to bind IP address"; - case Error::Read: return "Failed to read connection"; - case Error::Write: return "Failed to write connection"; - case Error::ExceedRedirectCount: return "Maximum redirect count exceeded"; - case Error::Canceled: return "Connection handling canceled"; - case Error::SSLConnection: return "SSL connection failed"; - case Error::SSLLoadingCerts: return "SSL certificate loading failed"; - case Error::SSLServerVerification: return "SSL server verification failed"; - case Error::SSLServerHostnameVerification: - return "SSL server hostname verification failed"; - case Error::UnsupportedMultipartBoundaryChars: - return "Unsupported HTTP multipart boundary characters"; - case Error::Compression: return "Compression failed"; - case Error::ConnectionTimeout: return "Connection timed out"; - case Error::ProxyConnection: return "Proxy connection failed"; - case Error::ResourceExhaustion: return "Resource exhaustion"; - case Error::TooManyFormDataFiles: return "Too many form data files"; - case Error::ExceedMaxPayloadSize: return "Exceeded maximum payload size"; - case Error::ExceedUriMaxLength: return "Exceeded maximum URI length"; - case Error::ExceedMaxSocketDescriptorCount: - return "Exceeded maximum socket descriptor count"; - case Error::InvalidRequestLine: return "Invalid request line"; - case Error::InvalidHTTPMethod: return "Invalid HTTP method"; - case Error::InvalidHTTPVersion: return "Invalid HTTP version"; - case Error::InvalidHeaders: return "Invalid headers"; - case Error::MultipartParsing: return "Multipart parsing failed"; - case Error::OpenFile: return "Failed to open file"; - case Error::Listen: return "Failed to listen on socket"; - case Error::GetSockName: return "Failed to get socket name"; - case Error::UnsupportedAddressFamily: return "Unsupported address family"; - case Error::HTTPParsing: return "HTTP parsing failed"; - case Error::InvalidRangeHeader: return "Invalid Range header"; - default: break; - } - - return "Invalid"; -} - -inline std::ostream &operator<<(std::ostream &os, const Error &obj) { - os << to_string(obj); - os << " (" << static_cast::type>(obj) << ')'; - return os; -} - inline size_t Result::get_request_header_value_u64(const std::string &key, size_t def, size_t id) const { @@ -2439,6 +2461,8 @@ struct FileStat { FileStat(const std::string &path); bool is_file() const; bool is_dir() const; + time_t mtime() const; + size_t size() const; private: #if defined(_WIN32) @@ -2449,6 +2473,9 @@ struct FileStat { int ret_ = -1; }; +std::string make_host_and_port_string(const std::string &host, int port, + bool is_ssl); + std::string trim_copy(const std::string &s); void divide( @@ -2669,6 +2696,25 @@ class stream_line_reader { std::string growable_buffer_; }; +bool parse_trailers(stream_line_reader &line_reader, Headers &dest, + const Headers &src_headers); + +struct ChunkedDecoder { + Stream &strm; + size_t chunk_remaining = 0; + bool finished = false; + char line_buf[64]; + size_t last_chunk_total = 0; + size_t last_chunk_offset = 0; + + explicit ChunkedDecoder(Stream &s); + + ssize_t read_payload(char *buf, size_t len, size_t &out_chunk_offset, + size_t &out_chunk_total); + + bool parse_trailers_into(Headers &dest, const Headers &src_headers); +}; + class mmap { public: mmap(const char *path); @@ -2696,58 +2742,668 @@ class mmap { // NOTE: https://www.rfc-editor.org/rfc/rfc9110#section-5 namespace fields { -inline bool is_token_char(char c) { - return std::isalnum(c) || c == '!' || c == '#' || c == '$' || c == '%' || - c == '&' || c == '\'' || c == '*' || c == '+' || c == '-' || - c == '.' || c == '^' || c == '_' || c == '`' || c == '|' || c == '~'; -} +bool is_token_char(char c); +bool is_token(const std::string &s); +bool is_field_name(const std::string &s); +bool is_vchar(char c); +bool is_obs_text(char c); +bool is_field_vchar(char c); +bool is_field_content(const std::string &s); +bool is_field_value(const std::string &s); + +} // namespace fields + +} // namespace detail -inline bool is_token(const std::string &s) { - if (s.empty()) { return false; } - for (auto c : s) { - if (!is_token_char(c)) { return false; } +namespace stream { + +class Result { +public: + Result() : chunk_size_(8192) {} + + explicit Result(ClientImpl::StreamHandle &&handle, size_t chunk_size = 8192) + : handle_(std::move(handle)), chunk_size_(chunk_size) {} + + Result(Result &&other) noexcept + : handle_(std::move(other.handle_)), buffer_(std::move(other.buffer_)), + current_size_(other.current_size_), chunk_size_(other.chunk_size_), + finished_(other.finished_) { + other.current_size_ = 0; + other.finished_ = true; + } + + Result &operator=(Result &&other) noexcept { + if (this != &other) { + handle_ = std::move(other.handle_); + buffer_ = std::move(other.buffer_); + current_size_ = other.current_size_; + chunk_size_ = other.chunk_size_; + finished_ = other.finished_; + other.current_size_ = 0; + other.finished_ = true; + } + return *this; + } + + Result(const Result &) = delete; + Result &operator=(const Result &) = delete; + + // Check if the result is valid (connection succeeded and response received) + bool is_valid() const { return handle_.is_valid(); } + explicit operator bool() const { return is_valid(); } + + // Response status code + int status() const { + return handle_.response ? handle_.response->status : -1; + } + + // Response headers + const Headers &headers() const { + static const Headers empty_headers; + return handle_.response ? handle_.response->headers : empty_headers; + } + + std::string get_header_value(const std::string &key, + const char *def = "") const { + return handle_.response ? handle_.response->get_header_value(key, def) + : def; + } + + bool has_header(const std::string &key) const { + return handle_.response ? handle_.response->has_header(key) : false; + } + + // Error information + Error error() const { return handle_.error; } + Error read_error() const { return handle_.get_read_error(); } + bool has_read_error() const { return handle_.has_read_error(); } + + // Streaming iteration API + // Call next() to read the next chunk, then access data via data()/size() + // Returns true if data was read, false when stream is exhausted + bool next() { + if (!handle_.is_valid() || finished_) { return false; } + + if (buffer_.size() < chunk_size_) { buffer_.resize(chunk_size_); } + + ssize_t n = handle_.read(&buffer_[0], chunk_size_); + if (n > 0) { + current_size_ = static_cast(n); + return true; + } + + current_size_ = 0; + finished_ = true; + return false; } - return true; + + // Pointer to current chunk data (valid after next() returns true) + const char *data() const { return buffer_.data(); } + + // Size of current chunk (valid after next() returns true) + size_t size() const { return current_size_; } + + // Convenience method: read all remaining data into a string + std::string read_all() { + std::string result; + while (next()) { + result.append(data(), size()); + } + return result; + } + +private: + ClientImpl::StreamHandle handle_; + std::string buffer_; + size_t current_size_ = 0; + size_t chunk_size_; + bool finished_ = false; +}; + +// GET +template +inline Result Get(ClientType &cli, const std::string &path, + size_t chunk_size = 8192) { + return Result{cli.open_stream("GET", path), chunk_size}; } -inline bool is_field_name(const std::string &s) { return is_token(s); } +template +inline Result Get(ClientType &cli, const std::string &path, + const Headers &headers, size_t chunk_size = 8192) { + return Result{cli.open_stream("GET", path, {}, headers), chunk_size}; +} -inline bool is_vchar(char c) { return c >= 33 && c <= 126; } +template +inline Result Get(ClientType &cli, const std::string &path, + const Params ¶ms, size_t chunk_size = 8192) { + return Result{cli.open_stream("GET", path, params), chunk_size}; +} -inline bool is_obs_text(char c) { return 128 <= static_cast(c); } +template +inline Result Get(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + size_t chunk_size = 8192) { + return Result{cli.open_stream("GET", path, params, headers), chunk_size}; +} -inline bool is_field_vchar(char c) { return is_vchar(c) || is_obs_text(c); } +// POST +template +inline Result Post(ClientType &cli, const std::string &path, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{cli.open_stream("POST", path, {}, {}, body, content_type), + chunk_size}; +} -inline bool is_field_content(const std::string &s) { - if (s.empty()) { return true; } +template +inline Result Post(ClientType &cli, const std::string &path, + const Headers &headers, const std::string &body, + const std::string &content_type, size_t chunk_size = 8192) { + return Result{cli.open_stream("POST", path, {}, headers, body, content_type), + chunk_size}; +} - if (s.size() == 1) { - return is_field_vchar(s[0]); - } else if (s.size() == 2) { - return is_field_vchar(s[0]) && is_field_vchar(s[1]); - } else { - size_t i = 0; +template +inline Result Post(ClientType &cli, const std::string &path, + const Params ¶ms, const std::string &body, + const std::string &content_type, size_t chunk_size = 8192) { + return Result{cli.open_stream("POST", path, params, {}, body, content_type), + chunk_size}; +} - if (!is_field_vchar(s[i])) { return false; } - i++; +template +inline Result Post(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{ + cli.open_stream("POST", path, params, headers, body, content_type), + chunk_size}; +} - while (i < s.size() - 1) { - auto c = s[i++]; - if (c == ' ' || c == '\t' || is_field_vchar(c)) { - } else { - return false; +// PUT +template +inline Result Put(ClientType &cli, const std::string &path, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{cli.open_stream("PUT", path, {}, {}, body, content_type), + chunk_size}; +} + +template +inline Result Put(ClientType &cli, const std::string &path, + const Headers &headers, const std::string &body, + const std::string &content_type, size_t chunk_size = 8192) { + return Result{cli.open_stream("PUT", path, {}, headers, body, content_type), + chunk_size}; +} + +template +inline Result Put(ClientType &cli, const std::string &path, + const Params ¶ms, const std::string &body, + const std::string &content_type, size_t chunk_size = 8192) { + return Result{cli.open_stream("PUT", path, params, {}, body, content_type), + chunk_size}; +} + +template +inline Result Put(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{ + cli.open_stream("PUT", path, params, headers, body, content_type), + chunk_size}; +} + +// PATCH +template +inline Result Patch(ClientType &cli, const std::string &path, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{cli.open_stream("PATCH", path, {}, {}, body, content_type), + chunk_size}; +} + +template +inline Result Patch(ClientType &cli, const std::string &path, + const Headers &headers, const std::string &body, + const std::string &content_type, size_t chunk_size = 8192) { + return Result{cli.open_stream("PATCH", path, {}, headers, body, content_type), + chunk_size}; +} + +template +inline Result Patch(ClientType &cli, const std::string &path, + const Params ¶ms, const std::string &body, + const std::string &content_type, size_t chunk_size = 8192) { + return Result{cli.open_stream("PATCH", path, params, {}, body, content_type), + chunk_size}; +} + +template +inline Result Patch(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{ + cli.open_stream("PATCH", path, params, headers, body, content_type), + chunk_size}; +} + +// DELETE +template +inline Result Delete(ClientType &cli, const std::string &path, + size_t chunk_size = 8192) { + return Result{cli.open_stream("DELETE", path), chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const Headers &headers, size_t chunk_size = 8192) { + return Result{cli.open_stream("DELETE", path, {}, headers), chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{cli.open_stream("DELETE", path, {}, {}, body, content_type), + chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const Headers &headers, const std::string &body, + const std::string &content_type, + size_t chunk_size = 8192) { + return Result{ + cli.open_stream("DELETE", path, {}, headers, body, content_type), + chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const Params ¶ms, size_t chunk_size = 8192) { + return Result{cli.open_stream("DELETE", path, params), chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + size_t chunk_size = 8192) { + return Result{cli.open_stream("DELETE", path, params, headers), chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const Params ¶ms, const std::string &body, + const std::string &content_type, + size_t chunk_size = 8192) { + return Result{cli.open_stream("DELETE", path, params, {}, body, content_type), + chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{ + cli.open_stream("DELETE", path, params, headers, body, content_type), + chunk_size}; +} + +// HEAD +template +inline Result Head(ClientType &cli, const std::string &path, + size_t chunk_size = 8192) { + return Result{cli.open_stream("HEAD", path), chunk_size}; +} + +template +inline Result Head(ClientType &cli, const std::string &path, + const Headers &headers, size_t chunk_size = 8192) { + return Result{cli.open_stream("HEAD", path, {}, headers), chunk_size}; +} + +template +inline Result Head(ClientType &cli, const std::string &path, + const Params ¶ms, size_t chunk_size = 8192) { + return Result{cli.open_stream("HEAD", path, params), chunk_size}; +} + +template +inline Result Head(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + size_t chunk_size = 8192) { + return Result{cli.open_stream("HEAD", path, params, headers), chunk_size}; +} + +// OPTIONS +template +inline Result Options(ClientType &cli, const std::string &path, + size_t chunk_size = 8192) { + return Result{cli.open_stream("OPTIONS", path), chunk_size}; +} + +template +inline Result Options(ClientType &cli, const std::string &path, + const Headers &headers, size_t chunk_size = 8192) { + return Result{cli.open_stream("OPTIONS", path, {}, headers), chunk_size}; +} + +template +inline Result Options(ClientType &cli, const std::string &path, + const Params ¶ms, size_t chunk_size = 8192) { + return Result{cli.open_stream("OPTIONS", path, params), chunk_size}; +} + +template +inline Result Options(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + size_t chunk_size = 8192) { + return Result{cli.open_stream("OPTIONS", path, params, headers), chunk_size}; +} + +} // namespace stream + +namespace sse { + +struct SSEMessage { + std::string event; // Event type (default: "message") + std::string data; // Event payload + std::string id; // Event ID for Last-Event-ID header + + SSEMessage() : event("message") {} + + void clear() { + event = "message"; + data.clear(); + id.clear(); + } +}; + +class SSEClient { +public: + using MessageHandler = std::function; + using ErrorHandler = std::function; + using OpenHandler = std::function; + + SSEClient(Client &client, const std::string &path) + : client_(client), path_(path) {} + + SSEClient(Client &client, const std::string &path, const Headers &headers) + : client_(client), path_(path), headers_(headers) {} + + ~SSEClient() { stop(); } + + SSEClient(const SSEClient &) = delete; + SSEClient &operator=(const SSEClient &) = delete; + + // Event handlers + SSEClient &on_message(MessageHandler handler) { + on_message_ = std::move(handler); + return *this; + } + + SSEClient &on_event(const std::string &type, MessageHandler handler) { + event_handlers_[type] = std::move(handler); + return *this; + } + + SSEClient &on_open(OpenHandler handler) { + on_open_ = std::move(handler); + return *this; + } + + SSEClient &on_error(ErrorHandler handler) { + on_error_ = std::move(handler); + return *this; + } + + SSEClient &set_reconnect_interval(int ms) { + reconnect_interval_ms_ = ms; + return *this; + } + + SSEClient &set_max_reconnect_attempts(int n) { + max_reconnect_attempts_ = n; + return *this; + } + + // State accessors + bool is_connected() const { return connected_.load(); } + const std::string &last_event_id() const { return last_event_id_; } + + // Blocking start - runs event loop with auto-reconnect + void start() { + running_.store(true); + run_event_loop(); + } + + // Non-blocking start - runs in background thread + void start_async() { + running_.store(true); + async_thread_ = std::thread([this]() { run_event_loop(); }); + } + + // Stop the client (thread-safe) + void stop() { + running_.store(false); + client_.stop(); // Cancel any pending operations + if (async_thread_.joinable()) { async_thread_.join(); } + } + +private: + // Parse a single SSE field line + // Returns true if this line ends an event (blank line) + bool parse_sse_line(const std::string &line, SSEMessage &msg, int &retry_ms) { + // Blank line signals end of event + if (line.empty() || line == "\r") { return true; } + + // Lines starting with ':' are comments (ignored) + if (!line.empty() && line[0] == ':') { return false; } + + // Find the colon separator + auto colon_pos = line.find(':'); + if (colon_pos == std::string::npos) { + // Line with no colon is treated as field name with empty value + return false; + } + + auto field = line.substr(0, colon_pos); + std::string value; + + // Value starts after colon, skip optional single space + if (colon_pos + 1 < line.size()) { + auto value_start = colon_pos + 1; + if (line[value_start] == ' ') { value_start++; } + value = line.substr(value_start); + // Remove trailing \r if present + if (!value.empty() && value.back() == '\r') { value.pop_back(); } + } + + // Handle known fields + if (field == "event") { + msg.event = value; + } else if (field == "data") { + // Multiple data lines are concatenated with newlines + if (!msg.data.empty()) { msg.data += "\n"; } + msg.data += value; + } else if (field == "id") { + // Empty id is valid (clears the last event ID) + msg.id = value; + } else if (field == "retry") { + // Parse retry interval in milliseconds + try { + retry_ms = std::stoi(value); + } catch (...) { + // Invalid retry value, ignore } } + // Unknown fields are ignored per SSE spec - return is_field_vchar(s[i]); + return false; } -} -inline bool is_field_value(const std::string &s) { return is_field_content(s); } + // Main event loop with auto-reconnect + void run_event_loop() { + auto reconnect_count = 0; -} // namespace fields + while (running_.load()) { + // Build headers, including Last-Event-ID if we have one + auto request_headers = headers_; + if (!last_event_id_.empty()) { + request_headers.emplace("Last-Event-ID", last_event_id_); + } -} // namespace detail + // Open streaming connection + auto result = stream::Get(client_, path_, request_headers); + + // Connection error handling + if (!result) { + connected_.store(false); + if (on_error_) { on_error_(result.error()); } + + if (!should_reconnect(reconnect_count)) { break; } + wait_for_reconnect(); + reconnect_count++; + continue; + } + + if (result.status() != 200) { + connected_.store(false); + // For certain errors, don't reconnect + if (result.status() == 204 || // No Content - server wants us to stop + result.status() == 404 || // Not Found + result.status() == 401 || // Unauthorized + result.status() == 403) { // Forbidden + if (on_error_) { on_error_(Error::Connection); } + break; + } + + if (on_error_) { on_error_(Error::Connection); } + + if (!should_reconnect(reconnect_count)) { break; } + wait_for_reconnect(); + reconnect_count++; + continue; + } + + // Connection successful + connected_.store(true); + reconnect_count = 0; + if (on_open_) { on_open_(); } + + // Event receiving loop + std::string buffer; + SSEMessage current_msg; + + while (running_.load() && result.next()) { + buffer.append(result.data(), result.size()); + + // Process complete lines in the buffer + size_t line_start = 0; + size_t newline_pos; + + while ((newline_pos = buffer.find('\n', line_start)) != + std::string::npos) { + auto line = buffer.substr(line_start, newline_pos - line_start); + line_start = newline_pos + 1; + + // Parse the line and check if event is complete + auto event_complete = + parse_sse_line(line, current_msg, reconnect_interval_ms_); + + if (event_complete && !current_msg.data.empty()) { + // Update last_event_id for reconnection + if (!current_msg.id.empty()) { last_event_id_ = current_msg.id; } + + // Dispatch event to appropriate handler + dispatch_event(current_msg); + + current_msg.clear(); + } + } + + // Keep unprocessed data in buffer + buffer.erase(0, line_start); + } + + // Connection ended + connected_.store(false); + + if (!running_.load()) { break; } + + // Check for read errors + if (result.has_read_error()) { + if (on_error_) { on_error_(result.read_error()); } + } + + if (!should_reconnect(reconnect_count)) { break; } + wait_for_reconnect(); + reconnect_count++; + } + + connected_.store(false); + } + + // Dispatch event to appropriate handler + void dispatch_event(const SSEMessage &msg) { + // Check for specific event type handler first + auto it = event_handlers_.find(msg.event); + if (it != event_handlers_.end()) { + it->second(msg); + return; + } + + // Fall back to generic message handler + if (on_message_) { on_message_(msg); } + } + + // Check if we should attempt to reconnect + bool should_reconnect(int count) const { + if (!running_.load()) { return false; } + if (max_reconnect_attempts_ == 0) { return true; } // unlimited + return count < max_reconnect_attempts_; + } + + // Wait for reconnect interval + void wait_for_reconnect() { + // Use small increments to check running_ flag frequently + auto waited = 0; + while (running_.load() && waited < reconnect_interval_ms_) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + waited += 100; + } + } + + // Client and path + Client &client_; + std::string path_; + Headers headers_; + + // Callbacks + MessageHandler on_message_; + std::map event_handlers_; + OpenHandler on_open_; + ErrorHandler on_error_; + + // Configuration + int reconnect_interval_ms_ = 3000; + int max_reconnect_attempts_ = 0; // 0 = unlimited + + // State + std::atomic running_{false}; + std::atomic connected_{false}; + std::string last_event_id_; + + // Async support + std::thread async_thread_; +}; + +} // namespace sse