Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,3 @@
*.db
*.db-wal
*.db-shm
/.direnv
45 changes: 45 additions & 0 deletions database/src/repos/achievement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,30 @@ impl<'a> AchievementRepo<'a> {
.await?)
}

async fn by_goal_id(&self, goal_id: u32) -> Result<Vec<AchievementGoal>, DatabaseError> {
Ok(query_as(
"
SELECT
achievement.id as achievement_id,
achievement.name as achievement_name,
service_id,
goal2.id as goal_id,
goal2.description as goal_description,
goal2.sequence as goal_sequence

FROM
goal as goal1
inner join achievement on achievement.id = goal1.achievement_id
inner join goal as goal2 on goal2.achievement_id = achievement.id
WHERE
goal1.id = ?;
",
)
.bind(goal_id)
.fetch_all(self.db)
.await?)
}

pub async fn for_service(
&self,
service_id: u32,
Expand Down Expand Up @@ -119,4 +143,25 @@ impl<'a> AchievementRepo<'a> {
tx.commit().await?;
self.by_id(db_achievement.id).await
}

pub async fn unlock_goal(
&self,
user_id: u32,
goal_id: u32,
) -> Result<Vec<AchievementGoal>, DatabaseError> {
query(
"
INSERT INTO
unlock (user_id, goal_id)
VALUES
(?,?);
",
)
.bind(user_id)
.bind(goal_id)
.execute(self.db)
.await?;

self.by_goal_id(goal_id).await
}
}
8 changes: 8 additions & 0 deletions database/src/repos/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,12 @@ impl<'a> ServiceRepo<'a> {
.await?
.ok_or(DatabaseError::NotFound)
}

pub async fn by_id(&self, id: u32) -> Result<Service, DatabaseError> {
sqlx::query_as("SELECT id, name, api_key FROM service WHERE id == ? LIMIT 1;")
.bind(id)
.fetch_optional(self.db)
.await?
.ok_or(DatabaseError::NotFound)
}
}
13 changes: 13 additions & 0 deletions src/dto/achievement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,19 @@ impl AchievementPayload {

Ok(achievements)
}

pub async fn unlock_goal(
db: &Database,
user_id: u32,
goal_id: u32,
) -> Result<AchievementPayload, AppError> {
let rows = db.achievements().unlock_goal(user_id, goal_id).await?;

// pack rows into an achievement payload
let mut rows = rows.into_iter().peekable();
let achievement = unpack_next_achievement(&mut rows).ok_or(AppError::NotFound)?;
Ok(achievement)
}
}

#[derive(Serialize, Deserialize, Debug, PartialEq)]
Expand Down
6 changes: 5 additions & 1 deletion src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ pub enum AppError {
#[error("Submitted image resolution was too large")]
ImageResTooLarge,

#[error("The requested image was not found")]
#[error("Not found")]
NotFound,

#[error("Submitted file had an incorrect type")]
Expand All @@ -60,6 +60,9 @@ pub enum AppError {
#[error("User was not logged in")]
NotLoggedIn,

#[error("Wrong api key")]
BadApiKey,

#[error("Forbidden")]
Forbidden,

Expand All @@ -80,6 +83,7 @@ impl AppError {
let (status, msg) = match self {
Self::PayloadError(_) => (StatusCode::BAD_REQUEST, "Payload error"),
Self::NotLoggedIn => (StatusCode::UNAUTHORIZED, "Not logged in."),
Self::BadApiKey => (StatusCode::UNAUTHORIZED, "Bad api key."),
Self::Forbidden => (StatusCode::FORBIDDEN, "Forbidden."),
Self::NoFile => (
StatusCode::BAD_REQUEST,
Expand Down
24 changes: 24 additions & 0 deletions src/extractors/api_key.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
use axum::{extract::FromRequestParts, http::request::Parts};
use axum_extra::TypedHeader;
use headers::{Authorization, authorization::Bearer};

use crate::error::AppError;

#[derive(Debug)]
pub struct ApiKey(pub String);

impl<S> FromRequestParts<S> for ApiKey
where
S: Send + Sync,
{
type Rejection = AppError;

async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let header = TypedHeader::<Authorization<Bearer>>::from_request_parts(parts, state).await;

match header {
Ok(TypedHeader(Authorization(bearer))) => Ok(ApiKey(bearer.token().to_string())),
_ => Err(AppError::BadApiKey),
}
}
}
1 change: 1 addition & 0 deletions src/extractors/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
pub mod admin;
pub mod api_key;
pub mod authenticated_user;
pub mod config;
pub mod database;
Expand Down
23 changes: 21 additions & 2 deletions src/handlers/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@ use axum::{Json, extract::Path};
use database::Database;

use crate::{
dto::service::{
ServiceCreatePayload, ServicePatchPayload, ServicePayloadAdmin, ServicePayloadUser,
dto::{
achievement::AchievementPayload,
service::{
ServiceCreatePayload, ServicePatchPayload, ServicePayloadAdmin, ServicePayloadUser,
},
},
error::AppError,
extractors::api_key::ApiKey,
};

pub struct ServiceHandler;
Expand Down Expand Up @@ -42,4 +46,19 @@ impl ServiceHandler {
ServicePayloadAdmin::regenerate_api_key(&db, service_id).await?,
))
}

pub async fn unlock_goal(
db: Database,
Path((user_id, service_id, goal_id)): Path<(u32, u32, u32)>,
ApiKey(api_key): ApiKey,
) -> Result<Json<AchievementPayload>, AppError> {
let expected_api_key = db.services().by_id(service_id).await?.api_key;
if api_key != expected_api_key {
return Err(AppError::BadApiKey);
}

Ok(Json(
AchievementPayload::unlock_goal(&db, user_id, goal_id).await?,
))
}
}
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ fn open_routes() -> Router<AppState> {
.route("/oauth/callback", get(AuthHandler::callback))
.route("/image/{id}", get(ImageHandler::get))
.route("/version", get(VersionHandler::get))
.route(
"/users/{id}/unlock/{service_id}/{goal_id}",
post(ServiceHandler::unlock_goal),
)
}

fn authenticated_routes() -> Router<AppState> {
Expand Down
69 changes: 52 additions & 17 deletions tests/achievement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,26 @@ use zpi::dto::{
goal::GoalCreatePayload,
};

use crate::common::{
into_struct::IntoStruct, router::AuthenticatedRouter, test_objects::TestObjects,
};
use crate::common::{into_struct::IntoStruct, router::TestRouter, test_objects::TestObjects};

mod common;

#[sqlx::test(fixtures("services", "achievements"))]
#[test_log::test]
async fn get_achievements_for_service(db_pool: SqlitePool) {
let router = AuthenticatedRouter::new(db_pool).await;
let response = router.get("/admin/services/1/achievements").await;
async fn get_achievements_for_service(db: SqlitePool) {
let none = TestRouter::new(db.clone());
let response = none.get("/admin/services/1/achievements").await;
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);

let user = TestRouter::as_user(db.clone()).await;
let response = user.get("/admin/services/1/achievements").await;
assert_eq!(response.status(), StatusCode::FORBIDDEN);

let admin = TestRouter::as_admin(db).await;
let response = admin.get("/admin/services/1/achievements").await;
assert_eq!(response.status(), StatusCode::OK);

let data: Vec<AchievementPayload> = response.into_struct().await;

assert_eq!(
data,
vec![TestObjects::achievement_1(), TestObjects::achievement_2()]
Expand All @@ -29,8 +33,7 @@ async fn get_achievements_for_service(db_pool: SqlitePool) {

#[sqlx::test(fixtures("services"))]
#[test_log::test]
async fn post_achievements_for_service(db_pool: SqlitePool) {
let router = AuthenticatedRouter::new(db_pool).await;
async fn post_achievements_for_service(db: SqlitePool) {
let body = AchievementCreatePayload {
name: "Achievements".into(),
goals: vec![
Expand All @@ -44,19 +47,26 @@ async fn post_achievements_for_service(db_pool: SqlitePool) {
},
],
};
let response = router.post("/admin/services/1/achievements", body).await;

let none = TestRouter::new(db.clone());
let response = none.post("/admin/services/1/achievements", &body).await;
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);

let user = TestRouter::as_user(db.clone()).await;
let response = user.post("/admin/services/1/achievements", &body).await;
assert_eq!(response.status(), StatusCode::FORBIDDEN);

let admin = TestRouter::as_admin(db).await;
let response = admin.post("/admin/services/1/achievements", &body).await;
assert_eq!(response.status(), StatusCode::OK);

let data: AchievementPayload = response.into_struct().await;

assert_eq!(data, TestObjects::achievement_1());
}

#[sqlx::test(fixtures("services"))]
#[test_log::test]
async fn post_achievements_wrong_sequence(db_pool: SqlitePool) {
let router = AuthenticatedRouter::new(db_pool).await;
async fn post_achievements_wrong_sequence(db: SqlitePool) {
let mut body = AchievementCreatePayload {
name: "Achievements".into(),
goals: vec![
Expand All @@ -71,13 +81,38 @@ async fn post_achievements_wrong_sequence(db_pool: SqlitePool) {
],
};

let response = router
.clone()
.post("/admin/services/1/achievements", &body)
.await;
let router = TestRouter::as_admin(db.clone()).await;
let response = router.post("/admin/services/1/achievements", &body).await;
assert_eq!(response.status(), StatusCode::BAD_REQUEST);

body.goals[1].sequence = 1;
let response = router.post("/admin/services/1/achievements", &body).await;
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
}

#[sqlx::test(fixtures("services", "achievements", "users"))]
#[test_log::test]
async fn unlock_goal(db: SqlitePool) {
let none = TestRouter::new(db.clone());
let response = none.post("/users/1/unlock/1/1", None::<()>).await;
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);

let router = TestRouter::with_api_key(db, "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa");
let response = router.post("/users/1/unlock/1/1", None::<()>).await;
assert_eq!(response.status(), StatusCode::OK);

let data: AchievementPayload = response.into_struct().await;
assert_eq!(data, TestObjects::achievement_1());
}

#[sqlx::test(fixtures("services"))]
#[test_log::test]
async fn unlock_goal_wrong_api_key(db_pool: SqlitePool) {
let router = TestRouter::with_api_key(db_pool, "wrongapikey");

let response = router.post("/users/1/unlock/1/1", None::<()>).await;
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
}

// TODO wat als goal niet bestaat -> status code 404
// TODO wat als goal al unlocked is -> status code 200
Loading