Skip to content

Commit 52af333

Browse files
committed
review comments
Signed-off-by: Akihiko Kuroda <akihikokuroda2020@gmail.com>
1 parent 5623420 commit 52af333

2 files changed

Lines changed: 77 additions & 51 deletions

File tree

mellea/stdlib/requirements/rag.py

Lines changed: 73 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ class GroundednessRequirement(Requirement):
4040
grounded. If False (default), response is grounded iff all spans
4141
needing citations are FULLY supported. If True, response is grounded
4242
if spans are fully or partially supported.
43+
max_new_tokens: Maximum tokens for LLM judgment outputs. Increase this
44+
if LLM outputs are being truncated (particularly for complex
45+
responses with many spans). Default is 500.
4346
description: Custom description for the requirement. If None,
4447
generates a default description.
4548
@@ -63,10 +66,12 @@ def __init__(
6366
self,
6467
documents: Iterable[Document] | Iterable[str] | None = None,
6568
allow_partial_support: bool = False,
69+
max_new_tokens: int = 500,
6670
description: str | None = None,
6771
):
6872
"""Initialize grounded requirement."""
6973
self.allow_partial_support = allow_partial_support
74+
self.max_new_tokens = max_new_tokens
7075

7176
# Convert documents to Document objects if provided
7277
if documents is not None:
@@ -177,14 +182,11 @@ async def validate(
177182

178183
try:
179184
# Step 1: Citation Generation
180-
# Call intrinsic directly for explicit control over model options
181-
from ..components.intrinsic._util import call_intrinsic
185+
# Import lazily to avoid circular dependency
186+
from ..components.intrinsic.rag import find_citations
182187

183-
citation_context = context_before_response.add(
184-
Message("assistant", response, documents=list(documents))
185-
)
186-
citations: list[dict] = call_intrinsic(
187-
"citations", citation_context, backend
188+
citations: list[dict] = find_citations(
189+
response, list(documents), context_before_response, backend
188190
)
189191
logger.debug(
190192
f"Step 1 - Citations generated: {len(citations)} citations found"
@@ -219,7 +221,12 @@ async def validate(
219221
# Step 3: Citation Support
220222
try:
221223
span_support = await self._assess_citation_support(
222-
response, citations, span_necessity, backend, context_before_response
224+
response,
225+
citations,
226+
span_necessity,
227+
backend,
228+
context_before_response,
229+
documents,
223230
)
224231
logger.debug(
225232
f"Step 3 - Citation support assessed: {len(span_support)} spans"
@@ -273,7 +280,10 @@ async def _identify_citation_necessity(
273280
result, _ = await backend.generate_from_context(
274281
action,
275282
context,
276-
model_options={"temperature": 0.0, "max_new_tokens": 500},
283+
model_options={
284+
"temperature": 0.0,
285+
"max_new_tokens": self.max_new_tokens,
286+
},
277287
)
278288
await result.avalue()
279289
output_text = result.value
@@ -295,6 +305,7 @@ async def _assess_citation_support(
295305
span_necessity: dict[tuple[int, int], bool],
296306
backend: Backend,
297307
context: ChatContext,
308+
documents: list[Document],
298309
) -> dict[tuple[int, int], str]:
299310
"""Assess level of support for spans that need citations.
300311
@@ -307,6 +318,7 @@ async def _assess_citation_support(
307318
span_necessity: Mapping of span (begin, end) to needs_citation flag
308319
backend: Backend for LLM judgment
309320
context: Chat context
321+
documents: List of source documents for context in LLM assessment
310322
311323
Returns:
312324
Dictionary mapping span (begin, end) to support level
@@ -358,7 +370,7 @@ async def _assess_citation_support(
358370
return span_support
359371

360372
# Single batch LLM call for all spans
361-
prompt = self._build_batch_support_prompt(response, spans_to_assess)
373+
prompt = self._build_batch_support_prompt(response, spans_to_assess, documents)
362374
logger.debug(
363375
f"Batch support assessment prompt (spans={len(spans_to_assess)}):\n{prompt}\n"
364376
)
@@ -368,7 +380,10 @@ async def _assess_citation_support(
368380
result, _ = await backend.generate_from_context(
369381
action,
370382
context,
371-
model_options={"temperature": 0.0, "max_new_tokens": 500},
383+
model_options={
384+
"temperature": 0.0,
385+
"max_new_tokens": self.max_new_tokens,
386+
},
372387
)
373388
await result.avalue()
374389
output_text = result.value
@@ -424,7 +439,7 @@ def _extract_response_spans(
424439
covered_ranges.sort()
425440
merged_ranges: list[tuple[int, int]] = []
426441
for begin, end in covered_ranges:
427-
if merged_ranges and begin <= merged_ranges[-1][1]:
442+
if merged_ranges and begin < merged_ranges[-1][1]:
428443
merged_ranges[-1] = (
429444
merged_ranges[-1][0],
430445
max(merged_ranges[-1][1], end),
@@ -437,41 +452,38 @@ def _extract_response_spans(
437452
f"Response span extraction - coverage: {covered_chars}/{len(response)} chars covered by citations"
438453
)
439454

440-
# Check if a position is covered by any citation
441-
def is_covered(pos: int) -> bool:
442-
for begin, end in merged_ranges:
443-
if begin <= pos < end:
444-
return True
445-
if begin > pos:
446-
break
447-
return False
448-
449455
# Extract spans by finding boundaries between covered and uncovered regions
456+
# Iterate over merged_ranges boundaries rather than every character for efficiency
450457
spans: list[dict] = []
451-
current_span_start = 0
452-
current_is_covered = is_covered(0) if response else False
453-
454-
for i in range(1, len(response) + 1):
455-
# Check if we're at a boundary (coverage changed or end of response)
456-
at_end = i == len(response)
457-
next_is_covered = False if at_end else is_covered(i)
458-
at_boundary = at_end or next_is_covered != current_is_covered
459-
460-
if at_boundary:
461-
span_text = response[current_span_start:i].strip()
462-
if span_text: # Only include non-empty spans
463-
spans.append(
464-
{
465-
"begin": current_span_start,
466-
"end": i,
467-
"text": span_text,
468-
"is_cited": current_is_covered,
469-
}
470-
)
471458

472-
current_span_start = i
473-
if not at_end:
474-
current_is_covered = next_is_covered
459+
# Build boundary points from merged ranges
460+
boundaries = [0] # Start of response
461+
for begin, end in merged_ranges:
462+
boundaries.append(begin)
463+
boundaries.append(end)
464+
boundaries.append(len(response)) # End of response
465+
boundaries = sorted(set(boundaries)) # Remove duplicates and sort
466+
467+
# Process each span between boundaries
468+
for i in range(len(boundaries) - 1):
469+
span_start = boundaries[i]
470+
span_end = boundaries[i + 1]
471+
472+
# Determine if this span is covered by any merged range
473+
is_cited = any(
474+
begin <= span_start and span_end <= end for begin, end in merged_ranges
475+
)
476+
477+
span_text = response[span_start:span_end].strip()
478+
if span_text: # Only include non-empty spans
479+
spans.append(
480+
{
481+
"begin": span_start,
482+
"end": span_end,
483+
"text": span_text,
484+
"is_cited": is_cited,
485+
}
486+
)
475487

476488
logger.debug(f"Response span extraction - extracted {len(spans)} spans")
477489
for span in spans:
@@ -518,7 +530,7 @@ def _build_necessity_prompt(self, response: str, spans: list[dict]) -> str:
518530
return prompt
519531

520532
def _build_batch_support_prompt(
521-
self, response: str, spans_to_assess: list[dict]
533+
self, response: str, spans_to_assess: list[dict], documents: list[Document]
522534
) -> str:
523535
"""Build prompt to assess citation support level for multiple spans at once.
524536
@@ -530,6 +542,7 @@ def _build_batch_support_prompt(
530542
spans_to_assess: List of span dicts with keys:
531543
- text: span text
532544
- citations: list of citation records for this span
545+
documents: List of source documents for context
533546
534547
Returns:
535548
Formatted prompt for LLM expecting JSON array output
@@ -561,13 +574,26 @@ def _build_batch_support_prompt(
561574

562575
spans_formatted = ",\n".join(span_assessments)
563576

577+
# Build source documents section for context
578+
documents_section = ""
579+
if documents:
580+
doc_lines = []
581+
for doc in documents:
582+
doc_id = doc.doc_id if hasattr(doc, "doc_id") else "unknown"
583+
doc_text = doc.text if hasattr(doc, "text") else str(doc)
584+
doc_lines.append(f"Document {doc_id}:\n{doc_text}")
585+
documents_section = "Source Documents:\n" + "\n\n".join(doc_lines) + "\n\n"
586+
564587
prompt = (
565-
"Assess the level of support for each response span based on provided citations.\n\n"
588+
"Assess the level of support for each response span based on provided citations "
589+
"and source documents.\n\n"
566590
"For each span, determine if the citations fully support, partially support, "
567-
"or do not support the span.\n\n"
591+
"or do not support the span. Consider the full context from the source documents "
592+
"where the citations appear.\n\n"
568593
"Respond with a JSON array of the form:\n"
569594
'[{"span_id": ..., "support_level": ...}, ...]\n\n'
570595
"Support levels must be ONLY one of: FULLY_SUPPORTED, PARTIALLY_SUPPORTED, or NOT_SUPPORTED.\n\n"
596+
f"{documents_section}"
571597
f"Response context:\n{response}\n\n"
572598
f"Spans to assess:\n[\n{spans_formatted}\n]\n\n"
573599
"JSON Output:\n"

test/stdlib/requirements/test_groundedness_requirement.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ async def test_identify_citation_necessity_none_output():
317317

318318

319319
@pytest.mark.asyncio
320-
async def test_assess_citation_support_overlap_edge_case():
320+
async def test_assess_citation_support_overlap_edge_case(sample_docs):
321321
"""Test citation support assessment with edge cases in span-citation overlap.
322322
323323
This tests the scenario where a span is marked as not having citations,
@@ -355,7 +355,7 @@ async def test_assess_citation_support_overlap_edge_case():
355355
context = ChatContext().add(Message("user", "Test question"))
356356

357357
span_support = await req._assess_citation_support(
358-
response, citations, span_necessity, mock_backend, context
358+
response, citations, span_necessity, mock_backend, context, sample_docs
359359
)
360360

361361
# Should attempt to assess support even though span isn't covered by citations
@@ -404,7 +404,7 @@ async def test_identify_citation_necessity_prompt_as_action():
404404
assert messages[0].content == "Original context message"
405405

406406

407-
def test_build_batch_support_prompt():
407+
def test_build_batch_support_prompt(sample_docs):
408408
"""Test building the batch support prompt for multiple spans."""
409409
req = GroundednessRequirement()
410410

@@ -430,7 +430,7 @@ def test_build_batch_support_prompt():
430430
},
431431
]
432432

433-
prompt = req._build_batch_support_prompt(response, spans_to_assess)
433+
prompt = req._build_batch_support_prompt(response, spans_to_assess, sample_docs)
434434

435435
# Verify prompt structure
436436
assert "JSON array" in prompt or "json" in prompt.lower()

0 commit comments

Comments
 (0)