diff --git a/init-db.sql b/init-db.sql index 3c1c960..1fdcab7 100644 --- a/init-db.sql +++ b/init-db.sql @@ -6,3 +6,5 @@ CREATE TABLE IF NOT EXISTS files ( kind varchar(255) not null, primary key (id) ); + +ALTER TABLE files ADD COLUMN IF NOT EXISTS delete_on_download boolean; diff --git a/src/deleter.rs b/src/deleter.rs index 8c1a34f..070bdb6 100644 --- a/src/deleter.rs +++ b/src/deleter.rs @@ -1,6 +1,11 @@ -use async_std::{channel::Receiver, fs, path::PathBuf, task}; +use async_std::{ + channel::Receiver, + fs, + path::{Path, PathBuf}, + task, +}; use chrono::{prelude::*, Duration}; -use futures::{TryStreamExt, future::FutureExt}; +use futures::{future::FutureExt, TryStreamExt}; use sqlx::{postgres::PgPool, Row}; pub(crate) async fn delete_old_files(receiver: Receiver<()>, db: PgPool, files_dir: PathBuf) { @@ -13,12 +18,9 @@ pub(crate) async fn delete_old_files(receiver: Receiver<()>, db: PgPool, files_d .fetch(&db); while let Some(row) = rows.try_next().await.expect("could not load expired files") { let file_id: String = row.try_get("file_id").expect("we selected this column"); - let mut path = files_dir.clone(); - path.push(&file_id); - if path.exists().await { - log::info!("delete file {}", file_id); - fs::remove_file(&path).await.expect("could not delete file"); - } + delete_content(&file_id, &files_dir) + .await + .expect("could not delete file"); } sqlx::query("DELETE FROM files WHERE valid_till < $1") @@ -29,6 +31,28 @@ pub(crate) async fn delete_old_files(receiver: Receiver<()>, db: PgPool, files_d } } +pub(crate) async fn delete_by_id( + db: &PgPool, + file_id: &str, + files_dir: &Path, +) -> Result<(), sqlx::Error> { + delete_content(file_id, &files_dir).await?; + sqlx::query("DELETE FROM files WHERE file_id = $1") + .bind(file_id) + .execute(db) + .await?; + Ok(()) +} + +async fn delete_content(file_id: &str, files_dir: &Path) -> Result<(), std::io::Error> { + let path = files_dir.join(file_id); + if path.exists().await { + log::info!("delete file {}", file_id); + fs::remove_file(&path).await?; + } + Ok(()) +} + async fn wait_for_file_expiry(receiver: &Receiver<()>, db: &PgPool) { let mut rows = sqlx::query("SELECT MIN(valid_till) as min from files").fetch(db); let row = rows diff --git a/src/main.rs b/src/main.rs index 8c43a36..2421d85 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,9 +15,11 @@ use async_std::{ path::{Path, PathBuf}, task, }; +use env_logger::Env; use file_kind::FileKind; use futures::TryStreamExt; use mime::Mime; +use multipart::UploadConfig; use rand::prelude::SliceRandom; use sqlx::{ postgres::{PgPool, PgPoolOptions, PgRow}, @@ -58,7 +60,12 @@ async fn upload( let parsed_multipart = multipart::parse_multipart(payload, &file_id, &filename, config.max_file_size).await; - let (original_name, valid_till, kind) = match parsed_multipart { + let UploadConfig { + original_name, + valid_till, + kind, + delete_on_download, + } = match parsed_multipart { Ok(data) => data, Err(err) => { if filename.exists().await { @@ -73,12 +80,14 @@ async fn upload( }; let db_insert = sqlx::query( - "INSERT INTO Files (file_id, file_name, valid_till, kind) VALUES ($1, $2, $3, $4)", + "INSERT INTO Files (file_id, file_name, valid_till, kind, delete_on_download) \ + VALUES ($1, $2, $3, $4, $5)", ) .bind(&file_id) - .bind(original_name.as_ref().unwrap_or(&file_id)) + .bind(&original_name) .bind(valid_till.naive_local()) .bind(kind.to_string()) + .bind(delete_on_download) .execute(db.as_ref()) .await; if db_insert.is_err() { @@ -93,22 +102,24 @@ async fn upload( } log::info!( - "create new file {} (valid_till: {}, kind: {})", + "{} create new file {} (valid_till: {}, kind: {}, delete_on_download: {})", + req.connection_info().realip_remote_addr().unwrap_or("-"), file_id, valid_till, - kind + kind, + delete_on_download ); expiry_watch_sender.send(()).await.unwrap(); - let redirect = if kind == FileKind::BINARY && original_name.is_some() { - let encoded_name = urlencoding::encode(original_name.as_ref().unwrap()); + let redirect = if kind == FileKind::BINARY { + let encoded_name = urlencoding::encode(&original_name); format!("/upload/{}/{}", file_id, encoded_name) } else { format!("/upload/{}", file_id) }; - let url = get_file_url(&req, &file_id, original_name.as_deref()); + let url = get_file_url(&req, &file_id, Some(&original_name)); Ok(HttpResponse::SeeOther() .header("location", redirect) .body(format!("{}\n", url))) @@ -153,9 +164,10 @@ async fn download( config: web::Data, ) -> Result { let id = req.match_info().query("id"); - let mut rows = sqlx::query("SELECT file_id, file_name from files WHERE file_id = $1") - .bind(id) - .fetch(db.as_ref()); + let mut rows = + sqlx::query("SELECT file_id, file_name, delete_on_download from files WHERE file_id = $1") + .bind(id) + .fetch(db.as_ref()); let row: PgRow = rows .try_next() .await @@ -164,12 +176,13 @@ async fn download( let file_id: String = row.get("file_id"); let file_name: String = row.get("file_name"); + let delete_on_download: bool = row.get("delete_on_download"); let mut path = config.files_dir.clone(); path.push(&file_id); let download = req.query_string().contains("dl"); let (content_type, mut content_disposition) = get_content_types(&path, &file_name); - if content_type.type_() == mime::TEXT && !download { + let response = if content_type.type_() == mime::TEXT && !download { let content = fs::read_to_string(path).await.map_err(|_| { error::ErrorInternalServerError("this file should be here but could not be found") })?; @@ -188,7 +201,13 @@ async fn download( .set_content_type(content_type) .set_content_disposition(content_disposition); file.into_response(&req) + }; + if delete_on_download { + deleter::delete_by_id(&db, &file_id, &config.files_dir) + .await + .map_err(|_| error::ErrorInternalServerError("could not delete file"))?; } + response } fn get_content_types(path: &Path, filename: &str) -> (Mime, ContentDisposition) { @@ -263,10 +282,12 @@ async fn setup_db() -> PgPool { .await .expect("could not create db pool"); - sqlx::query(include_str!("../init-db.sql")) - .execute(&pool) - .await - .expect("could not create table Files"); + for query in include_str!("../init-db.sql").split_inclusive(";") { + sqlx::query(query) + .execute(&pool) + .await + .expect("could not initialize database schema"); + } pool } @@ -279,10 +300,7 @@ struct Config { #[actix_web::main] async fn main() -> std::io::Result<()> { - if env::var("RUST_LOG").is_err() { - env::set_var("RUST_LOG", "info"); - } - env_logger::init(); + env_logger::Builder::from_env(Env::default().default_filter_or("info,sqlx=warn")).init(); let pool: PgPool = setup_db().await; let max_file_size = env::var("UPLOAD_MAX_BYTES") @@ -318,7 +336,7 @@ async fn main() -> std::io::Result<()> { HttpServer::new({ move || { App::new() - .wrap(middleware::Logger::default()) + .wrap(middleware::Logger::new(r#"%{r}a "%r" =%s %bbytes %Tsec"#)) .app_data(db.clone()) .app_data(expiry_watch_sender.clone()) .data(config.clone()) diff --git a/src/multipart.rs b/src/multipart.rs index 538a235..9d2f09f 100644 --- a/src/multipart.rs +++ b/src/multipart.rs @@ -5,15 +5,23 @@ use async_std::{fs, fs::File, path::Path, prelude::*}; use chrono::{prelude::*, Duration}; use futures::{StreamExt, TryStreamExt}; +pub(crate) struct UploadConfig { + pub original_name: String, + pub valid_till: DateTime, + pub kind: FileKind, + pub delete_on_download: bool, +} + pub(crate) async fn parse_multipart( mut payload: Multipart, file_id: &str, filename: &Path, max_size: Option, -) -> Result<(Option, DateTime, FileKind), error::Error> { +) -> Result { let mut original_name: Option = None; let mut keep_for: Option = None; let mut kind: Option = None; + let mut delete_on_download = false; while let Ok(Some(field)) = payload.try_next().await { let name = get_field_name(&field)?; @@ -45,31 +53,41 @@ pub(crate) async fn parse_multipart( .map_err(|_| error::ErrorInternalServerError("could not create file"))?; write_to_file(&mut file, field, max_size).await?; } + "delete_on_download" => { + delete_on_download = dbg!(parse_string(name, field).await?) != "false"; + } _ => {} }; } - if let Some(original_name) = &original_name { - if original_name.len() > 255 { - return Err(error::ErrorBadRequest("filename is too long")); - } - } - - let validity_secs = keep_for - .map(|timeout| timeout.parse()) - .transpose() - .map_err(|e| error::ErrorBadRequest(format!("field validity_secs is not a number: {}", e)))? - .unwrap_or(1800); // default to 30 minutes - let max_validity_secs = Duration::days(31).num_seconds(); - if validity_secs > max_validity_secs { - return Err(error::ErrorBadRequest(format!( - "maximum allowed validity is {} seconds, but you specified {} seconds", - max_validity_secs, validity_secs - ))); - } - let valid_till = Local::now() + Duration::seconds(validity_secs); + let original_name = original_name.ok_or_else(|| error::ErrorBadRequest("no content found"))?; let kind = kind.ok_or_else(|| error::ErrorBadRequest("no content found"))?; - Ok((original_name, valid_till, kind)) + + if original_name.len() > 255 { + return Err(error::ErrorBadRequest("filename is too long")); + } + let valid_till = if let Some(keep_for) = keep_for { + let keep_for = keep_for.parse().map_err(|e| { + error::ErrorBadRequest(format!("field keep_for is not a number: {}", e)) + })?; + let max_keep_for = Duration::days(31).num_seconds(); + if keep_for > max_keep_for { + return Err(error::ErrorBadRequest(format!( + "maximum allowed validity is {} seconds, but you specified {} seconds", + max_keep_for, keep_for + ))); + } + Local::now() + Duration::seconds(keep_for) + } else { + Local::now() + Duration::seconds(1800) + }; + + Ok(UploadConfig { + original_name, + valid_till, + kind, + delete_on_download, + }) } fn get_field_name(field: &Field) -> Result { diff --git a/static/index.css b/static/index.css index 51c5fd8..93e41e5 100644 --- a/static/index.css +++ b/static/index.css @@ -48,6 +48,10 @@ textarea, max-width: calc(100vw - 3rem - 4px); } +input[type="checkbox"] { + margin-bottom: 1.5rem; +} + .button { cursor: pointer; } diff --git a/template/index.html b/template/index.html index 02016d7..f56a0f9 100644 --- a/template/index.html +++ b/template/index.html @@ -29,6 +29,13 @@
+ + +