node: support namespace, support tracking package imports

This commit is contained in:
mappu 2020-04-08 19:57:59 +12:00
parent 48c58898ab
commit daf5000404
2 changed files with 58 additions and 16 deletions

View File

@ -21,7 +21,9 @@ func ConvertFile(filename string) (string, error) {
namespaces := visitor.NewNamespaceResolver() namespaces := visitor.NewNamespaceResolver()
// scope := NewScope() // scope := NewScope()
state := conversionState{} state := conversionState{
importPackages: make(map[string]struct{}),
}
p, err := parser.NewParser([]byte(inputFile), "7.4") p, err := parser.NewParser([]byte(inputFile), "7.4")
if err != nil { if err != nil {

70
node.go
View File

@ -39,6 +39,7 @@ type conversionState struct {
currentClassName string currentClassName string
currentClassParentName string currentClassParentName string
currentErrHandler string currentErrHandler string
importPackages map[string]struct{}
} }
// //
@ -95,40 +96,71 @@ func (this *conversionState) convertNoFreeFloating(n_ node.Node) (string, error)
// //
case *node.Root: case *node.Root:
ret := "package main\n\n"
// Hoist all declarations first, and put any top-level code into a generated main() function // Hoist all declarations first, and put any top-level code into a generated main() function
declarations := []string{}
statements := []string{} statements := []string{}
packageName := `main`
for _, s := range n.Stmts { for _, s := range n.Stmts {
switch s.(type) { switch s := s.(type) {
case *stmt.Class, *stmt.Function, *stmt.Interface: case *stmt.Class, *stmt.Function, *stmt.Interface:
this.currentErrHandler = "return nil, err\n" this.currentErrHandler = "return nil, err\n"
sm, err := this.convert(s)
if err != nil {
return "", parseErr{s, err}
}
// Declaration - emit immediately (hoist)
declarations = append(declarations, sm)
case *stmt.Namespace:
if len(s.Stmts) > 0 {
return "", parseErr{s, fmt.Errorf("please use `namespace Foo;` instead of blocks")}
}
namespace, err := this.resolveName(s.NamespaceName)
if err != nil {
return "", err
}
packageName = namespace
default: default:
this.currentErrHandler = "panic(err)\n" // top-level init/main behaviour this.currentErrHandler = "panic(err)\n" // top-level init/main behaviour
}
sm, err := this.convert(s) sm, err := this.convert(s)
if err != nil { if err != nil {
return "", parseErr{s, err} return "", parseErr{s, err}
} }
switch s.(type) {
case *stmt.Class, *stmt.Function, *stmt.Interface:
// Declaration - emit immediately (hoist)
ret += sm + "\n"
default:
// Top-level function code - deter emission // Top-level function code - deter emission
statements = append(statements, sm) statements = append(statements, sm)
} }
}
// Emit
ret := "package main\n\n"
if len(this.importPackages) > 0 {
ret += "import (\n"
for packageName, _ := range this.importPackages {
ret += "\t" + strconv.Quote(packageName) + "\n"
}
ret += ")\n"
}
if len(declarations) > 0 {
ret += strings.Join(declarations, "\n") + "\n"
} }
// Emit deferred statements
if len(statements) > 0 { if len(statements) > 0 {
ret += "func init() {\n" topFunc := `init`
if packageName == `main` {
topFunc = `main`
}
ret += "func " + topFunc + "() {\n"
ret += "\t" + strings.Join(statements, "\t") // Statements already added their own newline ret += "\t" + strings.Join(statements, "\t") // Statements already added their own newline
ret += "}\n" ret += "}\n"
} }
@ -568,6 +600,9 @@ func (this *conversionState) convertNoFreeFloating(n_ node.Node) (string, error)
args = append(args, exprGo) args = append(args, exprGo)
} }
this.importPackages["fmt"] = struct{}{}
return "fmt.Print(" + strings.Join(args, ", ") + ")\n", nil // newline - standalone statement return "fmt.Print(" + strings.Join(args, ", ") + ")\n", nil // newline - standalone statement
case *stmt.InlineHtml: case *stmt.InlineHtml:
@ -582,6 +617,8 @@ func (this *conversionState) convertNoFreeFloating(n_ node.Node) (string, error)
quoted = strconv.Quote(n.Value) quoted = strconv.Quote(n.Value)
} }
this.importPackages["fmt"] = struct{}{}
return "fmt.Print(" + quoted + ")\n", nil // newline - standalone statement return "fmt.Print(" + quoted + ")\n", nil // newline - standalone statement
case *stmt.If: case *stmt.If:
@ -697,6 +734,7 @@ func (this *conversionState) convertNoFreeFloating(n_ node.Node) (string, error)
catchVars = append(catchVars, tempCatchVar+"*"+typename) catchVars = append(catchVars, tempCatchVar+"*"+typename)
this.importPackages["errors"] = struct{}{}
catchStmt := "if errors.As(err, &" + tempCatchVar + ") {\n" catchStmt := "if errors.As(err, &" + tempCatchVar + ") {\n"
if len(catch.Types) > 1 { if len(catch.Types) > 1 {
catchStmt += catchVar + " := " + tempCatchVar + " // rename\n" catchStmt += catchVar + " := " + tempCatchVar + " // rename\n"
@ -969,6 +1007,7 @@ func (this *conversionState) convertNoFreeFloating(n_ node.Node) (string, error)
return "", parseErr{n, err} return "", parseErr{n, err}
} }
this.importPackages["math"] = struct{}{}
return `math.Mod(` + rval + `, ` + modulo + `)`, nil return `math.Mod(` + rval + `, ` + modulo + `)`, nil
case *binary.Mul: case *binary.Mul:
@ -996,6 +1035,7 @@ func (this *conversionState) convertNoFreeFloating(n_ node.Node) (string, error)
return "", parseErr{n, err} return "", parseErr{n, err}
} }
this.importPackages["math"] = struct{}{}
return `math.Pow(` + base + `, ` + exponent + `)`, nil return `math.Pow(` + base + `, ` + exponent + `)`, nil
case *binary.ShiftLeft: case *binary.ShiftLeft: