Skip to content

Commit c2ae257

Browse files
aksOpsclaude
andcommitted
feat: introduce GraphBackend protocol and refactor GraphStore as facade
Extract NetworkXBackend from GraphStore, define GraphBackend and CypherBackend protocols, fix 3 NetworkX leaks in query.py, views.py, layer_classifier.py. Zero behavioral changes — all 361 tests pass. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 2e34d7c commit c2ae257

7 files changed

Lines changed: 273 additions & 124 deletions

File tree

src/code_intelligence/classifiers/layer_classifier.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import re
66
from typing import Any, Sequence
77

8-
from code_intelligence.models.graph import GraphNode, NodeKind, SourceLocation
8+
from code_intelligence.models.graph import GraphNode, NodeKind
99

1010
_FRONTEND_NODE_KINDS = {NodeKind.COMPONENT, NodeKind.HOOK}
1111
_BACKEND_NODE_KINDS = {NodeKind.GUARD, NodeKind.MIDDLEWARE, NodeKind.ENDPOINT, NodeKind.REPOSITORY, NodeKind.DATABASE_CONNECTION, NodeKind.QUERY}
@@ -35,25 +35,10 @@ def classify(self, nodes: Sequence[GraphNode]) -> None:
3535
node.properties["layer"] = self._classify_one(node)
3636

3737
def classify_store(self, store: Any) -> None:
38-
"""Classify nodes directly in a GraphStore, updating networkx data in-place."""
39-
for node_id, data in store.graph.nodes(data=True):
40-
if "kind" not in data:
41-
continue
42-
# Build a lightweight GraphNode for classification
43-
loc = data.get("location")
44-
if isinstance(loc, dict):
45-
loc = SourceLocation(**loc)
46-
node = GraphNode(
47-
id=data.get("id", node_id),
48-
kind=NodeKind(data["kind"]),
49-
label=data.get("label", ""),
50-
location=loc,
51-
properties=data.get("properties", {}),
52-
)
38+
"""Classify nodes in a GraphStore, updating properties via public API."""
39+
for node in store.all_nodes():
5340
layer = self._classify_one(node)
54-
# Update the networkx data dict directly
55-
props = data.setdefault("properties", {})
56-
props["layer"] = layer
41+
store.update_node_properties(node.id, {"layer": layer})
5742

5843
def _classify_one(self, node: GraphNode) -> str:
5944
if node.kind in _FRONTEND_NODE_KINDS:
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""Protocol definitions for graph storage backends."""
2+
3+
from __future__ import annotations
4+
5+
from typing import Any, Protocol, runtime_checkable
6+
7+
from code_intelligence.models.graph import (
8+
EdgeKind,
9+
GraphEdge,
10+
GraphNode,
11+
NodeKind,
12+
)
13+
14+
15+
@runtime_checkable
16+
class GraphBackend(Protocol):
17+
"""Contract that every graph storage backend must satisfy."""
18+
19+
def add_node(self, node: GraphNode) -> None: ...
20+
def add_edge(self, edge: GraphEdge) -> None: ...
21+
def clear(self) -> None: ...
22+
def get_node(self, node_id: str) -> GraphNode | None: ...
23+
def has_node(self, node_id: str) -> bool: ...
24+
def get_edges_between(self, source: str, target: str) -> list[GraphEdge]: ...
25+
def all_nodes(self) -> list[GraphNode]: ...
26+
def all_edges(self) -> list[GraphEdge]: ...
27+
def nodes_by_kind(self, kind: NodeKind) -> list[GraphNode]: ...
28+
def edges_by_kind(self, kind: EdgeKind) -> list[GraphEdge]: ...
29+
30+
@property
31+
def node_count(self) -> int: ...
32+
@property
33+
def edge_count(self) -> int: ...
34+
35+
def neighbors(
36+
self, node_id: str,
37+
edge_kinds: set[EdgeKind] | None = None,
38+
direction: str = "both",
39+
) -> list[str]: ...
40+
41+
def find_cycles(self, limit: int = 100) -> list[list[str]]: ...
42+
def shortest_path(self, source: str, target: str) -> list[str] | None: ...
43+
def subgraph(self, node_ids: set[str]) -> GraphBackend: ...
44+
def update_node_properties(self, node_id: str, properties: dict[str, Any]) -> None: ...
45+
def close(self) -> None: ...
46+
47+
48+
@runtime_checkable
49+
class CypherBackend(Protocol):
50+
"""Optional capability for backends supporting Cypher queries."""
51+
52+
def query_cypher(self, cypher: str, params: dict[str, Any] | None = None) -> list[dict[str, Any]]: ...
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""Graph backend factory."""
2+
3+
from __future__ import annotations
4+
5+
from typing import TYPE_CHECKING
6+
7+
if TYPE_CHECKING:
8+
from code_intelligence.graph.backend import GraphBackend
9+
10+
11+
def create_backend(backend_name: str = "networkx", **kwargs) -> GraphBackend:
12+
"""Create a graph backend by name."""
13+
if backend_name == "networkx":
14+
from code_intelligence.graph.backends.networkx import NetworkXBackend
15+
return NetworkXBackend()
16+
elif backend_name == "kuzu":
17+
from code_intelligence.graph.backends.kuzu import KuzuBackend
18+
return KuzuBackend(db_path=kwargs.get("path", ".code-intelligence/graph.kuzu"))
19+
elif backend_name == "sqlite":
20+
from code_intelligence.graph.backends.sqlite_backend import SqliteGraphBackend
21+
return SqliteGraphBackend(db_path=kwargs.get("path", ".code-intelligence/graph.db"))
22+
else:
23+
raise ValueError(f"Unknown graph backend: {backend_name}")
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
"""NetworkX-backed graph backend."""
2+
3+
from __future__ import annotations
4+
5+
import logging
6+
from typing import Any
7+
8+
import networkx as nx
9+
10+
from code_intelligence.models.graph import (
11+
EdgeKind,
12+
GraphEdge,
13+
GraphNode,
14+
NodeKind,
15+
)
16+
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class NetworkXBackend:
21+
"""In-memory graph backend using NetworkX MultiDiGraph."""
22+
23+
def __init__(self) -> None:
24+
self._g: nx.MultiDiGraph = nx.MultiDiGraph()
25+
26+
@property
27+
def node_count(self) -> int:
28+
return self._g.number_of_nodes()
29+
30+
@property
31+
def edge_count(self) -> int:
32+
return self._g.number_of_edges()
33+
34+
def add_node(self, node: GraphNode) -> None:
35+
if node.id in self._g:
36+
logger.debug("Duplicate node ID %s, keeping first", node.id)
37+
return
38+
self._g.add_node(node.id, **node.model_dump())
39+
40+
def add_edge(self, edge: GraphEdge) -> None:
41+
self._g.add_edge(edge.source, edge.target, **edge.model_dump())
42+
43+
def clear(self) -> None:
44+
self._g.clear()
45+
46+
def get_node(self, node_id: str) -> GraphNode | None:
47+
if node_id not in self._g:
48+
return None
49+
return GraphNode(**self._g.nodes[node_id])
50+
51+
def has_node(self, node_id: str) -> bool:
52+
return node_id in self._g
53+
54+
def get_edges_between(self, source: str, target: str) -> list[GraphEdge]:
55+
if not self._g.has_edge(source, target):
56+
return []
57+
return [GraphEdge(**data) for _key, data in self._g[source][target].items()]
58+
59+
def all_nodes(self) -> list[GraphNode]:
60+
return [
61+
GraphNode(**data)
62+
for _, data in self._g.nodes(data=True)
63+
if "id" in data and "kind" in data
64+
]
65+
66+
def all_edges(self) -> list[GraphEdge]:
67+
return [
68+
GraphEdge(**data)
69+
for _, _, data in self._g.edges(data=True)
70+
if "source" in data and "target" in data
71+
]
72+
73+
def nodes_by_kind(self, kind: NodeKind) -> list[GraphNode]:
74+
return [
75+
GraphNode(**data)
76+
for _, data in self._g.nodes(data=True)
77+
if data.get("kind") == kind.value and "id" in data
78+
]
79+
80+
def edges_by_kind(self, kind: EdgeKind) -> list[GraphEdge]:
81+
return [
82+
GraphEdge(**data)
83+
for _, _, data in self._g.edges(data=True)
84+
if data.get("kind") == kind.value and "source" in data
85+
]
86+
87+
def neighbors(self, node_id: str, edge_kinds: set[EdgeKind] | None = None, direction: str = "both") -> list[str]:
88+
result: set[str] = set()
89+
if direction in ("out", "both"):
90+
for _, target, data in self._g.out_edges(node_id, data=True):
91+
if edge_kinds is None or EdgeKind(data.get("kind", "")) in edge_kinds:
92+
result.add(target)
93+
if direction in ("in", "both"):
94+
for source, _, data in self._g.in_edges(node_id, data=True):
95+
if edge_kinds is None or EdgeKind(data.get("kind", "")) in edge_kinds:
96+
result.add(source)
97+
return sorted(result)
98+
99+
def find_cycles(self, limit: int = 100) -> list[list[str]]:
100+
cycles: list[list[str]] = []
101+
for cycle in nx.simple_cycles(self._g):
102+
cycles.append(cycle)
103+
if len(cycles) >= limit:
104+
break
105+
return cycles
106+
107+
def shortest_path(self, source: str, target: str) -> list[str] | None:
108+
try:
109+
return nx.shortest_path(self._g, source, target)
110+
except nx.NetworkXNoPath:
111+
return None
112+
113+
def subgraph(self, node_ids: set[str]) -> NetworkXBackend:
114+
new_backend = NetworkXBackend()
115+
sub = self._g.subgraph(node_ids)
116+
new_backend._g = nx.MultiDiGraph(sub)
117+
return new_backend
118+
119+
def update_node_properties(self, node_id: str, properties: dict[str, Any]) -> None:
120+
if node_id in self._g:
121+
data = self._g.nodes[node_id]
122+
props = data.get("properties", {})
123+
props.update(properties)
124+
data["properties"] = props
125+
126+
def close(self) -> None:
127+
pass # In-memory, nothing to close

src/code_intelligence/graph/query.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,11 +162,11 @@ def execute(self) -> GraphStore:
162162
if spec.direction != "both":
163163
# For directional focus, do a BFS in the specified direction
164164
focused_ids.update(
165-
nid for nid, _ in ego_store.graph.nodes(data=True)
165+
n.id for n in ego_store.all_nodes()
166166
)
167167
else:
168168
focused_ids.update(
169-
nid for nid, _ in ego_store.graph.nodes(data=True)
169+
n.id for n in ego_store.all_nodes()
170170
)
171171
store = store.subgraph(focused_ids)
172172

0 commit comments

Comments
 (0)