diff --git a/main.go b/main.go index 034654e..3d65166 100644 --- a/main.go +++ b/main.go @@ -21,7 +21,9 @@ func ConvertFile(filename string) (string, error) { namespaces := visitor.NewNamespaceResolver() // scope := NewScope() - state := conversionState{} + state := conversionState{ + importPackages: make(map[string]struct{}), + } p, err := parser.NewParser([]byte(inputFile), "7.4") if err != nil { diff --git a/node.go b/node.go index 4fc76b1..6f6001e 100644 --- a/node.go +++ b/node.go @@ -39,6 +39,7 @@ type conversionState struct { currentClassName string currentClassParentName string currentErrHandler string + importPackages map[string]struct{} } // @@ -95,40 +96,71 @@ func (this *conversionState) convertNoFreeFloating(n_ node.Node) (string, error) // case *node.Root: - ret := "package main\n\n" // Hoist all declarations first, and put any top-level code into a generated main() function + declarations := []string{} statements := []string{} + packageName := `main` for _, s := range n.Stmts { - switch s.(type) { + switch s := s.(type) { case *stmt.Class, *stmt.Function, *stmt.Interface: this.currentErrHandler = "return nil, err\n" + sm, err := this.convert(s) + if err != nil { + return "", parseErr{s, err} + } + + // Declaration - emit immediately (hoist) + declarations = append(declarations, sm) + + case *stmt.Namespace: + if len(s.Stmts) > 0 { + return "", parseErr{s, fmt.Errorf("please use `namespace Foo;` instead of blocks")} + } + namespace, err := this.resolveName(s.NamespaceName) + if err != nil { + return "", err + } + packageName = namespace + default: this.currentErrHandler = "panic(err)\n" // top-level init/main behaviour - } - sm, err := this.convert(s) - if err != nil { - return "", parseErr{s, err} - } + sm, err := this.convert(s) + if err != nil { + return "", parseErr{s, err} + } - switch s.(type) { - case *stmt.Class, *stmt.Function, *stmt.Interface: - // Declaration - emit immediately (hoist) - ret += sm + "\n" - - default: // Top-level function code - deter emission statements = append(statements, sm) } + + } + + // Emit + ret := "package main\n\n" + if len(this.importPackages) > 0 { + ret += "import (\n" + for packageName, _ := range this.importPackages { + ret += "\t" + strconv.Quote(packageName) + "\n" + } + ret += ")\n" + } + + if len(declarations) > 0 { + ret += strings.Join(declarations, "\n") + "\n" } - // Emit deferred statements if len(statements) > 0 { - ret += "func init() {\n" + topFunc := `init` + if packageName == `main` { + topFunc = `main` + } + + ret += "func " + topFunc + "() {\n" ret += "\t" + strings.Join(statements, "\t") // Statements already added their own newline ret += "}\n" } @@ -568,6 +600,9 @@ func (this *conversionState) convertNoFreeFloating(n_ node.Node) (string, error) args = append(args, exprGo) } + + this.importPackages["fmt"] = struct{}{} + return "fmt.Print(" + strings.Join(args, ", ") + ")\n", nil // newline - standalone statement case *stmt.InlineHtml: @@ -582,6 +617,8 @@ func (this *conversionState) convertNoFreeFloating(n_ node.Node) (string, error) quoted = strconv.Quote(n.Value) } + this.importPackages["fmt"] = struct{}{} + return "fmt.Print(" + quoted + ")\n", nil // newline - standalone statement case *stmt.If: @@ -697,6 +734,7 @@ func (this *conversionState) convertNoFreeFloating(n_ node.Node) (string, error) catchVars = append(catchVars, tempCatchVar+"*"+typename) + this.importPackages["errors"] = struct{}{} catchStmt := "if errors.As(err, &" + tempCatchVar + ") {\n" if len(catch.Types) > 1 { catchStmt += catchVar + " := " + tempCatchVar + " // rename\n" @@ -969,6 +1007,7 @@ func (this *conversionState) convertNoFreeFloating(n_ node.Node) (string, error) return "", parseErr{n, err} } + this.importPackages["math"] = struct{}{} return `math.Mod(` + rval + `, ` + modulo + `)`, nil case *binary.Mul: @@ -996,6 +1035,7 @@ func (this *conversionState) convertNoFreeFloating(n_ node.Node) (string, error) return "", parseErr{n, err} } + this.importPackages["math"] = struct{}{} return `math.Pow(` + base + `, ` + exponent + `)`, nil case *binary.ShiftLeft: