Skip to content

Comments

[ML] Add per allocation and per deployment memory metadata fields to …#6

Open
MitchLewis930 wants to merge 1 commit intopr_016_beforefrom
pr_016_after
Open

[ML] Add per allocation and per deployment memory metadata fields to …#6
MitchLewis930 wants to merge 1 commit intopr_016_beforefrom
pr_016_after

Conversation

@MitchLewis930
Copy link

@MitchLewis930 MitchLewis930 commented Jan 29, 2026

User description

PR_016


PR Type

Enhancement


Description

  • Add per-deployment and per-allocation memory metadata fields to trained models

  • Extend memory usage estimation logic to use new metadata fields when available

  • Maintain backward compatibility with models lacking memory metadata

  • Update all task parameter constructors and serialization to include new fields


Diagram Walkthrough

flowchart LR
  A["TrainedModelConfig"] -->|stores metadata| B["per_deployment_memory_bytes<br/>per_allocation_memory_bytes"]
  B -->|passed to| C["TaskParams"]
  C -->|used in| D["estimateMemoryUsageBytes"]
  D -->|calculates| E["Required Native Memory"]
  F["TransportVersion V_8_500_064"] -->|enables| B
Loading

File Walkthrough

Relevant files
Configuration changes
1 files
TransportVersion.java
Register new transport version for memory metadata             
+3/-1     
Enhancement
7 files
StartTrainedModelDeploymentAction.java
Add memory metadata fields to TaskParams class                     
+92/-12 
TrainedModelConfig.java
Add memory metadata fields and version constant                   
+19/-0   
TrainedModelAssignment.java
Update builder to pass memory metadata fields                       
+3/-1     
TransportGetTrainedModelsStatsAction.java
Use memory metadata in stats calculation                                 
+22/-3   
TransportStartTrainedModelDeploymentAction.java
Extract and pass memory metadata from model config             
+10/-1   
TrainedModelAssignmentNodeService.java
Pass memory metadata when updating task params                     
+3/-1     
TrainedModelDeploymentTask.java
Include memory metadata in allocation updates                       
+3/-1     
Tests
15 files
StartTrainedModelDeploymentTaskParamsTests.java
Update test to include memory metadata parameters               
+3/-1     
TrainedModelAssignmentTests.java
Update test to include memory metadata parameters               
+3/-1     
PyTorchModelIT.java
Add integration test for memory estimation with metadata 
+65/-0   
PyTorchModelRestTestCase.java
Support creating models with memory metadata in tests       
+22/-2   
TransportGetDeploymentStatsActionTests.java
Update test to include memory metadata parameters               
+12/-1   
MlMemoryAutoscalingDeciderTests.java
Update all test cases with memory metadata fields               
+18/-6   
MlProcessorAutoscalingDeciderTests.java
Update test cases with memory metadata parameters               
+42/-14 
TrainedModelAssignmentClusterServiceTests.java
Update test helper to include memory metadata                       
+3/-1     
TrainedModelAssignmentMetadataTests.java
Update test to include memory metadata parameters               
+3/-1     
TrainedModelAssignmentNodeServiceTests.java
Update test helper with memory metadata fields                     
+3/-1     
TrainedModelAssignmentRebalancerTests.java
Update test helpers with memory metadata parameters           
+6/-2     
AllocationReducerTests.java
Update test to include memory metadata fields                       
+3/-1     
TrainedModelDeploymentTaskTests.java
Update test cases with memory metadata parameters               
+6/-2     
PyTorchBuilderTests.java
Update test cases with memory metadata parameters               
+3/-3     
NodeLoadDetectorTests.java
Update test with memory metadata parameters                           
+3/-1     

…the trained models config (elastic#98139)

To improve the required memory estimation of NLP models, this PR introduces two new metadata fields: per_deployment_memory_bytes and per_allocation_memory_bytes.

per_deployment_memory_bytes is the memory required to load the model in the deployment
per_allocation_memory_bytes is the temporary additional memory used during the inference for every allocation.

This PR extends the memory usage estimation logic while ensuring backward compatibility.

In a follow-up PR, I will adjust the assignment planner to use the refined memory usage information.
@qodo-code-review
Copy link

PR Compliance Guide 🔍

Below is a summary of compliance checks for this PR:

Security Compliance
Memory estimate overflow

Description: The new required-memory calculation uses unvalidated per_deployment_memory_bytes and
per_allocation_memory_bytes (plus perAllocationMemoryBytes * numberOfAllocations) which
can be negative or overflow long, potentially underestimating memory requirements and
enabling resource exhaustion/DoS by allowing deployments that should have been rejected.
StartTrainedModelDeploymentAction.java [707-734]

Referred Code
public static long estimateMemoryUsageBytes(
    String modelId,
    long totalDefinitionLength,
    long perDeploymentMemoryBytes,
    long perAllocationMemoryBytes,
    int numberOfAllocations
) {
    // While loading the model in the process we need twice the model size.

    // 1. If ELSER v1 then 2004MB
    // 2. If static memory and dynamic memory are not set then 240MB + 2 * model size
    // 3. Else static memory + dynamic memory * allocations + model size

    // The model size is still added in option 3 to account for the temporary requirement to hold the zip file in memory
    // in `pytorch_inference`.
    if (isElserV1Model(modelId)) {
        return ELSER_1_MEMORY_USAGE.getBytes();
    } else {
        long baseSize = MEMORY_OVERHEAD.getBytes() + 2 * totalDefinitionLength;
        if (perDeploymentMemoryBytes == 0 && perAllocationMemoryBytes == 0) {
            return baseSize;


 ... (clipped 7 lines)
Ticket Compliance
🎫 No ticket provided
  • Create ticket/issue
Codebase Duplication Compliance
Codebase context is not defined

Follow the guide to enable codebase context checks.

Custom Compliance
🟢
Generic: Comprehensive Audit Trails

Objective: To create a detailed and reliable record of critical system actions for security analysis
and compliance.

Status: Passed

Learn more about managing compliance generic rules or creating your own custom rules

Generic: Meaningful Naming and Self-Documenting Code

Objective: Ensure all identifiers clearly express their purpose and intent, making code
self-documenting

Status: Passed

Learn more about managing compliance generic rules or creating your own custom rules

Generic: Secure Logging Practices

Objective: To ensure logs are useful for debugging and auditing without exposing sensitive
information like PII, PHI, or cardholder data.

Status: Passed

Learn more about managing compliance generic rules or creating your own custom rules

🔴
Generic: Robust Error Handling and Edge Case Management

Objective: Ensure comprehensive error handling that provides meaningful context and graceful
degradation

Status:
Unsafe metadata cast: The new getters cast metadata values to Number without validation, which can throw a
runtime exception (e.g., ClassCastException) if stored metadata is missing or malformed.

Referred Code
public long getPerDeploymentMemoryBytes() {
    return metadata != null && metadata.containsKey(PER_DEPLOYMENT_MEMORY_BYTES.getPreferredName())
        ? ((Number) metadata.get(PER_DEPLOYMENT_MEMORY_BYTES.getPreferredName())).longValue()
        : 0L;
}

public long getPerAllocationMemoryBytes() {
    return metadata != null && metadata.containsKey(PER_ALLOCATION_MEMORY_BYTES.getPreferredName())
        ? ((Number) metadata.get(PER_ALLOCATION_MEMORY_BYTES.getPreferredName())).longValue()
        : 0L;

Learn more about managing compliance generic rules or creating your own custom rules

Generic: Security-First Input Validation and Data Handling

Objective: Ensure all data inputs are validated, sanitized, and handled securely to prevent
vulnerabilities

Status:
Missing input validation: The new memory metadata fields are read from metadata without type/range checks (e.g.,
ensuring numeric and non-negative), allowing invalid external/stored inputs to trigger
errors or incorrect behavior.

Referred Code
public long getPerDeploymentMemoryBytes() {
    return metadata != null && metadata.containsKey(PER_DEPLOYMENT_MEMORY_BYTES.getPreferredName())
        ? ((Number) metadata.get(PER_DEPLOYMENT_MEMORY_BYTES.getPreferredName())).longValue()
        : 0L;
}

public long getPerAllocationMemoryBytes() {
    return metadata != null && metadata.containsKey(PER_ALLOCATION_MEMORY_BYTES.getPreferredName())
        ? ((Number) metadata.get(PER_ALLOCATION_MEMORY_BYTES.getPreferredName())).longValue()
        : 0L;

Learn more about managing compliance generic rules or creating your own custom rules

Generic: Secure Error Handling

Objective: To prevent the leakage of sensitive system information through error messages while
providing sufficient detail for internal debugging.

Status:
Potential error exposure: If malformed metadata causes a runtime exception, it is unclear from the diff whether the
resulting error returned to clients is sanitized to avoid exposing internal details.

Referred Code
public long getPerDeploymentMemoryBytes() {
    return metadata != null && metadata.containsKey(PER_DEPLOYMENT_MEMORY_BYTES.getPreferredName())
        ? ((Number) metadata.get(PER_DEPLOYMENT_MEMORY_BYTES.getPreferredName())).longValue()
        : 0L;
}

public long getPerAllocationMemoryBytes() {
    return metadata != null && metadata.containsKey(PER_ALLOCATION_MEMORY_BYTES.getPreferredName())
        ? ((Number) metadata.get(PER_ALLOCATION_MEMORY_BYTES.getPreferredName())).longValue()
        : 0L;

Learn more about managing compliance generic rules or creating your own custom rules

Compliance status legend 🟢 - Fully Compliant
🟡 - Partial Compliant
🔴 - Not Compliant
⚪ - Requires Further Human Verification
🏷️ - Compliance label

@qodo-code-review
Copy link

PR Code Suggestions ✨

Explore these optional code suggestions:

CategorySuggestion                                                                                                                                    Impact
General
Remove unused fields from builder class

Remove the unused perDeploymentMemoryBytes and perAllocationMemoryBytes fields
from the TrainedModelConfig.Builder class, as their values are handled through
the metadata map.

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/TrainedModelConfig.java [571-593]

 public static class Builder {
     private String modelId;
     private TrainedModelType modelType;
     private String createdBy;
     private MlConfigVersion version;
     private String description;
     private Instant createTime;
     private List<String> tags;
     private Map<String, Object> metadata;
     private TrainedModelInput input;
     private Long modelSize;
     private Long estimatedOperations;
     private LazyModelDefinition definition;
     private String licenseLevel;
     private Map<String, String> defaultFieldMap;
     private InferenceConfig inferenceConfig;
     private TrainedModelLocation location;
     private ModelPackageConfig modelPackageConfig;
-    private Long perDeploymentMemoryBytes;
-    private Long perAllocationMemoryBytes;
 
     public Builder() {}

[To ensure code accuracy, apply this suggestion manually]

Suggestion importance[1-10]: 7

__

Why: The suggestion correctly identifies that the perDeploymentMemoryBytes and perAllocationMemoryBytes fields in TrainedModelConfig.Builder are unused and can be removed. This is a good code cleanup suggestion that improves maintainability.

Medium
Allow testing with partial memory metadata

In the createPassThroughModel test helper, modify the condition to add memory
metadata if either perDeploymentMemoryBytes or perAllocationMemoryBytes is
greater than zero, instead of requiring both.

x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/PyTorchModelRestTestCase.java [197-225]

 protected void createPassThroughModel(String modelId, long perDeploymentMemoryBytes, long perAllocationMemoryBytes) throws IOException {
     Request request = new Request("PUT", "/_ml/trained_models/" + modelId);
     String metadata;
-    if (perDeploymentMemoryBytes > 0 && perAllocationMemoryBytes > 0) {
+    if (perDeploymentMemoryBytes > 0 || perAllocationMemoryBytes > 0) {
         metadata = Strings.format("""
             "metadata": {
               "per_deployment_memory_bytes": %d,
               "per_allocation_memory_bytes": %d
             },""", perDeploymentMemoryBytes, perAllocationMemoryBytes);
     } else {
         metadata = "";
     }
     request.setJsonEntity(Strings.format("""
         {
            "description": "simple model for testing",
            "model_type": "pytorch",
             %s
            "inference_config": {
               "pass_through": {
                 "tokenization": {
                   "bert": {
                     "with_special_tokens": false
                   }
                 }
               }
             }
          }""", metadata));
     client().performRequest(request);
 }
  • Apply / Chat
Suggestion importance[1-10]: 6

__

Why: The suggestion correctly identifies a limitation in a test helper method. Changing the condition from && to || improves test coverage for the new feature by allowing tests for valid scenarios where only one of the memory parameters is non-zero.

Low
conditionally include memory fields

In TaskParams.toXContent, conditionally serialize per_deployment_memory_bytes
and per_allocation_memory_bytes only when their values are greater than zero to
avoid cluttering the output.

x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java [606-607]

-builder.field(PER_DEPLOYMENT_MEMORY_BYTES.getPreferredName(), perDeploymentMemoryBytes);
-builder.field(PER_ALLOCATION_MEMORY_BYTES.getPreferredName(), perAllocationMemoryBytes);
+if (perDeploymentMemoryBytes > 0L) {
+    builder.field(PER_DEPLOYMENT_MEMORY_BYTES.getPreferredName(), perDeploymentMemoryBytes);
+}
+if (perAllocationMemoryBytes > 0L) {
+    builder.field(PER_ALLOCATION_MEMORY_BYTES.getPreferredName(), perAllocationMemoryBytes);
+}
  • Apply / Chat
Suggestion importance[1-10]: 3

__

Why: The suggestion proposes a minor improvement to make the XContent output cleaner by omitting zero-value memory fields. While correct, this is a low-impact style preference.

Low
  • More

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants