mas_storage_pg/compat/
mod.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7//! A module containing PostgreSQL implementation of repositories for the
8//! compatibility layer
9
10mod access_token;
11mod refresh_token;
12mod session;
13mod sso_login;
14
15pub use self::{
16    access_token::PgCompatAccessTokenRepository, refresh_token::PgCompatRefreshTokenRepository,
17    session::PgCompatSessionRepository, sso_login::PgCompatSsoLoginRepository,
18};
19
20#[cfg(test)]
21mod tests {
22    use chrono::Duration;
23    use mas_data_model::{Device, UserAgent};
24    use mas_storage::{
25        Clock, Pagination, RepositoryAccess,
26        clock::MockClock,
27        compat::{
28            CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionFilter,
29            CompatSessionRepository, CompatSsoLoginFilter,
30        },
31        user::UserRepository,
32    };
33    use rand::SeedableRng;
34    use rand_chacha::ChaChaRng;
35    use sqlx::PgPool;
36    use ulid::Ulid;
37
38    use crate::PgRepository;
39
40    #[sqlx::test(migrator = "crate::MIGRATOR")]
41    async fn test_session_repository(pool: PgPool) {
42        let mut rng = ChaChaRng::seed_from_u64(42);
43        let clock = MockClock::default();
44        let mut repo = PgRepository::from_pool(&pool).await.unwrap();
45
46        // Create a user
47        let user = repo
48            .user()
49            .add(&mut rng, &clock, "john".to_owned())
50            .await
51            .unwrap();
52
53        let all = CompatSessionFilter::new().for_user(&user);
54        let active = all.active_only();
55        let finished = all.finished_only();
56        let pagination = Pagination::first(10);
57
58        assert_eq!(repo.compat_session().count(all).await.unwrap(), 0);
59        assert_eq!(repo.compat_session().count(active).await.unwrap(), 0);
60        assert_eq!(repo.compat_session().count(finished).await.unwrap(), 0);
61
62        let full_list = repo.compat_session().list(all, pagination).await.unwrap();
63        assert!(full_list.edges.is_empty());
64        let active_list = repo
65            .compat_session()
66            .list(active, pagination)
67            .await
68            .unwrap();
69        assert!(active_list.edges.is_empty());
70        let finished_list = repo
71            .compat_session()
72            .list(finished, pagination)
73            .await
74            .unwrap();
75        assert!(finished_list.edges.is_empty());
76
77        // Start a compat session for that user
78        let device = Device::generate(&mut rng);
79        let device_str = device.as_str().to_owned();
80        let session = repo
81            .compat_session()
82            .add(&mut rng, &clock, &user, device.clone(), None, false)
83            .await
84            .unwrap();
85        assert_eq!(session.user_id, user.id);
86        assert_eq!(session.device.as_ref().unwrap().as_str(), device_str);
87        assert!(session.is_valid());
88        assert!(!session.is_finished());
89
90        assert_eq!(repo.compat_session().count(all).await.unwrap(), 1);
91        assert_eq!(repo.compat_session().count(active).await.unwrap(), 1);
92        assert_eq!(repo.compat_session().count(finished).await.unwrap(), 0);
93
94        let full_list = repo.compat_session().list(all, pagination).await.unwrap();
95        assert_eq!(full_list.edges.len(), 1);
96        assert_eq!(full_list.edges[0].0.id, session.id);
97        let active_list = repo
98            .compat_session()
99            .list(active, pagination)
100            .await
101            .unwrap();
102        assert_eq!(active_list.edges.len(), 1);
103        assert_eq!(active_list.edges[0].0.id, session.id);
104        let finished_list = repo
105            .compat_session()
106            .list(finished, pagination)
107            .await
108            .unwrap();
109        assert!(finished_list.edges.is_empty());
110
111        // Lookup the session and check it didn't change
112        let session_lookup = repo
113            .compat_session()
114            .lookup(session.id)
115            .await
116            .unwrap()
117            .expect("compat session not found");
118        assert_eq!(session_lookup.id, session.id);
119        assert_eq!(session_lookup.user_id, user.id);
120        assert_eq!(session.device.as_ref().unwrap().as_str(), device_str);
121        assert!(session_lookup.is_valid());
122        assert!(!session_lookup.is_finished());
123
124        // Record a user-agent for the session
125        assert!(session_lookup.user_agent.is_none());
126        let session = repo
127            .compat_session()
128            .record_user_agent(session_lookup, UserAgent::parse("Mozilla/5.0".to_owned()))
129            .await
130            .unwrap();
131        assert_eq!(session.user_agent.as_deref(), Some("Mozilla/5.0"));
132
133        // Reload the session and check again
134        let session_lookup = repo
135            .compat_session()
136            .lookup(session.id)
137            .await
138            .unwrap()
139            .expect("compat session not found");
140        assert_eq!(session_lookup.user_agent.as_deref(), Some("Mozilla/5.0"));
141
142        // Look up the session by device
143        let list = repo
144            .compat_session()
145            .list(
146                CompatSessionFilter::new()
147                    .for_user(&user)
148                    .for_device(&device),
149                pagination,
150            )
151            .await
152            .unwrap();
153        assert_eq!(list.edges.len(), 1);
154        let session_lookup = &list.edges[0].0;
155        assert_eq!(session_lookup.id, session.id);
156        assert_eq!(session_lookup.user_id, user.id);
157        assert_eq!(session.device.as_ref().unwrap().as_str(), device_str);
158        assert!(session_lookup.is_valid());
159        assert!(!session_lookup.is_finished());
160
161        // Finish the session
162        let session = repo.compat_session().finish(&clock, session).await.unwrap();
163        assert!(!session.is_valid());
164        assert!(session.is_finished());
165
166        assert_eq!(repo.compat_session().count(all).await.unwrap(), 1);
167        assert_eq!(repo.compat_session().count(active).await.unwrap(), 0);
168        assert_eq!(repo.compat_session().count(finished).await.unwrap(), 1);
169
170        let full_list = repo.compat_session().list(all, pagination).await.unwrap();
171        assert_eq!(full_list.edges.len(), 1);
172        assert_eq!(full_list.edges[0].0.id, session.id);
173        let active_list = repo
174            .compat_session()
175            .list(active, pagination)
176            .await
177            .unwrap();
178        assert!(active_list.edges.is_empty());
179        let finished_list = repo
180            .compat_session()
181            .list(finished, pagination)
182            .await
183            .unwrap();
184        assert_eq!(finished_list.edges.len(), 1);
185        assert_eq!(finished_list.edges[0].0.id, session.id);
186
187        // Reload the session and check again
188        let session_lookup = repo
189            .compat_session()
190            .lookup(session.id)
191            .await
192            .unwrap()
193            .expect("compat session not found");
194        assert!(!session_lookup.is_valid());
195        assert!(session_lookup.is_finished());
196
197        // Now add another session, with an SSO login this time
198        let unknown_session = session;
199        // Start a new SSO login
200        let login = repo
201            .compat_sso_login()
202            .add(
203                &mut rng,
204                &clock,
205                "login-token".to_owned(),
206                "https://example.com/callback".parse().unwrap(),
207            )
208            .await
209            .unwrap();
210        assert!(login.is_pending());
211
212        // Start a compat session for that user
213        let device = Device::generate(&mut rng);
214        let sso_login_session = repo
215            .compat_session()
216            .add(&mut rng, &clock, &user, device, None, false)
217            .await
218            .unwrap();
219
220        // Associate the login with the session
221        let login = repo
222            .compat_sso_login()
223            .fulfill(&clock, login, &sso_login_session)
224            .await
225            .unwrap();
226        assert!(login.is_fulfilled());
227
228        // Now query the session list with both the unknown and SSO login session type
229        // filter
230        let all = CompatSessionFilter::new().for_user(&user);
231        let sso_login = all.sso_login_only();
232        let unknown = all.unknown_only();
233        assert_eq!(repo.compat_session().count(all).await.unwrap(), 2);
234        assert_eq!(repo.compat_session().count(sso_login).await.unwrap(), 1);
235        assert_eq!(repo.compat_session().count(unknown).await.unwrap(), 1);
236
237        let list = repo
238            .compat_session()
239            .list(sso_login, pagination)
240            .await
241            .unwrap();
242        assert_eq!(list.edges.len(), 1);
243        assert_eq!(list.edges[0].0.id, sso_login_session.id);
244        let list = repo
245            .compat_session()
246            .list(unknown, pagination)
247            .await
248            .unwrap();
249        assert_eq!(list.edges.len(), 1);
250        assert_eq!(list.edges[0].0.id, unknown_session.id);
251
252        // Check that combining the two filters works
253        // At this point, there is one active SSO login session and one finished unknown
254        // session
255        assert_eq!(
256            repo.compat_session()
257                .count(all.sso_login_only().active_only())
258                .await
259                .unwrap(),
260            1
261        );
262        assert_eq!(
263            repo.compat_session()
264                .count(all.sso_login_only().finished_only())
265                .await
266                .unwrap(),
267            0
268        );
269        assert_eq!(
270            repo.compat_session()
271                .count(all.unknown_only().active_only())
272                .await
273                .unwrap(),
274            0
275        );
276        assert_eq!(
277            repo.compat_session()
278                .count(all.unknown_only().finished_only())
279                .await
280                .unwrap(),
281            1
282        );
283
284        // Check that we can batch finish sessions
285        let affected = repo
286            .compat_session()
287            .finish_bulk(&clock, all.sso_login_only().active_only())
288            .await
289            .unwrap();
290        assert_eq!(affected, 1);
291        assert_eq!(repo.compat_session().count(finished).await.unwrap(), 2);
292        assert_eq!(repo.compat_session().count(active).await.unwrap(), 0);
293    }
294
295    #[sqlx::test(migrator = "crate::MIGRATOR")]
296    async fn test_access_token_repository(pool: PgPool) {
297        const FIRST_TOKEN: &str = "first_access_token";
298        const SECOND_TOKEN: &str = "second_access_token";
299        let mut rng = ChaChaRng::seed_from_u64(42);
300        let clock = MockClock::default();
301        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
302
303        // Create a user
304        let user = repo
305            .user()
306            .add(&mut rng, &clock, "john".to_owned())
307            .await
308            .unwrap();
309
310        // Start a compat session for that user
311        let device = Device::generate(&mut rng);
312        let session = repo
313            .compat_session()
314            .add(&mut rng, &clock, &user, device, None, false)
315            .await
316            .unwrap();
317
318        // Add an access token to that session
319        let token = repo
320            .compat_access_token()
321            .add(
322                &mut rng,
323                &clock,
324                &session,
325                FIRST_TOKEN.to_owned(),
326                Some(Duration::try_minutes(1).unwrap()),
327            )
328            .await
329            .unwrap();
330        assert_eq!(token.session_id, session.id);
331        assert_eq!(token.token, FIRST_TOKEN);
332
333        // Commit the txn and grab a new transaction, to test a conflict
334        repo.save().await.unwrap();
335
336        {
337            let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
338            // Adding the same token a second time should conflict
339            assert!(
340                repo.compat_access_token()
341                    .add(
342                        &mut rng,
343                        &clock,
344                        &session,
345                        FIRST_TOKEN.to_owned(),
346                        Some(Duration::try_minutes(1).unwrap()),
347                    )
348                    .await
349                    .is_err()
350            );
351            repo.cancel().await.unwrap();
352        }
353
354        // Grab a new repo
355        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
356
357        // Looking up via ID works
358        let token_lookup = repo
359            .compat_access_token()
360            .lookup(token.id)
361            .await
362            .unwrap()
363            .expect("compat access token not found");
364        assert_eq!(token.id, token_lookup.id);
365        assert_eq!(token_lookup.session_id, session.id);
366
367        // Looking up via the token value works
368        let token_lookup = repo
369            .compat_access_token()
370            .find_by_token(FIRST_TOKEN)
371            .await
372            .unwrap()
373            .expect("compat access token not found");
374        assert_eq!(token.id, token_lookup.id);
375        assert_eq!(token_lookup.session_id, session.id);
376
377        // Token is currently valid
378        assert!(token.is_valid(clock.now()));
379
380        clock.advance(Duration::try_minutes(1).unwrap());
381        // Token should have expired
382        assert!(!token.is_valid(clock.now()));
383
384        // Add a second access token, this time without expiration
385        let token = repo
386            .compat_access_token()
387            .add(&mut rng, &clock, &session, SECOND_TOKEN.to_owned(), None)
388            .await
389            .unwrap();
390        assert_eq!(token.session_id, session.id);
391        assert_eq!(token.token, SECOND_TOKEN);
392
393        // Token is currently valid
394        assert!(token.is_valid(clock.now()));
395
396        // Make it expire
397        repo.compat_access_token()
398            .expire(&clock, token)
399            .await
400            .unwrap();
401
402        // Reload it
403        let token = repo
404            .compat_access_token()
405            .find_by_token(SECOND_TOKEN)
406            .await
407            .unwrap()
408            .expect("compat access token not found");
409
410        // Token is not valid anymore
411        assert!(!token.is_valid(clock.now()));
412
413        repo.save().await.unwrap();
414    }
415
416    #[sqlx::test(migrator = "crate::MIGRATOR")]
417    async fn test_refresh_token_repository(pool: PgPool) {
418        const ACCESS_TOKEN: &str = "access_token";
419        const REFRESH_TOKEN: &str = "refresh_token";
420        let mut rng = ChaChaRng::seed_from_u64(42);
421        let clock = MockClock::default();
422        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
423
424        // Create a user
425        let user = repo
426            .user()
427            .add(&mut rng, &clock, "john".to_owned())
428            .await
429            .unwrap();
430
431        // Start a compat session for that user
432        let device = Device::generate(&mut rng);
433        let session = repo
434            .compat_session()
435            .add(&mut rng, &clock, &user, device, None, false)
436            .await
437            .unwrap();
438
439        // Add an access token to that session
440        let access_token = repo
441            .compat_access_token()
442            .add(&mut rng, &clock, &session, ACCESS_TOKEN.to_owned(), None)
443            .await
444            .unwrap();
445
446        let refresh_token = repo
447            .compat_refresh_token()
448            .add(
449                &mut rng,
450                &clock,
451                &session,
452                &access_token,
453                REFRESH_TOKEN.to_owned(),
454            )
455            .await
456            .unwrap();
457        assert_eq!(refresh_token.session_id, session.id);
458        assert_eq!(refresh_token.access_token_id, access_token.id);
459        assert_eq!(refresh_token.token, REFRESH_TOKEN);
460        assert!(refresh_token.is_valid());
461        assert!(!refresh_token.is_consumed());
462
463        // Look it up by ID and check everything matches
464        let refresh_token_lookup = repo
465            .compat_refresh_token()
466            .lookup(refresh_token.id)
467            .await
468            .unwrap()
469            .expect("refresh token not found");
470        assert_eq!(refresh_token_lookup.id, refresh_token.id);
471        assert_eq!(refresh_token_lookup.session_id, session.id);
472        assert_eq!(refresh_token_lookup.access_token_id, access_token.id);
473        assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN);
474        assert!(refresh_token_lookup.is_valid());
475        assert!(!refresh_token_lookup.is_consumed());
476
477        // Look it up by token and check everything matches
478        let refresh_token_lookup = repo
479            .compat_refresh_token()
480            .find_by_token(REFRESH_TOKEN)
481            .await
482            .unwrap()
483            .expect("refresh token not found");
484        assert_eq!(refresh_token_lookup.id, refresh_token.id);
485        assert_eq!(refresh_token_lookup.session_id, session.id);
486        assert_eq!(refresh_token_lookup.access_token_id, access_token.id);
487        assert_eq!(refresh_token_lookup.token, REFRESH_TOKEN);
488        assert!(refresh_token_lookup.is_valid());
489        assert!(!refresh_token_lookup.is_consumed());
490
491        // Consume it
492        let refresh_token = repo
493            .compat_refresh_token()
494            .consume(&clock, refresh_token)
495            .await
496            .unwrap();
497        assert!(!refresh_token.is_valid());
498        assert!(refresh_token.is_consumed());
499
500        // Reload it and check again
501        let refresh_token_lookup = repo
502            .compat_refresh_token()
503            .find_by_token(REFRESH_TOKEN)
504            .await
505            .unwrap()
506            .expect("refresh token not found");
507        assert!(!refresh_token_lookup.is_valid());
508        assert!(refresh_token_lookup.is_consumed());
509
510        // Consuming it again should not work
511        assert!(
512            repo.compat_refresh_token()
513                .consume(&clock, refresh_token)
514                .await
515                .is_err()
516        );
517
518        repo.save().await.unwrap();
519    }
520
521    #[sqlx::test(migrator = "crate::MIGRATOR")]
522    async fn test_compat_sso_login_repository(pool: PgPool) {
523        let mut rng = ChaChaRng::seed_from_u64(42);
524        let clock = MockClock::default();
525        let mut repo = PgRepository::from_pool(&pool).await.unwrap().boxed();
526
527        // Create a user
528        let user = repo
529            .user()
530            .add(&mut rng, &clock, "john".to_owned())
531            .await
532            .unwrap();
533
534        // Lookup an unknown SSO login
535        let login = repo.compat_sso_login().lookup(Ulid::nil()).await.unwrap();
536        assert_eq!(login, None);
537
538        let all = CompatSsoLoginFilter::new();
539        let for_user = all.for_user(&user);
540        let pending = all.pending_only();
541        let fulfilled = all.fulfilled_only();
542        let exchanged = all.exchanged_only();
543
544        // Check the initial counts
545        assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 0);
546        assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 0);
547        assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
548        assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
549        assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
550
551        // Lookup an unknown login token
552        let login = repo
553            .compat_sso_login()
554            .find_by_token("login-token")
555            .await
556            .unwrap();
557        assert_eq!(login, None);
558
559        // Start a new SSO login
560        let login = repo
561            .compat_sso_login()
562            .add(
563                &mut rng,
564                &clock,
565                "login-token".to_owned(),
566                "https://example.com/callback".parse().unwrap(),
567            )
568            .await
569            .unwrap();
570        assert!(login.is_pending());
571
572        // Check the counts
573        assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
574        assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 0);
575        assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 1);
576        assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
577        assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
578
579        // Lookup the login by ID
580        let login_lookup = repo
581            .compat_sso_login()
582            .lookup(login.id)
583            .await
584            .unwrap()
585            .expect("login not found");
586        assert_eq!(login_lookup, login);
587
588        // Find the login by token
589        let login_lookup = repo
590            .compat_sso_login()
591            .find_by_token("login-token")
592            .await
593            .unwrap()
594            .expect("login not found");
595        assert_eq!(login_lookup, login);
596
597        // Exchanging before fulfilling should not work
598        // Note: It should also not poison the SQL transaction
599        let res = repo
600            .compat_sso_login()
601            .exchange(&clock, login.clone())
602            .await;
603        assert!(res.is_err());
604
605        // Start a compat session for that user
606        let device = Device::generate(&mut rng);
607        let session = repo
608            .compat_session()
609            .add(&mut rng, &clock, &user, device, None, false)
610            .await
611            .unwrap();
612
613        // Associate the login with the session
614        let login = repo
615            .compat_sso_login()
616            .fulfill(&clock, login, &session)
617            .await
618            .unwrap();
619        assert!(login.is_fulfilled());
620
621        // Check the counts
622        assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
623        assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 1);
624        assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
625        assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 1);
626        assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 0);
627
628        // Fulfilling again should not work
629        // Note: It should also not poison the SQL transaction
630        let res = repo
631            .compat_sso_login()
632            .fulfill(&clock, login.clone(), &session)
633            .await;
634        assert!(res.is_err());
635
636        // Exchange that login
637        let login = repo
638            .compat_sso_login()
639            .exchange(&clock, login)
640            .await
641            .unwrap();
642        assert!(login.is_exchanged());
643
644        // Check the counts
645        assert_eq!(repo.compat_sso_login().count(all).await.unwrap(), 1);
646        assert_eq!(repo.compat_sso_login().count(for_user).await.unwrap(), 1);
647        assert_eq!(repo.compat_sso_login().count(pending).await.unwrap(), 0);
648        assert_eq!(repo.compat_sso_login().count(fulfilled).await.unwrap(), 0);
649        assert_eq!(repo.compat_sso_login().count(exchanged).await.unwrap(), 1);
650
651        // Exchange again should not work
652        // Note: It should also not poison the SQL transaction
653        let res = repo
654            .compat_sso_login()
655            .exchange(&clock, login.clone())
656            .await;
657        assert!(res.is_err());
658
659        // Fulfilling after exchanging should not work
660        // Note: It should also not poison the SQL transaction
661        let res = repo
662            .compat_sso_login()
663            .fulfill(&clock, login.clone(), &session)
664            .await;
665        assert!(res.is_err());
666
667        let pagination = Pagination::first(10);
668
669        // List all logins
670        let logins = repo.compat_sso_login().list(all, pagination).await.unwrap();
671        assert!(!logins.has_next_page);
672        assert_eq!(logins.edges, &[login.clone()]);
673
674        // List the logins for the user
675        let logins = repo
676            .compat_sso_login()
677            .list(for_user, pagination)
678            .await
679            .unwrap();
680        assert!(!logins.has_next_page);
681        assert_eq!(logins.edges, &[login.clone()]);
682
683        // List only the pending logins for the user
684        let logins = repo
685            .compat_sso_login()
686            .list(for_user.pending_only(), pagination)
687            .await
688            .unwrap();
689        assert!(!logins.has_next_page);
690        assert!(logins.edges.is_empty());
691
692        // List only the fulfilled logins for the user
693        let logins = repo
694            .compat_sso_login()
695            .list(for_user.fulfilled_only(), pagination)
696            .await
697            .unwrap();
698        assert!(!logins.has_next_page);
699        assert!(logins.edges.is_empty());
700
701        // List only the exchanged logins for the user
702        let logins = repo
703            .compat_sso_login()
704            .list(for_user.exchanged_only(), pagination)
705            .await
706            .unwrap();
707        assert!(!logins.has_next_page);
708        assert_eq!(logins.edges, &[login]);
709    }
710}