diff --git a/templates.go b/templates.go index 6314c28..ff00ee7 100644 --- a/templates.go +++ b/templates.go @@ -9,50 +9,69 @@ import ( pathpkg "path" "strings" "text/template" - "text/template/parse" ) // Template represents a template. type Template interface { - AddParseTree(*parse.Tree) error + Clone() (Template, error) + AddTemplates(Template) error Execute(io.Writer, interface{}) error - Tree() *parse.Tree } type textTemplate struct { tmpl *template.Template } -func (t textTemplate) AddParseTree(tree *parse.Tree) error { - _, err := t.tmpl.AddParseTree(t.tmpl.Name(), tree) - return err +func (t textTemplate) Clone() (Template, error) { + clone, err := t.tmpl.Clone() + return textTemplate{clone}, err +} + +func (t textTemplate) AddTemplates(other Template) error { + otherTmpl := other.(textTemplate).tmpl + for _, def := range otherTmpl.Templates() { + if def.Name() == otherTmpl.Name() { + continue + } + _, err := t.tmpl.AddParseTree(def.Name(), def.Tree) + if err != nil { + return err + } + } + return nil } func (t textTemplate) Execute(w io.Writer, data interface{}) error { return t.tmpl.Execute(w, data) } -func (t textTemplate) Tree() *parse.Tree { - return t.tmpl.Tree -} - type htmlTemplate struct { tmpl *htemplate.Template } -func (t htmlTemplate) AddParseTree(tree *parse.Tree) error { - _, err := t.tmpl.AddParseTree(t.tmpl.Name(), tree) - return err +func (t htmlTemplate) Clone() (Template, error) { + clone, err := t.tmpl.Clone() + return htmlTemplate{clone}, err +} + +func (t htmlTemplate) AddTemplates(other Template) error { + otherTmpl := other.(htmlTemplate).tmpl + for _, def := range otherTmpl.Templates() { + if def.Name() == otherTmpl.Name() { + continue + } + _, err := t.tmpl.AddParseTree(def.Name(), def.Tree) + if err != nil { + return err + } + } + return nil } func (t htmlTemplate) Execute(w io.Writer, data interface{}) error { return t.tmpl.Execute(w, data) } -func (t htmlTemplate) Tree() *parse.Tree { - return t.tmpl.Tree -} - // Templates contains site templates. type Templates struct { tmpls map[string]Template @@ -129,12 +148,20 @@ func (t *Templates) Load(dir string, exts []string) error { if _, ok := extsMap[ext]; !ok { continue } - base := pathpkg.Join(pathpkg.Dir(path), "base"+ext) - if tmpl, ok := t.tmpls[base]; ok { - err := t.tmpls[path].AddParseTree(tmpl.Tree()) + basePath := pathpkg.Join(pathpkg.Dir(path), "base"+ext) + if path == basePath { + continue + } + if base, ok := t.tmpls[basePath]; ok { + tmpl, err := base.Clone() if err != nil { return err } + // Load customized template definitions + if err := tmpl.AddTemplates(t.tmpls[path]); err != nil { + return err + } + t.tmpls[path] = tmpl } } return nil