diff --git a/Cargo.lock b/Cargo.lock index 17a006fda..6fe932a29 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1683,6 +1683,12 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "downcast" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1" + [[package]] name = "downcast-rs" version = "1.2.1" @@ -1924,6 +1930,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fragile" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa" + [[package]] name = "fs2" version = "0.4.3" @@ -3115,6 +3127,7 @@ dependencies = [ "activitypub_federation", "actix-web", "anyhow", + "async-trait", "chrono", "diesel", "diesel-async", @@ -3124,6 +3137,7 @@ dependencies = [ "lemmy_db_schema", "lemmy_db_views_actor", "lemmy_utils", + "mockall", "moka", "once_cell", "reqwest 0.11.27", @@ -3286,7 +3300,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.52.5", ] [[package]] @@ -3613,6 +3627,33 @@ version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c9be0862c1b3f26a88803c4a49de6889c10e608b3ee9344e6ef5b45fb37ad3d1" +[[package]] +name = "mockall" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43766c2b5203b10de348ffe19f7e54564b64f3d6018ff7648d1e2d6d3a0f0a48" +dependencies = [ + "cfg-if", + "downcast", + "fragile", + "lazy_static", + "mockall_derive", + "predicates", + "predicates-tree", +] + +[[package]] +name = "mockall_derive" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af7cbce79ec385a1d4f54baa90a76401eb15d9cab93685f62e7e9f942aa00ae2" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "syn 2.0.66", +] + [[package]] name = "moka" version = "0.12.7" @@ -4361,6 +4402,32 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" +[[package]] +name = "predicates" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68b87bfd4605926cdfefc1c3b5f8fe560e3feca9d5552cf68c466d3d8236c7e8" +dependencies = [ + "anstyle", + "predicates-core", +] + +[[package]] +name = "predicates-core" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b794032607612e7abeb4db69adb4e33590fa6cf1149e95fd7cb00e634b92f174" + +[[package]] +name = "predicates-tree" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368ba315fb8c5052ab692e68a0eefec6ec57b23a36959c14496f0b0df2c0cecf" +dependencies = [ + "predicates-core", + "termtree", +] + [[package]] name = "pretty_assertions" version = "1.4.0" @@ -5718,6 +5785,12 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "termtree" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" + [[package]] name = "test-context" version = "0.3.0" diff --git a/crates/db_schema/src/newtypes.rs b/crates/db_schema/src/newtypes.rs index c5c9e8e84..9aeaa5266 100644 --- a/crates/db_schema/src/newtypes.rs +++ b/crates/db_schema/src/newtypes.rs @@ -107,7 +107,7 @@ pub struct PrivateMessageReportId(i32); #[cfg_attr(feature = "full", derive(DieselNewType, TS))] #[cfg_attr(feature = "full", ts(export))] /// The site id. -pub struct SiteId(i32); +pub struct SiteId(pub i32); #[derive(Debug, Copy, Clone, Hash, Eq, PartialEq, Serialize, Deserialize, Default)] #[cfg_attr(feature = "full", derive(DieselNewType, TS))] diff --git a/crates/federate/Cargo.toml b/crates/federate/Cargo.toml index 8d98938a8..b8b438901 100644 --- a/crates/federate/Cargo.toml +++ b/crates/federate/Cargo.toml @@ -34,6 +34,7 @@ tokio = { workspace = true, features = ["full"] } tracing.workspace = true moka.workspace = true tokio-util = "0.7.11" +async-trait.workspace = true [dev-dependencies] serial_test = { workspace = true } @@ -42,3 +43,4 @@ actix-web.workspace = true tracing-test = "0.2.5" uuid.workspace = true test-context = "0.3.0" +mockall = "0.12.1" diff --git a/crates/federate/src/inboxes.rs b/crates/federate/src/inboxes.rs index 45fce8119..d99ff5b75 100644 --- a/crates/federate/src/inboxes.rs +++ b/crates/federate/src/inboxes.rs @@ -1,8 +1,9 @@ use crate::util::LEMMY_TEST_FAST_FEDERATION; use anyhow::Result; +use async_trait::async_trait; use chrono::{DateTime, TimeZone, Utc}; use lemmy_db_schema::{ - newtypes::{CommunityId, InstanceId}, + newtypes::{CommunityId, DbUrl, InstanceId}, source::{activity::SentActivity, site::Site}, utils::{ActualDbPool, DbPool}, }; @@ -33,7 +34,52 @@ static FOLLOW_ADDITIONS_RECHECK_DELAY: Lazy = Lazy::new(|| { static FOLLOW_REMOVALS_RECHECK_DELAY: Lazy = Lazy::new(|| chrono::TimeDelta::try_hours(1).expect("TimeDelta out of bounds")); -pub(crate) struct CommunityInboxCollector { +#[async_trait] +pub trait DataSource: Send + Sync { + async fn read_site_from_instance_id( + &self, + instance_id: InstanceId, + ) -> Result, diesel::result::Error>; + async fn get_instance_followed_community_inboxes( + &self, + instance_id: InstanceId, + last_fetch: DateTime, + ) -> Result, diesel::result::Error>; +} +pub struct DbDataSource { + pool: ActualDbPool, +} + +impl DbDataSource { + pub fn new(pool: ActualDbPool) -> Self { + Self { pool } + } +} + +#[async_trait] +impl DataSource for DbDataSource { + async fn read_site_from_instance_id( + &self, + instance_id: InstanceId, + ) -> Result, diesel::result::Error> { + Site::read_from_instance_id(&mut DbPool::Pool(&self.pool), instance_id).await + } + + async fn get_instance_followed_community_inboxes( + &self, + instance_id: InstanceId, + last_fetch: DateTime, + ) -> Result, diesel::result::Error> { + CommunityFollowerView::get_instance_followed_community_inboxes( + &mut DbPool::Pool(&self.pool), + instance_id, + last_fetch, + ) + .await + } +} + +pub(crate) struct CommunityInboxCollector { // load site lazily because if an instance is first seen due to being on allowlist, // the corresponding row in `site` may not exist yet since that is only added once // `fetch_instance_actor_for_object` is called. @@ -45,16 +91,26 @@ pub(crate) struct CommunityInboxCollector { last_incremental_communities_fetch: DateTime, instance_id: InstanceId, domain: String, - pool: ActualDbPool, + pub(crate) data_source: T, } -impl CommunityInboxCollector { - pub fn new( + +pub type RealCommunityInboxCollector = CommunityInboxCollector; + +impl CommunityInboxCollector { + pub fn new_real( pool: ActualDbPool, instance_id: InstanceId, domain: String, - ) -> CommunityInboxCollector { + ) -> RealCommunityInboxCollector { + CommunityInboxCollector::new(DbDataSource::new(pool), instance_id, domain) + } + pub fn new( + data_source: T, + instance_id: InstanceId, + domain: String, + ) -> CommunityInboxCollector { CommunityInboxCollector { - pool, + data_source, site_loaded: false, site: None, followed_communities: HashMap::new(), @@ -73,7 +129,10 @@ impl CommunityInboxCollector { if activity.send_all_instances { if !self.site_loaded { - self.site = Site::read_from_instance_id(&mut self.pool(), self.instance_id).await?; + self.site = self + .data_source + .read_site_from_instance_id(self.instance_id) + .await?; self.site_loaded = true; } if let Some(site) = &self.site { @@ -145,22 +204,397 @@ impl CommunityInboxCollector { // published date is not exact let new_last_fetch = Utc::now() - chrono::TimeDelta::try_seconds(10).expect("TimeDelta out of bounds"); - Ok(( - CommunityFollowerView::get_instance_followed_community_inboxes( - &mut self.pool(), - instance_id, - last_fetch, - ) - .await? - .into_iter() - .fold(HashMap::new(), |mut map, (c, u)| { + + let inboxes = self + .data_source + .get_instance_followed_community_inboxes(instance_id, last_fetch) + .await?; + + let map: HashMap> = + inboxes.into_iter().fold(HashMap::new(), |mut map, (c, u)| { map.entry(c).or_default().insert(u.into()); map - }), - new_last_fetch, - )) - } - fn pool(&self) -> DbPool<'_> { - DbPool::Pool(&self.pool) + }); + + Ok((map, new_last_fetch)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use lemmy_db_schema::{ + newtypes::{ActivityId, CommunityId, InstanceId, SiteId}, + source::activity::{ActorType, SentActivity}, + }; + use mockall::{mock, predicate::*}; + use serde_json::json; + + mock! { + DataSource {} + #[async_trait] + impl DataSource for DataSource { + async fn read_site_from_instance_id(&self, instance_id: InstanceId) -> Result, diesel::result::Error>; + async fn get_instance_followed_community_inboxes( + &self, + instance_id: InstanceId, + last_fetch: DateTime, + ) -> Result, diesel::result::Error>; + } + } + + fn setup_collector() -> CommunityInboxCollector { + let mock_data_source = MockDataSource::new(); + let instance_id = InstanceId(1); + let domain = "example.com".to_string(); + CommunityInboxCollector::new(mock_data_source, instance_id, domain) + } + + #[tokio::test] + async fn test_get_inbox_urls_empty() { + let mut collector = setup_collector(); + let activity = SentActivity { + id: ActivityId(1), + ap_id: Url::parse("https://example.com/activities/1") + .unwrap() + .into(), + data: json!({}), + sensitive: false, + published: Utc::now(), + send_inboxes: vec![], + send_community_followers_of: None, + send_all_instances: false, + actor_type: ActorType::Person, + actor_apub_id: None, + }; + + let result = collector.get_inbox_urls(&activity).await.unwrap(); + assert!(result.is_empty()); + } + + #[tokio::test] + async fn test_get_inbox_urls_send_all_instances() { + let mut collector = setup_collector(); + let site = Site { + id: SiteId(1), + name: "Test Site".to_string(), + sidebar: None, + published: Utc::now(), + updated: None, + icon: None, + banner: None, + description: None, + actor_id: Url::parse("https://example.com/site").unwrap().into(), + last_refreshed_at: Utc::now(), + inbox_url: Url::parse("https://example.com/inbox").unwrap().into(), + private_key: None, + public_key: "test_key".to_string(), + instance_id: InstanceId(1), + content_warning: None, + }; + + collector + .data_source + .expect_read_site_from_instance_id() + .return_once(move |_| Ok(Some(site))); + + let activity = SentActivity { + id: ActivityId(1), + ap_id: Url::parse("https://example.com/activities/1") + .unwrap() + .into(), + data: json!({}), + sensitive: false, + published: Utc::now(), + send_inboxes: vec![], + send_community_followers_of: None, + send_all_instances: true, + actor_type: ActorType::Person, + actor_apub_id: None, + }; + + let result = collector.get_inbox_urls(&activity).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0], Url::parse("https://example.com/inbox").unwrap()); + } + + #[tokio::test] + async fn test_get_inbox_urls_community_followers() { + let mut collector = setup_collector(); + let community_id = CommunityId(1); + + collector + .data_source + .expect_get_instance_followed_community_inboxes() + .return_once(move |_, _| { + Ok(vec![ + ( + community_id, + Url::parse("https://follower1.example.com/inbox").unwrap().into(), + ), + ( + community_id, + Url::parse("https://follower2.example.com/inbox").unwrap().into(), + ), + ]) + }); + + collector.update_communities().await.unwrap(); + + let activity = SentActivity { + id: ActivityId(1), + ap_id: Url::parse("https://example.com/activities/1") + .unwrap() + .into(), + data: json!({}), + sensitive: false, + published: Utc::now(), + send_inboxes: vec![], + send_community_followers_of: Some(community_id), + send_all_instances: false, + actor_type: ActorType::Person, + actor_apub_id: None, + }; + + let result = collector.get_inbox_urls(&activity).await.unwrap(); + assert_eq!(result.len(), 2); + assert!(result.contains(&Url::parse("https://follower1.example.com/inbox").unwrap())); + assert!(result.contains(&Url::parse("https://follower2.example.com/inbox").unwrap())); + } + + #[tokio::test] + async fn test_get_inbox_urls_send_inboxes() { + let mut collector = setup_collector(); + collector.domain = "example.com".to_string(); + + let activity = SentActivity { + id: ActivityId(1), + ap_id: Url::parse("https://example.com/activities/1") + .unwrap() + .into(), + data: json!({}), + sensitive: false, + published: Utc::now(), + send_inboxes: vec![ + Some( + Url::parse("https://example.com/user1/inbox") + .unwrap() + .into(), + ), + Some( + Url::parse("https://example.com/user2/inbox") + .unwrap() + .into(), + ), + Some( + Url::parse("https://other-domain.com/user3/inbox") + .unwrap() + .into(), + ), + ], + send_community_followers_of: None, + send_all_instances: false, + actor_type: ActorType::Person, + actor_apub_id: None, + }; + + let result = collector.get_inbox_urls(&activity).await.unwrap(); + assert_eq!(result.len(), 2); + assert!(result.contains(&Url::parse("https://example.com/user1/inbox").unwrap())); + assert!(result.contains(&Url::parse("https://example.com/user2/inbox").unwrap())); + assert!(!result.contains(&Url::parse("https://other-domain.com/user3/inbox").unwrap())); + } + + #[tokio::test] + async fn test_get_inbox_urls_combined() { + let mut collector = setup_collector(); + collector.domain = "example.com".to_string(); + let community_id = CommunityId(1); + + let site = Site { + id: SiteId(1), + name: "Test Site".to_string(), + sidebar: None, + published: Utc::now(), + updated: None, + icon: None, + banner: None, + description: None, + actor_id: Url::parse("https://example.com/site").unwrap().into(), + last_refreshed_at: Utc::now(), + inbox_url: Url::parse("https://example.com/site_inbox").unwrap().into(), + private_key: None, + public_key: "test_key".to_string(), + instance_id: InstanceId(1), + content_warning: None, + }; + + collector + .data_source + .expect_read_site_from_instance_id() + .return_once(move |_| Ok(Some(site))); + + collector + .data_source + .expect_get_instance_followed_community_inboxes() + .return_once(move |_, _| { + Ok(vec![( + community_id, + Url::parse("https://follower.example.com/inbox").unwrap().into(), + )]) + }); + + collector.update_communities().await.unwrap(); + + let activity = SentActivity { + id: ActivityId(1), + ap_id: Url::parse("https://example.com/activities/1") + .unwrap() + .into(), + data: json!({}), + sensitive: false, + published: Utc::now(), + send_inboxes: vec![ + Some( + Url::parse("https://example.com/user1/inbox") + .unwrap() + .into(), + ), + Some( + Url::parse("https://other-domain.com/user2/inbox") + .unwrap() + .into(), + ), + ], + send_community_followers_of: Some(community_id), + send_all_instances: true, + actor_type: ActorType::Person, + actor_apub_id: None, + }; + + let result = collector.get_inbox_urls(&activity).await.unwrap(); + assert_eq!(result.len(), 3); + assert!(result.contains(&Url::parse("https://example.com/site_inbox").unwrap())); + assert!(result.contains(&Url::parse("https://follower.example.com/inbox").unwrap())); + assert!(result.contains(&Url::parse("https://example.com/user1/inbox").unwrap())); + assert!(!result.contains(&Url::parse("https://other-domain.com/user2/inbox").unwrap())); + } + + #[tokio::test] + async fn test_update_communities() { + let mut collector = setup_collector(); + let community_id1 = CommunityId(1); + let community_id2 = CommunityId(2); + let community_id3 = CommunityId(3); + + collector + .data_source + .expect_get_instance_followed_community_inboxes() + .times(2) + .returning(move |_, last_fetch| { + if last_fetch == Utc.timestamp_nanos(0) { + Ok(vec![ + ( + community_id1, + Url::parse("https://follower1.example.com/inbox").unwrap().into(), + ), + ( + community_id2, + Url::parse("https://follower2.example.com/inbox").unwrap().into(), + ), + ]) + } else { + Ok(vec![( + community_id3, + Url::parse("https://follower3.example.com/inbox").unwrap().into(), + )]) + } + }); + + // First update + collector.update_communities().await.unwrap(); + assert_eq!(collector.followed_communities.len(), 2); + assert!(collector.followed_communities[&community_id1] + .contains(&Url::parse("https://follower1.example.com/inbox").unwrap())); + assert!(collector.followed_communities[&community_id2] + .contains(&Url::parse("https://follower2.example.com/inbox").unwrap())); + + // Simulate time passing + collector.last_full_communities_fetch = Utc::now() - chrono::TimeDelta::try_minutes(3).unwrap(); + collector.last_incremental_communities_fetch = + Utc::now() - chrono::TimeDelta::try_minutes(3).unwrap(); + + // Second update (incremental) + collector.update_communities().await.unwrap(); + assert_eq!(collector.followed_communities.len(), 3); + assert!(collector.followed_communities[&community_id1] + .contains(&Url::parse("https://follower1.example.com/inbox").unwrap())); + assert!(collector.followed_communities[&community_id3] + .contains(&Url::parse("https://follower3.example.com/inbox").unwrap())); + assert!(collector.followed_communities[&community_id2] + .contains(&Url::parse("https://follower2.example.com/inbox").unwrap())); + } + + #[tokio::test] + async fn test_get_inbox_urls_no_duplicates() { + let mut collector = setup_collector(); + collector.domain = "example.com".to_string(); + let community_id = CommunityId(1); + + let site = Site { + id: SiteId(1), + name: "Test Site".to_string(), + sidebar: None, + published: Utc::now(), + updated: None, + icon: None, + banner: None, + description: None, + actor_id: Url::parse("https://example.com/site").unwrap().into(), + last_refreshed_at: Utc::now(), + inbox_url: Url::parse("https://example.com/site_inbox").unwrap().into(), + private_key: None, + public_key: "test_key".to_string(), + instance_id: InstanceId(1), + content_warning: None, + }; + + collector + .data_source + .expect_read_site_from_instance_id() + .return_once(move |_| Ok(Some(site))); + + collector + .data_source + .expect_get_instance_followed_community_inboxes() + .return_once(move |_, _| { + Ok(vec![( + community_id, + Url::parse("https://example.com/site_inbox").unwrap().into(), + )]) + }); + + collector.update_communities().await.unwrap(); + + let activity = SentActivity { + id: ActivityId(1), + ap_id: Url::parse("https://example.com/activities/1") + .unwrap() + .into(), + data: json!({}), + sensitive: false, + published: Utc::now(), + send_inboxes: vec![Some( + Url::parse("https://example.com/site_inbox").unwrap().into(), + )], + send_community_followers_of: Some(community_id), + send_all_instances: true, + actor_type: ActorType::Person, + actor_apub_id: None, + }; + + let result = collector.get_inbox_urls(&activity).await.unwrap(); + assert_eq!(result.len(), 1); + assert!(result.contains(&Url::parse("https://example.com/site_inbox").unwrap())); } } diff --git a/crates/federate/src/worker.rs b/crates/federate/src/worker.rs index 247071c36..28210d99d 100644 --- a/crates/federate/src/worker.rs +++ b/crates/federate/src/worker.rs @@ -1,5 +1,5 @@ use crate::{ - inboxes::CommunityInboxCollector, + inboxes::RealCommunityInboxCollector, send::{SendActivityResult, SendRetryTask, SendSuccessInfo}, util::{ get_activity_cached, @@ -65,7 +65,7 @@ pub(crate) struct InstanceWorker { state: FederationQueueState, last_state_insert: DateTime, pool: ActualDbPool, - inbox_collector: CommunityInboxCollector, + inbox_collector: RealCommunityInboxCollector, // regularily send stats back to the SendManager stats_sender: UnboundedSender, // each HTTP send will report back to this channel concurrently @@ -92,7 +92,7 @@ impl InstanceWorker { let (report_send_result, receive_send_result) = tokio::sync::mpsc::unbounded_channel::(); let mut worker = InstanceWorker { - inbox_collector: CommunityInboxCollector::new( + inbox_collector: RealCommunityInboxCollector::new_real( pool.clone(), instance.id, instance.domain.clone(),