Bebop

templates.go at tip
Login

File generator/templates.go from the latest check-in


package generator

import (
	"bytes"
	_ "embed"
	"fmt"
	"io"
	"math"
	"path/filepath"
	"sort"
	"strings"
	"text/template"
)

//go:embed templates.tpl
var tpl string

type Templates struct {
	Templates *template.Template
	Writer    io.Writer

	File               *File
	PackagePath        string
	FileName           string
	DeclarationsByName map[string]BebopElement
	PackageNameForType map[Type]string
}

func NewTemplates(writer io.Writer, file *File, packagePath string, declarationsByName map[string]BebopElement) (*Templates, error) {
	filename := filepath.Base(file.RelativePath)
	filename = strings.TrimSuffix(filename, filepath.Ext(filename))

	templates := &Templates{
		Writer:             writer,
		File:               file,
		PackagePath:        packagePath,
		FileName:           filename,
		DeclarationsByName: declarationsByName,
	}

	tpls, err := template.New(``).
		Funcs(template.FuncMap{
			"EncodeField":    templates.EncodeFieldFunc,
			"DecodeField":    templates.DecodeFieldFunc,
			"MarshalField":   templates.MarshalFieldFunc,
			"UnmarshalField": templates.UnmarshalFieldFunc,
			"ConstSize":      templates.ConstSizeFunc,
			"FieldSize":      templates.FieldSizeFunc,
			"IsLast":         IsLast,
			"Add":            Add,
			"PackageNames":   templates.PackageNamesFunc,
		}).
		Parse(tpl)

	if err != nil {
		return nil, err
	}

	templates.Templates = tpls

	return templates, nil
}

func IsLast(idx, length int) bool {
	return idx+1 == length
}

func Add(a, b int) int {
	return a + b
}

func (self *Templates) Package() error {
	imports := []string{}
	if self.needsRuntimeImport() {
		imports = append(imports, `wellquite.org/bebop/runtime`, `io`)
	}
	if self.needsMathImport() {
		imports = append(imports, `math`)
	}
	if self.needsGuidImport() {
		imports = append(imports, `github.com/google/uuid`)
	}
	if self.needsTimeImport() {
		imports = append(imports, `time`)
	}
	sort.Strings(imports)

	bebopImports, err := self.gatherImports()
	if err != nil {
		return err
	}
	sort.Strings(bebopImports)
	imports = append(imports, bebopImports...)

	return self.Templates.ExecuteTemplate(self.Writer, `package`, struct {
		Package string
		Imports []string
	}{Package: filepath.Base(self.PackagePath), Imports: imports})
}

func (self *Templates) needsRuntimeImport() bool {
	return len(self.File.Enums) > 0 || len(self.File.Structs) > 0 || len(self.File.Messages) > 0 || len(self.File.Unions) > 0
}

func (self *Templates) needsMathImport() bool {
	for _, konst := range self.File.Consts {
		if konst.Type == Float32Type || konst.Type == Float64Type {
			value := konst.Value.(float64)
			if math.IsNaN(value) || math.IsInf(value, 0) {
				return true
			}
		}
	}
	return false
}

func (self *Templates) needsGuidImport() bool {
	return self.usesType(GuidType)
}

func (self *Templates) needsTimeImport() bool {
	return self.usesType(DateType)
}

func (self *Templates) usesType(tipe Type) bool {
	for _, konst := range self.File.Consts {
		if konst.Type == tipe {
			return true
		}
	}
	for _, strukt := range self.File.Structs {
		for _, field := range strukt.Fields {
			if field.Type == tipe {
				return true
			}
		}
	}

	for _, message := range self.File.Messages {
		for _, field := range message.Fields {
			if field.Type == tipe {
				return true
			}
		}
	}

	return false
}

func (self *Templates) gatherImports() ([]string, error) {
	var localUsedTypes []Type

	myFile := self.File

	for _, strukt := range myFile.Structs {
		for _, field := range strukt.Fields {
			localUsedTypes = append(localUsedTypes, field.Type)
		}
	}

	for _, message := range myFile.Messages {
		for _, field := range message.Fields {
			localUsedTypes = append(localUsedTypes, field.Type)
		}
	}

	packageNameForType := make(map[Type]string)
	var bebopImports []string
	myDir := filepath.Dir(myFile.RelativePath)
	filesSeen := map[*File]string{
		myFile: "",
	}

	for len(localUsedTypes) > 0 {
		tipe := localUsedTypes[0]
		localUsedTypes = localUsedTypes[1:]
		switch tipe := tipe.(type) {
		case BaseType:
			// noop

		case ArrayType:
			localUsedTypes = append(localUsedTypes, tipe.Of)

		case MapType:
			localUsedTypes = append(localUsedTypes, tipe.Key, tipe.Value)

		case CustomType:
			file := self.DeclarationsByName[tipe.Name].File

			packageName, found := filesSeen[file]
			if !found {
				relativePath, err := filepath.Rel(myDir, filepath.Dir(file.RelativePath))
				if err != nil {
					return nil, err
				}

				bebopImports = append(bebopImports, filepath.Join(self.PackagePath, relativePath))

				packageName = filepath.Base(filepath.Dir(file.RelativePath))
				filesSeen[file] = packageName
			}

			packageNameForType[tipe] = packageName
		}
	}

	self.PackageNameForType = packageNameForType

	return bebopImports, nil
}

func (self *Templates) Consts() (err error) {
	prevWasConst := true

	for idx, konst := range self.File.Consts {
		templateName := `const`

		switch konst.Type {
		case StringType:
			templateName = `const string`

		case GuidType:
			templateName = `const guid`

		case Float64Type:
			value := konst.Value.(float64)

			if math.IsNaN(value) {
				templateName = `const float64 nan`
			} else if math.IsInf(value, 1) {
				templateName = `const float64 +inf`
			} else if math.IsInf(value, -1) {
				templateName = `const float64 -inf`
			}

		case Float32Type:
			value := konst.Value.(float64)

			if math.IsNaN(value) {
				templateName = `const float32 nan`
			} else if math.IsInf(value, 1) {
				templateName = `const float32 +inf`
			} else if math.IsInf(value, -1) {
				templateName = `const float32 -inf`
			}
		}

		isConst := templateName == `const` || templateName == `const string`
		if idx > 0 && isConst != prevWasConst {
			fmt.Fprintln(self.Writer)
		}
		prevWasConst = isConst

		err = self.Templates.ExecuteTemplate(self.Writer, templateName, konst)
		if err != nil {
			return err
		}
	}

	if len(self.File.Consts) > 0 {
		fmt.Fprintln(self.Writer)
	}

	return nil
}

func (self *Templates) Enums() (err error) {
	for _, enum := range self.File.Enums {
		err = self.Templates.ExecuteTemplate(self.Writer, `enum`, enum)
		if err != nil {
			return err
		}
		if err = self.EncodeDecodeMarshalUnmarshalTrampoline(enum); err != nil {
			return err
		}
		if err = self.Templates.ExecuteTemplate(self.Writer, `encode enum`, enum); err != nil {
			return err
		}
		if err = self.Templates.ExecuteTemplate(self.Writer, `decode enum`, enum); err != nil {
			return err
		}
	}
	return nil
}

func (self *Templates) Structs() (err error) {
	for _, strukt := range self.File.Structs {
		err = self.Templates.ExecuteTemplate(self.Writer, `struct`, strukt)
		if err != nil {
			return err
		}
		if err = self.Opcode(strukt.Name, strukt.Opcode); err != nil {
			return err
		}
		if err = self.EncodeDecodeMarshalUnmarshalTrampoline(strukt); err != nil {
			return err
		}
		if err = self.Templates.ExecuteTemplate(self.Writer, `encode struct`, strukt); err != nil {
			return err
		}
		if err = self.Templates.ExecuteTemplate(self.Writer, `decode struct`, strukt); err != nil {
			return err
		}
	}
	return nil
}

func (self *Templates) Messages() (err error) {
	for _, message := range self.File.Messages {
		err = self.Templates.ExecuteTemplate(self.Writer, `message`, message)
		if err != nil {
			return err
		}
		if err = self.Opcode(message.Name, message.Opcode); err != nil {
			return err
		}
		if err = self.EncodeDecodeMarshalUnmarshalTrampoline(message); err != nil {
			return err
		}
		if err = self.Templates.ExecuteTemplate(self.Writer, `encode message`, message); err != nil {
			return err
		}
		if err = self.Templates.ExecuteTemplate(self.Writer, `decode message`, message); err != nil {
			return err
		}
	}
	return nil
}

func (self *Templates) Unions() (err error) {
	for _, union := range self.File.Unions {
		err = self.Templates.ExecuteTemplate(self.Writer, `union`, union)
		if err != nil {
			return err
		}
		if err = self.Opcode(union.Name, union.Opcode); err != nil {
			return err
		}
		if err = self.EncodeDecodeMarshalUnmarshalTrampoline(union); err != nil {
			return err
		}
		if err = self.Templates.ExecuteTemplate(self.Writer, `encode union`, union); err != nil {
			return err
		}
		if err = self.Templates.ExecuteTemplate(self.Writer, `decode union`, union); err != nil {
			return err
		}
	}
	return nil
}

func (self *Templates) ArraysAndMaps() (err error) {
	return NewArraysAndMaps(self).ProcessFile()
}

func (self *Templates) Opcode(receiver string, opcode *Opcode) error {
	if opcode == nil {
		return nil
	}
	return self.Templates.ExecuteTemplate(self.Writer, `opcode`, struct {
		Receiver string
		Opcode   *Opcode
	}{Receiver: receiver, Opcode: opcode})
}

func (self *Templates) EncodeDecodeMarshalUnmarshalTrampoline(element interface{}) error {
	return self.Templates.ExecuteTemplate(self.Writer, `encode decode marshal unmarshal trampoline`, element)
}

func (self *Templates) EncodeFieldFunc(tipe Type, field string, isPtr bool) (str string, err error) {
	data := struct {
		Type     Type
		Field    string
		FileName string
	}{Type: tipe, Field: field, FileName: self.FileName}

	buf := new(bytes.Buffer)
	switch tipe.(type) {
	case BaseType:
		if isPtr {
			data.Field = fmt.Sprintf(`*(%s)`, field)
		}
		err = self.Templates.ExecuteTemplate(buf, `encode basetype`, data)
	case CustomType:
		err = self.Templates.ExecuteTemplate(buf, `encode customtype`, field)
	case ArrayType, MapType:
		err = self.Templates.ExecuteTemplate(buf, `encode array or map`, data)
	}
	if err != nil {
		return "", err
	}
	return buf.String(), nil
}

func (self *Templates) MarshalFieldFunc(tipe Type, field string, isPtr bool) (str string, err error) {
	data := struct {
		Type     Type
		Field    string
		FileName string
	}{Type: tipe, Field: field, FileName: self.FileName}

	buf := new(bytes.Buffer)
	switch tipe.(type) {
	case BaseType:
		if isPtr {
			data.Field = fmt.Sprintf(`*(%s)`, field)
		}
		err = self.Templates.ExecuteTemplate(buf, `marshal basetype`, data)
	case CustomType:
		err = self.Templates.ExecuteTemplate(buf, `marshal customtype`, data)
	case ArrayType, MapType:
		err = self.Templates.ExecuteTemplate(buf, `marshal array or map`, data)
	}
	if err != nil {
		return "", err
	}
	return buf.String(), nil
}

func (self *Templates) DecodeFieldFunc(tipe Type, field string, castOpen, castClose, indent, returnPrefix string) (str string, err error) {
	data := struct {
		Type         Type
		Field        string
		CastOpen     string
		CastClose    string
		FileName     string
		New          bool
		ReturnPrefix string
	}{
		Type:         tipe,
		Field:        field,
		CastOpen:     castOpen,
		CastClose:    castClose,
		FileName:     self.FileName,
		New:          castOpen == `&`,
		ReturnPrefix: returnPrefix,
	}

	buf := new(bytes.Buffer)
	switch tipe.(type) {
	case BaseType:
		err = self.Templates.ExecuteTemplate(buf, `decode basetype`, data)
	case CustomType:
		err = self.Templates.ExecuteTemplate(buf, `decode customtype`, data)
	case ArrayType, MapType:
		err = self.Templates.ExecuteTemplate(buf, `decode array or map`, data)
	}
	if err != nil {
		return "", err
	}
	lines := strings.Split(buf.String(), "\n")
	return strings.Join(lines, "\n"+indent), nil
}

func (self *Templates) UnmarshalFieldFunc(tipe Type, field string, castOpen, castClose, indent, returnPrefix string) (str string, err error) {
	data := struct {
		Type         Type
		Field        string
		CastOpen     string
		CastClose    string
		FileName     string
		New          bool
		ReturnPrefix string
	}{
		Type:         tipe,
		Field:        field,
		CastOpen:     castOpen,
		CastClose:    castClose,
		FileName:     self.FileName,
		New:          castOpen == `&`,
		ReturnPrefix: returnPrefix,
	}

	buf := new(bytes.Buffer)
	switch tipe.(type) {
	case BaseType:
		err = self.Templates.ExecuteTemplate(buf, `unmarshal basetype`, data)
	case CustomType:
		err = self.Templates.ExecuteTemplate(buf, `unmarshal customtype`, data)
	case ArrayType, MapType:
		err = self.Templates.ExecuteTemplate(buf, `unmarshal array or map`, data)
	}
	if err != nil {
		return "", err
	}
	lines := strings.Split(buf.String(), "\n")
	return strings.Join(lines, "\n"+indent), nil
}

func (self *Templates) FieldSizeFunc(tipe Type, field string) string {
	constSize, isConst := ConstSize(self.DeclarationsByName, tipe)
	if isConst {
		return fmt.Sprint(constSize)
	}

	switch tipe := tipe.(type) {
	case BaseType: // must be string
		return tipe.Size(field)
	case CustomType:
		return fmt.Sprintf(`%s.SizeBebop()`, field)
	case ArrayType:
		return fmt.Sprintf(`SizeBebopOf%s_%s(%s)`, tipe.FunctionName(), self.FileName, field)
	case MapType:
		return fmt.Sprintf(`SizeBebopOf%s_%s(%s)`, tipe.FunctionName(), self.FileName, field)
	}
	return ``
}

func (self *Templates) ConstSizeFunc(name string) int {
	constSize, isConst := ConstSize(self.DeclarationsByName, NewCustomType(name))
	if isConst {
		return constSize
	}
	return -1
}

func (self *Templates) PackageNamesFunc() map[Type]string {
	return self.PackageNameForType
}