Skip to content

Commit 4c32479

Browse files
committed
@traceInduct clustering implementation
1 parent 9574673 commit 4c32479

File tree

5 files changed

+307
-9
lines changed

5 files changed

+307
-9
lines changed

core/src/main/scala/stainless/Component.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,20 @@ object optFunctions extends inox.OptionDef[Seq[String]] {
3131
val usageRhs = "f1,f2,..."
3232
}
3333

34+
object optCompareFuns extends inox.OptionDef[Seq[String]] {
35+
val name = "comparefuns"
36+
val default = Seq[String]()
37+
val parser = inox.OptionParsers.seqParser(inox.OptionParsers.stringParser)
38+
val usageRhs = "f1,f2,..."
39+
}
40+
41+
object optModels extends inox.OptionDef[Seq[String]] {
42+
val name = "models"
43+
val default = Seq[String]()
44+
val parser = inox.OptionParsers.seqParser(inox.OptionParsers.stringParser)
45+
val usageRhs = "f1,f2,..."
46+
}
47+
3448
trait ComponentRun { self =>
3549
val component: Component
3650
val trees: ast.Trees

core/src/main/scala/stainless/MainHelpers.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ trait MainHelpers extends inox.MainHelpers { self =>
2525
optVersion -> Description(General, "Display the version number"),
2626
optConfigFile -> Description(General, "Path to configuration file, set to false to disable (default: stainless.conf or .stainless.conf)"),
2727
optFunctions -> Description(General, "Only consider functions f1,f2,..."),
28+
optModels -> Description(General, "Consider functions f1, f2, ... as model functions for @traceInduct"),
29+
optCompareFuns -> Description(General, "Only consider @traceInduct functions f1,f2,..."),
2830
extraction.utils.optDebugObjects -> Description(General, "Only print debug output for functions/adts named o1,o2,..."),
2931
extraction.utils.optDebugPhases -> Description(General, {
3032
"Only print debug output for phases p1,p2,...\nAvailable: " +

core/src/main/scala/stainless/Report.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,22 @@ trait AbstractReport[SelfType <: AbstractReport[SelfType]] { self: SelfType =>
100100
case Level.Error => Console.RED
101101
}
102102

103+
def isError(identifier: Identifier)(implicit ctx: inox.Context): Boolean = {
104+
val res = for {
105+
RecordRow(id, pos, level, extra, time) <- annotatedRows
106+
if(level == Level.Error && id == identifier)
107+
}yield (id, level)
108+
!res.isEmpty
109+
}
110+
111+
def isUnknown(identifier: Identifier)(implicit ctx: inox.Context): Boolean = {
112+
val res = for {
113+
RecordRow(id, pos, level, extra, time) <- annotatedRows
114+
if(level == Level.Warning && id == identifier)
115+
}yield (id, level)
116+
!res.isEmpty
117+
}
118+
103119
// Emit the report table, with all VCs when full is true, otherwise only with unknown/invalid VCs.
104120
private def emitTable(full: Boolean)(implicit ctx: inox.Context): Table = {
105121
val rows = processRows(full)

core/src/main/scala/stainless/extraction/trace/Trace.scala

Lines changed: 251 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,87 @@ trait Trace extends CachingPhase with SimpleFunctions with IdentitySorts { self
2424
override val t: self.t.type = self.t
2525
}
2626

27+
override protected def extractSymbols(context: TransformerContext, symbols: s.Symbols): t.Symbols = {
28+
import symbols._
29+
import trees._
30+
31+
val models = symbols.functions.values.toList.filter(elem => isModel(elem.id)).map(elem => elem.id)
32+
val functions = symbols.functions.values.toList.filter(elem => shouldBeChecked(elem.id)).map(elem => elem.id)
33+
34+
if (Trace.getModels.isEmpty) {
35+
Trace.setModels(models)
36+
Trace.nextModel
37+
}
38+
if (Trace.getFunctions.isEmpty) {
39+
Trace.setFunctions(functions)
40+
Trace.nextFunction
41+
}
42+
43+
var localCounter = 0
44+
45+
def freshId(a: Identifier, b: Identifier): Identifier = {
46+
localCounter = localCounter + 1
47+
new Identifier(fixedFullName(a)+"$"+fixedFullName(b),localCounter,localCounter)
48+
}
49+
50+
//if (fd1 != fd2) && (fd1.params.size == fd2.params.size)
51+
def checkPair(fd1: s.FunDef, fd2: s.FunDef): s.FunDef = {
52+
53+
val newParams = fd1.params.map{param => param.freshen}
54+
val newParamVars = newParams.map{param => param.toVariable}
55+
val newParamTypes = fd1.tparams.map{tparam => tparam.freshen}
56+
val newParamTps = newParamTypes.map{tparam => tparam.tp}
57+
58+
val vd = s.ValDef.fresh("holds", s.BooleanType())
59+
val post = s.Lambda(Seq(vd), vd.toVariable)
60+
61+
val body = s.Ensuring(s.Equals(s.FunctionInvocation(fd1.id, newParamTps, newParamVars), s.FunctionInvocation(fd2.id, newParamTps, newParamVars)), post)
62+
val flags: Seq[s.Flag] = Seq(s.Derived(fd1.id), s.Annotation("traceInduct",List(StringLiteral(fd1.id.name))))
63+
64+
new s.FunDef(freshId(fd1.id, fd2.id), newParamTypes, newParams, s.BooleanType(), body, flags)
65+
}
66+
67+
def newFuns: List[s.FunDef] = (Trace.getModel, Trace.getFunction) match {
68+
case (Some(model), Some(function)) => {
69+
val m = symbols.functions.filter(elem => elem._2.id == model).head._2
70+
val f = symbols.functions.filter(elem => elem._2.id == function).head._2
71+
val newFun = checkPair(m, f)
72+
Trace.setTrace(newFun.id)
73+
List(newFun)
74+
}
75+
case _ => Nil
76+
}
77+
78+
/*
79+
def newFuns: List[s.FunDef] = Trace.nextModel match {
80+
case Some(model) => toCheck.map(f => checkPair(
81+
82+
symbols.functions.filter(elem => elem._2.id == model).head._2
83+
84+
, f))
85+
case None => check(toCheck, toCheck, Nil)
86+
}
87+
88+
def check(funs1: List[s.FunDef], funs2: List[s.FunDef], acc: List[s.FunDef]): List[s.FunDef] = {
89+
funs1 match {
90+
case Nil => acc
91+
case fd1::xs1 => {
92+
funs2 match {
93+
case Nil => check(xs1, toCheck, acc)
94+
//todo: check if both funs have same arg list
95+
case fd2::xs2 if (fd1 != fd2) && (fd1.params.size == fd2.params.size) =>
96+
check(funs1, xs2, checkPair(fd1, fd2)::acc)
97+
case _ => check(funs1, funs2.tail, acc)
98+
}
99+
}
100+
}
101+
}
102+
*/
103+
val extracted = super.extractSymbols(context, symbols)
104+
//newFuns(toCheck, toCheck, Nil).map(f => extractFunction(symbols, f))
105+
registerFunctions(extracted, newFuns.map(f => extractFunction(symbols, f)))
106+
}
107+
27108
override protected def extractFunction(symbols: Symbols, fd: FunDef): t.FunDef = {
28109
import symbols._
29110
var funInv: Option[FunctionInvocation] = None
@@ -33,13 +114,12 @@ trait Trace extends CachingPhase with SimpleFunctions with IdentitySorts { self
33114
case Annotation("traceInduct", fun) => {
34115
exprOps.preTraversal {
35116
case _ if funInv.isDefined => // do nothing
36-
case fi @ FunctionInvocation(tfd, _, args) if symbols.isRecursive(tfd) && (fun.contains(StringLiteral(tfd.name)) || fun.contains(StringLiteral("")))
37-
=> {
117+
case fi @ FunctionInvocation(tfd, _, args) if symbols.isRecursive(tfd) && (fun.contains(StringLiteral(tfd.name)) || fun.contains(StringLiteral(""))) => {
38118
val paramVars = fd.params.map(_.toVariable)
39119
val argCheck = args.forall(paramVars.contains) && args.toSet.size == args.size
40120
if (argCheck)
41121
funInv = Some(fi)
42-
}
122+
}
43123
case _ =>
44124
}(fd.fullBody)
45125
}
@@ -51,6 +131,7 @@ trait Trace extends CachingPhase with SimpleFunctions with IdentitySorts { self
51131
case Some(finv) => createTactFun(symbols, fd, finv)
52132
})
53133

134+
54135
identity.transform(result.copy(flags = result.flags filterNot (f => f == TraceInduct)))
55136
}
56137

@@ -105,8 +186,8 @@ trait Trace extends CachingPhase with SimpleFunctions with IdentitySorts { self
105186
val argsMap = callee.params.map(_.toVariable).zip(finv.args).toMap
106187
val tparamMap = callee.typeArgs.zip(finv.tfd.tps).toMap
107188
val inlinedBody = typeOps.instantiateType(exprOps.replaceFromSymbols(argsMap, callee.body.get), tparamMap)
108-
val inductScheme = inductPattern(inlinedBody)
109189

190+
val inductScheme = inductPattern(inlinedBody)
110191
val prevBody = function.fullBody match {
111192
case Ensuring(body, pred) => body
112193
case _ => function.fullBody
@@ -115,19 +196,95 @@ trait Trace extends CachingPhase with SimpleFunctions with IdentitySorts { self
115196
// body, pre and post for the tactFun
116197

117198
val body = andJoin(Seq(inductScheme, prevBody))
118-
val precondition = function.precondition
119-
val postcondition = function.postcondition
120-
199+
val precondition = exprOps.preconditionOf(function.fullBody) //function.precondition
200+
val postcondition = exprOps.postconditionOf(function.fullBody) //function.postcondition
121201
val bodyPre = exprOps.withPrecondition(body, precondition)
122202
val bodyPost = exprOps.withPostcondition(bodyPre,postcondition)
203+
function.copy(function.id, function.tparams, function.params, function.returnType, bodyPost, function.flags)
204+
205+
}
206+
207+
//from CheckFilter.scala
208+
type Path = Seq[String]
209+
private def fullNameToPath(fullName: String): Path = (fullName split '.').toSeq
210+
211+
// TODO this is probably done somewhere else in a cleaner fasion...
212+
private def fixedFullName(id: Identifier): String = id.fullName
213+
.replaceAllLiterally("$bar", "|")
214+
.replaceAllLiterally("$up", "^")
215+
.replaceAllLiterally("$eq", "=")
216+
.replaceAllLiterally("$plus", "+")
217+
.replaceAllLiterally("$minus", "-")
218+
.replaceAllLiterally("$times", "*")
219+
.replaceAllLiterally("$div", "/")
220+
.replaceAllLiterally("$less", "<")
221+
.replaceAllLiterally("$geater", ">")
222+
.replaceAllLiterally("$colon", ":")
223+
.replaceAllLiterally("$amp", "&")
224+
.replaceAllLiterally("$tilde", "~")
225+
226+
227+
private lazy val pathsOpt: Option[Seq[Path]] = context.options.findOption(optCompareFuns) map { functions =>
228+
functions map fullNameToPath
229+
}
230+
231+
private def shouldBeChecked(fid: Identifier): Boolean = pathsOpt match {
232+
case None => false
233+
234+
case Some(paths) =>
235+
// Support wildcard `_` as specified in the documentation.
236+
// A leading wildcard is always assumes.
237+
val path: Path = fullNameToPath(fixedFullName(fid))
238+
paths exists { p =>
239+
if (p endsWith Seq("_")) path containsSlice p.init
240+
else path endsWith p
241+
}
242+
}
123243

124-
function.copy(function.id, function.tparams, function.params, function.returnType, bodyPost, function.flags)
244+
private lazy val pathsOptModels: Option[Seq[Path]] = context.options.findOption(optModels) map { functions =>
245+
functions map fullNameToPath
246+
}
247+
248+
private def isModel(fid: Identifier): Boolean = pathsOptModels match {
249+
case None => false
250+
251+
case Some(paths) =>
252+
// Support wildcard `_` as specified in the documentation.
253+
// A leading wildcard is always assumes.
254+
val path: Path = fullNameToPath(fixedFullName(fid))
255+
paths exists { p =>
256+
if (p endsWith Seq("_")) path containsSlice p.init
257+
else path endsWith p
258+
}
125259
}
126260

127261
}
128262

129263

130264
object Trace {
265+
var boxes: Map[Identifier, List[Identifier]] = Map()
266+
var errors: List[Identifier] = List()
267+
var unknowns: List[Identifier] = List()
268+
269+
def printEverything() = {
270+
System.out.println("boxes")
271+
System.out.println(boxes)
272+
System.out.println("errors")
273+
System.out.println(errors)
274+
System.out.println("unknowns")
275+
System.out.println(unknowns)
276+
}
277+
278+
var allModels: List[Identifier] = List()
279+
var tmpModels: List[Identifier] = List()
280+
281+
var allFunctions: List[Identifier] = List()
282+
var tmpFunctions: List[Identifier] = List()
283+
284+
var model: Option[Identifier] = None
285+
var function: Option[Identifier] = None
286+
var trace: Option[Identifier] = None
287+
131288
def apply(ts: Trees, tt: termination.Trees)(implicit ctx: inox.Context): ExtractionPipeline {
132289
val s: ts.type
133290
val t: tt.type
@@ -136,4 +293,90 @@ object Trace {
136293
override val t: tt.type = tt
137294
override val context = ctx
138295
}
296+
297+
def setModels(m: List[Identifier]) = {
298+
allModels = m
299+
tmpModels = m
300+
boxes = (m zip m.map(_ => Nil)).toMap
301+
}
302+
303+
def setFunctions(f: List[Identifier]) = {
304+
allFunctions = f
305+
tmpFunctions = f
306+
}
307+
308+
def getModels = allModels
309+
310+
def getFunctions = allFunctions
311+
312+
313+
//model for the current iteration
314+
def getModel = model
315+
316+
//function to check in the current iteration
317+
def getFunction = function
318+
319+
def setTrace(t: Identifier) = trace = Some(t)
320+
def getTrace = trace
321+
322+
//iterate model for the current function
323+
def nextModel = (tmpModels, allModels) match {
324+
case (x::xs, _) => { // check the next model for the current function
325+
tmpModels = xs
326+
model = Some(x)
327+
}
328+
case (Nil, x::xs) => {
329+
tmpModels = allModels
330+
model = Some(x)
331+
tmpModels = xs
332+
function = tmpFunctions match {
333+
case x::xs => {
334+
tmpFunctions = xs
335+
Some(x)
336+
}
337+
case Nil => None
338+
}
339+
}
340+
case _ => model = None
341+
}
342+
343+
//iterate function to check; reset model
344+
def nextFunction = tmpFunctions match {
345+
case x::xs => {
346+
tmpFunctions = xs
347+
function = Some(x)
348+
tmpModels = allModels
349+
tmpModels match {
350+
case Nil => model = None
351+
case x::xs => {
352+
model = Some(x)
353+
tmpModels = xs
354+
}
355+
}
356+
function
357+
}
358+
case Nil => {
359+
function = None
360+
}
361+
}
362+
363+
def isDone = function == None
364+
365+
def reportError = {
366+
errors = function.get::errors
367+
nextFunction
368+
}
369+
370+
def reportUnknown = {
371+
nextModel
372+
if(model == None){
373+
unknowns = function.get::unknowns
374+
nextFunction
375+
}
376+
}
377+
378+
def reportValid = {
379+
boxes = boxes + (model.get -> (function.get::boxes(model.get)))
380+
nextFunction
381+
}
139382
}

0 commit comments

Comments
 (0)