diff --git a/spec/go/expr_to_i_trailing_test.go b/spec/go/expr_to_i_trailing_test.go new file mode 100644 index 000000000..c5beefa2d --- /dev/null +++ b/spec/go/expr_to_i_trailing_test.go @@ -0,0 +1,51 @@ +// Autogenerated from KST: please remove this line if doing any edits by hand! + +package spec + +import ( + "runtime/debug" + "os" + "testing" + "github.com/kaitai-io/kaitai_struct_go_runtime/kaitai" + . "test_formats" + "github.com/stretchr/testify/assert" +) + +func TestExprToITrailing(t *testing.T) { + defer func() { + if r := recover(); r != nil { + debug.PrintStack() + t.Fatal("unexpected panic:", r) + } + }() + f, err := os.Open("../../src/term_strz.bin") + if err != nil { + t.Fatal(err) + } + s := kaitai.NewStream(f) + var r ExprToITrailing + err = r.Read(s, &r, &r) + if err != nil { + t.Fatal(err) + } + + { + tmp1, err := r.ToIR10() + assert.Error(t, err) + var wantErr strconv.NumError + assert.ErrorAs(t, err, &wantErr) + assert.EqualValues(t, 0, tmp1) + } + tmp2, err := r.ToIR16() + if err != nil { + t.Fatal(err) + } + assert.EqualValues(t, 152517308, tmp2) + { + tmp3, err := r.ToIGarbage() + assert.Error(t, err) + var wantErr strconv.NumError + assert.ErrorAs(t, err, &wantErr) + assert.EqualValues(t, 0, tmp3) + } +} diff --git a/translator/src/main/scala/io/kaitai/struct/testtranslator/specgenerators/GoSG.scala b/translator/src/main/scala/io/kaitai/struct/testtranslator/specgenerators/GoSG.scala index 705c59bfe..ad9de8632 100644 --- a/translator/src/main/scala/io/kaitai/struct/testtranslator/specgenerators/GoSG.scala +++ b/translator/src/main/scala/io/kaitai/struct/testtranslator/specgenerators/GoSG.scala @@ -1,38 +1,50 @@ package io.kaitai.struct.testtranslator.specgenerators import _root_.io.kaitai.struct.datatype.{DataType, EndOfStreamError, KSError} +import _root_.io.kaitai.struct.datatype.DataType._ import _root_.io.kaitai.struct.exprlang.Ast import _root_.io.kaitai.struct.languages.GoCompiler import _root_.io.kaitai.struct.testtranslator.{Main, TestAssert, TestEquals, TestSpec} -import _root_.io.kaitai.struct.translators.GoTranslator -import _root_.io.kaitai.struct.{ClassTypeProvider, RuntimeConfig, StringLanguageOutputWriter, Utils} +import _root_.io.kaitai.struct.translators.{GoTranslator, TypeProvider} +import _root_.io.kaitai.struct.{ClassTypeProvider, ImportList, RuntimeConfig, StringLanguageOutputWriter, Utils} class GoSG(spec: TestSpec, provider: ClassTypeProvider) extends BaseGenerator(spec) { /** * Special wrapper around StringLanguageOutputWriter that catches all attempts * to access "this.INIT_OBJ_NAME" and replaces it with "r." */ - class GoOutputWriter(out: StringLanguageOutputWriter) extends StringLanguageOutputWriter(indentStr) { - override def inc: Unit = out.inc - override def dec: Unit = out.dec - override def indentNow: String = out.indentNow - - override def add(other: StringLanguageOutputWriter): Unit = out.add(other) + class GoOutputWriter(indentStr: String) extends StringLanguageOutputWriter(indentStr) { override def puts(s: String): Unit = { - val mangled = s.replace(REPLACER, "r.").replaceAll("return err$", "t.Fatal(err)") - out.puts(mangled) + super.puts(s.replace(REPLACER, "r.")) } - override def puts = out.puts - override def close = out.close - override def putsLines(prefix: String, lines: String, hanging: String): Unit = - out.putsLines(prefix, lines, hanging) + } - override def result: String = out.result + /** + * Special wrapper around translator that catches all attempts to write error + * check and turns it into assertion. + */ + class GoTestTranslator( + out: StringLanguageOutputWriter, + provider: TypeProvider, + importList: ImportList, + ) extends GoTranslator(out, provider, importList) { + var doErrCheck = true + + override def outAddErrCheck(): Unit = { + if (doErrCheck) { + out.puts("if err != nil {") + out.inc + out.puts("t.Fatal(err)") + out.dec + out.puts("}") + } + } } + override val out = new GoOutputWriter(indentStr) val compiler = new GoCompiler(provider, RuntimeConfig()) val className = GoCompiler.types2class(List(spec.id)) - val translator = new GoTranslator(new GoOutputWriter(out), provider, importList) + val translator = new GoTestTranslator(out, provider, importList) override def fileName(name: String): String = s"${name}_test.go" @@ -68,17 +80,7 @@ class GoSG(spec: TestSpec, provider: ClassTypeProvider) extends BaseGenerator(sp override def runParseExpectError(exception: KSError): Unit = { out.puts("err = r.Read(s, &r, &r)") - importList.add("\"github.com/stretchr/testify/assert\"") - out.puts("assert.Error(t, err)") - exception match { - case EndOfStreamError => - importList.add("\"io\"") - out.puts("assert.ErrorIs(t, err, io.ErrUnexpectedEOF)") - case _ => - val errorName = GoCompiler.ksErrorName(exception) - out.puts(s"var wantErr ${errorName}") - out.puts("assert.ErrorAs(t, err, &wantErr)") - } + checkErr(exception) } override def footer() = { @@ -109,6 +111,33 @@ class GoSG(spec: TestSpec, provider: ClassTypeProvider) extends BaseGenerator(sp def trueArrayEquality(check: TestEquals, elType: DataType, elts: Seq[Ast.expr]): Unit = simpleEquality(check) + override def testException(actual: Ast.expr, exception: KSError): Unit = { + // We need a scope otherwise we got redeclaration error from Go in case of + // several assertions, because we use the same name for expected exception + out.puts("{") + out.inc + + // We do not want error check because we expect an error + translator.doErrCheck = false + val actStr = translateAct(actual) + translator.doErrCheck = true + + checkErr(exception) + + // translateAct generates unused variable which not allowed in Go, + // so we use it by checking its value + translator.detectType(actual) match { + case _: FloatType => out.puts(s"assert.InDelta(t, 0, $actStr, $FLOAT_DELTA)") + case _: NumericType => out.puts(s"assert.EqualValues(t, 0, $actStr)") + case _: BooleanType => out.puts(s"assert.EqualValues(t, false, $actStr)") + case _: StrType => out.puts(s"assert.EqualValues(t, \"\", $actStr)") + case _ => out.puts(s"assert.Nil(t, $actStr)") + } + + out.dec + out.puts("}") + } + override def indentStr: String = "\t" override def results: String = { @@ -131,4 +160,19 @@ class GoSG(spec: TestSpec, provider: ClassTypeProvider) extends BaseGenerator(sp def translateAct(x: Ast.expr) = translator.translate(x).replace(REPLACER, "r.") + + /** Generates code to check returned Go error to match of specified `exception`. */ + def checkErr(exception: KSError): Unit = { + importList.add("\"github.com/stretchr/testify/assert\"") + out.puts("assert.Error(t, err)") + exception match { + case EndOfStreamError => + importList.add("\"io\"") + out.puts("assert.ErrorIs(t, err, io.ErrUnexpectedEOF)") + case _ => + val errorName = GoCompiler.ksErrorName(exception) + out.puts(s"var wantErr ${errorName}") + out.puts("assert.ErrorAs(t, err, &wantErr)") + } + } }