Skip to content

Commit 4e90c51

Browse files
authored
Support of new training process (#89)
* Support of new training process
1 parent f7129f6 commit 4e90c51

File tree

8 files changed

+596
-350
lines changed

8 files changed

+596
-350
lines changed

VSharp.API/VSharp.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,10 @@ private static Statistics StartExploration(
191191
stopOnCoverageAchieved: 100,
192192
randomSeed: options.RandomSeed,
193193
stepsLimit: options.StepsLimit,
194-
aiAgentTrainingOptions: options.AIAgentTrainingOptions == null ? FSharpOption<AIAgentTrainingOptions>.None : FSharpOption<AIAgentTrainingOptions>.Some(options.AIAgentTrainingOptions),
194+
aiOptions: options.AIOptions == null ? FSharpOption<AIOptions>.None : FSharpOption<AIOptions>.Some(options.AIOptions),
195195
pathToModel: options.PathToModel == null ? FSharpOption<string>.None : FSharpOption<string>.Some(options.PathToModel),
196-
useGPU: options.UseGPU == null ? FSharpOption<bool>.None : FSharpOption<bool>.Some(options.UseGPU),
197-
optimize: options.Optimize == null ? FSharpOption<bool>.None : FSharpOption<bool>.Some(options.Optimize)
196+
useGPU: options.UseGPU,
197+
optimize: options.Optimize
198198
);
199199

200200
var fuzzerOptions =

VSharp.API/VSharpOptions.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ public readonly record struct VSharpOptions
113113
public readonly bool ReleaseBranches = DefaultReleaseBranches;
114114
public readonly int RandomSeed = DefaultRandomSeed;
115115
public readonly uint StepsLimit = DefaultStepsLimit;
116-
public readonly AIAgentTrainingOptions AIAgentTrainingOptions = null;
116+
public readonly AIOptions? AIOptions = null;
117117
public readonly string PathToModel = DefaultPathToModel;
118118
public readonly bool UseGPU = false;
119119
public readonly bool Optimize = false;
@@ -133,7 +133,7 @@ public readonly record struct VSharpOptions
133133
/// <param name="releaseBranches">If true and timeout is specified, a part of allotted time in the end is given to execute remaining states without branching.</param>
134134
/// <param name="randomSeed">Fixed seed for random operations. Used if greater than or equal to zero.</param>
135135
/// <param name="stepsLimit">Number of symbolic machine steps to stop execution after. Zero value means no limit.</param>
136-
/// <param name="aiAgentTrainingOptions">Settings for AI searcher training.</param>
136+
/// <param name="aiOptions">Settings for AI searcher training.</param>
137137
/// <param name="pathToModel">Path to ONNX file with model to use in AI searcher.</param>
138138
/// <param name="useGPU">Specifies whether the ONNX execution session should use a CUDA-enabled GPU.</param>
139139
/// <param name="optimize">Enabling options like parallel execution and various graph transformations to enhance performance of ONNX.</param>
@@ -150,7 +150,7 @@ public VSharpOptions(
150150
bool releaseBranches = DefaultReleaseBranches,
151151
int randomSeed = DefaultRandomSeed,
152152
uint stepsLimit = DefaultStepsLimit,
153-
AIAgentTrainingOptions aiAgentTrainingOptions = null,
153+
AIOptions? aiOptions = null,
154154
string pathToModel = DefaultPathToModel,
155155
bool useGPU = false,
156156
bool optimize = false)
@@ -167,7 +167,7 @@ public VSharpOptions(
167167
ReleaseBranches = releaseBranches;
168168
RandomSeed = randomSeed;
169169
StepsLimit = stepsLimit;
170-
AIAgentTrainingOptions = aiAgentTrainingOptions;
170+
AIOptions = aiOptions;
171171
PathToModel = pathToModel;
172172
UseGPU = useGPU;
173173
Optimize = optimize;

VSharp.Explorer/AISearcher.fs

Lines changed: 135 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,31 @@ namespace VSharp.Explorer
22

33
open System.Collections.Generic
44
open Microsoft.ML.OnnxRuntime
5+
open System
6+
open System.Text
7+
open System.Text.Json
58
open VSharp
69
open VSharp.IL.Serializer
710
open VSharp.ML.GameServer.Messages
811

9-
type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentTrainingOptions>) =
12+
type AIMode =
13+
| Runner
14+
| TrainingSendModel
15+
| TrainingSendEachStep
16+
17+
18+
type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrainingMode>) =
1019
let stepsToSwitchToAI =
11-
match aiAgentTrainingOptions with
20+
match aiAgentTrainingMode with
1221
| None -> 0u<step>
13-
| Some options -> options.stepsToSwitchToAI
22+
| Some(SendModel options) -> options.aiAgentTrainingOptions.stepsToSwitchToAI
23+
| Some(SendEachStep options) -> options.aiAgentTrainingOptions.stepsToSwitchToAI
1424

1525
let stepsToPlay =
16-
match aiAgentTrainingOptions with
26+
match aiAgentTrainingMode with
1727
| None -> 0u<step>
18-
| Some options -> options.stepsToPlay
28+
| Some(SendModel options) -> options.aiAgentTrainingOptions.stepsToPlay
29+
| Some(SendEachStep options) -> options.aiAgentTrainingOptions.stepsToPlay
1930

2031
let mutable lastCollectedStatistics = Statistics()
2132
let mutable defaultSearcherSteps = 0u<step>
@@ -25,14 +36,17 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
2536
let mutable incorrectPredictedStateId = false
2637

2738
let defaultSearcher =
28-
match aiAgentTrainingOptions with
29-
| None -> BFSSearcher() :> IForwardSearcher
30-
| Some options ->
31-
match options.defaultSearchStrategy with
39+
let pickSearcher =
40+
function
3241
| BFSMode -> BFSSearcher() :> IForwardSearcher
3342
| DFSMode -> DFSSearcher() :> IForwardSearcher
3443
| x -> failwithf $"Unexpected default searcher {x}. DFS and BFS supported for now."
3544

45+
match aiAgentTrainingMode with
46+
| None -> BFSSearcher() :> IForwardSearcher
47+
| Some(SendModel options) -> pickSearcher options.aiAgentTrainingOptions.aiBaseOptions.defaultSearchStrategy
48+
| Some(SendEachStep options) -> pickSearcher options.aiAgentTrainingOptions.aiBaseOptions.defaultSearchStrategy
49+
3650
let mutable stepsPlayed = 0u<step>
3751

3852
let isInAIMode () =
@@ -41,59 +55,6 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
4155
let q = ResizeArray<_>()
4256
let availableStates = HashSet<_>()
4357

44-
let updateGameState (delta: GameState) =
45-
match gameState with
46-
| None -> gameState <- Some delta
47-
| Some s ->
48-
let updatedBasicBlocks = delta.GraphVertices |> Array.map (fun b -> b.Id) |> HashSet
49-
let updatedStates = delta.States |> Array.map (fun s -> s.Id) |> HashSet
50-
51-
let vertices =
52-
s.GraphVertices
53-
|> Array.filter (fun v -> updatedBasicBlocks.Contains v.Id |> not)
54-
|> ResizeArray<_>
55-
56-
vertices.AddRange delta.GraphVertices
57-
58-
let edges =
59-
s.Map
60-
|> Array.filter (fun e -> updatedBasicBlocks.Contains e.VertexFrom |> not)
61-
|> ResizeArray<_>
62-
63-
edges.AddRange delta.Map
64-
let activeStates = vertices |> Seq.collect (fun v -> v.States) |> HashSet
65-
66-
let states =
67-
let part1 =
68-
s.States
69-
|> Array.filter (fun s -> activeStates.Contains s.Id && (not <| updatedStates.Contains s.Id))
70-
|> ResizeArray<_>
71-
72-
part1.AddRange delta.States
73-
74-
part1.ToArray()
75-
|> Array.map (fun s ->
76-
State(
77-
s.Id,
78-
s.Position,
79-
s.PathCondition,
80-
s.VisitedAgainVertices,
81-
s.VisitedNotCoveredVerticesInZone,
82-
s.VisitedNotCoveredVerticesOutOfZone,
83-
s.StepWhenMovedLastTime,
84-
s.InstructionsVisitedInCurrentBlock,
85-
s.History,
86-
s.Children |> Array.filter activeStates.Contains
87-
))
88-
89-
let pathConditionVertices =
90-
ResizeArray<PathConditionVertex> s.PathConditionVertices
91-
92-
pathConditionVertices.AddRange delta.PathConditionVertices
93-
94-
gameState <-
95-
Some
96-
<| GameState(vertices.ToArray(), states, pathConditionVertices.ToArray(), edges.ToArray())
9758

9859

9960
let init states =
@@ -128,15 +89,18 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
12889
for bb in state._history do
12990
bb.Key.AssociatedStates.Remove state |> ignore
13091

131-
let inTrainMode = aiAgentTrainingOptions.IsSome
132-
92+
let aiMode =
93+
match aiAgentTrainingMode with
94+
| Some(SendEachStep _) -> TrainingSendEachStep
95+
| Some(SendModel _) -> TrainingSendModel
96+
| None -> Runner
13397
let pick selector =
13498
if useDefaultSearcher then
13599
defaultSearcherSteps <- defaultSearcherSteps + 1u<step>
136100

137101
if Seq.length availableStates > 0 then
138102
let gameStateDelta = collectGameStateDelta ()
139-
updateGameState gameStateDelta
103+
gameState <- AISearcher.updateGameState gameStateDelta gameState
140104
let statistics = computeStatistics gameState.Value
141105
Application.applicationGraphDelta.Clear()
142106
lastCollectedStatistics <- statistics
@@ -149,7 +113,7 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
149113
Some(Seq.head availableStates)
150114
else
151115
let gameStateDelta = collectGameStateDelta ()
152-
updateGameState gameStateDelta
116+
gameState <- AISearcher.updateGameState gameStateDelta gameState
153117
let statistics = computeStatistics gameState.Value
154118

155119
if isInAIMode () then
@@ -158,14 +122,18 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
158122

159123
Application.applicationGraphDelta.Clear()
160124

161-
if inTrainMode && stepsToPlay = stepsPlayed then
125+
if stepsToPlay = stepsPlayed then
162126
None
163127
else
164128
let toPredict =
165-
if inTrainMode && stepsPlayed > 0u<step> then
166-
gameStateDelta
167-
else
168-
gameState.Value
129+
match aiMode with
130+
| TrainingSendEachStep
131+
| TrainingSendModel ->
132+
if stepsPlayed > 0u<step> then
133+
gameStateDelta
134+
else
135+
gameState.Value
136+
| Runner -> gameState.Value
169137

170138
let stateId = oracle.Predict toPredict
171139
afterFirstAIPeek <- true
@@ -179,13 +147,77 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
179147
incorrectPredictedStateId <- true
180148
oracle.Feedback(Feedback.IncorrectPredictedStateId stateId)
181149
None
150+
static member updateGameState (delta: GameState) (gameState: Option<GameState>) =
151+
match gameState with
152+
| None -> Some delta
153+
| Some s ->
154+
let updatedBasicBlocks = delta.GraphVertices |> Array.map (fun b -> b.Id) |> HashSet
155+
let updatedStates = delta.States |> Array.map (fun s -> s.Id) |> HashSet
156+
157+
let vertices =
158+
s.GraphVertices
159+
|> Array.filter (fun v -> updatedBasicBlocks.Contains v.Id |> not)
160+
|> ResizeArray<_>
161+
162+
vertices.AddRange delta.GraphVertices
163+
164+
let edges =
165+
s.Map
166+
|> Array.filter (fun e -> updatedBasicBlocks.Contains e.VertexFrom |> not)
167+
|> ResizeArray<_>
168+
169+
edges.AddRange delta.Map
170+
let activeStates = vertices |> Seq.collect (fun v -> v.States) |> HashSet
171+
172+
let states =
173+
let part1 =
174+
s.States
175+
|> Array.filter (fun s -> activeStates.Contains s.Id && (not <| updatedStates.Contains s.Id))
176+
|> ResizeArray<_>
177+
178+
part1.AddRange delta.States
179+
180+
part1.ToArray()
181+
|> Array.map (fun s ->
182+
State(
183+
s.Id,
184+
s.Position,
185+
s.PathCondition,
186+
s.VisitedAgainVertices,
187+
s.VisitedNotCoveredVerticesInZone,
188+
s.VisitedNotCoveredVerticesOutOfZone,
189+
s.StepWhenMovedLastTime,
190+
s.InstructionsVisitedInCurrentBlock,
191+
s.History,
192+
s.Children |> Array.filter activeStates.Contains
193+
))
194+
195+
let pathConditionVertices = ResizeArray<PathConditionVertex> s.PathConditionVertices
182196

183-
new(pathToONNX: string, useGPU: bool, optimize: bool) =
197+
pathConditionVertices.AddRange delta.PathConditionVertices
198+
199+
Some
200+
<| GameState(vertices.ToArray(), states, pathConditionVertices.ToArray(), edges.ToArray())
201+
202+
static member convertOutputToJson (output: IDisposableReadOnlyCollection<OrtValue>) =
203+
seq { 0 .. output.Count - 1 }
204+
|> Seq.map (fun i -> output[i].GetTensorDataAsSpan<float32>().ToArray())
205+
206+
207+
208+
new
209+
(
210+
pathToONNX: string,
211+
useGPU: bool,
212+
optimize: bool,
213+
aiAgentTrainingModelOptions: Option<AIAgentTrainingModelOptions>
214+
) =
184215
let numOfVertexAttributes = 7
185216
let numOfStateAttributes = 7
186217
let numOfHistoryEdgeAttributes = 2
187218

188-
let createOracle (pathToONNX: string) =
219+
220+
let createOracleRunner (pathToONNX: string, aiAgentTrainingModelOptions: Option<AIAgentTrainingModelOptions>) =
189221
let sessionOptions =
190222
if useGPU then
191223
SessionOptions.MakeSessionOptionWithCudaProvider(0)
@@ -199,10 +231,21 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
199231
sessionOptions.GraphOptimizationLevel <- GraphOptimizationLevel.ORT_ENABLE_BASIC
200232

201233
let session = new InferenceSession(pathToONNX, sessionOptions)
234+
202235
let runOptions = new RunOptions()
203236
let feedback (x: Feedback) = ()
204237

205-
let predict (gameState: GameState) =
238+
let mutable stepsPlayed = 0
239+
let mutable currentGameState = None
240+
241+
let predict (gameStateOrDelta: GameState) =
242+
let _ =
243+
match aiAgentTrainingModelOptions with
244+
| Some _ when not (stepsPlayed = 0) ->
245+
currentGameState <- AISearcher.updateGameState gameStateOrDelta currentGameState
246+
| _ -> currentGameState <- Some gameStateOrDelta
247+
248+
let gameState = currentGameState.Value
206249
let stateIds = Dictionary<uint<stateId>, int>()
207250
let verticesIds = Dictionary<uint<basicBlockGlobalId>, int>()
208251

@@ -243,7 +286,7 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
243286
let j = i * numOfStateAttributes
244287
attributes.[j] <- float32 v.Position
245288
// TODO: Support path condition
246-
// attributes.[j + 1] <- float32 v.PathConditionSize
289+
// attributes.[j + 1] <- float32 v.PathConditionSize
247290
attributes.[j + 2] <- float32 v.VisitedAgainVertices
248291
attributes.[j + 3] <- float32 v.VisitedNotCoveredVerticesInZone
249292
attributes.[j + 4] <- float32 v.VisitedNotCoveredVerticesOutOfZone
@@ -350,14 +393,30 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingOptions: Option<AIAgentT
350393
res
351394

352395
let output = session.Run(runOptions, networkInput, session.OutputNames)
396+
397+
let _ =
398+
match aiAgentTrainingModelOptions with
399+
| Some aiAgentOptions ->
400+
aiAgentOptions.stepSaver (
401+
AIGameStep(gameState = gameStateOrDelta, output = AISearcher.convertOutputToJson output)
402+
)
403+
| None -> ()
404+
405+
stepsPlayed <- stepsPlayed + 1
406+
353407
let weighedStates = output[0].GetTensorDataAsSpan<float32>().ToArray()
354408

355409
let id = weighedStates |> Array.mapi (fun i v -> i, v) |> Array.maxBy snd |> fst
356410
stateIds |> Seq.find (fun kvp -> kvp.Value = id) |> (fun x -> x.Key)
357411

358412
Oracle(predict, feedback)
359413

360-
AISearcher(createOracle pathToONNX, None)
414+
let aiAgentTrainingOptions =
415+
match aiAgentTrainingModelOptions with
416+
| Some aiAgentTrainingModelOptions -> Some(SendModel aiAgentTrainingModelOptions)
417+
| None -> None
418+
419+
AISearcher(createOracleRunner (pathToONNX, aiAgentTrainingModelOptions), aiAgentTrainingOptions)
361420

362421
interface IForwardSearcher with
363422
override x.Init states = init states

0 commit comments

Comments
 (0)