diff --git a/main.go b/main.go
index d778a8b..97a6bcb 100644
--- a/main.go
+++ b/main.go
@@ -108,6 +108,12 @@ func test() {
}
}
}
+func RemoteAddr(r *http.Request) string {
+ if r.Header.Get("CF-Connecting-IP") != "" {
+ return r.Header.Get("CF-Connecting-IP")
+ }
+ return r.RemoteAddr
+}
func middleware(n httprouter.Handle) httprouter.Handle {
return func(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
//remoteAddr := r.RemoteAddr
diff --git a/routes.go b/routes.go
index a653ad0..b926cd1 100644
--- a/routes.go
+++ b/routes.go
@@ -141,7 +141,7 @@ var funcMap = template.FuncMap{
}
body = buf.String()
body = strings.Replace(body, `$1$2")
return template.HTML(body)
},
@@ -1429,6 +1429,33 @@ func UserOp(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
}
http.Redirect(w, r, r.URL.String(), 301)
}
+func GetLink(w http.ResponseWriter, r *http.Request, ps httprouter.Params) {
+ var dest *url.URL
+ m, _ := url.ParseQuery(r.URL.RawQuery)
+ if len(m["url"]) > 0 {
+ dest, _ = url.Parse(m["url"][0])
+ }
+ if dest.Host == r.Host || !IsLemmy(dest.Host, RemoteAddr(r)) {
+ http.Redirect(w, r, dest.String(), 302)
+ return
+ }
+ if host := ps.ByName("host"); host != "" {
+ redirect := "/" + host + dest.Path
+ if host != dest.Host && !strings.Contains(redirect, "@") {
+ redirect += ("@" + dest.Host)
+ }
+ http.Redirect(w, r, redirect, 302)
+ return
+ }
+ if host := os.Getenv("LEMMY_DOMAIN"); host != "" {
+ redirect := dest.Path
+ if host != dest.Host && !strings.Contains(redirect, "@") {
+ redirect += ("@" + dest.Host)
+ }
+ http.Redirect(w, r, redirect, 302)
+ return
+ }
+}
func GetRouter() *httprouter.Router {
host := os.Getenv("LEMMY_DOMAIN")
router := httprouter.New()
@@ -1467,6 +1494,7 @@ func GetRouter() *httprouter.Router {
router.GET("/:host/create_community", middleware(GetCreateCommunity))
router.POST("/:host/create_community", middleware(UserOp))
router.GET("/:host/communities", middleware(GetCommunities))
+ router.GET("/:host/link", middleware(GetLink))
} else {
router.ServeFiles("/_/static/*filepath", http.Dir("public"))
router.GET("/", middleware(GetFrontpage))
@@ -1499,6 +1527,7 @@ func GetRouter() *httprouter.Router {
router.GET("/create_community", middleware(GetCreateCommunity))
router.POST("/create_community", middleware(UserOp))
router.GET("/communities", middleware(GetCommunities))
+ router.GET("/link", middleware(GetLink))
}
return router
}