Skip to content

Weirdness with tokenization in Phi-3 #12

@uogbuji

Description

@uogbuji

Server:

toolio_server --model=mlx-community/Phi-3-mini-128k-instruct-4bit

Client:

toolio_request --apibase="http://localhost:8000" --prompt='What is the average airspeed of an unladen swallow?'

You can run the above any number of times, but as soon as you run a version that tries to use a prior prompt cache:

toolio_request --apibase="http://localhost:8000" --prompt='What is the average airspeed of an unladen swallow? Where have I heard that before?'

It blows up. Server exception tail:

  File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/toolio/cli/server.py", line 271, in post_v1_chat_completions_impl
    for result in app.state.model.completion(
  File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/toolio/schema_helper.py", line 296, in completion
    logits, cache = self._evaluate_prompt(
                    ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/toolio/schema_helper.py", line 92, in _evaluate_prompt
    logits = self.model(mx.array(tokens)[None], cache)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/mlx_lm/models/phi3.py", line 202, in __call__
    out = self.model(inputs, cache)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/mlx_lm/models/phi3.py", line 184, in __call__
    h = layer(h, mask, c)
        ^^^^^^^^^^^^^^^^^
  File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/mlx_lm/models/phi3.py", line 148, in __call__
    r = self.self_attn(self.input_layernorm(x), mask, cache)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/uche/.local/venv/toolio/lib/python3.11/site-packages/mlx_lm/models/phi3.py", line 110, in __call__
    output = mx.fast.scaled_dot_product_attention(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: Shapes (1,32,9,24) and (9,9) cannot be broadcast.

Modified schema_helper.py for a trace

    def _evaluate_prompt(
        self, prompt: list[int], prior_prompt: list[int] = None, prior_cache=None
    ):
        if prior_prompt:
            i = 0
            for i, t in enumerate(prior_prompt):
                # Need to leave at least one token to evaluate because we don't
                # save the past logits.
                if i >= len(prompt) - 1 or prompt[i] != t:
                    break
            cache = prior_cache
            for layer_cache in cache:
                layer_cache.reuse(len(prompt), i)
            tokens = prompt[i:]
            print('CACHED', tokens, prompt)
        else:
            cache = ReusableKVCache.for_model(self.model)
            tokens = prompt
            print('UNCACHED', tokens)

        logits = self.model(mx.array(tokens)[None], cache)
        return logits, cache

First run of the shorter prompt displays:

UNCACHED [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 32007, 32007]

Already notice the repeated 32007, which is the Phi-3 '<|end|>' token. This is probably not good. Identical run again:

CACHED [32007] [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 32007, 32007]

Expected logic, with nothing but that end token post-cache. Now the longer prompt:

CACHED [6804, 505, 306, 6091, 393, 1434, 29973, 32007, 32007] [32010, 1724, 338, 278, 6588, 4799, 19322, 310, 385, 443, 4528, 264, 2381, 9536, 29973, 6804, 505, 306, 6091, 393, 1434, 29973, 32007, 32007]

End prompt is re-doubled.

At this point I don't know whether this tokenizer oddness is what leads to the shape error, but it's a start for investigating.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions