package main
import (
htemplate "html/template"
"io"
"io/fs"
"io/ioutil"
"os"
pathpkg "path"
"path/filepath"
"strings"
"text/template"
"text/template/parse"
)
// Template represents a template.
type Template interface {
AddParseTree(*parse.Tree) error
Execute(io.Writer, interface{}) error
Tree() *parse.Tree
}
type textTemplate struct {
tmpl *template.Template
}
func (t textTemplate) AddParseTree(tree *parse.Tree) error {
_, err := t.tmpl.AddParseTree(t.tmpl.Name(), tree)
return err
}
func (t textTemplate) Execute(w io.Writer, data interface{}) error {
return t.tmpl.Execute(w, data)
}
func (t textTemplate) Tree() *parse.Tree {
return t.tmpl.Tree
}
type htmlTemplate struct {
tmpl *htemplate.Template
}
func (t htmlTemplate) AddParseTree(tree *parse.Tree) error {
_, err := t.tmpl.AddParseTree(t.tmpl.Name(), tree)
return err
}
func (t htmlTemplate) Execute(w io.Writer, data interface{}) error {
return t.tmpl.Execute(w, data)
}
func (t htmlTemplate) Tree() *parse.Tree {
return t.tmpl.Tree
}
// Templates contains site templates.
type Templates struct {
tmpls map[string]Template
funcs map[string]interface{}
}
// NewTemplates returns a new Templates with the default templates.
func NewTemplates() *Templates {
t := &Templates{
tmpls: map[string]Template{},
}
return t
}
// Funcs sets the functions available to newly created templates.
func (t *Templates) Funcs(funcs map[string]interface{}) {
t.funcs = funcs
}
// LoadTemplate loads a template from the provided filenames.
func (t *Templates) LoadTemplate(name string, filenames ...string) error {
if ext := pathpkg.Ext(name); ext == ".html" || ext == ".xml" {
return t.loadHTMLTemplate(name, filenames...)
}
return t.loadTextTemplate(name, filenames...)
}
func (t *Templates) loadTextTemplate(name string, filenames ...string) error {
tmpl := template.New(name).Funcs(t.funcs)
for i := range filenames {
b, err := ioutil.ReadFile(filenames[i])
if err != nil {
return err
}
if _, err := tmpl.Parse(string(b)); err != nil {
return err
}
}
t.tmpls[name] = textTemplate{tmpl}
return nil
}
func (t *Templates) loadHTMLTemplate(name string, filenames ...string) error {
tmpl := htemplate.New(name).Funcs(t.funcs)
for i := range filenames {
b, err := ioutil.ReadFile(filenames[i])
if err != nil {
return err
}
if _, err := tmpl.Parse(string(b)); err != nil {
return err
}
}
t.tmpls[name] = htmlTemplate{tmpl}
return nil
}
// Load loads templates from the provided directory.
func (t *Templates) Load(dir string, exts []string) error {
err := filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error {
if err != nil {
return err
}
if d.Type().IsRegular() {
name := strings.TrimPrefix(path, dir)
t.LoadTemplate(name, path)
}
return nil
})
if err != nil && !os.IsNotExist(err) {
return err
}
// Add base templates
var extsMap = map[string]struct{}{}
for _, ext := range exts {
extsMap[ext] = struct{}{}
}
for path := range t.tmpls {
ext := pathpkg.Ext(path)
if _, ok := extsMap[ext]; !ok {
continue
}
base := pathpkg.Join(pathpkg.Dir(path), "base"+ext)
if tmpl, ok := t.tmpls[base]; ok {
err := t.tmpls[path].AddParseTree(tmpl.Tree())
if err != nil {
return err
}
}
}
return nil
}
// FindTemplate returns the template for the given path.
func (t *Templates) FindTemplate(path string, tmpl string) (Template, bool) {
tmplPath := pathpkg.Join(path, tmpl)
if t, ok := t.tmpls[tmplPath]; ok {
return t, true
}
if t, ok := t.tmpls[pathpkg.Join("/_default", tmpl)]; ok {
return t, true
}
// Failed to find template
return nil, false
}
// FindPartial returns the partial template of the given name.
func (t *Templates) FindPartial(name string) (Template, bool) {
if t, ok := t.tmpls[pathpkg.Join("/_partials", name)]; ok {
return t, true
}
return nil, false
}