diff --git a/src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs index 6b6ec7a234..405676c0de 100644 --- a/src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/BPETokenizer.cs @@ -757,11 +757,10 @@ public OperationStatus Decode(IEnumerable ids, Span destination, bool /// Read the given files to extract the vocab and merges internal static async ValueTask<(Dictionary?, Vec<(string, string)>)> ReadModelDataAsync(Stream vocab, Stream? merges, bool useAsync, CancellationToken cancellationToken = default) { - JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyConverter.Instance } }; + Dictionary? dic = useAsync + ? await JsonSerializer.DeserializeAsync(vocab, ModelSourceGenerationContext.Default.DictionaryStringSpanOrdinalKeyInt32, cancellationToken).ConfigureAwait(false) + : JsonSerializer.Deserialize(vocab, ModelSourceGenerationContext.Default.DictionaryStringSpanOrdinalKeyInt32); - Dictionary? dic = useAsync ? - await JsonSerializer.DeserializeAsync>(vocab, options, cancellationToken).ConfigureAwait(false) as Dictionary : - JsonSerializer.Deserialize>(vocab, options) as Dictionary; var m = useAsync ? await ConvertMergesToHashmapAsync(merges, useAsync, cancellationToken).ConfigureAwait(false) : ConvertMergesToHashmapAsync(merges, useAsync).GetAwaiter().GetResult(); diff --git a/src/Microsoft.ML.Tokenizers/Model/CodeGenTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/CodeGenTokenizer.cs index c1fd6bb1ca..a8b4577ea5 100644 --- a/src/Microsoft.ML.Tokenizers/Model/CodeGenTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/CodeGenTokenizer.cs @@ -1764,11 +1764,10 @@ void TryMerge(int left, int right, ReadOnlySpan textSpan) private static Dictionary GetVocabulary(Stream vocabularyStream) { - Dictionary? vocab; + Vocabulary? vocab; try { - JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyCustomConverter.Instance } }; - vocab = JsonSerializer.Deserialize>(vocabularyStream, options) as Dictionary; + vocab = JsonSerializer.Deserialize(vocabularyStream, ModelSourceGenerationContext.Default.Vocabulary); } catch (Exception e) { diff --git a/src/Microsoft.ML.Tokenizers/Model/EnglishRobertaTokenizer.cs b/src/Microsoft.ML.Tokenizers/Model/EnglishRobertaTokenizer.cs index 85f921ff0f..4557508c73 100644 --- a/src/Microsoft.ML.Tokenizers/Model/EnglishRobertaTokenizer.cs +++ b/src/Microsoft.ML.Tokenizers/Model/EnglishRobertaTokenizer.cs @@ -169,8 +169,7 @@ private static Dictionary GetVocabulary(Stream vocabu Dictionary? vocab; try { - JsonSerializerOptions options = new() { Converters = { StringSpanOrdinalKeyConverter.Instance } }; - vocab = JsonSerializer.Deserialize>(vocabularyStream, options) as Dictionary; + vocab = JsonSerializer.Deserialize(vocabularyStream, ModelSourceGenerationContext.Default.DictionaryStringSpanOrdinalKeyInt32); } catch (Exception e) { diff --git a/src/Microsoft.ML.Tokenizers/Model/ModelSourceGenerationContext.cs b/src/Microsoft.ML.Tokenizers/Model/ModelSourceGenerationContext.cs new file mode 100644 index 0000000000..e3075d92e5 --- /dev/null +++ b/src/Microsoft.ML.Tokenizers/Model/ModelSourceGenerationContext.cs @@ -0,0 +1,12 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using System.Collections.Generic; +using System.Text.Json.Serialization; + +namespace Microsoft.ML.Tokenizers; + +[JsonSerializable(typeof(Dictionary))] +[JsonSerializable(typeof(Vocabulary))] +internal partial class ModelSourceGenerationContext : JsonSerializerContext; diff --git a/src/Microsoft.ML.Tokenizers/Utils/StringSpanOrdinalKey.cs b/src/Microsoft.ML.Tokenizers/Utils/StringSpanOrdinalKey.cs index 69f8c31b85..0ae599381e 100644 --- a/src/Microsoft.ML.Tokenizers/Utils/StringSpanOrdinalKey.cs +++ b/src/Microsoft.ML.Tokenizers/Utils/StringSpanOrdinalKey.cs @@ -15,6 +15,7 @@ namespace Microsoft.ML.Tokenizers /// This should only be used with a Ptr/Length for querying. For storing in a dictionary, this should /// always be used with a string. /// + [JsonConverter(typeof(StringSpanOrdinalKeyConverter))] internal readonly unsafe struct StringSpanOrdinalKey : IEquatable { public readonly char* Ptr; @@ -124,12 +125,14 @@ internal void Set(string k, TValue v) } } + [JsonConverter(typeof(VocabularyConverter))] + internal sealed class Vocabulary : Dictionary; + /// /// Custom JSON converter for . /// internal sealed class StringSpanOrdinalKeyConverter : JsonConverter { - public static StringSpanOrdinalKeyConverter Instance { get; } = new StringSpanOrdinalKeyConverter(); public override StringSpanOrdinalKey ReadAsPropertyName(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) => new StringSpanOrdinalKey(reader.GetString()!); @@ -140,13 +143,11 @@ public override void WriteAsPropertyName(Utf8JsonWriter writer, StringSpanOrdina public override void Write(Utf8JsonWriter writer, StringSpanOrdinalKey value, JsonSerializerOptions options) => writer.WriteStringValue(value.Data!); } - internal class StringSpanOrdinalKeyCustomConverter : JsonConverter> + internal class VocabularyConverter : JsonConverter { - public static StringSpanOrdinalKeyCustomConverter Instance { get; } = new StringSpanOrdinalKeyCustomConverter(); - - public override Dictionary Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) + public override Vocabulary Read(ref Utf8JsonReader reader, Type typeToConvert, JsonSerializerOptions options) { - var dictionary = new Dictionary(); + var dictionary = new Vocabulary(); while (reader.Read()) { if (reader.TokenType == JsonTokenType.EndObject) @@ -165,7 +166,7 @@ internal class StringSpanOrdinalKeyCustomConverter : JsonConverter value, JsonSerializerOptions options) => throw new NotImplementedException(); + public override void Write(Utf8JsonWriter writer, Vocabulary value, JsonSerializerOptions options) => throw new NotImplementedException(); } ///