diff --git a/main.go b/main.go index d778a8b..97a6bcb 100644 --- a/main.go +++ b/main.go @@ -108,6 +108,12 @@ func test() { } } } +func RemoteAddr(r *http.Request) string { + if r.Header.Get("CF-Connecting-IP") != "" { + return r.Header.Get("CF-Connecting-IP") + } + return r.RemoteAddr +} func middleware(n httprouter.Handle) httprouter.Handle { return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { //remoteAddr := r.RemoteAddr diff --git a/routes.go b/routes.go index a653ad0..b926cd1 100644 --- a/routes.go +++ b/routes.go @@ -141,7 +141,7 @@ var funcMap = template.FuncMap{ } body = buf.String() body = strings.Replace(body, `$1$2") return template.HTML(body) }, @@ -1429,6 +1429,33 @@ func UserOp(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { } http.Redirect(w, r, r.URL.String(), 301) } +func GetLink(w http.ResponseWriter, r *http.Request, ps httprouter.Params) { + var dest *url.URL + m, _ := url.ParseQuery(r.URL.RawQuery) + if len(m["url"]) > 0 { + dest, _ = url.Parse(m["url"][0]) + } + if dest.Host == r.Host || !IsLemmy(dest.Host, RemoteAddr(r)) { + http.Redirect(w, r, dest.String(), 302) + return + } + if host := ps.ByName("host"); host != "" { + redirect := "/" + host + dest.Path + if host != dest.Host && !strings.Contains(redirect, "@") { + redirect += ("@" + dest.Host) + } + http.Redirect(w, r, redirect, 302) + return + } + if host := os.Getenv("LEMMY_DOMAIN"); host != "" { + redirect := dest.Path + if host != dest.Host && !strings.Contains(redirect, "@") { + redirect += ("@" + dest.Host) + } + http.Redirect(w, r, redirect, 302) + return + } +} func GetRouter() *httprouter.Router { host := os.Getenv("LEMMY_DOMAIN") router := httprouter.New() @@ -1467,6 +1494,7 @@ func GetRouter() *httprouter.Router { router.GET("/:host/create_community", middleware(GetCreateCommunity)) router.POST("/:host/create_community", middleware(UserOp)) router.GET("/:host/communities", middleware(GetCommunities)) + router.GET("/:host/link", middleware(GetLink)) } else { router.ServeFiles("/_/static/*filepath", http.Dir("public")) router.GET("/", middleware(GetFrontpage)) @@ -1499,6 +1527,7 @@ func GetRouter() *httprouter.Router { router.GET("/create_community", middleware(GetCreateCommunity)) router.POST("/create_community", middleware(UserOp)) router.GET("/communities", middleware(GetCommunities)) + router.GET("/link", middleware(GetLink)) } return router }