diff --git a/src/main.rs b/src/main.rs index 65f3e9aee25dd0d59414de46e50b3288ec924906..8547f6bd6842f970ee4fd7d64d88a219ae29da5c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -41,7 +41,7 @@ use diesel::pg::PgConnection; use diesel::r2d2::ConnectionManager; use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness}; use futures::{StreamExt, TryFutureExt, TryStreamExt}; -use handlebars::Handlebars; +use handlebars::{Handlebars, RenderError}; use itertools::Itertools; use lazy_static::lazy_static; use listenfd::ListenFd; @@ -404,7 +404,7 @@ impl FromRequest for BaseContext { let req = req.clone(); Box::pin(async move { let identity = Option::<Identity>::from_request(&req, &mut Payload::None).await; - let logged_user = identity?.and_then(|identity| get_identity(&identity)); + let logged_user = identity?.and_then(|identity| get_identity(&Some(&identity))); let flash_messages = IncomingFlashMessages::from_request(&req, &mut Payload::None).await?; Ok(BaseContext { @@ -416,13 +416,21 @@ impl FromRequest for BaseContext { } } +fn render<Err: From<RenderError>, Ctx: serde::Serialize>( + hb: &Handlebars, + name: &str, + context: &Ctx, +) -> Result<HttpResponse, Err> { + Ok(HttpResponse::Ok().body(hb.render(name, context)?)) +} + #[get("/login")] async fn get_login(base: BaseContext, hb: web::Data<Handlebars<'_>>) -> GetResult { #[derive(Serialize)] struct Context { base: BaseContext, } - Ok(HttpResponse::Ok().body(hb.render("login", &Context { base })?)) + render(&hb, "login", &Context { base }) } #[get("/me")] @@ -432,7 +440,7 @@ async fn get_me(base: BaseContext, identity: Identity, hb: web::Data<Handlebars< struct Context { base: BaseContext, } - Ok(HttpResponse::Ok().body(hb.render("me", &Context { base })?)) + render(&hb, "me", &Context { base }) } #[get("/problems/{id}/assets/{filename}")] @@ -479,15 +487,17 @@ struct LoginForm { password: String, } -fn get_identity(identity: &Identity) -> Option<LoggedUser> { - let identity = identity.id(); - identity - .ok() - .and_then(|identity| serde_json::from_str(&identity).ok()) +fn get_identity(identity: &Option<&Identity>) -> Option<LoggedUser> { + identity.as_ref().and_then(|identity| { + let identity = identity.id(); + identity + .ok() + .and_then(|identity| serde_json::from_str(&identity).ok()) + }) } fn require_identity(identity: &Identity) -> Result<LoggedUser, UnauthorizedError> { - get_identity(&identity).ok_or(UnauthorizedError {}) + get_identity(&Some(identity)).ok_or(UnauthorizedError {}) } fn format_duration(duration: chrono::Duration) -> String { @@ -581,20 +591,19 @@ async fn get_contest_by_id( let submissions = submission::get_submissions_user_by_contest(&mut connection, logged_user.id, contest_id)?; - Ok(HttpResponse::Ok().body( - hb.render( - "contest", - &Context { - base, - contest: get_formatted_contest(&tz, &contest), - problems: problems - .iter() - .map(|p| get_formatted_problem_by_contest_with_score(p, &contest)) - .collect(), - submissions: get_formatted_submissions(&tz, &submissions), - }, - )?, - )) + render( + &hb, + "contest", + &Context { + base, + contest: get_formatted_contest(&tz, &contest), + problems: problems + .iter() + .map(|p| get_formatted_problem_by_contest_with_score(p, &contest)) + .collect(), + submissions: get_formatted_submissions(&tz, &submissions), + }, + ) } #[get("/contests/{id}/scoreboard")] @@ -661,7 +670,8 @@ async fn get_contest_scoreboard_by_id( )) }); - Ok(HttpResponse::Ok().body(hb.render( + render( + &hb, "scoreboard", &Context { base, @@ -669,7 +679,7 @@ async fn get_contest_scoreboard_by_id( scores, submissions: get_formatted_submissions(&tz, &submissions), }, - )?)) + ) } fn assert_contest_not_started(logged_user: &LoggedUser, contest: &Contest) -> Result<(), GetError> { @@ -748,7 +758,8 @@ async fn get_contest_problem_by_id_label( &problem_label, )?; - Ok(HttpResponse::Ok().body(hb.render( + render( + &hb, "contest_problem", &Context { base, @@ -759,7 +770,7 @@ async fn get_contest_problem_by_id_label( language: session.get("language")?, submissions: get_formatted_submissions(&tz, &submissions), }, - )?)) + ) } #[derive(Serialize, Deserialize, Clone)] @@ -896,13 +907,14 @@ async fn get_submissions( submission::get_submissions_user(&mut connection, logged_user.id)? }; - Ok(HttpResponse::Ok().body(hb.render( + render( + &hb, "submissions", &Context { base, submissions: get_formatted_submissions(&tz, &submissions), }, - )?)) + ) } #[get("/submissions/me/")] @@ -924,13 +936,14 @@ async fn get_submissions_me( let submissions = submission::get_submissions_user(&mut connection, logged_user.id)?; - Ok(HttpResponse::Ok().body(hb.render( + render( + &hb, "submissions", &Context { base, submissions: get_formatted_submissions(&tz, &submissions), }, - )?)) + ) } #[get("/submissions/{uuid}")] @@ -976,7 +989,8 @@ async fn get_submission( )); } - Ok(HttpResponse::Ok().body(hb.render( + render( + &hb, "submission", &Context { base, @@ -995,7 +1009,7 @@ async fn get_submission( failed_test: submission.failed_test, }, }, - )?)) + ) } #[get("/submissions/me/contests/{id}")] @@ -1020,13 +1034,14 @@ async fn get_submissions_me_by_contest_id( let submissions = submission::get_submissions_user_by_contest(&mut connection, logged_user.id, contest_id)?; - Ok(HttpResponse::Ok().body(hb.render( + render( + &hb, "submissions", &Context { base, submissions: get_formatted_submissions(&tz, &submissions), }, - )?)) + ) } #[get("/submissions/me/contests/{id}/{label}")] @@ -1055,13 +1070,14 @@ async fn get_submissions_me_by_contest_id_problem_label( &problem_label, )?; - Ok(HttpResponse::Ok().body(hb.render( + render( + &hb, "submissions", &Context { base, submissions: get_formatted_submissions(&tz, &submissions), }, - )?)) + ) } #[derive(Serialize, Deserialize)] @@ -1401,12 +1417,12 @@ fn get_formatted_contests( #[get("/")] async fn get_main( base: BaseContext, - identity: Identity, + identity: Option<Identity>, pool: web::Data<DbPool>, hb: web::Data<Handlebars<'_>>, tz: web::Data<Tz>, ) -> GetResult { - let logged_user = get_identity(&identity); + let logged_user = get_identity(&identity.as_ref()); #[derive(Serialize)] struct Context { @@ -1418,14 +1434,15 @@ async fn get_main( let mut connection = pool.get()?; let submissions = submission::get_submissions(&mut connection)?; - Ok(HttpResponse::Ok().body(hb.render( + render( + &hb, "main", &Context { base, contests: get_formatted_contests(&mut connection, logged_user.map(|u| u.id), &tz)?, submissions: get_formatted_submissions(&tz, &submissions), }, - )?)) + ) } #[get("/contests/")] @@ -1445,13 +1462,14 @@ async fn get_contests( } let mut connection = pool.get()?; - Ok(HttpResponse::Ok().body(hb.render( + render( + &hb, "contests", &Context { base, contests: get_formatted_contests(&mut connection, Some(logged_user.id), &tz)?, }, - )?)) + ) } #[post("/contests/")] @@ -1641,6 +1659,15 @@ async fn create_contest( .into()) } + let main_solution = &metadata + .assets + .solutions + .solution + .iter() + .find(|s| s.tag == "main") + .ok_or(PostError::Validation("No main solution".into()))? + .source; + let problem = problem::upsert_problem( &mut connection, problem::NewProblem { @@ -1662,27 +1689,8 @@ async fn create_contest( validator_language: map_codeforces_language( &metadata.assets.validators.validator[0].source.r#type, )?, - main_solution_path: metadata - .assets - .solutions - .solution - .iter() - .find(|s| s.tag == "main") - .ok_or(PostError::Validation("No main solution".into()))? - .source - .path - .clone(), - main_solution_language: map_codeforces_language( - &metadata - .assets - .solutions - .solution - .iter() - .find(|s| s.tag == "main") - .ok_or(PostError::Validation("No main solution".into()))? - .source - .r#type, - )?, + main_solution_path: main_solution.path.clone(), + main_solution_language: map_codeforces_language(&main_solution.r#type)?, test_pattern: metadata.judging.testset[0].input_path_pattern.value.clone(), test_count: metadata.judging.testset[0] .test_count