Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions internal/api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ func writeJSON(w http.ResponseWriter, status int, v any) {

func writeError(w http.ResponseWriter, r *http.Request, status int, msg string, err error) {
if status >= 500 && err != nil {
slog.ErrorContext(r.Context(), "❌ handler error", "path", r.URL.Path, "err", err)
ContextLogger(r.Context()).Error("❌ handler error", "path", r.URL.Path, "err", err)
}
// NF-P1-3: docs/rest-api.md promises error bodies of shape
// {"error":"...","request_id":"..."}. Echo the per-request ID into
Expand Down Expand Up @@ -457,7 +457,7 @@ func (h *handlers) upload(w http.ResponseWriter, r *http.Request) {
// absolute-path containment before creating the file.
name := filepath.Base(fh.Filename)
if name == "" || name == "." || name == ".." || strings.ContainsAny(name, "/\\") {
slog.Warn("⚠️ upload: skipping invalid filename", "filename", fh.Filename)
ContextLogger(r.Context()).Warn("⚠️ upload: skipping invalid filename", "filename", fh.Filename)
continue
}
dst := filepath.Join(tmpDir, name)
Expand All @@ -467,7 +467,7 @@ func (h *handlers) upload(w http.ResponseWriter, r *http.Request) {
return
}
if !strings.HasPrefix(absDst, absTmp+string(os.PathSeparator)) {
slog.Warn("⚠️ upload: entry escapes tmp dir; skipping",
ContextLogger(r.Context()).Warn("⚠️ upload: entry escapes tmp dir; skipping",
"filename", fh.Filename, "resolved", absDst)
continue
}
Expand Down Expand Up @@ -661,7 +661,7 @@ func enforceUploadLimit(w http.ResponseWriter, r *http.Request, limit int64) boo
}
// Fast path: Content-Length is declared and already exceeds the limit.
if r.ContentLength > limit {
slog.Warn("⚠️ upload: rejected oversize request", "content_length", r.ContentLength, "limit", limit)
ContextLogger(r.Context()).Warn("⚠️ upload: rejected oversize request", "content_length", r.ContentLength, "limit", limit)
writeTooLarge(w, limit)
return false
}
Expand Down
21 changes: 21 additions & 0 deletions internal/api/log_context.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package api

import (
"context"
"log/slog"
)

// ContextLogger returns slog.Default() enriched with the per-request ID
// when the context carries one. Handler trees that funnel log calls
// through this helper get free request-level log correlation; downstream
// code that needs the ID for metric labels should still read it via
// RequestIDFromContext(ctx) directly.
//
// Callers that don't need the enrichment can keep using slog.Default()
// — the helper is additive, not mandatory.
func ContextLogger(ctx context.Context) *slog.Logger {
if id := RequestIDFromContext(ctx); id != "" {
return slog.Default().With("req_id", id)
}
return slog.Default()
}
42 changes: 42 additions & 0 deletions internal/api/log_context_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package api

import (
"bytes"
"context"
"log/slog"
"strings"
"testing"
)

func TestContextLogger_AddsReqIDFromContext(t *testing.T) {
// Cannot t.Parallel() because we mutate slog.Default.
var buf bytes.Buffer
prev := slog.Default()
slog.SetDefault(slog.New(slog.NewTextHandler(&buf, nil)))
t.Cleanup(func() { slog.SetDefault(prev) })

ctx := context.WithValue(context.Background(), ctxRequestIDKey{}, "abc123")
ContextLogger(ctx).Info("hello", "k", "v")

out := buf.String()
if !strings.Contains(out, "req_id=abc123") {
t.Fatalf("expected req_id=abc123 in log output; got %q", out)
}
if !strings.Contains(out, "k=v") {
t.Fatalf("expected k=v to survive; got %q", out)
}
}

func TestContextLogger_NoReqIDWhenMissing(t *testing.T) {
var buf bytes.Buffer
prev := slog.Default()
slog.SetDefault(slog.New(slog.NewTextHandler(&buf, nil)))
t.Cleanup(func() { slog.SetDefault(prev) })

ContextLogger(context.Background()).Info("hello")

out := buf.String()
if strings.Contains(out, "req_id=") {
t.Fatalf("req_id should be absent when context has no ID; got %q", out)
}
}
19 changes: 11 additions & 8 deletions internal/api/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,17 @@ func NewRouter(prov llm.Provider, emb *embedder.Embedder, cfg *config.Config, re
mux.Handle("/", spaHandler(ui.Assets, cfg))

// Middleware ordering (outermost → innermost):
// logging → recovery → auth → project → mux
// project scope sits BELOW auth (an unauthenticated caller never
// reaches the registry) and ABOVE the mux (so handlers and the MCP
// server see the resolved slug via ProjectFromContext).
return loggingMiddleware(
recoveryMiddleware(
bearerAuthMiddleware(cfg.Server.APIKey,
projectMiddleware(cfg, registry, mux))))
// securityHeaders → logging → recovery → auth → project → mux
// securityHeaders sits outermost so CSP + baseline headers are
// applied to every response (including 401s, 404s, and panic
// recoveries). project scope sits BELOW auth (an unauthenticated
// caller never reaches the registry) and ABOVE the mux (so handlers
// and the MCP server see the resolved slug via ProjectFromContext).
return securityHeadersMiddleware(cfg)(
loggingMiddleware(
recoveryMiddleware(
bearerAuthMiddleware(cfg.Server.APIKey,
projectMiddleware(cfg, registry, mux)))))
}

func spaHandler(assets fs.FS, _ *config.Config) http.Handler {
Expand Down
57 changes: 57 additions & 0 deletions internal/api/security_headers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package api

import (
"net/http"

"github.com/RandomCodeSpace/docsiq/internal/config"
)

// Header values are deliberately strict for the air-gapped deployment
// posture: no CDN origins, no inline scripts, WASM allowed (shiki uses
// it for syntax highlighting), no iframing. Inline styles are permitted
// because Tailwind + shadcn/ui emit them.
const (
contentSecurityPolicy = "default-src 'self'; " +
"script-src 'self' 'wasm-unsafe-eval'; " +
"style-src 'self' 'unsafe-inline'; " +
"connect-src 'self'; " +
"img-src 'self' data:; " +
"font-src 'self'; " +
"frame-ancestors 'none'; " +
"base-uri 'self'"

permissionsPolicy = "camera=(), microphone=(), geolocation=(), payment=(), usb=()"

hstsValue = "max-age=31536000; includeSubDomains"
)

// securityHeadersMiddleware sets browser-side hardening headers on every
// response that actually carries a body (i.e. non-OPTIONS):
// - Content-Security-Policy
// - X-Content-Type-Options: nosniff
// - Referrer-Policy: strict-origin-when-cross-origin
// - Permissions-Policy (disables camera/mic/geo/payment/usb)
// - Strict-Transport-Security (only when cfg.Server.HSTSEnabled=true)
//
// Intended to be wrapped as the OUTERMOST middleware so headers are
// emitted on panic recoveries and auth failures too.
func securityHeadersMiddleware(cfg *config.Config) func(http.Handler) http.Handler {
hstsEnabled := cfg != nil && cfg.Server.HSTSEnabled
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodOptions {
next.ServeHTTP(w, r)
return
}
h := w.Header()
h.Set("Content-Security-Policy", contentSecurityPolicy)
h.Set("X-Content-Type-Options", "nosniff")
h.Set("Referrer-Policy", "strict-origin-when-cross-origin")
h.Set("Permissions-Policy", permissionsPolicy)
if hstsEnabled {
h.Set("Strict-Transport-Security", hstsValue)
}
next.ServeHTTP(w, r)
})
}
}
128 changes: 128 additions & 0 deletions internal/api/security_headers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package api

import (
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/RandomCodeSpace/docsiq/internal/config"
)

func TestSecurityHeaders_CSPOnEveryResponse(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
cfg := &config.Config{}
h := securityHeadersMiddleware(cfg)(next)

req := httptest.NewRequest(http.MethodGet, "/health", nil)
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)

csp := rr.Header().Get("Content-Security-Policy")
if csp == "" {
t.Fatal("CSP header missing")
}
for _, want := range []string{
"default-src 'self'",
"script-src 'self' 'wasm-unsafe-eval'",
"style-src 'self' 'unsafe-inline'",
"connect-src 'self'",
"img-src 'self' data:",
"font-src 'self'",
"frame-ancestors 'none'",
"base-uri 'self'",
} {
if !strings.Contains(csp, want) {
t.Errorf("CSP missing directive %q: got %q", want, csp)
}
}
}

func TestSecurityHeaders_SkipsOPTIONS(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNoContent)
})
cfg := &config.Config{}
h := securityHeadersMiddleware(cfg)(next)

req := httptest.NewRequest(http.MethodOptions, "/api/ping", nil)
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)

if rr.Header().Get("Content-Security-Policy") != "" {
t.Errorf("CSP should not be set on OPTIONS; got %q", rr.Header().Get("Content-Security-Policy"))
}
if rr.Code != http.StatusNoContent {
t.Errorf("OPTIONS should pass through; got status %d", rr.Code)
}
}

func TestSecurityHeaders_PreservesExistingHeaders(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Custom", "xyz")
w.WriteHeader(http.StatusOK)
})
cfg := &config.Config{}
h := securityHeadersMiddleware(cfg)(next)

req := httptest.NewRequest(http.MethodGet, "/health", nil)
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)

if rr.Header().Get("X-Custom") != "xyz" {
t.Errorf("downstream header clobbered")
}
}

func TestSecurityHeaders_BaselineHeaders(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
cfg := &config.Config{}
h := securityHeadersMiddleware(cfg)(next)

req := httptest.NewRequest(http.MethodGet, "/health", nil)
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)

if got := rr.Header().Get("X-Content-Type-Options"); got != "nosniff" {
t.Errorf("X-Content-Type-Options: want nosniff, got %q", got)
}
if got := rr.Header().Get("Referrer-Policy"); got != "strict-origin-when-cross-origin" {
t.Errorf("Referrer-Policy: got %q", got)
}
perms := rr.Header().Get("Permissions-Policy")
for _, want := range []string{"camera=()", "microphone=()", "geolocation=()", "payment=()", "usb=()"} {
if !strings.Contains(perms, want) {
t.Errorf("Permissions-Policy missing %q: got %q", want, perms)
}
}
if rr.Header().Get("Strict-Transport-Security") != "" {
t.Error("HSTS should not be set when HSTSEnabled=false")
}
}

func TestSecurityHeaders_HSTSOnlyWhenEnabled(t *testing.T) {
t.Parallel()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
cfg := &config.Config{}
cfg.Server.HSTSEnabled = true
h := securityHeadersMiddleware(cfg)(next)

req := httptest.NewRequest(http.MethodGet, "/health", nil)
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)

hsts := rr.Header().Get("Strict-Transport-Security")
if !strings.Contains(hsts, "max-age=31536000") {
t.Errorf("HSTS missing max-age; got %q", hsts)
}
}
Loading
Loading