Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 21 additions & 21 deletions tests/mask_store/test_byte_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,13 @@ def test_raw_tokenizer(self):
"""Test ByteTokenizer with a RAW (tiktoken-style) tokenizer."""
# Create mock vocabulary for a raw tokenizer
vocab = {
"hello": 1,
"world": 2,
"!": 3,
"你": 4,
"好": 5,
"吗": 6,
"?": 7
b"hello": 1,
b"world": 2,
b"!": 3,
b"\xE4\xBD\xA0": 4, # 你
b"\xE5\xA5\xBD": 5, # 好
b"\xE5\x90": 6, # first two bytes of 吗
b"\x97": 7, # last byte of 吗
}

mock_tokenizer = self.create_mock_tokenizer(vocab, VocabType.RAW)
Expand All @@ -84,10 +84,10 @@ def test_raw_tokenizer(self):
mock_tokenizer.encode.return_value = expected_ids

# Test decoding
token_ids = [4, 5, 6, 7] # 你, 好, 吗, ?
token_ids = [4, 5, 6] # 你, 好, 吗 (first two bytes).
mock_tokenizer.decode.return_value = "你好吗?"
result = byte_tokenizer.decode(token_ids)
self.assertEqual(result.decode('utf-8'), "你好吗?")
self.assertEqual(result, b"\xE4\xBD\xA0\xE5\xA5\xBD\xE5\x90")

def test_byte_fallback_tokenizer(self):
"""Test ByteTokenizer with a BYTE_FALLBACK (Llama-2-style) tokenizer."""
Expand Down Expand Up @@ -151,11 +151,11 @@ def test_byte_level_tokenizer(self):
def test_batched_decoding(self):
"""Test batched decoding capabilities."""
vocab = {
"hello": 1,
"world": 2,
"!": 3,
"<s>": 4, # special token
"</s>": 5, # special token
b"hello": 1,
b"world": 2,
b"!": 3,
b"<s>": 4, # special token
b"</s>": 5, # special token
}

mock_tokenizer = self.create_mock_tokenizer(vocab, VocabType.RAW)
Expand Down Expand Up @@ -196,10 +196,10 @@ def test_auto_detection(self):
def test_decoding_performance(self):
"""Test basic decoding performance."""
# Create a larger vocabulary for more realistic testing
vocab = {f"token{i}": i for i in range(1000)}
vocab = {bytes(f"token{i}".encode('utf-8')): i for i in range(1000)}
# Add some special tokens
vocab["<s>"] = 1000
vocab["</s>"] = 1001
vocab[b"<s>"] = 1000
vocab[b"</s>"] = 1001

mock_tokenizer = self.create_mock_tokenizer(vocab, VocabType.RAW)
mock_tokenizer.all_special_ids = [1000, 1001]
Expand Down Expand Up @@ -313,10 +313,10 @@ def test_roundtrip_encoding_decoding(self):
"""Test encoding and decoding round-trip."""
# Create a simple vocabulary for testing
raw_vocab = {
"hello": 1,
" ": 2,
"world": 3,
"!": 4,
b"hello": 1,
b" ": 2,
b"world": 3,
b"!": 4,
}

mock_tokenizer = self.create_mock_tokenizer(raw_vocab, VocabType.RAW)
Expand Down
Loading