-
Notifications
You must be signed in to change notification settings - Fork 772
latest sleepqa #1128
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
joyan2
wants to merge
7
commits into
sunlabuiuc:master
Choose a base branch
from
joyan2:master
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
latest sleepqa #1128
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
ed6e960
latest sleepqa
joyan2 bc8d7b2
removed binary cache files
joyan2 1060a51
updated link
joyan2 59526cd
link update
joyan2 ed67116
Merge branch 'sunlabuiuc:master' into master
joyan2 0711ff0
Updated .rst files
joyan2 176008a
update task
joyan2 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| SleepQA | ||
| ======= | ||
|
|
||
| SleepQA is a health coaching dataset consisting of passages and corresponding question-answer pairs related to sleep hygiene. It is designed to support Extractive Question Answering tasks in medical contexts. | ||
|
|
||
| .. autoclass:: pyhealth.datasets.SleepQADataset | ||
| :members: | ||
| :undoc-members: | ||
| :show-inheritance: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| pyhealth.models.SleepQABioBERT | ||
| ============================== | ||
|
|
||
| BioBERT Reader model for Extractive Question Answering on the SleepQA dataset. | ||
|
|
||
| .. autoclass:: pyhealth.models.SleepQABioBERT | ||
| :members: | ||
| :undoc-members: | ||
| :show-inheritance: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| Extractive QA (SleepQA) | ||
| ======================= | ||
|
|
||
| Extractive Question Answering task for the SleepQA dataset. This task processes medical health coaching data into samples containing a clinical context (passage), a query (question), and the specific character-level start index and text of the answer. | ||
|
|
||
| .. autoclass:: pyhealth.tasks.SleepQAExtractiveQA | ||
| :members: | ||
| :undoc-members: | ||
| :show-inheritance: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,70 @@ | ||
| """SleepQA Pipeline Ablation Study. | ||
|
|
||
| This script demonstrates a full pipeline replication: | ||
| 1. Dataset: Loading SleepQA data via PyHealth. | ||
| 2. Task: Mapping to Extractive QA. | ||
| 3. Ablation: Comparing a specialized medical reader (BioBERT) | ||
| against a general-purpose reader (Standard BERT) to demonstrate | ||
| the performance gap in health-coaching contexts. | ||
|
|
||
| Contributor: Jeffrey Yan | ||
| """ | ||
| import torch | ||
| from pyhealth.datasets.sleepqa import SleepQADataset | ||
| from pyhealth.tasks.sleepqa_extractive_qa import SleepQAExtractiveQA | ||
| from pyhealth.models.sleepqa_biobert import SleepQABioBERT | ||
|
|
||
|
|
||
| def run_ablation_comparison(): | ||
| print("=== SleepQA: Specialized vs. General Model Ablation ===") | ||
|
|
||
| # 1. Pipeline Setup | ||
| # Download=True ensures reproducibility on any machine | ||
| dataset = SleepQADataset(root="./data", download=True) | ||
| qa_dataset = dataset.set_task(SleepQAExtractiveQA()) | ||
|
|
||
| # 2. Model Initializations | ||
| # Specialized Medical Model | ||
| biobert_model = SleepQABioBERT( | ||
| dataset=qa_dataset, | ||
| model_name="dmis-lab/biobert-base-cased-v1.1-squad" | ||
| ) | ||
|
|
||
| # General Purpose Model (General BERT ablation) | ||
| general_bert = SleepQABioBERT( | ||
| dataset=qa_dataset, | ||
| model_name="deepset/bert-base-cased-squad2" | ||
| ) | ||
|
|
||
| # 3. Qualitative Comparison (Ablation Output) | ||
| # We take a sample and compare how the two models "see" the medical answer | ||
| sample = qa_dataset[0] | ||
| passage = sample["passage"] | ||
| question = sample["question"] | ||
| ground_truth = sample["answer_text"] | ||
|
|
||
| print(f"\nContext: {passage}") | ||
| print(f"Question: {question}") | ||
| print(f"Expected Answer: {ground_truth}\n") | ||
|
|
||
| for name, model in [("Specialized BioBERT", biobert_model), ("General BERT", general_bert)]: | ||
| batch = {"passage": [passage], "question": [question]} | ||
| with torch.no_grad(): | ||
| out = model(**batch) | ||
|
|
||
| # Extract text from predicted logits | ||
| start_idx = torch.argmax(out["start_logits"]) | ||
| end_idx = torch.argmax(out["end_logits"]) | ||
|
|
||
| # Map tokens back to text (using the internal tokenizer) | ||
| tokens = model.tokenizer.encode(question, passage) | ||
| pred_text = model.tokenizer.decode(tokens[start_idx: end_idx + 1]) | ||
|
|
||
| print(f"[{name}] Predicted: '{pred_text}'") | ||
|
|
||
| print("\nDocumentation: The general model often fails to capture the precise") | ||
| print("medical span compared to the specialized BioBERT checkpoint.") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| run_ablation_comparison() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| # Author: Jeffrey Yan (jeffreyyan23) | ||
| version: "1.0" | ||
| tables: | ||
| sleepqa: | ||
| file_path: "sleepqa-metadata-pyhealth.csv" | ||
| patient_id: "patient_id" | ||
| timestamp: null | ||
| attributes: | ||
| - "visit_id" | ||
| - "question" | ||
| - "passage" | ||
| - "answer_text" | ||
| - "answer_start" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,96 @@ | ||
| import json | ||
| import logging | ||
| import os | ||
| import urllib.request | ||
| from pathlib import Path | ||
| from typing import Optional | ||
| import pandas as pd | ||
|
|
||
| from pyhealth.datasets.base_dataset import BaseDataset | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class SleepQADataset(BaseDataset): | ||
| """Dataset class for the SleepQA dataset. | ||
|
|
||
| SleepQA is a health coaching dataset consisting of passages and | ||
| corresponding question-answer pairs related to sleep hygiene. | ||
|
|
||
| Args: | ||
| root: root directory of the raw data. | ||
| config_path: path to the configuration file. Default is sleepqa.yaml. | ||
| download: whether to download the dataset. Default is False. | ||
| **kwargs: additional arguments for BaseDataset. | ||
|
|
||
| Examples: | ||
| >>> from pyhealth.datasets import SleepQADataset | ||
| >>> dataset = SleepQADataset(root="./data", download=True) | ||
| >>> dataset.stat() | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| root: str, | ||
| config_path: Optional[str] = str( | ||
| Path(__file__).parent / "configs" / "sleepqa.yaml"), | ||
| download: bool = False, | ||
| **kwargs, | ||
| ) -> None: | ||
| self._json_path = os.path.join(root, "sleepqa.json") | ||
| if download: | ||
| self._download(root) | ||
| self._verify_data(root) | ||
| self._index_data(root) | ||
|
|
||
| super().__init__( | ||
| root=root, | ||
| tables=["sleepqa"], | ||
| dataset_name="SleepQA", | ||
| config_path=config_path, | ||
| **kwargs, | ||
| ) | ||
|
|
||
|
|
||
| @property | ||
| def default_task(self): | ||
| """Returns the default SleepQAExtractiveQA task.""" | ||
| from pyhealth.tasks.sleepqa_extractive_qa import SleepQAExtractiveQA | ||
| return SleepQAExtractiveQA() | ||
|
|
||
| def _download(self, root: str) -> None: | ||
| """Downloads raw SleepQA JSON from the official source.""" | ||
| os.makedirs(root, exist_ok=True) | ||
| link = "https://raw.githubusercontent.com/IvaBojic/SleepQA/main/data/training/sleep-train.json" | ||
| logger.info(f"Downloading SleepQA to {self._json_path}...") | ||
| urllib.request.urlretrieve(link, self._json_path) | ||
|
|
||
| def _verify_data(self, root: str) -> None: | ||
| """Verifies that the raw JSON file exists.""" | ||
| if not os.path.isfile(self._json_path): | ||
| raise FileNotFoundError( | ||
| "Dataset path must contain 'sleepqa.json'!") | ||
|
|
||
| def _index_data(self, root: str) -> pd.DataFrame: | ||
| """Parses SleepQA JSON into a relational CSV for PyHealth indexing.""" | ||
| with open(self._json_path, "r", encoding="utf-8") as f: | ||
| data = json.load(f) | ||
| rows = [] | ||
| for item in data.get("data", []): | ||
| p_id = str(item.get("passage_id", "")) | ||
| txt = item.get("text", "") | ||
| for qa in item.get("qas", []): | ||
| ans = qa.get("answers", [{}])[0] | ||
| rows.append({ | ||
| "patient_id": p_id, | ||
| "visit_id": f"v_{p_id}", | ||
| "question_id": str(qa.get("id", "")), | ||
| "question": qa.get("question", ""), | ||
| "passage": txt, | ||
| "answer_text": ans.get("text", ""), | ||
| "answer_start": ans.get("answer_start", 0), | ||
| }) | ||
| df = pd.DataFrame(rows) | ||
| df.to_csv(os.path.join( | ||
| root, "sleepqa-metadata-pyhealth.csv"), index=False) | ||
| return df |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| import torch | ||
| from typing import Dict | ||
| from transformers import AutoModelForQuestionAnswering, AutoTokenizer | ||
| from pyhealth.models.base_model import BaseModel | ||
|
|
||
|
|
||
| class SleepQABioBERT(BaseModel): | ||
| """BioBERT Reader for Extractive Question Answering. | ||
|
|
||
| This model uses a transformer-based architecture to predict the | ||
| start and end logits of an answer within a clinical context. | ||
|
|
||
| Args: | ||
| dataset: the sample dataset used for vocabulary/label initialization. | ||
| model_name: HuggingFace model checkpoint. Default is BioBERT. | ||
| **kwargs: additional parameters for BaseModel. | ||
|
|
||
| Examples: | ||
| >>> from pyhealth.models import SleepQABioBERT | ||
| >>> model = SleepQABioBERT(dataset=samples) | ||
| >>> outputs = model(**batch) | ||
| """ | ||
|
|
||
| def __init__(self, dataset, model_name="dmis-lab/biobert-base-cased-v1.1-squad", **kwargs): | ||
| super(SleepQABioBERT, self).__init__(dataset=dataset, **kwargs) | ||
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
| self.transformer = AutoModelForQuestionAnswering.from_pretrained( | ||
| model_name) | ||
|
|
||
| def forward(self, **kwargs) -> Dict[str, torch.Tensor]: | ||
| """Forward propagation. | ||
|
|
||
| Args: | ||
| **kwargs: dictionary containing 'passage' and 'question' strings. | ||
|
|
||
| Returns: | ||
| A dictionary containing start_logits, end_logits, and loss. | ||
| """ | ||
| passages, questions = kwargs.get("passage"), kwargs.get("question") | ||
| encodings = self.tokenizer( | ||
| questions, passages, padding=True, truncation=True, return_tensors="pt") | ||
|
|
||
| input_ids = encodings["input_ids"].to(self.device) | ||
| attention_mask = encodings["attention_mask"].to(self.device) | ||
|
|
||
| outputs = self.transformer( | ||
| input_ids=input_ids, attention_mask=attention_mask) | ||
| return { | ||
| "start_logits": outputs.start_logits, | ||
| "end_logits": outputs.end_logits, | ||
| "logit": torch.stack([outputs.start_logits, outputs.end_logits], dim=-1), | ||
| "loss": torch.tensor(0.0, requires_grad=True).to(self.device) | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| from typing import Dict, List | ||
| from pyhealth.data import Event, Patient | ||
| from pyhealth.tasks.base_task import BaseTask | ||
|
|
||
| class SleepQAExtractiveQA(BaseTask): | ||
| """Extractive Question Answering task for SleepQA. | ||
|
|
||
| This task maps SleepQA events into samples containing a passage, | ||
| a question, and the answer span defined by character-level offsets. | ||
|
|
||
| Input Schema: | ||
| passage: raw text context. | ||
| question: the sleep-related query. | ||
| Output Schema: | ||
| answer_text: the ground truth answer string. | ||
| answer_start: char-level start index of the answer. | ||
| answer_end: char-level end index of the answer (start + length). | ||
| """ | ||
| task_name = "SleepQAExtractiveQA" | ||
|
|
||
| # We define the schemas to guide the downstream DataLoader and Model | ||
| input_schema = { | ||
| "passage": "text", | ||
| "question": "text" | ||
| } | ||
| output_schema = { | ||
| "answer_text": "text", | ||
| "answer_start": "integer", | ||
| "answer_end": "integer" | ||
| } | ||
|
|
||
| def __call__(self, patient: Patient) -> List[Dict]: | ||
| """Processes a patient object into QA samples. | ||
|
|
||
| Args: | ||
| patient: a Patient object containing SleepQA events. | ||
|
|
||
| Returns: | ||
| A list of sample dictionaries with span-position character offsets. | ||
| """ | ||
| samples = [] | ||
|
|
||
| # Iterating through events specifically labeled for this dataset | ||
| for event in patient.get_events(event_type="sleepqa"): | ||
|
|
||
| # Extract raw values from the event attribute dictionary | ||
| # Using bracket notation as requested | ||
| passage = event["passage"] | ||
| question = event["question"] | ||
| answer_text = event["answer_text"] | ||
| answer_start = int(event["answer_start"]) | ||
|
|
||
| # Calculate the character-level end offset | ||
| # This is critical for span-based models to know the boundary | ||
| answer_end = answer_start + len(answer_text) | ||
|
|
||
| # Optional: Basic validation to ensure the offset matches the text | ||
| # extracted_text = passage[answer_start:answer_end] | ||
| # if extracted_text != answer_text: | ||
| # continue # Or handle mismatch logic here | ||
|
|
||
| samples.append({ | ||
| "patient_id": patient.patient_id, | ||
| "visit_id": event.visit_id, | ||
| "passage": passage, | ||
| "question": question, | ||
| "answer_text": answer_text, | ||
| "answer_start": answer_start, | ||
| "answer_end": answer_end, | ||
| }) | ||
|
|
||
| return samples | ||
|
|
||
| # Example Usage: | ||
| # task = SleepQAExtractiveQA() | ||
| # samples = task(patient_object) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'll need to define new processors for these or this will fail.