mirror of
https://github.com/LemmyNet/lemmy.git
synced 2024-11-08 09:24:17 +00:00
some cleanup/refactoring
This commit is contained in:
parent
fe5702e714
commit
84565be252
|
@ -1,4 +1,4 @@
|
|||
use crate::{settings::structs::RateLimitConfig, utils::get_ip, IpAddr, LemmyError};
|
||||
use crate::{settings::structs::RateLimitConfig, utils::get_ip, IpAddr};
|
||||
use actix_web::{
|
||||
dev::{Service, ServiceRequest, ServiceResponse, Transform},
|
||||
HttpResponse,
|
||||
|
@ -67,15 +67,8 @@ impl RateLimit {
|
|||
}
|
||||
|
||||
impl RateLimited {
|
||||
/// Returns None if the request was rejected due to hitting rate limit.
|
||||
pub async fn wrap<T, E>(
|
||||
self,
|
||||
ip_addr: IpAddr,
|
||||
fut: impl Future<Output = Result<T, E>>,
|
||||
) -> Result<Option<T>, E>
|
||||
where
|
||||
E: From<LemmyError>,
|
||||
{
|
||||
/// Returns true if the request passed the rate limit, false if it failed and should be rejected.
|
||||
pub async fn check(self, ip_addr: IpAddr) -> bool {
|
||||
// Does not need to be blocking because the RwLock in settings never held across await points,
|
||||
// and the operation here locks only long enough to clone
|
||||
let rate_limit = self.rate_limit_config;
|
||||
|
@ -89,14 +82,7 @@ impl RateLimited {
|
|||
RateLimitType::Image => (rate_limit.image, rate_limit.image_per_second),
|
||||
RateLimitType::Comment => (rate_limit.comment, rate_limit.comment_per_second),
|
||||
};
|
||||
let passed = limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval)?;
|
||||
|
||||
drop(limiter);
|
||||
if passed {
|
||||
fut.await.map(|f| Some(f))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
limiter.check_rate_limit_full(self.type_, &ip_addr, kind, interval)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -141,10 +127,7 @@ where
|
|||
let service = self.service.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
let opt = rate_limited
|
||||
.wrap(ip_addr, async move { Ok(()) as Result<(), LemmyError> })
|
||||
.await?;
|
||||
if let Some(()) = opt {
|
||||
if rate_limited.check(ip_addr).await {
|
||||
service.call(req).await
|
||||
} else {
|
||||
let (http_req, _) = req.into_parts();
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
use crate::{IpAddr, LemmyError};
|
||||
use std::{collections::HashMap, time::SystemTime};
|
||||
use crate::IpAddr;
|
||||
use std::{collections::HashMap, time::Instant};
|
||||
use strum::IntoEnumIterator;
|
||||
use tracing::debug;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RateLimitBucket {
|
||||
last_checked: SystemTime,
|
||||
last_checked: Instant,
|
||||
allowance: f64,
|
||||
}
|
||||
|
||||
|
@ -36,7 +36,7 @@ impl RateLimiter {
|
|||
bucket.insert(
|
||||
ip.clone(),
|
||||
RateLimitBucket {
|
||||
last_checked: SystemTime::now(),
|
||||
last_checked: Instant::now(),
|
||||
allowance: -2f64,
|
||||
},
|
||||
);
|
||||
|
@ -55,12 +55,12 @@ impl RateLimiter {
|
|||
ip: &IpAddr,
|
||||
rate: i32,
|
||||
per: i32,
|
||||
) -> Result<bool, LemmyError> {
|
||||
) -> bool {
|
||||
self.insert_ip(ip);
|
||||
if let Some(bucket) = self.buckets.get_mut(&type_) {
|
||||
if let Some(rate_limit) = bucket.get_mut(ip) {
|
||||
let current = SystemTime::now();
|
||||
let time_passed = current.duration_since(rate_limit.last_checked)?.as_secs() as f64;
|
||||
let current = Instant::now();
|
||||
let time_passed = current.duration_since(rate_limit.last_checked).as_secs() as f64;
|
||||
|
||||
// The initial value
|
||||
if rate_limit.allowance == -2f64 {
|
||||
|
@ -81,16 +81,16 @@ impl RateLimiter {
|
|||
time_passed,
|
||||
rate_limit.allowance
|
||||
);
|
||||
Ok(false)
|
||||
false
|
||||
} else {
|
||||
rate_limit.allowance -= 1.0;
|
||||
Ok(true)
|
||||
true
|
||||
}
|
||||
} else {
|
||||
Ok(true)
|
||||
true
|
||||
}
|
||||
} else {
|
||||
Ok(true)
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -478,26 +478,30 @@ impl ChatServer {
|
|||
.as_str()
|
||||
.ok_or_else(|| LemmyError::from_message("missing op"))?;
|
||||
|
||||
let res = if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) {
|
||||
let fut = (message_handler_crud)(context, msg.id, user_operation_crud.clone(), data);
|
||||
match user_operation_crud {
|
||||
UserOperationCrud::Register => rate_limiter.register().wrap(ip, fut).await,
|
||||
UserOperationCrud::CreatePost => rate_limiter.post().wrap(ip, fut).await,
|
||||
UserOperationCrud::CreateCommunity => rate_limiter.register().wrap(ip, fut).await,
|
||||
UserOperationCrud::CreateComment => rate_limiter.comment().wrap(ip, fut).await,
|
||||
_ => rate_limiter.message().wrap(ip, fut).await,
|
||||
}
|
||||
// check if api call passes the rate limit, and generate future for later execution
|
||||
let (passed, fut) = if let Ok(user_operation_crud) = UserOperationCrud::from_str(op) {
|
||||
let passed = match user_operation_crud {
|
||||
UserOperationCrud::Register => rate_limiter.register().check(ip).await,
|
||||
UserOperationCrud::CreatePost => rate_limiter.post().check(ip).await,
|
||||
UserOperationCrud::CreateCommunity => rate_limiter.register().check(ip).await,
|
||||
UserOperationCrud::CreateComment => rate_limiter.comment().check(ip).await,
|
||||
_ => rate_limiter.message().check(ip).await,
|
||||
};
|
||||
let fut = (message_handler_crud)(context, msg.id, user_operation_crud, data);
|
||||
(passed, fut)
|
||||
} else {
|
||||
let user_operation = UserOperation::from_str(op)?;
|
||||
let fut = (message_handler)(context, msg.id, user_operation.clone(), data);
|
||||
match user_operation {
|
||||
UserOperation::GetCaptcha => rate_limiter.post().wrap(ip, fut).await,
|
||||
_ => rate_limiter.message().wrap(ip, fut).await,
|
||||
}
|
||||
}?;
|
||||
let passed = match user_operation {
|
||||
UserOperation::GetCaptcha => rate_limiter.post().check(ip).await,
|
||||
_ => rate_limiter.message().check(ip).await,
|
||||
};
|
||||
let fut = (message_handler)(context, msg.id, user_operation, data);
|
||||
(passed, fut)
|
||||
};
|
||||
|
||||
if let Some(r) = res {
|
||||
Ok(r)
|
||||
// if rate limit passed, execute api call future
|
||||
if passed {
|
||||
fut.await
|
||||
} else {
|
||||
// if rate limit was hit, respond with empty message
|
||||
Ok("".to_string())
|
||||
|
|
Loading…
Reference in a new issue