use axum::{
Router,
routing::{get, post, put, delete},
Json,
Extension,
extract::{Path, Query, State},
response::{IntoResponse, Response},
http::{StatusCode, Method},
middleware,
};
use tower_http::{
cors::{CorsLayer, Any},
trace::TraceLayer,
};
use wacht::{
init_from_env,
middleware::*,
models::*,
UsersApi,
OrganizationsApi,
WorkspacesApi,
Error as WachtError,
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tracing::{info, error};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
// Define custom permissions
require_permission!(CanManageUsers, "users:manage", Organization);
require_permission!(CanViewReports, "reports:view", Organization);
require_permission!(CanEditProjects, "projects:edit", Workspace);
// Application state
#[derive(Clone)]
struct AppState {
// Add any shared state here
}
// Error handling
#[derive(Debug, thiserror::Error)]
enum AppError {
#[error("Wacht API error: {0}")]
Wacht(#[from] WachtError),
#[error("Validation error: {0}")]
Validation(String),
#[error("Not found: {0}")]
NotFound(String),
#[error("Internal error: {0}")]
Internal(String),
}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
let (status, message) = match &self {
AppError::Wacht(WachtError::Api { status, message, .. }) => {
(*status, message.clone())
}
AppError::Wacht(_) => {
(StatusCode::INTERNAL_SERVER_ERROR, "Internal error".to_string())
}
AppError::Validation(msg) => {
(StatusCode::BAD_REQUEST, msg.clone())
}
AppError::NotFound(msg) => {
(StatusCode::NOT_FOUND, msg.clone())
}
AppError::Internal(msg) => {
(StatusCode::INTERNAL_SERVER_ERROR, msg.clone())
}
};
error!("Request failed: {} - {}", status, message);
(
status,
Json(serde_json::json!({
"error": {
"message": message,
"type": match &self {
AppError::Wacht(_) => "api_error",
AppError::Validation(_) => "validation_error",
AppError::NotFound(_) => "not_found",
AppError::Internal(_) => "internal_error",
}
}
}))
).into_response()
}
}
type Result<T> = std::result::Result<T, AppError>;
// Request/Response types
#[derive(Debug, Deserialize)]
struct ListQuery {
limit: Option<i32>,
offset: Option<i32>,
search: Option<String>,
}
#[derive(Debug, Serialize)]
struct UserResponse {
id: String,
email: String,
name: String,
role: String,
organizations: Vec<String>,
}
#[derive(Debug, Deserialize)]
struct CreateProjectRequest {
name: String,
description: Option<String>,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Load environment
dotenv::dotenv().ok();
// Initialize tracing
tracing_subscriber::registry()
.with(tracing_subscriber::EnvFilter::from_default_env())
.with(tracing_subscriber::fmt::layer())
.init();
info!("Starting Wacht Axum application");
// Initialize Wacht SDK
init_from_env().await?;
info!("Wacht SDK initialized");
// Create application state
let state = Arc::new(AppState {});
// Build application
let app = create_app(state);
// Start server
let port = std::env::var("SERVER_PORT")
.unwrap_or_else(|_| "3000".to_string())
.parse::<u16>()?;
let addr = format!("0.0.0.0:{}", port);
let listener = tokio::net::TcpListener::bind(&addr).await?;
info!("Server listening on {}", addr);
axum::serve(listener, app).await?;
Ok(())
}
fn create_app(state: Arc<AppState>) -> Router {
// Public routes (no auth required)
let public_routes = Router::new()
.route("/", get(root_handler))
.route("/health", get(health_handler))
.route("/login", post(login_handler));
// User routes (auth required)
let user_routes = Router::new()
.route("/profile", get(get_profile))
.route("/profile", put(update_profile))
.route("/organizations", get(list_organizations))
.layer(middleware::from_fn(require_auth_middleware));
// Admin routes (auth + permissions required)
let admin_routes = Router::new()
.route("/users", get(list_users))
.route("/users/:id", get(get_user))
.route("/users/:id", delete(delete_user))
.layer(PermissionLayer::organization("admin:access"))
.layer(middleware::from_fn(require_auth_middleware));
// Project routes (workspace permissions)
let project_routes = Router::new()
.route("/projects", get(list_projects))
.route("/projects", post(create_project))
.route("/projects/:id", put(update_project))
.layer(middleware::from_fn(require_auth_middleware));
// Combine all routes
Router::new()
.merge(public_routes)
.nest("/api/user", user_routes)
.nest("/api/admin", admin_routes)
.nest("/api/workspace", project_routes)
.layer(
CorsLayer::new()
.allow_origin(Any)
.allow_methods([Method::GET, Method::POST, Method::PUT, Method::DELETE])
.allow_headers(Any)
)
.layer(TraceLayer::new_for_http())
.layer(AuthLayer::new())
.with_state(state)
}
// Middleware function
async fn require_auth_middleware(
req: axum::extract::Request,
next: middleware::Next,
) -> Result<Response> {
// The AuthLayer has already validated the token
// This is just an example of custom middleware
let response = next.run(req).await;
Ok(response)
}
// Public handlers
async fn root_handler() -> &'static str {
"Wacht Axum Example API"
}
async fn health_handler() -> Json<serde_json::Value> {
Json(serde_json::json!({
"status": "ok",
"timestamp": chrono::Utc::now().to_rfc3339()
}))
}
async fn login_handler() -> Json<serde_json::Value> {
// This would normally validate credentials and return a JWT
Json(serde_json::json!({
"token": "jwt-token-here",
"message": "Login endpoint - implement your login logic"
}))
}
// Authenticated handlers
async fn get_profile(
Extension(auth): Extension<AuthContext>,
) -> Result<Json<serde_json::Value>> {
let user = UsersApi::users_user_id_get(&auth.user_id).await?;
Ok(Json(serde_json::json!({
"id": user.id,
"email": user.email,
"name": format!("{} {}", user.first_name, user.last_name),
"username": user.username,
"organization": auth.organization_id,
"workspace": auth.workspace_id,
})))
}
async fn update_profile(
Extension(auth): Extension<AuthContext>,
Json(payload): Json<UpdateUserRequest>,
) -> Result<Json<User>> {
let user = UsersApi::users_user_id_patch(&auth.user_id, payload).await?;
Ok(Json(user))
}
async fn list_organizations(
Extension(auth): Extension<AuthContext>,
) -> Result<Json<Vec<Organization>>> {
let orgs = OrganizationsApi::organizations_get(None, None).await?;
Ok(Json(orgs.data))
}
// Admin handlers with permission checks
async fn list_users(
_perm: CanManageUsers,
Query(params): Query<ListQuery>,
) -> Result<Json<Vec<UserResponse>>> {
let users = UsersApi::users_get(
params.limit,
params.offset,
params.search.as_deref()
).await?;
let response: Vec<UserResponse> = users.data
.into_iter()
.map(|u| UserResponse {
id: u.id,
email: u.email,
name: format!("{} {}", u.first_name, u.last_name),
role: "member".to_string(), // You'd determine this from permissions
organizations: u.organization_memberships
.into_iter()
.map(|m| m.organization.name)
.collect(),
})
.collect();
Ok(Json(response))
}
async fn get_user(
_perm: CanManageUsers,
Path(user_id): Path<String>,
) -> Result<Json<User>> {
let user = UsersApi::users_user_id_get(&user_id)
.await
.map_err(|e| match e {
WachtError::Api { status, .. } if status == StatusCode::NOT_FOUND => {
AppError::NotFound(format!("User {} not found", user_id))
}
e => e.into(),
})?;
Ok(Json(user))
}
async fn delete_user(
_perm: CanManageUsers,
Path(user_id): Path<String>,
) -> Result<StatusCode> {
UsersApi::users_user_id_delete(&user_id).await?;
Ok(StatusCode::NO_CONTENT)
}
// Workspace handlers
async fn list_projects(
Extension(auth): Extension<AuthContext>,
) -> Result<Json<serde_json::Value>> {
// Check if user has workspace context
let workspace_id = auth.workspace_id
.ok_or_else(|| AppError::Validation("No workspace selected".to_string()))?;
// This would call your projects API
Ok(Json(serde_json::json!({
"workspace_id": workspace_id,
"projects": []
})))
}
async fn create_project(
_perm: CanEditProjects,
Extension(auth): Extension<AuthContext>,
Json(payload): Json<CreateProjectRequest>,
) -> Result<Json<serde_json::Value>> {
// Validate input
if payload.name.is_empty() {
return Err(AppError::Validation("Project name is required".to_string()));
}
let workspace_id = auth.workspace_id
.ok_or_else(|| AppError::Validation("No workspace selected".to_string()))?;
// Create project in workspace
Ok(Json(serde_json::json!({
"id": "proj_123",
"name": payload.name,
"description": payload.description,
"workspace_id": workspace_id,
"created_by": auth.user_id,
})))
}
async fn update_project(
_perm: CanEditProjects,
Path(project_id): Path<String>,
Json(payload): Json<serde_json::Value>,
) -> Result<Json<serde_json::Value>> {
Ok(Json(serde_json::json!({
"id": project_id,
"updated": true,
"data": payload
})))
}