package generator
import (
"bytes"
_ "embed"
"fmt"
"io"
"math"
"path/filepath"
"sort"
"strings"
"text/template"
)
//go:embed templates.tpl
var tpl string
type Templates struct {
Templates *template.Template
Writer io.Writer
File *File
PackagePath string
FileName string
DeclarationsByName map[string]BebopElement
PackageNameForType map[Type]string
}
func NewTemplates(writer io.Writer, file *File, packagePath string, declarationsByName map[string]BebopElement) (*Templates, error) {
filename := filepath.Base(file.RelativePath)
filename = strings.TrimSuffix(filename, filepath.Ext(filename))
templates := &Templates{
Writer: writer,
File: file,
PackagePath: packagePath,
FileName: filename,
DeclarationsByName: declarationsByName,
}
tpls, err := template.New(``).
Funcs(template.FuncMap{
"EncodeField": templates.EncodeFieldFunc,
"DecodeField": templates.DecodeFieldFunc,
"MarshalField": templates.MarshalFieldFunc,
"UnmarshalField": templates.UnmarshalFieldFunc,
"ConstSize": templates.ConstSizeFunc,
"FieldSize": templates.FieldSizeFunc,
"IsLast": IsLast,
"Add": Add,
"PackageNames": templates.PackageNamesFunc,
}).
Parse(tpl)
if err != nil {
return nil, err
}
templates.Templates = tpls
return templates, nil
}
func IsLast(idx, length int) bool {
return idx+1 == length
}
func Add(a, b int) int {
return a + b
}
func (self *Templates) Package() error {
imports := []string{}
if self.needsRuntimeImport() {
imports = append(imports, `wellquite.org/bebop/runtime`, `io`)
}
if self.needsMathImport() {
imports = append(imports, `math`)
}
if self.needsGuidImport() {
imports = append(imports, `github.com/google/uuid`)
}
if self.needsTimeImport() {
imports = append(imports, `time`)
}
sort.Strings(imports)
bebopImports, err := self.gatherImports()
if err != nil {
return err
}
sort.Strings(bebopImports)
imports = append(imports, bebopImports...)
return self.Templates.ExecuteTemplate(self.Writer, `package`, struct {
Package string
Imports []string
}{Package: filepath.Base(self.PackagePath), Imports: imports})
}
func (self *Templates) needsRuntimeImport() bool {
return len(self.File.Enums) > 0 || len(self.File.Structs) > 0 || len(self.File.Messages) > 0 || len(self.File.Unions) > 0
}
func (self *Templates) needsMathImport() bool {
for _, konst := range self.File.Consts {
if konst.Type == Float32Type || konst.Type == Float64Type {
value := konst.Value.(float64)
if math.IsNaN(value) || math.IsInf(value, 0) {
return true
}
}
}
return false
}
func (self *Templates) needsGuidImport() bool {
return self.usesType(GuidType)
}
func (self *Templates) needsTimeImport() bool {
return self.usesType(DateType)
}
func (self *Templates) usesType(tipe Type) bool {
for _, konst := range self.File.Consts {
if konst.Type == tipe {
return true
}
}
for _, strukt := range self.File.Structs {
for _, field := range strukt.Fields {
if field.Type == tipe {
return true
}
}
}
for _, message := range self.File.Messages {
for _, field := range message.Fields {
if field.Type == tipe {
return true
}
}
}
return false
}
func (self *Templates) gatherImports() ([]string, error) {
var localUsedTypes []Type
myFile := self.File
for _, strukt := range myFile.Structs {
for _, field := range strukt.Fields {
localUsedTypes = append(localUsedTypes, field.Type)
}
}
for _, message := range myFile.Messages {
for _, field := range message.Fields {
localUsedTypes = append(localUsedTypes, field.Type)
}
}
packageNameForType := make(map[Type]string)
var bebopImports []string
myDir := filepath.Dir(myFile.RelativePath)
filesSeen := map[*File]string{
myFile: "",
}
for len(localUsedTypes) > 0 {
tipe := localUsedTypes[0]
localUsedTypes = localUsedTypes[1:]
switch tipe := tipe.(type) {
case BaseType:
// noop
case ArrayType:
localUsedTypes = append(localUsedTypes, tipe.Of)
case MapType:
localUsedTypes = append(localUsedTypes, tipe.Key, tipe.Value)
case CustomType:
file := self.DeclarationsByName[tipe.Name].File
packageName, found := filesSeen[file]
if !found {
relativePath, err := filepath.Rel(myDir, filepath.Dir(file.RelativePath))
if err != nil {
return nil, err
}
bebopImports = append(bebopImports, filepath.Join(self.PackagePath, relativePath))
packageName = filepath.Base(filepath.Dir(file.RelativePath))
filesSeen[file] = packageName
}
packageNameForType[tipe] = packageName
}
}
self.PackageNameForType = packageNameForType
return bebopImports, nil
}
func (self *Templates) Consts() (err error) {
prevWasConst := true
for idx, konst := range self.File.Consts {
templateName := `const`
switch konst.Type {
case StringType:
templateName = `const string`
case GuidType:
templateName = `const guid`
case Float64Type:
value := konst.Value.(float64)
if math.IsNaN(value) {
templateName = `const float64 nan`
} else if math.IsInf(value, 1) {
templateName = `const float64 +inf`
} else if math.IsInf(value, -1) {
templateName = `const float64 -inf`
}
case Float32Type:
value := konst.Value.(float64)
if math.IsNaN(value) {
templateName = `const float32 nan`
} else if math.IsInf(value, 1) {
templateName = `const float32 +inf`
} else if math.IsInf(value, -1) {
templateName = `const float32 -inf`
}
}
isConst := templateName == `const` || templateName == `const string`
if idx > 0 && isConst != prevWasConst {
fmt.Fprintln(self.Writer)
}
prevWasConst = isConst
err = self.Templates.ExecuteTemplate(self.Writer, templateName, konst)
if err != nil {
return err
}
}
if len(self.File.Consts) > 0 {
fmt.Fprintln(self.Writer)
}
return nil
}
func (self *Templates) Enums() (err error) {
for _, enum := range self.File.Enums {
err = self.Templates.ExecuteTemplate(self.Writer, `enum`, enum)
if err != nil {
return err
}
if err = self.EncodeDecodeMarshalUnmarshalTrampoline(enum); err != nil {
return err
}
if err = self.Templates.ExecuteTemplate(self.Writer, `encode enum`, enum); err != nil {
return err
}
if err = self.Templates.ExecuteTemplate(self.Writer, `decode enum`, enum); err != nil {
return err
}
}
return nil
}
func (self *Templates) Structs() (err error) {
for _, strukt := range self.File.Structs {
err = self.Templates.ExecuteTemplate(self.Writer, `struct`, strukt)
if err != nil {
return err
}
if err = self.Opcode(strukt.Name, strukt.Opcode); err != nil {
return err
}
if err = self.EncodeDecodeMarshalUnmarshalTrampoline(strukt); err != nil {
return err
}
if err = self.Templates.ExecuteTemplate(self.Writer, `encode struct`, strukt); err != nil {
return err
}
if err = self.Templates.ExecuteTemplate(self.Writer, `decode struct`, strukt); err != nil {
return err
}
}
return nil
}
func (self *Templates) Messages() (err error) {
for _, message := range self.File.Messages {
err = self.Templates.ExecuteTemplate(self.Writer, `message`, message)
if err != nil {
return err
}
if err = self.Opcode(message.Name, message.Opcode); err != nil {
return err
}
if err = self.EncodeDecodeMarshalUnmarshalTrampoline(message); err != nil {
return err
}
if err = self.Templates.ExecuteTemplate(self.Writer, `encode message`, message); err != nil {
return err
}
if err = self.Templates.ExecuteTemplate(self.Writer, `decode message`, message); err != nil {
return err
}
}
return nil
}
func (self *Templates) Unions() (err error) {
for _, union := range self.File.Unions {
err = self.Templates.ExecuteTemplate(self.Writer, `union`, union)
if err != nil {
return err
}
if err = self.Opcode(union.Name, union.Opcode); err != nil {
return err
}
if err = self.EncodeDecodeMarshalUnmarshalTrampoline(union); err != nil {
return err
}
if err = self.Templates.ExecuteTemplate(self.Writer, `encode union`, union); err != nil {
return err
}
if err = self.Templates.ExecuteTemplate(self.Writer, `decode union`, union); err != nil {
return err
}
}
return nil
}
func (self *Templates) ArraysAndMaps() (err error) {
return NewArraysAndMaps(self).ProcessFile()
}
func (self *Templates) Opcode(receiver string, opcode *Opcode) error {
if opcode == nil {
return nil
}
return self.Templates.ExecuteTemplate(self.Writer, `opcode`, struct {
Receiver string
Opcode *Opcode
}{Receiver: receiver, Opcode: opcode})
}
func (self *Templates) EncodeDecodeMarshalUnmarshalTrampoline(element interface{}) error {
return self.Templates.ExecuteTemplate(self.Writer, `encode decode marshal unmarshal trampoline`, element)
}
func (self *Templates) EncodeFieldFunc(tipe Type, field string, isPtr bool) (str string, err error) {
data := struct {
Type Type
Field string
FileName string
}{Type: tipe, Field: field, FileName: self.FileName}
buf := new(bytes.Buffer)
switch tipe.(type) {
case BaseType:
if isPtr {
data.Field = fmt.Sprintf(`*(%s)`, field)
}
err = self.Templates.ExecuteTemplate(buf, `encode basetype`, data)
case CustomType:
err = self.Templates.ExecuteTemplate(buf, `encode customtype`, field)
case ArrayType, MapType:
err = self.Templates.ExecuteTemplate(buf, `encode array or map`, data)
}
if err != nil {
return "", err
}
return buf.String(), nil
}
func (self *Templates) MarshalFieldFunc(tipe Type, field string, isPtr bool) (str string, err error) {
data := struct {
Type Type
Field string
FileName string
}{Type: tipe, Field: field, FileName: self.FileName}
buf := new(bytes.Buffer)
switch tipe.(type) {
case BaseType:
if isPtr {
data.Field = fmt.Sprintf(`*(%s)`, field)
}
err = self.Templates.ExecuteTemplate(buf, `marshal basetype`, data)
case CustomType:
err = self.Templates.ExecuteTemplate(buf, `marshal customtype`, data)
case ArrayType, MapType:
err = self.Templates.ExecuteTemplate(buf, `marshal array or map`, data)
}
if err != nil {
return "", err
}
return buf.String(), nil
}
func (self *Templates) DecodeFieldFunc(tipe Type, field string, castOpen, castClose, indent, returnPrefix string) (str string, err error) {
data := struct {
Type Type
Field string
CastOpen string
CastClose string
FileName string
New bool
ReturnPrefix string
}{
Type: tipe,
Field: field,
CastOpen: castOpen,
CastClose: castClose,
FileName: self.FileName,
New: castOpen == `&`,
ReturnPrefix: returnPrefix,
}
buf := new(bytes.Buffer)
switch tipe.(type) {
case BaseType:
err = self.Templates.ExecuteTemplate(buf, `decode basetype`, data)
case CustomType:
err = self.Templates.ExecuteTemplate(buf, `decode customtype`, data)
case ArrayType, MapType:
err = self.Templates.ExecuteTemplate(buf, `decode array or map`, data)
}
if err != nil {
return "", err
}
lines := strings.Split(buf.String(), "\n")
return strings.Join(lines, "\n"+indent), nil
}
func (self *Templates) UnmarshalFieldFunc(tipe Type, field string, castOpen, castClose, indent, returnPrefix string) (str string, err error) {
data := struct {
Type Type
Field string
CastOpen string
CastClose string
FileName string
New bool
ReturnPrefix string
}{
Type: tipe,
Field: field,
CastOpen: castOpen,
CastClose: castClose,
FileName: self.FileName,
New: castOpen == `&`,
ReturnPrefix: returnPrefix,
}
buf := new(bytes.Buffer)
switch tipe.(type) {
case BaseType:
err = self.Templates.ExecuteTemplate(buf, `unmarshal basetype`, data)
case CustomType:
err = self.Templates.ExecuteTemplate(buf, `unmarshal customtype`, data)
case ArrayType, MapType:
err = self.Templates.ExecuteTemplate(buf, `unmarshal array or map`, data)
}
if err != nil {
return "", err
}
lines := strings.Split(buf.String(), "\n")
return strings.Join(lines, "\n"+indent), nil
}
func (self *Templates) FieldSizeFunc(tipe Type, field string) string {
constSize, isConst := ConstSize(self.DeclarationsByName, tipe)
if isConst {
return fmt.Sprint(constSize)
}
switch tipe := tipe.(type) {
case BaseType: // must be string
return tipe.Size(field)
case CustomType:
return fmt.Sprintf(`%s.SizeBebop()`, field)
case ArrayType:
return fmt.Sprintf(`SizeBebopOf%s_%s(%s)`, tipe.FunctionName(), self.FileName, field)
case MapType:
return fmt.Sprintf(`SizeBebopOf%s_%s(%s)`, tipe.FunctionName(), self.FileName, field)
}
return ``
}
func (self *Templates) ConstSizeFunc(name string) int {
constSize, isConst := ConstSize(self.DeclarationsByName, NewCustomType(name))
if isConst {
return constSize
}
return -1
}
func (self *Templates) PackageNamesFunc() map[Type]string {
return self.PackageNameForType
}