diff --git a/cmd/serve.go b/cmd/serve.go index e9755ad..d15e1fe 100644 --- a/cmd/serve.go +++ b/cmd/serve.go @@ -188,22 +188,29 @@ var serveCmd = &cobra.Command{ }() <-ctx.Done() - slog.Info("🛑 shutting down...") - shutCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + // Block 3.7: explicit per-phase log lines so operators can see + // exactly where a stuck shutdown is hung. Deadline raised from + // 10s to 30s to match the workq drain phase and the spec. + slog.Info("🛑 shutting down HTTP server...") + shutCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() if err := srv.Shutdown(shutCtx); err != nil { - slog.Error("❌ shutdown error", "err", err) + slog.Error("❌ HTTP server shutdown error", "err", err) return err } + slog.Info("✅ HTTP server shutdown complete") // Drain workq within its own 30s deadline. Server.Shutdown has already // stopped accepting new HTTP requests, so no new jobs can be submitted; // all that remains is letting in-flight pipelines finish or honour the // cancelled ctx. + slog.Info("🛑 draining workq...") drainCtx, drainCancel := context.WithTimeout(context.Background(), 30*time.Second) defer drainCancel() if err := pool.Close(drainCtx); err != nil { slog.Warn("⚠️ workq drain timeout; some indexing jobs were cancelled mid-flight", "err", err) + } else { + slog.Info("✅ workq drained") } slog.Info("✅ shutdown complete") diff --git a/internal/api/itest/doubles.go b/internal/api/itest/doubles.go index e3fa445..10bb313 100644 --- a/internal/api/itest/doubles.go +++ b/internal/api/itest/doubles.go @@ -38,6 +38,10 @@ func (p *FakeProvider) Name() string { return "fake" } // ModelID returns a stable model identifier that callers can log. func (p *FakeProvider) ModelID() string { return "fake-model-v1" } +// BatchCeiling declares no ceiling — this fake accepts any batch size +// since it never makes a real upstream call. Block 3.4. +func (p *FakeProvider) BatchCeiling() int { return 0 } + // Complete returns a deterministic string derived from the prompt. The // output embeds the sha256 prefix of the prompt so tests can assert the // provider actually saw what they sent. diff --git a/internal/api/panic_enrichment_test.go b/internal/api/panic_enrichment_test.go new file mode 100644 index 0000000..e22d630 --- /dev/null +++ b/internal/api/panic_enrichment_test.go @@ -0,0 +1,99 @@ +package api + +import ( + "bytes" + "context" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +// TestRecoveryMiddleware_EnrichedLog verifies that a panic is logged +// with req_id, route, method, user, and a stack trace. Block 3.7. +func TestRecoveryMiddleware_EnrichedLog(t *testing.T) { + // Capture slog output via a TextHandler into a buffer. + var buf bytes.Buffer + prev := slog.Default() + slog.SetDefault(slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{ + Level: slog.LevelDebug, + }))) + t.Cleanup(func() { slog.SetDefault(prev) }) + + panicky := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("boom") + }) + handler := recoveryMiddleware(panicky) + + // Seed request with a request id (mimicking loggingMiddleware + // having already run) and a user ctx value. + req := httptest.NewRequest(http.MethodPost, "/api/documents/abc", nil) + ctx := context.WithValue(req.Context(), ctxRequestIDKey{}, "rid-test-123") + ctx = withUserForTest(ctx, "alice") + req = req.WithContext(ctx) + + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusInternalServerError { + t.Fatalf("status = %d; want 500", rec.Code) + } + + logOutput := buf.String() + for _, want := range []string{ + "panic recovered", + "req_id=rid-test-123", + "method=POST", + "route=/api/documents/abc", + "user=alice", + "panic=boom", + } { + if !strings.Contains(logOutput, want) { + t.Errorf("log missing %q\nlog: %s", want, logOutput) + } + } + // A stack trace marker ("goroutine " or "runtime/panic.go") must + // appear somewhere in the log. slog serializes newlines as \n + // literals inside the stack attribute — either is acceptable. + if !strings.Contains(logOutput, "goroutine") && !strings.Contains(logOutput, "runtime/panic") { + t.Errorf("log missing stack trace marker\nlog: %s", logOutput) + } +} + +// TestRecoveryMiddleware_NoUserNoReqID verifies the middleware still +// recovers cleanly when neither ctxUserKey nor request id are set. +func TestRecoveryMiddleware_NoUserNoReqID(t *testing.T) { + var buf bytes.Buffer + prev := slog.Default() + slog.SetDefault(slog.New(slog.NewTextHandler(&buf, &slog.HandlerOptions{ + Level: slog.LevelDebug, + }))) + t.Cleanup(func() { slog.SetDefault(prev) }) + + panicky := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + panic("boom-bare") + }) + handler := recoveryMiddleware(panicky) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusInternalServerError { + t.Fatalf("status = %d; want 500", rec.Code) + } + logOutput := buf.String() + if strings.Contains(logOutput, "req_id=") { + t.Errorf("unexpected req_id attr on unset ctx\nlog: %s", logOutput) + } + if strings.Contains(logOutput, "user=") { + t.Errorf("unexpected user attr on unset ctx\nlog: %s", logOutput) + } +} + +// withUserForTest is a test-only helper that injects a user id into +// ctx using the same key the real auth middleware would use. +func withUserForTest(ctx context.Context, user string) context.Context { + return context.WithValue(ctx, ctxUserKey{}, user) +} diff --git a/internal/api/request_id.go b/internal/api/request_id.go index b7d2f4a..d35f5e3 100644 --- a/internal/api/request_id.go +++ b/internal/api/request_id.go @@ -10,6 +10,12 @@ import ( // struct keeps the key collision-free with other packages. type ctxRequestIDKey struct{} +// ctxUserKey is the typed context key for the authenticated user ID. +// Reserved for future auth integration; today only the panic +// recoveryMiddleware reads it. Empty string means the request is +// unauthenticated (allowed on public routes). Block 3.7. +type ctxUserKey struct{} + // RequestIDFromContext returns the ID attached by loggingMiddleware. Empty // string when called from a context that never hit the middleware (e.g. // a direct handler call from a test that skips the router wrapper). diff --git a/internal/api/request_timeout.go b/internal/api/request_timeout.go new file mode 100644 index 0000000..75cf5f1 --- /dev/null +++ b/internal/api/request_timeout.go @@ -0,0 +1,65 @@ +package api + +import ( + "net/http" + "strings" + + "github.com/RandomCodeSpace/docsiq/internal/config" +) + +// isUploadRoute reports whether r should be granted UploadTimeout. +// The carve-out covers two long-running paths: +// - POST /api/upload — multipart document upload +// - POST /api/projects/{p}/import — tar / bulk notes import +// +// A /api/projects/* POST is only an upload if the trailing segment is +// /import — this avoids granting 10-minute timeouts to note-write POSTs +// on the same prefix. Block 3.2. +func isUploadRoute(r *http.Request) bool { + if r.Method != http.MethodPost { + return false + } + switch { + case r.URL.Path == "/api/upload": + return true + case strings.HasPrefix(r.URL.Path, "/api/projects/") && strings.HasSuffix(r.URL.Path, "/import"): + return true + } + return false +} + +// requestTimeoutMiddleware wraps inner in http.TimeoutHandler with +// cfg.Server.RequestTimeout as the default bound, bumped to +// cfg.Server.UploadTimeout for upload routes. +// +// Zero timeout means "no cap" — useful for local dev. In that case +// inner is returned unchanged. +// +// Layering rationale (Block 3.2 comment): this middleware sits INSIDE +// securityHeadersMiddleware (so a 503 still carries CSP) and OUTSIDE +// loggingMiddleware (so the timeout is logged). See router.go. +func requestTimeoutMiddleware(cfg *config.Config) func(http.Handler) http.Handler { + return func(inner http.Handler) http.Handler { + reqTimeout := cfg.Server.RequestTimeout + upTimeout := cfg.Server.UploadTimeout + + // Pre-build the two TimeoutHandler instances so each request + // just dispatches to one — no per-request allocation. + defaultTO := inner + if reqTimeout > 0 { + defaultTO = http.TimeoutHandler(inner, reqTimeout, "request timeout") + } + uploadTO := inner + if upTimeout > 0 { + uploadTO = http.TimeoutHandler(inner, upTimeout, "upload timeout") + } + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if isUploadRoute(r) { + uploadTO.ServeHTTP(w, r) + return + } + defaultTO.ServeHTTP(w, r) + }) + } +} diff --git a/internal/api/router.go b/internal/api/router.go index 85019e2..c8b64cd 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -2,10 +2,12 @@ package api import ( "context" + "fmt" "io/fs" "log/slog" "net/http" "path" + "runtime/debug" "strings" "time" @@ -166,11 +168,15 @@ func NewRouter(prov llm.Provider, emb *embedder.Embedder, cfg *config.Config, re // recoveries). project scope sits BELOW auth (an unauthenticated // caller never reaches the registry) and ABOVE the mux (so handlers // and the MCP server see the resolved slug via ProjectFromContext). + // Block 3.2: requestTimeoutMiddleware sits INSIDE securityHeaders + // (so 503 timeouts still carry CSP) and OUTSIDE loggingMiddleware + // (so operators still see the latency spike in logs). return securityHeadersMiddleware(cfg)( - loggingMiddleware( - recoveryMiddleware( - bearerAuthMiddleware(cfg.Server.APIKey, - projectMiddleware(cfg, registry, mux))))) + requestTimeoutMiddleware(cfg)( + loggingMiddleware( + recoveryMiddleware( + bearerAuthMiddleware(cfg.Server.APIKey, + projectMiddleware(cfg, registry, mux)))))) } func spaHandler(assets fs.FS, _ *config.Config) http.Handler { @@ -204,12 +210,38 @@ func spaHandler(assets fs.FS, _ *config.Config) http.Handler { }) } -// recoveryMiddleware catches panics in handlers and returns a 500 response. +// recoveryMiddleware catches panics in handlers, logs them with +// request context (req_id, route, method, user if authed) plus the +// full stack, then returns a 500 response. The enriched log surface +// is Block 3.7's requirement: during a production panic you need +// enough context to reconstruct the request without tailing raw +// stderr. func recoveryMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { defer func() { if rec := recover(); rec != nil { - slog.Error("❌ panic recovered", "path", r.URL.Path, "panic", rec) + // Gather every piece of request context that exists on + // the ctx — any absent value surfaces as "" and gets + // filtered from the attr list. + rid := RequestIDFromContext(r.Context()) + user, _ := r.Context().Value(ctxUserKey{}).(string) + + stack := debug.Stack() + + attrs := []any{ + "route", r.URL.Path, + "method", r.Method, + "panic", fmt.Sprint(rec), + "stack", string(stack), + } + if rid != "" { + attrs = append(attrs, "req_id", rid) + } + if user != "" { + attrs = append(attrs, "user", user) + } + + slog.Error("❌ panic recovered", attrs...) http.Error(w, "internal server error", http.StatusInternalServerError) } }() diff --git a/internal/api/shutdown_integration_test.go b/internal/api/shutdown_integration_test.go index d88293e..35e71d0 100644 --- a/internal/api/shutdown_integration_test.go +++ b/internal/api/shutdown_integration_test.go @@ -25,6 +25,9 @@ import ( // goroutine blocks in a select on a channel that is not drained // synchronously by *DB.Close(); go-sqlite3's docs acknowledge a // brief post-Close window. Tolerated as a stdlib artifact. +// - database/sql.connectionCleaner — spawned by SetConnMaxLifetime +// (Block 3.6 sets this to 1h on every store). Exits when the +// *sql.DB is closed; same post-Close window as connectionOpener. // // goleak.IgnoreCurrent() baselines out any stdlib goroutines that // predate test start (Go runtime housekeeping, test framework, etc.). @@ -34,6 +37,7 @@ func TestShutdown_NoGoroutineLeaks(t *testing.T) { goleak.IgnoreAnyFunction("net/http.(*persistConn).readLoop"), goleak.IgnoreAnyFunction("net/http.(*persistConn).writeLoop"), goleak.IgnoreAnyFunction("database/sql.(*DB).connectionOpener"), + goleak.IgnoreAnyFunction("database/sql.(*DB).connectionCleaner"), goleak.IgnoreAnyFunction("internal/poll.runtime_pollWait"), ) diff --git a/internal/api/timeout_test.go b/internal/api/timeout_test.go new file mode 100644 index 0000000..3b2f89c --- /dev/null +++ b/internal/api/timeout_test.go @@ -0,0 +1,118 @@ +package api + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/RandomCodeSpace/docsiq/internal/config" +) + +// TestRequestTimeoutMiddleware_FiresOnSlowHandler: a handler that +// sleeps past the request timeout returns 503 Service Unavailable. +func TestRequestTimeoutMiddleware_FiresOnSlowHandler(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Server.RequestTimeout = 50 * time.Millisecond + cfg.Server.UploadTimeout = 1 * time.Second + + slow := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(200 * time.Millisecond) + _, _ = w.Write([]byte("too late")) + }) + handler := requestTimeoutMiddleware(cfg)(slow) + + req := httptest.NewRequest(http.MethodGet, "/api/stats", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusServiceUnavailable { + t.Fatalf("status = %d; want 503", rec.Code) + } + body, _ := io.ReadAll(rec.Body) + if !strings.Contains(string(body), "request timeout") { + t.Fatalf("body = %q; want substring 'request timeout'", body) + } +} + +// TestRequestTimeoutMiddleware_UploadRouteGetsExtendedTimeout: an upload +// request that completes within UploadTimeout (but exceeds +// RequestTimeout) succeeds. +func TestRequestTimeoutMiddleware_UploadRouteGetsExtendedTimeout(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Server.RequestTimeout = 50 * time.Millisecond + cfg.Server.UploadTimeout = 500 * time.Millisecond + + slow := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + time.Sleep(200 * time.Millisecond) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ok":true}`)) + }) + handler := requestTimeoutMiddleware(cfg)(slow) + + req := httptest.NewRequest(http.MethodPost, "/api/upload", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d; want 200 (upload route under UploadTimeout)", rec.Code) + } +} + +// TestRequestTimeoutMiddleware_FastHandlerUnaffected: a handler that +// responds well within the timeout is passed through unchanged. +func TestRequestTimeoutMiddleware_FastHandlerUnaffected(t *testing.T) { + t.Parallel() + cfg := &config.Config{} + cfg.Server.RequestTimeout = 100 * time.Millisecond + cfg.Server.UploadTimeout = 1 * time.Second + + fast := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Test", "ok") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ok":true}`)) + }) + handler := requestTimeoutMiddleware(cfg)(fast) + + req := httptest.NewRequest(http.MethodGet, "/api/stats", nil) + rec := httptest.NewRecorder() + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d; want 200", rec.Code) + } + if got := rec.Header().Get("X-Test"); got != "ok" { + t.Fatalf("X-Test = %q; want ok", got) + } +} + +// TestIsUploadRoute_Classification: rules for which routes receive the +// upload timeout. +func TestIsUploadRoute_Classification(t *testing.T) { + t.Parallel() + cases := []struct { + method, path string + want bool + }{ + {http.MethodPost, "/api/upload", true}, + {http.MethodGet, "/api/upload", false}, // GET → request timeout + {http.MethodPost, "/api/projects/foo/import", true}, + {http.MethodPost, "/api/projects/foo/notes", false}, + {http.MethodPost, "/api/projects/foo", false}, + {http.MethodPost, "/api/stats", false}, + } + for _, c := range cases { + c := c + t.Run(c.method+" "+c.path, func(t *testing.T) { + req := httptest.NewRequest(c.method, c.path, nil) + got := isUploadRoute(req) + if got != c.want { + t.Fatalf("isUploadRoute(%s %s) = %v; want %v", c.method, c.path, got, c.want) + } + }) + } +} diff --git a/internal/config/config.go b/internal/config/config.go index 249ebc8..81a457b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" "strings" + "time" "github.com/spf13/viper" ) @@ -50,6 +51,13 @@ type LLMConfig struct { Azure AzureConfig `mapstructure:"azure"` Ollama OllamaConfig `mapstructure:"ollama"` OpenAI OpenAIConfig `mapstructure:"openai"` + + // CallTimeout caps the end-to-end duration of a single provider + // call (Complete / Embed / EmbedBatch). Any retry wrapper counts + // against this deadline — the timeout is NOT reset between + // attempts. Zero disables the per-call cap (caller's ctx is + // authoritative). Default 60s. Block 3.3. + CallTimeout time.Duration `mapstructure:"call_timeout"` } // OpenAIConfig configures the direct OpenAI (api.openai.com) provider, @@ -152,6 +160,16 @@ type ServerConfig struct { WorkqWorkers int `mapstructure:"workq_workers"` // 0 → runtime.NumCPU() WorkqDepth int `mapstructure:"workq_depth"` // 0 → 64 HSTSEnabled bool `mapstructure:"hsts_enabled"` // emits Strict-Transport-Security when true + + // RequestTimeout caps the duration of every HTTP handler except + // the carve-outs listed in isUploadRoute. Block 3.2 default 30s. + // Zero disables the cap (not recommended in production). + RequestTimeout time.Duration `mapstructure:"request_timeout"` + + // UploadTimeout caps long-running upload / import endpoints + // (POST /api/upload, POST /api/projects/{project}/import). Block + // 3.2 default 10m. + UploadTimeout time.Duration `mapstructure:"upload_timeout"` } func Load(cfgFile string) (*Config, error) { @@ -197,6 +215,10 @@ func Load(cfgFile string) (*Config, error) { v.SetDefault("llm.openai.embed_model", "text-embedding-3-small") v.SetDefault("llm.openai.organization", "") + // LLM — per-call timeout (Block 3.3). Default 60s caps every + // Complete / Embed / EmbedBatch invocation. + v.SetDefault("llm.call_timeout", 60*time.Second) + // Indexing v.SetDefault("indexing.chunk_size", 512) v.SetDefault("indexing.chunk_overlap", 50) @@ -218,6 +240,8 @@ func Load(cfgFile string) (*Config, error) { v.SetDefault("server.workq_workers", 0) // 0 → runtime.NumCPU() v.SetDefault("server.workq_depth", 64) v.SetDefault("server.hsts_enabled", false) + v.SetDefault("server.request_timeout", 30*time.Second) + v.SetDefault("server.upload_timeout", 10*time.Minute) // Config file search paths. Only ~/.docsiq and CWD are consulted. newCfgDir := filepath.Join(home, ".docsiq") @@ -244,6 +268,9 @@ func Load(cfgFile string) (*Config, error) { _ = v.BindEnv("server.workq_workers", "DOCSIQ_SERVER_WORKQ_WORKERS") _ = v.BindEnv("server.workq_depth", "DOCSIQ_SERVER_WORKQ_DEPTH") _ = v.BindEnv("server.hsts_enabled", "DOCSIQ_SERVER_HSTS_ENABLED") + _ = v.BindEnv("llm.call_timeout", "DOCSIQ_LLM_CALL_TIMEOUT") + _ = v.BindEnv("server.request_timeout", "DOCSIQ_SERVER_REQUEST_TIMEOUT") + _ = v.BindEnv("server.upload_timeout", "DOCSIQ_SERVER_UPLOAD_TIMEOUT") if err := v.ReadInConfig(); err != nil { if _, ok := err.(viper.ConfigFileNotFoundError); ok { diff --git a/internal/crawler/crawler.go b/internal/crawler/crawler.go index d46529b..9393c93 100644 --- a/internal/crawler/crawler.go +++ b/internal/crawler/crawler.go @@ -54,7 +54,7 @@ func Crawl(ctx context.Context, rootURL string, opts Options) ([]*Page, error) { // Try sitemap first var urls []string if !opts.SkipSitemap { - urls, err = discoverSitemap(client, base) + urls, err = discoverSitemap(ctx, client, base) if err != nil { slog.Debug("🔍 sitemap not found, falling back to BFS", "url", rootURL, "reason", err) } else { @@ -131,7 +131,7 @@ type urlSet struct { } `xml:"url"` } -func discoverSitemap(client *http.Client, base *url.URL) ([]string, error) { +func discoverSitemap(ctx context.Context, client *http.Client, base *url.URL) ([]string, error) { candidates := []string{ base.Scheme + "://" + base.Host + "/sitemap.xml", base.String() + "/sitemap.xml", @@ -139,7 +139,7 @@ func discoverSitemap(client *http.Client, base *url.URL) ([]string, error) { } for _, candidate := range candidates { - urls, err := parseSitemap(client, candidate, base) + urls, err := parseSitemap(ctx, client, candidate, base) if err == nil && len(urls) > 0 { slog.Debug("🔍 sitemap parsed", "url", candidate, "entries", len(urls)) return urls, nil @@ -148,9 +148,16 @@ func discoverSitemap(client *http.Client, base *url.URL) ([]string, error) { return nil, fmt.Errorf("no sitemap found") } -func parseSitemap(client *http.Client, sitemapURL string, base *url.URL) ([]string, error) { - resp, err := client.Get(sitemapURL) +func parseSitemap(ctx context.Context, client *http.Client, sitemapURL string, base *url.URL) ([]string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, sitemapURL, nil) + if err != nil { + return nil, fmt.Errorf("sitemap request: %w", err) + } + resp, err := client.Do(req) if err != nil || resp.StatusCode != http.StatusOK { + if resp != nil { + _ = resp.Body.Close() + } return nil, fmt.Errorf("sitemap not found") } defer resp.Body.Close() @@ -166,7 +173,7 @@ func parseSitemap(client *http.Client, sitemapURL string, base *url.URL) ([]stri slog.Debug("🔍 sitemap index found", "url", sitemapURL, "sub_sitemaps", len(idx.Sitemaps)) var all []string for _, s := range idx.Sitemaps { - sub, err := parseSitemap(client, s.Loc, base) + sub, err := parseSitemap(ctx, client, s.Loc, base) if err == nil { all = append(all, sub...) } @@ -220,7 +227,7 @@ func bfsCrawl(ctx context.Context, client *http.Client, base *url.URL, opts Opti continue } - links := extractLinks(client, item.u, base) + links := extractLinks(ctx, client, item.u, base) slog.Debug("🔗 BFS page links extracted", "url", item.u, "depth", item.depth, "links", len(links)) for _, l := range links { if !visited[l] { @@ -234,12 +241,20 @@ func bfsCrawl(ctx context.Context, client *http.Client, base *url.URL, opts Opti return found } -func extractLinks(client *http.Client, pageURL string, base *url.URL) []string { - resp, err := client.Get(pageURL) +func extractLinks(ctx context.Context, client *http.Client, pageURL string, base *url.URL) []string { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, pageURL, nil) + if err != nil { + slog.Debug("⚠️ failed to build request for link extraction", "url", pageURL, "err", err) + return nil + } + resp, err := client.Do(req) if err != nil || resp.StatusCode != http.StatusOK { if err != nil { slog.Debug("⚠️ failed to fetch page for link extraction", "url", pageURL, "err", err) } + if resp != nil { + _ = resp.Body.Close() + } return nil } defer resp.Body.Close() diff --git a/internal/embedder/batch_test.go b/internal/embedder/batch_test.go new file mode 100644 index 0000000..e4b4653 --- /dev/null +++ b/internal/embedder/batch_test.go @@ -0,0 +1,125 @@ +package embedder + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/RandomCodeSpace/docsiq/internal/llm" +) + +// recordingProvider implements llm.Provider and captures every +// EmbedBatch call's slice length. It returns zero-filled vectors of +// length 4 per input text. +type recordingProvider struct { + mu sync.Mutex + ceiling int + callSizes []int + delay time.Duration +} + +func (r *recordingProvider) Name() string { return "recording" } +func (r *recordingProvider) ModelID() string { return "recording-v1" } +func (r *recordingProvider) BatchCeiling() int { return r.ceiling } + +func (r *recordingProvider) Complete(ctx context.Context, prompt string, opts ...llm.Option) (string, error) { + return "", nil +} + +func (r *recordingProvider) Embed(ctx context.Context, text string) ([]float32, error) { + return []float32{0, 0, 0, 0}, nil +} + +func (r *recordingProvider) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) { + r.mu.Lock() + r.callSizes = append(r.callSizes, len(texts)) + r.mu.Unlock() + if r.delay > 0 { + select { + case <-time.After(r.delay): + case <-ctx.Done(): + return nil, ctx.Err() + } + } + out := make([][]float32, len(texts)) + for i := range texts { + out[i] = []float32{0, 0, 0, 0} + } + return out, nil +} + +// TestEmbedder_New_ClampsToBatchCeiling: a user asking for batchSize=5000 +// against an OpenAI-like provider with ceiling=2048 gets clamped to 2048. +func TestEmbedder_New_ClampsToBatchCeiling(t *testing.T) { + t.Parallel() + p := &recordingProvider{ceiling: 2048} + e := New(p, 5000) + if e.batchSize != 2048 { + t.Fatalf("batchSize = %d; want 2048 (clamped to ceiling)", e.batchSize) + } +} + +// TestEmbedder_New_BelowCeilingIsUnchanged: a user asking for 100 against +// a ceiling of 2048 keeps 100. +func TestEmbedder_New_BelowCeilingIsUnchanged(t *testing.T) { + t.Parallel() + p := &recordingProvider{ceiling: 2048} + e := New(p, 100) + if e.batchSize != 100 { + t.Fatalf("batchSize = %d; want 100 (unchanged)", e.batchSize) + } +} + +// TestEmbedder_EmbedTexts_ChunksToBatchSize: 500 texts with batchSize=100 +// results in 5 EmbedBatch calls, each of size 100. +func TestEmbedder_EmbedTexts_ChunksToBatchSize(t *testing.T) { + t.Parallel() + p := &recordingProvider{ceiling: 2048} + e := New(p, 100) + + texts := make([]string, 500) + for i := range texts { + texts[i] = "t" + } + + if _, err := e.EmbedTexts(context.Background(), texts); err != nil { + t.Fatalf("EmbedTexts: %v", err) + } + + p.mu.Lock() + defer p.mu.Unlock() + if len(p.callSizes) != 5 { + t.Fatalf("EmbedBatch calls = %d; want 5 (500 / 100)", len(p.callSizes)) + } + for i, n := range p.callSizes { + if n != 100 { + t.Fatalf("call[%d] size = %d; want 100", i, n) + } + } +} + +// TestEmbedder_EmbedTexts_PreservesOrder: returned vectors are assembled +// in input order, even with concurrent batches. +func TestEmbedder_EmbedTexts_PreservesOrder(t *testing.T) { + t.Parallel() + p := &recordingProvider{ceiling: 2048, delay: 5 * time.Millisecond} + e := New(llm.Provider(p), 50) + + texts := make([]string, 250) + for i := range texts { + texts[i] = "t" + } + vecs, err := e.EmbedTexts(context.Background(), texts) + if err != nil { + t.Fatalf("EmbedTexts: %v", err) + } + if len(vecs) != 250 { + t.Fatalf("vecs len = %d; want 250", len(vecs)) + } + for i, v := range vecs { + if len(v) != 4 { + t.Fatalf("vecs[%d] len = %d; want 4", i, len(v)) + } + } +} diff --git a/internal/embedder/embedder.go b/internal/embedder/embedder.go index 4a4edde..4581f96 100644 --- a/internal/embedder/embedder.go +++ b/internal/embedder/embedder.go @@ -17,6 +17,11 @@ type Embedder struct { // New creates a new Embedder. If provider is nil (LLM disabled via // provider=none), New returns nil. Callers must check for nil before use. +// +// Block 3.4: the caller-supplied batchSize is clamped to the provider's +// declared batch ceiling (OpenAI 2048, Azure 16, Ollama 128). A +// batchSize exceeding the ceiling would cause silent truncation or +// explicit 400s depending on the provider. func New(provider llm.Provider, batchSize int) *Embedder { if provider == nil { return nil @@ -24,6 +29,9 @@ func New(provider llm.Provider, batchSize int) *Embedder { if batchSize <= 0 { batchSize = 20 } + if ceiling := provider.BatchCeiling(); ceiling > 0 && batchSize > ceiling { + batchSize = ceiling + } return &Embedder{provider: provider, batchSize: batchSize, concurrency: 4} } diff --git a/internal/llm/httpclient.go b/internal/llm/httpclient.go new file mode 100644 index 0000000..a7f8843 --- /dev/null +++ b/internal/llm/httpclient.go @@ -0,0 +1,47 @@ +// Package llm — HTTP client pooling per provider (Block 3.5). +// +// langchaingo constructs a fresh net/http.Transport inside each +// provider constructor by default. For a long-running server that +// calls the same provider on every request, that leaks connections: +// every call-site allocates its own idle-conn pool, TLS session +// cache, and DNS resolver bucket. Pooling one *http.Client per +// provider (constructed here) fixes the leak. +package llm + +import ( + "net" + "net/http" + "time" +) + +// newHTTPClient returns a *http.Client tuned for long-lived LLM +// provider traffic. The transport settings are spec-driven: +// - MaxIdleConns=100 — plenty of headroom for bursty batching +// - MaxIdleConnsPerHost=10 — matches langchaingo default fan-out +// - IdleConnTimeout=90s — trim idle conns before cloud LBs do +// - TLSHandshakeTimeout=10s — fail fast on broken TLS upstreams +// - ResponseHeaderTimeout=60s — distinct from body-stream timeout; +// bounds the silent-server failure mode +// +// Deliberately NOT set: +// - Client.Timeout — would hard-cut streaming bodies; per-call +// timeouts come from ctx (Task 3 / Block 3.3). +// - DialContext timeout — Go's default (no timeout, relies on ctx) +// is correct here; a fixed dial timeout fights ctx-driven shutdown. +func newHTTPClient() *http.Client { + tr := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: 60 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + return &http.Client{Transport: tr} +} diff --git a/internal/llm/httpclient_test.go b/internal/llm/httpclient_test.go new file mode 100644 index 0000000..7133884 --- /dev/null +++ b/internal/llm/httpclient_test.go @@ -0,0 +1,48 @@ +package llm + +import ( + "net/http" + "testing" + "time" +) + +// TestNewHTTPClient_TransportSettings verifies the tuned transport +// settings required by Block 3.5. +func TestNewHTTPClient_TransportSettings(t *testing.T) { + t.Parallel() + c := newHTTPClient() + if c == nil { + t.Fatal("newHTTPClient returned nil") + } + tr, ok := c.Transport.(*http.Transport) + if !ok { + t.Fatalf("Transport = %T; want *http.Transport", c.Transport) + } + if got, want := tr.MaxIdleConns, 100; got != want { + t.Errorf("MaxIdleConns = %d; want %d", got, want) + } + if got, want := tr.MaxIdleConnsPerHost, 10; got != want { + t.Errorf("MaxIdleConnsPerHost = %d; want %d", got, want) + } + if got, want := tr.IdleConnTimeout, 90*time.Second; got != want { + t.Errorf("IdleConnTimeout = %v; want %v", got, want) + } + if got, want := tr.TLSHandshakeTimeout, 10*time.Second; got != want { + t.Errorf("TLSHandshakeTimeout = %v; want %v", got, want) + } + if got, want := tr.ResponseHeaderTimeout, 60*time.Second; got != want { + t.Errorf("ResponseHeaderTimeout = %v; want %v", got, want) + } +} + +// TestNewHTTPClient_NoClientTimeout asserts we do NOT set +// http.Client.Timeout — that would hard-cut the body mid-stream on +// large embedding responses. Per-call timeouts live on ctx instead +// (Task 3). +func TestNewHTTPClient_NoClientTimeout(t *testing.T) { + t.Parallel() + c := newHTTPClient() + if c.Timeout != 0 { + t.Fatalf("Client.Timeout = %v; want 0 (use ctx per-call)", c.Timeout) + } +} diff --git a/internal/llm/openai.go b/internal/llm/openai.go index 6a58e0d..c047627 100644 --- a/internal/llm/openai.go +++ b/internal/llm/openai.go @@ -45,10 +45,16 @@ func newOpenAIProvider(cfg *config.LLMConfig) (Provider, error) { embedModel = defaultOpenAIEmbedModel } + // Block 3.5: one pooled *http.Client shared across chat + embed + // langchaingo handles. Same connection pool for every outbound + // request the lcProvider makes. + httpClient := newHTTPClient() + chatOpts := []openai.Option{ openai.WithToken(oc.APIKey), openai.WithBaseURL(baseURL), openai.WithModel(chatModel), + openai.WithHTTPClient(httpClient), } if oc.Organization != "" { chatOpts = append(chatOpts, openai.WithOrganization(oc.Organization)) @@ -66,6 +72,7 @@ func newOpenAIProvider(cfg *config.LLMConfig) (Provider, error) { // even for an embedding-only client — it refuses to build // without a chat model set. Reuse the same model here. openai.WithModel(chatModel), + openai.WithHTTPClient(httpClient), } if oc.Organization != "" { embedOpts = append(embedOpts, openai.WithOrganization(oc.Organization)) @@ -82,9 +89,12 @@ func newOpenAIProvider(cfg *config.LLMConfig) (Provider, error) { // care about for vector-dimension consistency — mirroring the // Azure provider's convention. return &lcProvider{ - llm: chatLLM, - emb: emb, - name: "openai", - modelID: embedModel, + llm: chatLLM, + emb: emb, + name: "openai", + modelID: embedModel, + httpClient: httpClient, + callTimeout: cfg.CallTimeout, + batchCeiling: 2048, }, nil } diff --git a/internal/llm/provider.go b/internal/llm/provider.go index d0b0681..3dfc4d8 100644 --- a/internal/llm/provider.go +++ b/internal/llm/provider.go @@ -3,6 +3,8 @@ package llm import ( "context" "fmt" + "net/http" + "time" "github.com/RandomCodeSpace/docsiq/internal/config" "github.com/tmc/langchaingo/embeddings" @@ -20,7 +22,7 @@ type callOptions struct { jsonMode bool } -func WithMaxTokens(n int) Option { return func(o *callOptions) { o.maxTokens = n } } +func WithMaxTokens(n int) Option { return func(o *callOptions) { o.maxTokens = n } } func WithTemperature(t float64) Option { return func(o *callOptions) { o.temperature = t } } func WithJSONMode() Option { return func(o *callOptions) { o.jsonMode = true } } @@ -39,6 +41,12 @@ type Provider interface { EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) Name() string ModelID() string + // BatchCeiling returns the maximum number of texts that can be + // passed to EmbedBatch in a single call. Callers that need to + // process larger inputs must slice to this ceiling. Zero means + // "no declared ceiling" (rare — only for providers that don't + // care). Block 3.4. + BatchCeiling() int } // NewProvider constructs the configured provider. @@ -74,12 +82,41 @@ type lcProvider struct { emb embeddings.Embedder name string modelID string + + // Block 3.5: pooled HTTP client shared with the langchaingo + // sub-clients. Stored here so tests can assert on it and so + // future work can swap it (e.g. for a tracing transport). + httpClient *http.Client + + // Block 3.3: optional per-call timeout wrapped around ctx. Zero + // means "no timeout" (caller's ctx is authoritative); positive + // values trigger context.WithTimeout in Complete/Embed/EmbedBatch. + callTimeout time.Duration + + // Block 3.4: provider-declared batch ceiling. EmbedBatch slices + // input to this size; caller-visible chunking also uses this + // value so the Embedder can construct correctly-sized jobs. + batchCeiling int } -func (p *lcProvider) Name() string { return p.name } -func (p *lcProvider) ModelID() string { return p.modelID } +func (p *lcProvider) Name() string { return p.name } +func (p *lcProvider) ModelID() string { return p.modelID } +func (p *lcProvider) BatchCeiling() int { return p.batchCeiling } + +// withCallTimeout returns a child ctx bounded by p.callTimeout when +// positive, plus its cancel. Zero/negative callTimeout returns the +// parent ctx unchanged and a no-op cancel — callers always defer +// cancel() without branching. Block 3.3. +func (p *lcProvider) withCallTimeout(parent context.Context) (context.Context, context.CancelFunc) { + if p.callTimeout <= 0 { + return parent, func() {} + } + return context.WithTimeout(parent, p.callTimeout) +} func (p *lcProvider) Complete(ctx context.Context, prompt string, opts ...Option) (string, error) { + ctx, cancel := p.withCallTimeout(ctx) + defer cancel() o := applyOptions(opts) callOpts := []llms.CallOption{ llms.WithMaxTokens(o.maxTokens), @@ -92,17 +129,105 @@ func (p *lcProvider) Complete(ctx context.Context, prompt string, opts ...Option } func (p *lcProvider) Embed(ctx context.Context, text string) ([]float32, error) { + ctx, cancel := p.withCallTimeout(ctx) + defer cancel() return p.emb.EmbedQuery(ctx, text) } +// EmbedBatch embeds texts in provider-sized chunks. Input is sliced to +// at-most p.batchCeiling per upstream request. Per-chunk results are +// pushed through a buffered channel — a slow consumer backpressures +// the producer once the buffer fills. +// +// The function assembles the final [][]float32 in input order. Errors +// from any chunk short-circuit the whole call via ctx cancellation. +// +// When batchCeiling <= 0 we fall back to a single upstream call — no +// chunking, no buffer. That path preserves behaviour for providers +// that have not declared a ceiling. Block 3.4. func (p *lcProvider) EmbedBatch(ctx context.Context, texts []string) ([][]float32, error) { - return p.emb.EmbedDocuments(ctx, texts) + ctx, cancel := p.withCallTimeout(ctx) + defer cancel() + + if len(texts) == 0 { + return nil, nil + } + + // No declared ceiling — single pass, preserve old behaviour. + if p.batchCeiling <= 0 { + return p.emb.EmbedDocuments(ctx, texts) + } + + ceiling := p.batchCeiling + if len(texts) <= ceiling { + return p.emb.EmbedDocuments(ctx, texts) + } + + // Chunk boundaries (start, end) pairs — deterministic order. + type chunk struct { + start, end int + } + var chunks []chunk + for i := 0; i < len(texts); i += ceiling { + end := i + ceiling + if end > len(texts) { + end = len(texts) + } + chunks = append(chunks, chunk{start: i, end: end}) + } + + type chunkResult struct { + start int + vecs [][]float32 + err error + } + // Buffer sized 2 chunks (equivalent to 2*ceiling vector slots). One + // chunk completed, one en route; a slow consumer backpressures the + // third. Concurrent multi-chunk dispatch is the Embedder's job. + results := make(chan chunkResult, 2) + + // Producer: iterate chunks serially. Serial emission is intentional + // — the buffer provides headroom for a single slow consumer step. + go func() { + defer close(results) + for _, c := range chunks { + slice := texts[c.start:c.end] + vecs, err := p.emb.EmbedDocuments(ctx, slice) + select { + case results <- chunkResult{start: c.start, vecs: vecs, err: err}: + case <-ctx.Done(): + return + } + if err != nil { + return + } + } + }() + + out := make([][]float32, len(texts)) + for r := range results { + if r.err != nil { + return nil, fmt.Errorf("embed batch [%d:%d]: %w", r.start, r.start+len(r.vecs), r.err) + } + for i, v := range r.vecs { + out[r.start+i] = v + } + } + + // Defensive: every slot must be populated. If a chunk errored + // between buffer push and loop drain we'd have returned above; + // reaching here means every result arrived. + return out, ctx.Err() } func newOllamaProvider(cfg *config.LLMConfig) (Provider, error) { + // Block 3.5: one pooled *http.Client shared across chat + embed handles. + httpClient := newHTTPClient() + chatLLM, err := ollama.New( ollama.WithServerURL(cfg.Ollama.BaseURL), ollama.WithModel(cfg.Ollama.ChatModel), + ollama.WithHTTPClient(httpClient), ) if err != nil { return nil, fmt.Errorf("ollama chat LLM: %w", err) @@ -110,6 +235,7 @@ func newOllamaProvider(cfg *config.LLMConfig) (Provider, error) { embedLLM, err := ollama.New( ollama.WithServerURL(cfg.Ollama.BaseURL), ollama.WithModel(cfg.Ollama.EmbedModel), + ollama.WithHTTPClient(httpClient), ) if err != nil { return nil, fmt.Errorf("ollama embed LLM: %w", err) @@ -118,18 +244,30 @@ func newOllamaProvider(cfg *config.LLMConfig) (Provider, error) { if err != nil { return nil, fmt.Errorf("ollama embedder: %w", err) } - return &lcProvider{llm: chatLLM, emb: emb, name: "ollama", modelID: cfg.Ollama.EmbedModel}, nil + return &lcProvider{ + llm: chatLLM, + emb: emb, + name: "ollama", + modelID: cfg.Ollama.EmbedModel, + httpClient: httpClient, + callTimeout: cfg.CallTimeout, + batchCeiling: 128, + }, nil } func newAzureProvider(cfg *config.LLMConfig) (Provider, error) { az := &cfg.Azure + // Block 3.5: one pooled *http.Client shared across chat + embed handles. + httpClient := newHTTPClient() + chatLLM, err := openai.New( openai.WithBaseURL(az.ChatEndpoint()), openai.WithToken(az.ChatAPIKey()), openai.WithAPIVersion(az.ChatAPIVersion()), openai.WithAPIType(openai.APITypeAzure), openai.WithModel(az.ChatModel()), + openai.WithHTTPClient(httpClient), ) if err != nil { return nil, fmt.Errorf("azure openai chat LLM: %w", err) @@ -141,6 +279,7 @@ func newAzureProvider(cfg *config.LLMConfig) (Provider, error) { openai.WithAPIVersion(az.EmbedAPIVersion()), openai.WithAPIType(openai.APITypeAzure), openai.WithEmbeddingModel(az.EmbedModel()), + openai.WithHTTPClient(httpClient), ) if err != nil { return nil, fmt.Errorf("azure openai embed LLM: %w", err) @@ -150,5 +289,13 @@ func newAzureProvider(cfg *config.LLMConfig) (Provider, error) { if err != nil { return nil, fmt.Errorf("azure openai embedder: %w", err) } - return &lcProvider{llm: chatLLM, emb: emb, name: "azure", modelID: az.EmbedModel()}, nil + return &lcProvider{ + llm: chatLLM, + emb: emb, + name: "azure", + modelID: az.EmbedModel(), + httpClient: httpClient, + callTimeout: cfg.CallTimeout, + batchCeiling: 16, + }, nil } diff --git a/internal/llm/timeout_test.go b/internal/llm/timeout_test.go new file mode 100644 index 0000000..6c6204f --- /dev/null +++ b/internal/llm/timeout_test.go @@ -0,0 +1,159 @@ +package llm + +import ( + "context" + "errors" + "sync" + "testing" + "time" + + "github.com/tmc/langchaingo/embeddings" + "github.com/tmc/langchaingo/llms" +) + +// stubModel implements llms.Model by blocking forever on GenerateContent +// until the context is cancelled. It proves the provider honours ctx +// deadlines rather than swallowing them. +type stubModel struct{} + +func (stubModel) Call(ctx context.Context, prompt string, opts ...llms.CallOption) (string, error) { + return (stubModel{}).generate(ctx) +} + +func (stubModel) GenerateContent(ctx context.Context, msgs []llms.MessageContent, opts ...llms.CallOption) (*llms.ContentResponse, error) { + if _, err := (stubModel{}).generate(ctx); err != nil { + return nil, err + } + return &llms.ContentResponse{}, nil +} + +func (stubModel) generate(ctx context.Context) (string, error) { + <-ctx.Done() + return "", ctx.Err() +} + +// stubEmbedder blocks on EmbedDocuments / EmbedQuery until ctx done. +type stubEmbedder struct{} + +func (stubEmbedder) EmbedDocuments(ctx context.Context, texts []string) ([][]float32, error) { + <-ctx.Done() + return nil, ctx.Err() +} + +func (stubEmbedder) EmbedQuery(ctx context.Context, text string) ([]float32, error) { + <-ctx.Done() + return nil, ctx.Err() +} + +var _ embeddings.Embedder = stubEmbedder{} + +func TestLcProvider_Complete_HonoursCallTimeout(t *testing.T) { + t.Parallel() + p := &lcProvider{ + llm: stubModel{}, + emb: stubEmbedder{}, + name: "stub", + modelID: "stub", + callTimeout: 50 * time.Millisecond, + } + start := time.Now() + _, err := p.Complete(context.Background(), "hello") + elapsed := time.Since(start) + if err == nil { + t.Fatal("Complete: want non-nil error on timeout, got nil") + } + if !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("Complete error: want context.DeadlineExceeded, got %v", err) + } + if elapsed > 500*time.Millisecond { + t.Fatalf("Complete returned after %v; callTimeout=50ms — deadline not propagated", elapsed) + } +} + +func TestLcProvider_Embed_HonoursCallTimeout(t *testing.T) { + t.Parallel() + p := &lcProvider{ + llm: stubModel{}, + emb: stubEmbedder{}, + callTimeout: 50 * time.Millisecond, + } + start := time.Now() + _, err := p.Embed(context.Background(), "hello") + elapsed := time.Since(start) + if err == nil || !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("Embed error: want DeadlineExceeded, got %v", err) + } + if elapsed > 500*time.Millisecond { + t.Fatalf("Embed elapsed = %v; want < 500ms", elapsed) + } +} + +func TestLcProvider_ZeroCallTimeout_LeavesParentCtxAuthoritative(t *testing.T) { + t.Parallel() + p := &lcProvider{ + llm: stubModel{}, + emb: stubEmbedder{}, + callTimeout: 0, // disabled — parent ctx wins + } + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + _, err := p.Complete(ctx, "hello") + if err == nil || !errors.Is(err, context.DeadlineExceeded) { + t.Fatalf("Complete error with parent deadline: want DeadlineExceeded, got %v", err) + } +} + +// chunkCountingEmbedder counts how many times EmbedDocuments is called +// and with what sizes. Used to verify provider-level chunking. +type chunkCountingEmbedder struct { + mu sync.Mutex + callSizes []int +} + +func (c *chunkCountingEmbedder) EmbedDocuments(ctx context.Context, texts []string) ([][]float32, error) { + c.mu.Lock() + c.callSizes = append(c.callSizes, len(texts)) + c.mu.Unlock() + out := make([][]float32, len(texts)) + for i := range texts { + out[i] = []float32{float32(len(c.callSizes)), float32(i)} + } + return out, nil +} + +func (c *chunkCountingEmbedder) EmbedQuery(ctx context.Context, text string) ([]float32, error) { + return []float32{0}, nil +} + +func TestLcProvider_EmbedBatch_ChunksToCeiling(t *testing.T) { + t.Parallel() + ce := &chunkCountingEmbedder{} + p := &lcProvider{ + llm: stubModel{}, + emb: ce, + batchCeiling: 16, // Azure-sized + } + + texts := make([]string, 50) + for i := range texts { + texts[i] = "t" + } + + vecs, err := p.EmbedBatch(context.Background(), texts) + if err != nil { + t.Fatalf("EmbedBatch: %v", err) + } + if len(vecs) != 50 { + t.Fatalf("vecs len = %d; want 50", len(vecs)) + } + + ce.mu.Lock() + defer ce.mu.Unlock() + // 50 / 16 = 3 full chunks of 16 + 1 tail of 2 → 4 calls. + if len(ce.callSizes) != 4 { + t.Fatalf("chunk calls = %d; want 4", len(ce.callSizes)) + } + if ce.callSizes[0] != 16 || ce.callSizes[1] != 16 || ce.callSizes[2] != 16 || ce.callSizes[3] != 2 { + t.Fatalf("chunk sizes = %v; want [16 16 16 2]", ce.callSizes) + } +} diff --git a/internal/pipeline/pipeline_test.go b/internal/pipeline/pipeline_test.go index c83e278..4f27e33 100644 --- a/internal/pipeline/pipeline_test.go +++ b/internal/pipeline/pipeline_test.go @@ -18,8 +18,9 @@ import ( // zero vector or an empty string — nothing reaches the network. type nopProvider struct{} -func (nopProvider) Name() string { return "nop" } -func (nopProvider) ModelID() string { return "nop-0" } +func (nopProvider) Name() string { return "nop" } +func (nopProvider) ModelID() string { return "nop-0" } +func (nopProvider) BatchCeiling() int { return 0 } func (nopProvider) Complete(_ context.Context, _ string, _ ...llm.Option) (string, error) { return "", nil } diff --git a/internal/search/local_test.go b/internal/search/local_test.go index a67c985..d6ff19d 100644 --- a/internal/search/local_test.go +++ b/internal/search/local_test.go @@ -21,8 +21,9 @@ type mockProvider struct { embedCalls int } -func (m *mockProvider) Name() string { return "mock" } -func (m *mockProvider) ModelID() string { return m.modelID } +func (m *mockProvider) Name() string { return "mock" } +func (m *mockProvider) ModelID() string { return m.modelID } +func (m *mockProvider) BatchCeiling() int { return 0 } func (m *mockProvider) Complete(ctx context.Context, prompt string, opts ...llm.Option) (string, error) { m.completeCalls++ return "answer: " + prompt[:min(len(prompt), 32)], nil diff --git a/internal/store/store.go b/internal/store/store.go index ea7f9e8..a2c56fa 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -28,20 +28,48 @@ type Store struct { // open is the low-level SQLite opener. It is unexported — the only public // factory is OpenForProject. Kept as a helper because the project registry // and the per-project store both use the same DSN+migrate recipe. +// +// Block 3.6 hardening: +// - PRAGMAs set explicitly via Exec after sql.Open (driver-portable). +// - PRAGMA synchronous=NORMAL — the WAL-safe sweet spot that cuts two +// fsyncs per commit to one. +// - MaxOpenConns=4 + MaxIdleConns=2 allows concurrent readers under +// WAL; the writer is already serialized by SQLite itself. +// - ConnMaxLifetime=1h guards against stale connections in long-lived +// server processes. func open(path string) (*Store, error) { if path == "" { return nil, fmt.Errorf("open db: path is empty") } - // P1-2: _busy_timeout lets SQLite wait-and-retry on SQLITE_BUSY - // rather than failing instantly. With MaxOpenConns=1 this prevents - // spurious "database is locked" errors under concurrent load. - db, err := sql.Open("sqlite3", path+"?_journal_mode=WAL&_foreign_keys=on&_busy_timeout=5000") + // _busy_timeout retained in the DSN as a belt-and-braces default: + // the explicit PRAGMA below is the authoritative setting, but the + // DSN form protects against an early query landing before the + // PRAGMA Exec completes. + db, err := sql.Open("sqlite3", path+"?_busy_timeout=5000") if err != nil { return nil, fmt.Errorf("open db: %w", err) } - db.SetMaxOpenConns(1) // SQLite WAL allows 1 writer + + pragmas := []string{ + `PRAGMA journal_mode=WAL`, + `PRAGMA busy_timeout=5000`, + `PRAGMA synchronous=NORMAL`, + `PRAGMA foreign_keys=ON`, + } + for _, p := range pragmas { + if _, err := db.Exec(p); err != nil { + _ = db.Close() + return nil, fmt.Errorf("open db: %s: %w", p, err) + } + } + + db.SetMaxOpenConns(4) + db.SetMaxIdleConns(2) + db.SetConnMaxLifetime(1 * time.Hour) + s := &Store{db: db} if err := s.migrate(); err != nil { + _ = db.Close() return nil, err } return s, nil @@ -80,6 +108,13 @@ func OpenForProject(dataDir, slug string) (*Store, error) { func (s *Store) Close() error { return s.db.Close() } +// Ping verifies the database connection is alive. Uses PingContext so +// a cancelled ctx surfaces as ctx.Err(); callers (e.g. /readyz) can +// differentiate "request cancelled" from "SQLite broken". +func (s *Store) Ping(ctx context.Context) error { + return s.db.PingContext(ctx) +} + func (s *Store) DB() *sql.DB { return s.db } func (s *Store) migrate() error { diff --git a/internal/store/store_hardening_test.go b/internal/store/store_hardening_test.go new file mode 100644 index 0000000..bd76551 --- /dev/null +++ b/internal/store/store_hardening_test.go @@ -0,0 +1,97 @@ +package store + +import ( + "context" + "testing" + "time" +) + +// TestOpen_HardeningPragmas verifies Block 3.6 — every PRAGMA the spec +// requires is observable on a freshly-opened store. +func TestOpen_HardeningPragmas(t *testing.T) { + t.Parallel() + dir := t.TempDir() + s, err := OpenForProject(dir, "harden") + if err != nil { + t.Fatalf("OpenForProject: %v", err) + } + defer s.Close() + + cases := []struct { + name string + sql string + want string + }{ + {"journal_mode", `PRAGMA journal_mode`, "wal"}, + {"foreign_keys", `PRAGMA foreign_keys`, "1"}, + {"synchronous", `PRAGMA synchronous`, "1"}, // 1 = NORMAL + } + for _, c := range cases { + c := c + t.Run(c.name, func(t *testing.T) { + var got string + if err := s.DB().QueryRow(c.sql).Scan(&got); err != nil { + t.Fatalf("%s: %v", c.sql, err) + } + if got != c.want { + t.Fatalf("%s = %q; want %q", c.sql, got, c.want) + } + }) + } + + t.Run("busy_timeout_ge_5000", func(t *testing.T) { + var got int + if err := s.DB().QueryRow(`PRAGMA busy_timeout`).Scan(&got); err != nil { + t.Fatalf("PRAGMA busy_timeout: %v", err) + } + if got < 5000 { + t.Fatalf("busy_timeout = %d ms; want >= 5000", got) + } + }) +} + +// TestOpen_PoolSettings asserts the raised MaxOpenConns / MaxIdleConns +// values survive the Open recipe. MaxOpenConns=4, MaxIdleConns=2, +// ConnMaxLifetime=1h are not individually observable via sql.DB stats +// without opening connections; we assert on Stats().MaxOpenConnections +// which reflects SetMaxOpenConns. +func TestOpen_PoolSettings(t *testing.T) { + t.Parallel() + dir := t.TempDir() + s, err := OpenForProject(dir, "pool") + if err != nil { + t.Fatalf("OpenForProject: %v", err) + } + defer s.Close() + + stats := s.DB().Stats() + if stats.MaxOpenConnections != 4 { + t.Fatalf("MaxOpenConnections = %d; want 4", stats.MaxOpenConnections) + } +} + +// TestStore_PingContext asserts the new context-aware Ping method. +// A cancelled context must surface as a ctx.Err(), not a generic +// database error, so that /readyz can distinguish "caller gave up" +// from "SQLite is sick". +func TestStore_PingContext(t *testing.T) { + t.Parallel() + dir := t.TempDir() + s, err := OpenForProject(dir, "ping") + if err != nil { + t.Fatalf("OpenForProject: %v", err) + } + defer s.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if err := s.Ping(ctx); err != nil { + t.Fatalf("Ping: %v", err) + } + + cancelled, cancel2 := context.WithCancel(context.Background()) + cancel2() + if err := s.Ping(cancelled); err == nil { + t.Fatalf("Ping on cancelled ctx: want non-nil error, got nil") + } +} diff --git a/scripts/ctx-audit.sh b/scripts/ctx-audit.sh new file mode 100755 index 0000000..1bbe20f --- /dev/null +++ b/scripts/ctx-audit.sh @@ -0,0 +1,80 @@ +#!/usr/bin/env bash +# scripts/ctx-audit.sh — Block 3.1 static check. +# +# Two guarantees across internal/{llm,embedder,extractor,crawler,store}: +# 1. No HTTP call bypasses ctx — http.Get / http.Post / client.Get / +# client.Post / http.DefaultClient / http.NewRequest (non-ctx) are +# all banned. Use http.NewRequestWithContext + client.Do instead. +# 2. No DB call bypasses ctx — .Query / .Exec / .QueryRow are banned. +# Use .QueryContext / .ExecContext / .QueryRowContext instead. +# The migrate() / open() PRAGMA paths are exempt because ctx is +# not yet available at store construction time. +# +# Exits non-zero if any violation is found. Intended as a CI gate. +# +# Note on exported-func auditing: doing a robust first-arg-type check in +# bash against Go source requires a full parser (return tuples like +# `func F(int) (int, error)` break naive regex). The Go compiler itself +# plus `go vet` already ensures ctx-accepting function signatures are +# respected at every call site; the value this script adds is the I/O +# side-channel check that vet does not cover. We intentionally scope to +# (1) + (2). +set -euo pipefail + +PACKAGES=( + internal/llm + internal/embedder + internal/extractor + internal/crawler + internal/store +) + +ROOT="$(git rev-parse --show-toplevel 2>/dev/null || pwd)" +cd "$ROOT" + +fail=0 + +echo "==> HTTP calls without ctx" +for pkg in "${PACKAGES[@]}"; do + hits="$(grep -rnE \ + -e 'http\.Get\(' \ + -e 'http\.Post\(' \ + -e '(^|[^A-Za-z_])client\.Get\(' \ + -e '(^|[^A-Za-z_])client\.Post\(' \ + -e 'http\.DefaultClient\.' \ + -e 'http\.NewRequest\(' \ + "$pkg" --include='*.go' --exclude='*_test.go' 2>/dev/null || true)" + if [ -n "$hits" ]; then + echo "$hits" + fail=1 + fi +done + +echo "==> DB calls without ctx" +for pkg in "${PACKAGES[@]}"; do + # Find .Query/.Exec/.QueryRow( that are not *Context variants. + # Strip comment-only lines. Then exclude the two known-safe lines + # inside internal/store/store.go where ctx is not yet available: + # - migrate()'s schema + migrations execs + # - open()'s PRAGMA exec (before the Store is returned) + hits="$(grep -rnE \ + -e '\.(Query|Exec|QueryRow)\(' \ + "$pkg" --include='*.go' --exclude='*_test.go' 2>/dev/null \ + | grep -vE '\.(Query|Exec|QueryRow)Context\(' \ + | grep -vE '^\S+\.go:[0-9]+:\s*//' \ + | grep -v 'store\.go.*s\.db\.Exec(schema)' \ + | grep -v 'store\.go.*s\.db\.Exec(m)' \ + | grep -vE 'store\.go:[0-9]+:\s*if _, err := db\.Exec\(p\)' \ + || true)" + if [ -n "$hits" ]; then + echo "$hits" + fail=1 + fi +done + +if [ $fail -ne 0 ]; then + echo "" + echo "CTX AUDIT FAILED — see output above." + exit 1 +fi +echo "CTX AUDIT OK"