Skip to content

Commit 690daed

Browse files
authored
Convert PathConditionVertices from GameState to tensor. (#95)
1 parent 010e5fe commit 690daed

File tree

1 file changed

+68
-5
lines changed

1 file changed

+68
-5
lines changed

VSharp.Explorer/AISearcher.fs

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
213213
aiAgentTrainingModelOptions: Option<AIAgentTrainingModelOptions>
214214
) =
215215
let numOfVertexAttributes = 7
216-
let numOfStateAttributes = 7
216+
let numOfStateAttributes = 6
217+
let numOfPathConditionVertexAttributes = 49
217218
let numOfHistoryEdgeAttributes = 2
218219

219220

@@ -248,10 +249,33 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
248249
let gameState = currentGameState.Value
249250
let stateIds = Dictionary<uint<stateId>, int>()
250251
let verticesIds = Dictionary<uint<basicBlockGlobalId>, int>()
252+
let pathConditionVerticesIds = Dictionary<uint<pathConditionVertexId>, int>()
251253

252254
let networkInput =
253255
let res = Dictionary<_, _>()
254256

257+
let pathConditionVertices, numOfPcToPcEdges =
258+
let mutable numOfPcToPcEdges = 0
259+
260+
let shape =
261+
[| int64 gameState.PathConditionVertices.Length
262+
numOfPathConditionVertexAttributes |]
263+
264+
let attributes =
265+
Array.zeroCreate (
266+
gameState.PathConditionVertices.Length * numOfPathConditionVertexAttributes
267+
)
268+
269+
for i in 0 .. gameState.PathConditionVertices.Length - 1 do
270+
let v = gameState.PathConditionVertices.[i]
271+
numOfPcToPcEdges <- numOfPcToPcEdges + v.Children.Length * 2
272+
pathConditionVerticesIds.Add(v.Id, i)
273+
let j = i * numOfPathConditionVertexAttributes
274+
attributes.[j + int v.Type] <- float32 1u
275+
276+
OrtValue.CreateTensorValueFromMemory(attributes, shape), numOfPcToPcEdges
277+
278+
255279
let gameVertices =
256280
let shape = [| int64 gameState.GraphVertices.Length; numOfVertexAttributes |]
257281

@@ -285,8 +309,6 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
285309
stateIds.Add(v.Id, i)
286310
let j = i * numOfStateAttributes
287311
attributes.[j] <- float32 v.Position
288-
// TODO: Support path condition
289-
// attributes.[j + 1] <- float32 v.PathConditionSize
290312
attributes.[j + 2] <- float32 v.VisitedAgainVertices
291313
attributes.[j + 3] <- float32 v.VisitedNotCoveredVerticesInZone
292314
attributes.[j + 4] <- float32 v.VisitedNotCoveredVerticesOutOfZone
@@ -295,6 +317,31 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
295317

296318
OrtValue.CreateTensorValueFromMemory(attributes, shape), numOfParentOfEdges, numOfHistoryEdges
297319

320+
let pcToPcEdgeIndex =
321+
let shapeOfIndex = [| 2L; numOfPcToPcEdges |]
322+
let index = Array.zeroCreate (2 * numOfPcToPcEdges)
323+
let mutable firstFreePositionOfIndex = 0
324+
325+
for v in gameState.PathConditionVertices do
326+
for child in v.Children do
327+
// from vertex to child
328+
index.[firstFreePositionOfIndex] <- pathConditionVerticesIds.[v.Id]
329+
330+
index.[firstFreePositionOfIndex + 2 * numOfPcToPcEdges] <-
331+
pathConditionVerticesIds.[child]
332+
333+
firstFreePositionOfIndex <- firstFreePositionOfIndex + 1
334+
// from child to vertex
335+
index.[firstFreePositionOfIndex] <- pathConditionVerticesIds.[child]
336+
337+
index.[firstFreePositionOfIndex + 2 * numOfPcToPcEdges] <-
338+
pathConditionVerticesIds.[v.Id]
339+
340+
firstFreePositionOfIndex <- firstFreePositionOfIndex + 1
341+
342+
OrtValue.CreateTensorValueFromMemory(index, shapeOfIndex)
343+
344+
298345
let vertexToVertexEdgesIndex, vertexToVertexEdgesAttributes =
299346
let shapeOfIndex = [| 2L; gameState.Map.Length |]
300347
let shapeOfAttributes = [| int64 gameState.Map.Length |]
@@ -310,11 +357,13 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
310357
OrtValue.CreateTensorValueFromMemory(index, shapeOfIndex),
311358
OrtValue.CreateTensorValueFromMemory(attributes, shapeOfAttributes)
312359

313-
let historyEdgesIndex_vertexToState, historyEdgesAttributes, parentOfEdges =
360+
let historyEdgesIndex_vertexToState, historyEdgesAttributes, parentOfEdges, edgeIndex_pcToState =
314361
let shapeOfParentOf = [| 2L; numOfParentOfEdges |]
315362
let parentOf = Array.zeroCreate (2 * numOfParentOfEdges)
316363
let shapeOfHistory = [| 2L; numOfHistoryEdges |]
317364
let historyIndex_vertexToState = Array.zeroCreate (2 * numOfHistoryEdges)
365+
let shapeOfPcToState = [| 2L; gameState.States.Length |]
366+
let index_pcToState = Array.zeroCreate (2 * gameState.States.Length)
318367

319368
let shapeOfHistoryAttributes =
320369
[| int64 numOfHistoryEdges; int64 numOfHistoryEdgeAttributes |]
@@ -323,6 +372,7 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
323372
let mutable firstFreePositionInParentsOf = 0
324373
let mutable firstFreePositionInHistoryIndex = 0
325374
let mutable firstFreePositionInHistoryAttributes = 0
375+
let mutable firstFreePositionInPcToState = 0
326376

327377
gameState.States
328378
|> Array.iter (fun state ->
@@ -334,6 +384,14 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
334384

335385
firstFreePositionInParentsOf <- firstFreePositionInParentsOf + state.Children.Length
336386

387+
index_pcToState.[firstFreePositionInPcToState] <-
388+
int64 pathConditionVerticesIds[state.PathCondition.Id]
389+
390+
index_pcToState.[firstFreePositionInPcToState + gameState.States.Length] <-
391+
int64 stateIds[state.Id]
392+
393+
firstFreePositionInPcToState <- firstFreePositionInPcToState + 1
394+
337395
state.History
338396
|> Array.iteri (fun i historyElem ->
339397
let j = firstFreePositionInHistoryIndex + i
@@ -352,7 +410,8 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
352410

353411
OrtValue.CreateTensorValueFromMemory(historyIndex_vertexToState, shapeOfHistory),
354412
OrtValue.CreateTensorValueFromMemory(historyAttributes, shapeOfHistoryAttributes),
355-
OrtValue.CreateTensorValueFromMemory(parentOf, shapeOfParentOf)
413+
OrtValue.CreateTensorValueFromMemory(parentOf, shapeOfParentOf),
414+
OrtValue.CreateTensorValueFromMemory(index_pcToState, shapeOfPcToState)
356415

357416
let statePosition_stateToVertex, statePosition_vertexToState =
358417
let data_stateToVertex = Array.zeroCreate (2 * gameState.States.Length)
@@ -380,6 +439,7 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
380439

381440
res.Add("game_vertex", gameVertices)
382441
res.Add("state_vertex", states)
442+
res.Add("path_condition_vertex", pathConditionVertices)
383443

384444
res.Add("gamevertex_to_gamevertex_index", vertexToVertexEdgesIndex)
385445
res.Add("gamevertex_to_gamevertex_type", vertexToVertexEdgesAttributes)
@@ -390,6 +450,9 @@ type internal AISearcher(oracle: Oracle, aiAgentTrainingMode: Option<AIAgentTrai
390450
res.Add("gamevertex_in_statevertex", statePosition_vertexToState)
391451
res.Add("statevertex_parentof_statevertex", parentOfEdges)
392452

453+
res.Add("pathconditionvertex_to_pathconditionvertex", pcToPcEdgeIndex)
454+
res.Add("pathconditionvertex_to_statevertex", edgeIndex_pcToState)
455+
393456
res
394457

395458
let output = session.Run(runOptions, networkInput, session.OutputNames)

0 commit comments

Comments
 (0)