diff --git a/src/Tokenizer/BaseTokenizer.cs b/src/Tokenizer/BaseTokenizer.cs index 3d87111..b2cb163 100644 --- a/src/Tokenizer/BaseTokenizer.cs +++ b/src/Tokenizer/BaseTokenizer.cs @@ -109,7 +109,7 @@ public TokensWithOffsets TokenizeWithOffsets(string text) foreach (var token in tokens) { - texts.Add(token.Text); + texts.Add(new string(token.Text)); offsets.Add(token.ReferenceOffsets.Any() ? new Offset(token.ReferenceOffsets.First(), token.ReferenceOffsets.Last() + 1) : (Offset?)null); originalPositions.Add(token.ReferenceOffsets); masks.Add(token.Mask); @@ -174,7 +174,7 @@ public virtual List TokenizeToTokens(Token initialToken) return ownedToken; }) - .Where(token => !string.IsNullOrEmpty(token.Text)) + .Where(token => !string.IsNullOrEmpty(new string(token.Text))) .ToList(); return tokens; @@ -183,7 +183,7 @@ public virtual List TokenizeToTokens(Token initialToken) public void DecomposeNfkc(Token token) { // Perform NFKC normalization on the token text - var decomposedText = token.Text.Normalize(NormalizationForm.FormKC); + var decomposedText = new string(token.Text).Normalize(NormalizationForm.FormKC); // Calculate the new reference offsets var newReferenceOffsets = new List(); @@ -197,7 +197,7 @@ public void DecomposeNfkc(Token token) } // Update the token's properties - token.Text = decomposedText; + token.Text = decomposedText.ToCharArray(); token.ReferenceOffsets = newReferenceOffsets; token.Offset.Begin = newReferenceOffsets.FirstOrDefault(); token.Offset.End = newReferenceOffsets.LastOrDefault() + 1; @@ -357,11 +357,11 @@ private string ConvertTokensToString(List tokens) protected List SplitOnSpecialTokens(Token token, IVocab vocab) { - Func testSubstr = (s) => + Func testSubstr = (s) => { foreach (var specialValue in vocab.SpecialValues.Keys) { - if (s.StartsWith(specialValue)) + if (new string(s).StartsWith(specialValue)) { return ( specialValue.Length, @@ -400,7 +400,7 @@ private string CleanUpTokenization(string inputString) private List WhitespaceTokenize(Token initialToken) { - var parts = initialToken.Text.Split(new[] { ' ', '\t', '\n', '\r' }, StringSplitOptions.RemoveEmptyEntries); + var parts = new string(initialToken.Text).Split(new[] { ' ', '\t', '\n', '\r' }, StringSplitOptions.RemoveEmptyEntries); var tokens = new List(); foreach (var part in parts) { @@ -410,7 +410,7 @@ private List WhitespaceTokenize(Token initialToken) return tokens; } - private List SplitOnSubstr(Token token, Func testSubstr, bool addSeparators) + private List SplitOnSubstr(Token token, Func testSubstr, bool addSeparators) { var tokens = new List(); uint charBegin = 0; @@ -420,7 +420,7 @@ private List SplitOnSubstr(Token token, Func te if (token.Mask == Mask.None) { // Iterate over characters with byte indices - var itr = TokenizationUtils.Enumerate(TokenizationUtils.CharIndicesForRunes(token.Text)); + var itr = TokenizationUtils.Enumerate(TokenizationUtils.CharIndicesForRunes(new string(token.Text))); foreach (var (charIdx, (bytesIdx, _)) in itr) { charCount++; @@ -431,7 +431,7 @@ private List SplitOnSubstr(Token token, Func te if (charBegin < charIdx) { // Add previous token - var trimmedText = TokenizationUtils.SubstringRunes(token.Text, bytesBegin, bytesIdx - bytesBegin).TrimEnd(); + var trimmedText = new string(TokenizationUtils.SubstringRunes(token.Text, bytesBegin, bytesIdx - bytesBegin)).TrimEnd(); if (trimmedText.EnumerateRunes().Count() > 0) { tokens.Add(new Token(trimmedText) @@ -468,7 +468,7 @@ private List SplitOnSubstr(Token token, Func te var text = TokenizationUtils.SubstringRunes(token.Text, bytesBegin, bytesBegin + (bytesIdx - bytesBegin)); if (charCount == 0) { - charCount = token.Text.EnumerateRunes().Count(); + charCount = new string(token.Text).EnumerateRunes().Count(); } tokens.Add(new Token(text) { @@ -493,7 +493,7 @@ private List SplitOnPunct(Token token) if (char.IsPunctuation(charCurrent)) { var offsets = token.ReferenceOffsets.Skip(start).Take(1).ToArray(); - tokens.Add(new Token(text.Substring(start, 1), offsets) { Mask = Mask.Punctuation }); + tokens.Add(new Token(new string(text).Substring(start, 1), offsets) { Mask = Mask.Punctuation }); start++; } else @@ -504,7 +504,7 @@ private List SplitOnPunct(Token token) end++; } var offsets = token.ReferenceOffsets.Skip(start).Take(end - start).ToArray(); - tokens.Add(new Token(text.Substring(start, end - start), offsets)); + tokens.Add(new Token(new string(text).Substring(start, end - start), offsets)); start = end; } } @@ -523,7 +523,7 @@ private List TokenizeCjkChars(Token token) if (IsCjkChar(charCurrent)) { var offsets = token.ReferenceOffsets.Skip(start).Take(1).ToArray(); - tokens.Add(new Token(text.Substring(start, 1), offsets) { Mask = Mask.CJK }); + tokens.Add(new Token(new string(text).Substring(start, 1), offsets) { Mask = Mask.CJK }); start++; } else @@ -534,7 +534,7 @@ private List TokenizeCjkChars(Token token) end++; } var offsets = token.ReferenceOffsets.Skip(start).Take(end - start).ToArray(); - tokens.Add(new Token(text.Substring(start, end - start), offsets)); + tokens.Add(new Token(new string(text).Substring(start, end - start), offsets)); start = end; } } @@ -551,19 +551,19 @@ private void CleanText(Token token, bool removeControlCharacters) { if (removeControlCharacters) { - token.Text = Regex.Replace(token.Text, @"\p{C}+", ""); + token.Text = Regex.Replace(new string(token.Text), @"\p{C}+", "").ToCharArray(); } - token.Text = token.Text.Replace("``", "\"").Replace("''", "\""); + token.Text = new string(token.Text).Replace("``", "\"").Replace("''", "\"").ToCharArray(); } private void Lowercase(Token token) { - token.Text = token.Text.ToLowerInvariant(); + token.Text = new string(token.Text).ToLowerInvariant().ToCharArray(); } private void StripAccents(Token token) { - token.Text = RemoveDiacritics(token.Text); + token.Text = RemoveDiacritics(new string(token.Text)).ToCharArray(); } private string RemoveDiacritics(string text) diff --git a/src/Tokenizer/Token.cs b/src/Tokenizer/Token.cs index 09c6682..b031d64 100644 --- a/src/Tokenizer/Token.cs +++ b/src/Tokenizer/Token.cs @@ -8,7 +8,7 @@ public class Token : IToken /// /// String representation /// - public string Text { get; set; } + public char[] Text { get; set; } /// /// Start and end positions of the token with respect to the original text @@ -30,9 +30,9 @@ public class Token : IToken /// Creates a new owned token from a `String`. /// /// text reference - public Token(string text) + public Token(ReadOnlySpan text) { - Text = text; + Text = text.ToArray(); var text_size = (uint)text.Length; Offset = new Offset(0, text_size); ReferenceOffsets = Enumerable.Range(0, (int)text_size).Select(i => (uint)i).ToList(); @@ -44,17 +44,17 @@ public Token(string text) /// /// text reference /// reference positions with respect to the original text - public Token(string text, uint[] offsets) + public Token(ReadOnlySpan text, uint[] offsets) { - Text = text; + Text = text.ToArray(); Offset = new Offset(0, (uint)offsets.Length); ReferenceOffsets = offsets; Mask = Mask.None; } - public Token(string text, Offset offset, IReadOnlyList referenceOffsets, Mask mask) + public Token(ReadOnlySpan text, Offset offset, IReadOnlyList referenceOffsets, Mask mask) { - Text = text; + Text = text.ToArray(); Offset = offset; ReferenceOffsets = referenceOffsets; Mask = mask; @@ -62,7 +62,7 @@ public Token(string text, Offset offset, IReadOnlyList referenceOffsets, M public override string ToString() { - return Text; + return new string(Text); } public static Token From(string text) diff --git a/src/Tokenizer/TokenizationUtils.cs b/src/Tokenizer/TokenizationUtils.cs index 2d3f04e..25de862 100644 --- a/src/Tokenizer/TokenizationUtils.cs +++ b/src/Tokenizer/TokenizationUtils.cs @@ -11,38 +11,31 @@ namespace Lokad.Tokenizers.Tokenizer; internal static class TokenizationUtils { + /// /// Substring Runes (characters) /// - public static string SubstringRunes(string text, int start, int length) + public static char[] SubstringRunes(ReadOnlySpan text, int start, int length) { var sb = new StringBuilder(); - text.EnumerateRunes().Skip(start).Take(length).ToList().ForEach(r => sb.Append(r)); - return sb.ToString(); + text.EnumerateRunes().ToList().Skip(start).Take(length).ToList().ForEach(r => sb.Append(r)); + return sb.ToString().ToCharArray(); } /// /// Substring Runes (characters) /// - public static string SubstringRunes(string text, int start) + public static char[] SubstringRunes(ReadOnlySpan text, int start) { var sb = new StringBuilder(); - text.EnumerateRunes().Skip(start).ToList().ForEach(r => sb.Append(r)); - return sb.ToString(); - } - - /// - /// Get String Info - /// - public static StringInfo GetStringInfo(string text) - { - return new System.Globalization.StringInfo(text); + text.EnumerateRunes().ToList().Skip(start).ToList().ForEach(r => sb.Append(r)); + return sb.ToString().ToCharArray(); } /// /// Get UTF 8 Bytes Count /// - public static int GetUtf8BytesCount(string text) + public static int GetUtf8BytesCount(ReadOnlySpan text) { return Encoding.UTF8.GetByteCount(text); } @@ -66,7 +59,7 @@ public static int GetUtf8BytesCount(string text) /// NFKC decomposition /// public static IEnumerable<(Rune Character, int ExtraCharSize)> NFKC(string str) - { + { var runes = str.EnumerateRunes().ToList(); for (var i = 0; i < runes.Count; i++) { @@ -89,18 +82,18 @@ public static int GetUtf8BytesCount(string text) /// /// Substring by byte offset /// - public static string SubstringByByteOffset(string s, int start) + public static char[] SubstringByByteOffset(char[] s, int start) { var bytes = Encoding.UTF8.GetBytes(s); var substringBytes = new byte[bytes.Length - start]; Array.Copy(bytes, start, substringBytes, 0, bytes.Length - start); - return Encoding.UTF8.GetString(substringBytes); + return Encoding.UTF8.GetChars(substringBytes); } /// /// Substring by byte offset /// - public static string SubstringByByteOffset(string s, int start, int end) + public static char[] SubstringByByteOffset(char[] s, int start, int end) { var bytes = Encoding.UTF8.GetBytes(s); if (end > bytes.Length || start > end) @@ -109,7 +102,7 @@ public static string SubstringByByteOffset(string s, int start, int end) } var substringBytes = new byte[end - start]; Array.Copy(bytes, start, substringBytes, 0, end - start); - return Encoding.UTF8.GetString(substringBytes); + return Encoding.UTF8.GetChars(substringBytes); } /// @@ -120,7 +113,7 @@ public static void CleanText(Token token, bool strict) var cleanedString = new StringBuilder(token.Text.Length); var characterMapping = new List(token.Text.Length); - foreach (var (character, position) in token.Text.EnumerateRunes().Zip(token.ReferenceOffsets)) + foreach (var (character, position) in token.Text.AsSpan().EnumerateRunes().ToList().Zip(token.ReferenceOffsets)) { if (IsControl(character, strict) || character == new Rune('\x00') || character == new Rune('\uFFFD')) { @@ -131,7 +124,7 @@ public static void CleanText(Token token, bool strict) characterMapping.Add(position); } - token.Text = cleanedString.ToString(); + token.Text = cleanedString.ToString().ToCharArray(); token.ReferenceOffsets = characterMapping; token.Offset = new Offset(token.ReferenceOffsets.FirstOrDefault(), token.ReferenceOffsets.LastOrDefault() + 1); } @@ -191,7 +184,7 @@ public static void Lowercase(Token token) } } - token.Text = lowerCasedString.ToString(); + token.Text = lowerCasedString.ToString().ToCharArray(); token.ReferenceOffsets = characterMapping; token.Offset = new Offset(token.ReferenceOffsets.FirstOrDefault(), token.ReferenceOffsets.LastOrDefault() + 1); } @@ -205,13 +198,13 @@ public static void DecomposeNfkc(Token token) var decomposedString = new StringBuilder(capacity); var characterMapping = new List(capacity); var curPosition = 0; - var normalizedString = token.Text.Normalize(NormalizationForm.FormKC); + var normalizedString = new string(token.Text).Normalize(NormalizationForm.FormKC); foreach (var (character, currentExtraCharSize) in TokenizationUtils.NFKC(normalizedString)) { var extraCharSize = 0; //HINT: [@eslam] check if character is removed from the original text after normalization - if (!token.Text.EnumerateRunes().Contains(character)) + if (!token.Text.AsSpan().EnumerateRunes().ToList().Contains(character)) extraCharSize -= currentExtraCharSize; decomposedString.Append(character); @@ -236,7 +229,7 @@ public static void DecomposeNfkc(Token token) curPosition += 1; // Adjust based on Unicode character width if needed } - token.Text = decomposedString.ToString();//.Normalize(NormalizationForm.FormKC); + token.Text = decomposedString.ToString().ToCharArray();//.Normalize(NormalizationForm.FormKC); token.ReferenceOffsets = characterMapping; token.Offset.Begin = token.ReferenceOffsets.FirstOrDefault(); token.Offset.End = token.ReferenceOffsets.LastOrDefault() + 1; @@ -475,4 +468,17 @@ public static (TokenIdsWithOffsets, TokenIdsWithOffsets?, List, List + /// extension method to enumerate span runes to list + /// + public static List ToList(this SpanRuneEnumerator enumerator) + { + var runes = new List(); + foreach (var rune in enumerator) + { + runes.Add(rune); + } + return runes; + } + } \ No newline at end of file diff --git a/src/Tokenizer/XLMRobertaTokenizer.cs b/src/Tokenizer/XLMRobertaTokenizer.cs index ccaf379..10b7b24 100644 --- a/src/Tokenizer/XLMRobertaTokenizer.cs +++ b/src/Tokenizer/XLMRobertaTokenizer.cs @@ -75,15 +75,15 @@ public override List TokenizeToTokens(Token tokenRef) // Manually replacing whitespace characters var newText = new StringBuilder(); - foreach (var c in token.Text.EnumerateRunes()) + foreach (var c in token.Text.AsSpan().EnumerateRunes().ToList()) { newText.Append(TokenizationUtils.IsWhitespace(c) ? new Rune(Constants.LowerOneEighthBlock) : c.ToString()); } - token.Text = newText.ToString(); + token.Text = newText.ToString().ToCharArray(); - if (!token.Text.StartsWith(Constants.LowerOneEighthBlock)) + if (!new string(token.Text).StartsWith(Constants.LowerOneEighthBlock)) { - token.Text = Constants.LowerOneEighthBlock + token.Text; + token.Text = (Constants.LowerOneEighthBlock.ToString() + new string(token.Text)).ToCharArray(); var newReferenceOffsets = new List { 0 }; newReferenceOffsets.AddRange(token.ReferenceOffsets); token.ReferenceOffsets = newReferenceOffsets; diff --git a/src/Vocab/SentencePieceUnigramModel.cs b/src/Vocab/SentencePieceUnigramModel.cs index 8423ec5..d49d2ee 100644 --- a/src/Vocab/SentencePieceUnigramModel.cs +++ b/src/Vocab/SentencePieceUnigramModel.cs @@ -125,7 +125,7 @@ public List CommonPrefixSearch(string text) { var charPositions = new List(); - var runes = TokenizationUtils.CharIndicesForRunes(token.Text).ToList(); + var runes = TokenizationUtils.CharIndicesForRunes(new string(token.Text)).ToList(); runes.ForEach((i => charPositions.Add(i.Index))); charPositions.Add(TokenizationUtils.GetUtf8BytesCount(token.Text)); @@ -136,7 +136,7 @@ public List CommonPrefixSearch(string text) for (var charStart = 0; charStart < charPositions.Count - 1; charStart++) { var prefix = TokenizationUtils.SubstringByByteOffset(token.Text, charPositions[charStart]); - var matches = CommonPrefixSearch(prefix.ToString()); + var matches = CommonPrefixSearch(new string(prefix)); foreach (var node in matches) { @@ -148,7 +148,7 @@ public List CommonPrefixSearch(string text) var t = TokenizationUtils.SubstringByByteOffset(token.Text, charPositions[charStart], charPositions[charEnd]); results[charEnd] = new Node ( - text: t, + text: new string(t), score: localScore, index: node.Index, start: charStart, @@ -165,7 +165,7 @@ public List CommonPrefixSearch(string text) var t = TokenizationUtils.SubstringByByteOffset(token.Text, charPositions[charStart], charPositions[charStart + 1]); results[charStart + 1] = new Node ( - text: t, + text: new string(t), score: float.MinValue, index: 0, start: charStart, @@ -217,13 +217,13 @@ public List ParseNodesToTokens(List nodes) if (isPrevUnknown && (node.Index == 0)) { var prevToken = output.Last(); - var text = new StringBuilder(prevToken.Text); + var text = new StringBuilder(new string(prevToken.Text)); text.Append(node.Text); var referenceOffsets = new List(); referenceOffsets.AddRange(node.ReferenceOffsets); var consolidatedUnknown = new Token(text.ToString()) { - Text = text.ToString(), + Text = text.ToString().ToCharArray(), Offset = new Offset(0, 0), ReferenceOffsets = referenceOffsets, Mask = Mask.Unknown, @@ -235,7 +235,7 @@ public List ParseNodesToTokens(List nodes) { output.Add(new Token(node.Text) { - Text = node.Text, + Text = node.Text.ToCharArray(), Offset = new Offset(0, 0), ReferenceOffsets = node.ReferenceOffsets.ToList(), Mask = Mask.None, @@ -278,7 +278,7 @@ public void PopulateMasks(List tokens, char whitespaceToken) } } - if (!token.Text.StartsWith(whitespaceToken) && previousMask != Mask.Punctuation && previousMask != Mask.Whitespace) + if (!new string(token.Text).StartsWith(whitespaceToken) && previousMask != Mask.Punctuation && previousMask != Mask.Whitespace) { token.Mask = Mask.Continuation; previousMask = Mask.Continuation; diff --git a/test/Tokenizer/TokenizationUtilsTests.cs b/test/Tokenizer/TokenizationUtilsTests.cs index 7d2274c..183e9b0 100644 --- a/test/Tokenizer/TokenizationUtilsTests.cs +++ b/test/Tokenizer/TokenizationUtilsTests.cs @@ -135,10 +135,10 @@ public void TestSubString_01() var s = "▁tokénized"; - var prefix = TokenizationUtils.SubstringByByteOffset(s, 0); + var prefix = TokenizationUtils.SubstringByByteOffset(s.ToCharArray(), 0); // Expected positions - var expected = "▁tokénized"; + var expected = "▁tokénized".ToCharArray(); // Assert that the character positions match the expected positions Assert.Equal(expected, prefix); } @@ -149,10 +149,10 @@ public void TestSubString_02() var s = "▁tokénized"; - var prefix = TokenizationUtils.SubstringByByteOffset(s, 3); + var prefix = TokenizationUtils.SubstringByByteOffset(s.ToCharArray(), 3); // Expected positions - var expected = "tokénized"; + var expected = "tokénized".ToCharArray(); // Assert that the character positions match the expected positions Assert.Equal(expected, prefix); }