132 lines
4.0 KiB
Rust
132 lines
4.0 KiB
Rust
use std::time::Duration as StdDuration;
|
|
|
|
use serde::{Serialize, Deserialize};
|
|
use chrono::naive::serde::ts_seconds::serialize as ts_seconds_naive;
|
|
use chrono::{Duration as ChronoDuration, NaiveDateTime, Utc, DateTime};
|
|
use diesel::prelude::*;
|
|
use rand::Rng;
|
|
use rand::rngs::OsRng;
|
|
use rand::distributions::Alphanumeric;
|
|
use rocket::request::{FromRequest, Request, Outcome};
|
|
use rocket::outcome::try_outcome;
|
|
|
|
use crate::schema::*;
|
|
use crate::DbConn;
|
|
use crate::models::user::UserInfo;
|
|
use crate::models::errors::{UserError, ErrorResponse, make_500};
|
|
|
|
const BEARER: &str = "Bearer ";
|
|
const AUTH_HEADER: &str = "Authorization";
|
|
pub const COOKIE_NAME: &str = "session_id";
|
|
|
|
|
|
#[derive(Debug, Deserialize, FromForm)]
|
|
pub struct AuthTokenRequest {
|
|
pub email: String,
|
|
pub password: String,
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Queryable, Identifiable, Insertable)]
|
|
#[table_name = "session"]
|
|
#[primary_key(session_id)]
|
|
pub struct Session {
|
|
#[serde(rename = "token")]
|
|
pub session_id: String,
|
|
#[serde(skip)]
|
|
pub user_id: String,
|
|
#[serde(serialize_with = "ts_seconds_naive")]
|
|
pub expires_at: NaiveDateTime,
|
|
}
|
|
|
|
impl Session {
|
|
pub fn generate_id() -> String {
|
|
OsRng
|
|
.sample_iter(&Alphanumeric)
|
|
.take(50)
|
|
.map(char::from)
|
|
.collect()
|
|
}
|
|
|
|
pub fn from_session_id(conn: &diesel::SqliteConnection, id: &str) -> Result<Session, UserError> {
|
|
use crate::schema::session::dsl::*;
|
|
session
|
|
.find(id)
|
|
.get_result(conn)
|
|
.map_err(|_| UserError::ExpiredSession)
|
|
.and_then(|s: Session| {
|
|
let expires = DateTime::<Utc>::from_utc(s.expires_at, Utc);
|
|
if expires < Utc::now() {
|
|
Err(UserError::ExpiredSession)
|
|
} else {
|
|
Ok(s)
|
|
}
|
|
})
|
|
}
|
|
|
|
pub fn new(conn: &diesel::SqliteConnection, user_info: &UserInfo, token_duration: StdDuration) -> Result<Session, UserError> {
|
|
use crate::schema::session::dsl::*;
|
|
|
|
let expires = Utc::now() + ChronoDuration::from_std(token_duration).unwrap();
|
|
|
|
let user_session = Session {
|
|
session_id: Session::generate_id(),
|
|
user_id: user_info.id.clone(),
|
|
expires_at: expires.naive_utc(),
|
|
};
|
|
|
|
diesel::insert_into(session)
|
|
.values(&user_session)
|
|
.execute(conn)
|
|
.map_err(UserError::DbError)?;
|
|
|
|
Ok(user_session)
|
|
|
|
}
|
|
|
|
fn get_token_from_header<'r>(request: &'r Request<'_>) -> Outcome<String, ErrorResponse> {
|
|
let auth_header = match request.headers().get_one(AUTH_HEADER) {
|
|
None => return Outcome::Forward(()),
|
|
Some(auth_header) => auth_header,
|
|
};
|
|
|
|
let token = if auth_header.starts_with(BEARER) {
|
|
auth_header.trim_start_matches(BEARER).to_string()
|
|
} else {
|
|
return ErrorResponse::from(UserError::MalformedHeader).into();
|
|
};
|
|
|
|
Outcome::Success(token)
|
|
}
|
|
|
|
fn get_token_from_cookie<'r>(request: &'r Request<'_>) -> Outcome<String, ErrorResponse> {
|
|
match request.cookies().get(COOKIE_NAME) {
|
|
None => Outcome::Forward(()),
|
|
Some(session_cookie) => Outcome::Success(session_cookie.value().to_string()),
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
#[rocket::async_trait]
|
|
impl<'r> FromRequest<'r> for Session {
|
|
type Error = ErrorResponse;
|
|
|
|
async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
|
|
let token = try_outcome!(
|
|
Session::get_token_from_header(request)
|
|
.forward_then(|_| Session::get_token_from_cookie(request))
|
|
.forward_then(|_| ErrorResponse::from(UserError::MissingToken).into())
|
|
);
|
|
|
|
let conn = try_outcome!(request.guard::<DbConn>().await.map_failure(make_500));
|
|
|
|
conn.run(move |c| {
|
|
match Session::from_session_id(c, &token) {
|
|
Err(e) => ErrorResponse::from(e).into(),
|
|
Ok(s) => Outcome::Success(s),
|
|
}
|
|
}).await
|
|
}
|
|
}
|