@@ -2,20 +2,31 @@ namespace VSharp.Explorer
22
33open System.Collections .Generic
44open Microsoft.ML .OnnxRuntime
5+ open System
6+ open System.Text
7+ open System.Text .Json
58open VSharp
69open VSharp.IL .Serializer
710open VSharp.ML .GameServer .Messages
811
9- type internal AISearcher ( oracle : Oracle , aiAgentTrainingOptions : Option < AIAgentTrainingOptions >) =
12+ type AIMode =
13+ | Runner
14+ | TrainingSendModel
15+ | TrainingSendEachStep
16+
17+
18+ type internal AISearcher ( oracle : Oracle , aiAgentTrainingMode : Option < AIAgentTrainingMode >) =
1019 let stepsToSwitchToAI =
11- match aiAgentTrainingOptions with
20+ match aiAgentTrainingMode with
1221 | None -> 0 u< step>
13- | Some options -> options.stepsToSwitchToAI
22+ | Some( SendModel options) -> options.aiAgentTrainingOptions.stepsToSwitchToAI
23+ | Some( SendEachStep options) -> options.aiAgentTrainingOptions.stepsToSwitchToAI
1424
1525 let stepsToPlay =
16- match aiAgentTrainingOptions with
26+ match aiAgentTrainingMode with
1727 | None -> 0 u< step>
18- | Some options -> options.stepsToPlay
28+ | Some( SendModel options) -> options.aiAgentTrainingOptions.stepsToPlay
29+ | Some( SendEachStep options) -> options.aiAgentTrainingOptions.stepsToPlay
1930
2031 let mutable lastCollectedStatistics = Statistics()
2132 let mutable defaultSearcherSteps = 0 u< step>
@@ -25,14 +36,17 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
2536 let mutable incorrectPredictedStateId = false
2637
2738 let defaultSearcher =
28- match aiAgentTrainingOptions with
29- | None -> BFSSearcher() :> IForwardSearcher
30- | Some options ->
31- match options.defaultSearchStrategy with
39+ let pickSearcher =
40+ function
3241 | BFSMode -> BFSSearcher() :> IForwardSearcher
3342 | DFSMode -> DFSSearcher() :> IForwardSearcher
3443 | x -> failwithf $" Unexpected default searcher {x}. DFS and BFS supported for now."
3544
45+ match aiAgentTrainingMode with
46+ | None -> BFSSearcher() :> IForwardSearcher
47+ | Some( SendModel options) -> pickSearcher options.aiAgentTrainingOptions.aiBaseOptions.defaultSearchStrategy
48+ | Some( SendEachStep options) -> pickSearcher options.aiAgentTrainingOptions.aiBaseOptions.defaultSearchStrategy
49+
3650 let mutable stepsPlayed = 0 u< step>
3751
3852 let isInAIMode () =
@@ -41,59 +55,6 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
4155 let q = ResizeArray<_>()
4256 let availableStates = HashSet<_>()
4357
44- let updateGameState ( delta : GameState ) =
45- match gameState with
46- | None -> gameState <- Some delta
47- | Some s ->
48- let updatedBasicBlocks = delta.GraphVertices |> Array.map ( fun b -> b.Id) |> HashSet
49- let updatedStates = delta.States |> Array.map ( fun s -> s.Id) |> HashSet
50-
51- let vertices =
52- s.GraphVertices
53- |> Array.filter ( fun v -> updatedBasicBlocks.Contains v.Id |> not )
54- |> ResizeArray<_>
55-
56- vertices.AddRange delta.GraphVertices
57-
58- let edges =
59- s.Map
60- |> Array.filter ( fun e -> updatedBasicBlocks.Contains e.VertexFrom |> not )
61- |> ResizeArray<_>
62-
63- edges.AddRange delta.Map
64- let activeStates = vertices |> Seq.collect ( fun v -> v.States) |> HashSet
65-
66- let states =
67- let part1 =
68- s.States
69- |> Array.filter ( fun s -> activeStates.Contains s.Id && ( not <| updatedStates.Contains s.Id))
70- |> ResizeArray<_>
71-
72- part1.AddRange delta.States
73-
74- part1.ToArray()
75- |> Array.map ( fun s ->
76- State(
77- s.Id,
78- s.Position,
79- s.PathCondition,
80- s.VisitedAgainVertices,
81- s.VisitedNotCoveredVerticesInZone,
82- s.VisitedNotCoveredVerticesOutOfZone,
83- s.StepWhenMovedLastTime,
84- s.InstructionsVisitedInCurrentBlock,
85- s.History,
86- s.Children |> Array.filter activeStates.Contains
87- ))
88-
89- let pathConditionVertices =
90- ResizeArray< PathConditionVertex> s.PathConditionVertices
91-
92- pathConditionVertices.AddRange delta.PathConditionVertices
93-
94- gameState <-
95- Some
96- <| GameState( vertices.ToArray(), states, pathConditionVertices.ToArray(), edges.ToArray())
9758
9859
9960 let init states =
@@ -128,15 +89,18 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
12889 for bb in state._ history do
12990 bb.Key.AssociatedStates.Remove state |> ignore
13091
131- let inTrainMode = aiAgentTrainingOptions.IsSome
132-
92+ let aiMode =
93+ match aiAgentTrainingMode with
94+ | Some( SendEachStep _) -> TrainingSendEachStep
95+ | Some( SendModel _) -> TrainingSendModel
96+ | None -> Runner
13397 let pick selector =
13498 if useDefaultSearcher then
13599 defaultSearcherSteps <- defaultSearcherSteps + 1 u< step>
136100
137101 if Seq.length availableStates > 0 then
138102 let gameStateDelta = collectGameStateDelta ()
139- updateGameState gameStateDelta
103+ gameState <- AISearcher. updateGameState gameStateDelta gameState
140104 let statistics = computeStatistics gameState.Value
141105 Application.applicationGraphDelta.Clear()
142106 lastCollectedStatistics <- statistics
@@ -149,7 +113,7 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
149113 Some( Seq.head availableStates)
150114 else
151115 let gameStateDelta = collectGameStateDelta ()
152- updateGameState gameStateDelta
116+ gameState <- AISearcher. updateGameState gameStateDelta gameState
153117 let statistics = computeStatistics gameState.Value
154118
155119 if isInAIMode () then
@@ -158,14 +122,18 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
158122
159123 Application.applicationGraphDelta.Clear()
160124
161- if inTrainMode && stepsToPlay = stepsPlayed then
125+ if stepsToPlay = stepsPlayed then
162126 None
163127 else
164128 let toPredict =
165- if inTrainMode && stepsPlayed > 0 u< step> then
166- gameStateDelta
167- else
168- gameState.Value
129+ match aiMode with
130+ | TrainingSendEachStep
131+ | TrainingSendModel ->
132+ if stepsPlayed > 0 u< step> then
133+ gameStateDelta
134+ else
135+ gameState.Value
136+ | Runner -> gameState.Value
169137
170138 let stateId = oracle.Predict toPredict
171139 afterFirstAIPeek <- true
@@ -179,13 +147,77 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
179147 incorrectPredictedStateId <- true
180148 oracle.Feedback( Feedback.IncorrectPredictedStateId stateId)
181149 None
150+ static member updateGameState ( delta : GameState ) ( gameState : Option < GameState >) =
151+ match gameState with
152+ | None -> Some delta
153+ | Some s ->
154+ let updatedBasicBlocks = delta.GraphVertices |> Array.map ( fun b -> b.Id) |> HashSet
155+ let updatedStates = delta.States |> Array.map ( fun s -> s.Id) |> HashSet
156+
157+ let vertices =
158+ s.GraphVertices
159+ |> Array.filter ( fun v -> updatedBasicBlocks.Contains v.Id |> not )
160+ |> ResizeArray<_>
161+
162+ vertices.AddRange delta.GraphVertices
163+
164+ let edges =
165+ s.Map
166+ |> Array.filter ( fun e -> updatedBasicBlocks.Contains e.VertexFrom |> not )
167+ |> ResizeArray<_>
168+
169+ edges.AddRange delta.Map
170+ let activeStates = vertices |> Seq.collect ( fun v -> v.States) |> HashSet
171+
172+ let states =
173+ let part1 =
174+ s.States
175+ |> Array.filter ( fun s -> activeStates.Contains s.Id && ( not <| updatedStates.Contains s.Id))
176+ |> ResizeArray<_>
177+
178+ part1.AddRange delta.States
179+
180+ part1.ToArray()
181+ |> Array.map ( fun s ->
182+ State(
183+ s.Id,
184+ s.Position,
185+ s.PathCondition,
186+ s.VisitedAgainVertices,
187+ s.VisitedNotCoveredVerticesInZone,
188+ s.VisitedNotCoveredVerticesOutOfZone,
189+ s.StepWhenMovedLastTime,
190+ s.InstructionsVisitedInCurrentBlock,
191+ s.History,
192+ s.Children |> Array.filter activeStates.Contains
193+ ))
194+
195+ let pathConditionVertices = ResizeArray< PathConditionVertex> s.PathConditionVertices
182196
183- new ( pathToONNX: string, useGPU: bool, optimize: bool) =
197+ pathConditionVertices.AddRange delta.PathConditionVertices
198+
199+ Some
200+ <| GameState( vertices.ToArray(), states, pathConditionVertices.ToArray(), edges.ToArray())
201+
202+ static member convertOutputToJson ( output : IDisposableReadOnlyCollection < OrtValue >) =
203+ seq { 0 .. output.Count - 1 }
204+ |> Seq.map ( fun i -> output[ i]. GetTensorDataAsSpan< float32>() .ToArray())
205+
206+
207+
208+ new
209+ (
210+ pathToONNX: string,
211+ useGPU: bool,
212+ optimize: bool,
213+ aiAgentTrainingModelOptions: Option< AIAgentTrainingModelOptions>
214+ ) =
184215 let numOfVertexAttributes = 7
185216 let numOfStateAttributes = 7
186217 let numOfHistoryEdgeAttributes = 2
187218
188- let createOracle ( pathToONNX : string ) =
219+
220+ let createOracleRunner ( pathToONNX : string , aiAgentTrainingModelOptions : Option < AIAgentTrainingModelOptions >) =
189221 let sessionOptions =
190222 if useGPU then
191223 SessionOptions.MakeSessionOptionWithCudaProvider( 0 )
@@ -199,10 +231,21 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
199231 sessionOptions.GraphOptimizationLevel <- GraphOptimizationLevel.ORT_ ENABLE_ BASIC
200232
201233 let session = new InferenceSession( pathToONNX, sessionOptions)
234+
202235 let runOptions = new RunOptions()
203236 let feedback ( x : Feedback ) = ()
204237
205- let predict ( gameState : GameState ) =
238+ let mutable stepsPlayed = 0
239+ let mutable currentGameState = None
240+
241+ let predict ( gameStateOrDelta : GameState ) =
242+ let _ =
243+ match aiAgentTrainingModelOptions with
244+ | Some _ when not ( stepsPlayed = 0 ) ->
245+ currentGameState <- AISearcher.updateGameState gameStateOrDelta currentGameState
246+ | _ -> currentGameState <- Some gameStateOrDelta
247+
248+ let gameState = currentGameState.Value
206249 let stateIds = Dictionary< uint< stateId>, int>()
207250 let verticesIds = Dictionary< uint< basicBlockGlobalId>, int>()
208251
@@ -243,7 +286,7 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
243286 let j = i * numOfStateAttributes
244287 attributes.[ j] <- float32 v.Position
245288 // TODO: Support path condition
246- // attributes.[j + 1] <- float32 v.PathConditionSize
289+ // attributes.[j + 1] <- float32 v.PathConditionSize
247290 attributes.[ j + 2 ] <- float32 v.VisitedAgainVertices
248291 attributes.[ j + 3 ] <- float32 v.VisitedNotCoveredVerticesInZone
249292 attributes.[ j + 4 ] <- float32 v.VisitedNotCoveredVerticesOutOfZone
@@ -350,14 +393,30 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
350393 res
351394
352395 let output = session.Run( runOptions, networkInput, session.OutputNames)
396+
397+ let _ =
398+ match aiAgentTrainingModelOptions with
399+ | Some aiAgentOptions ->
400+ aiAgentOptions.stepSaver (
401+ AIGameStep( gameState = gameStateOrDelta, output = AISearcher.convertOutputToJson output)
402+ )
403+ | None -> ()
404+
405+ stepsPlayed <- stepsPlayed + 1
406+
353407 let weighedStates = output[ 0 ]. GetTensorDataAsSpan< float32>() .ToArray()
354408
355409 let id = weighedStates |> Array.mapi ( fun i v -> i, v) |> Array.maxBy snd |> fst
356410 stateIds |> Seq.find ( fun kvp -> kvp.Value = id) |> ( fun x -> x.Key)
357411
358412 Oracle( predict, feedback)
359413
360- AISearcher( createOracle pathToONNX, None)
414+ let aiAgentTrainingOptions =
415+ match aiAgentTrainingModelOptions with
416+ | Some aiAgentTrainingModelOptions -> Some( SendModel aiAgentTrainingModelOptions)
417+ | None -> None
418+
419+ AISearcher( createOracleRunner ( pathToONNX, aiAgentTrainingModelOptions), aiAgentTrainingOptions)
361420
362421 interface IForwardSearcher with
363422 override x.Init states = init states
0 commit comments