-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate.py
More file actions
83 lines (70 loc) · 3.63 KB
/
generate.py
File metadata and controls
83 lines (70 loc) · 3.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# generate.py
import torch
from config import config
from data.data_loader import Vocabulary # Assuming vocab is saved and loaded
from models.as2sp import AS2SP
from models.trce import TRCEModel
from models.petr import PETRModel
import amrlib
def load_model_for_generation(model_path, vocab):
"""Loads a trained model from a file."""
device = torch.device(config.DEVICE)
# This is a simplified loading function.
# In a real scenario, you'd save model state_dict and hyperparameters.
if config.MODEL_TYPE == 'AS2SP' or config.MODEL_TYPE == 'RL':
model = AS2SP(len(vocab.word2idx), config.EMBED_SIZE, config.HIDDEN_SIZE, config.ENC_HIDDEN_SIZE, config.DROPOUT)
elif config.MODEL_TYPE == 'TRCE':
model = TRCEModel(len(vocab.word2idx), config.TRANSFORMER_D_MODEL, config.TRANSFORMER_NHEAD,
config.TRANSFORMER_ENCODER_LAYERS, config.TRANSFORMER_DECODER_LAYERS,
config.TRANSFORMER_DIM_FEEDFORWARD, config.DROPOUT)
elif config.MODEL_TYPE == 'PETR':
model = PETRModel(len(vocab.word2idx), config.TRANSFORMER_D_MODEL, config.TRANSFORMER_NHEAD,
config.TRANSFORMER_ENCODER_LAYERS, config.TRANSFORMER_DECODER_LAYERS,
config.TRANSFORMER_DIM_FEEDFORWARD, config.DROPOUT)
else:
raise ValueError(f"Unknown model type: {config.MODEL_TYPE}")
model.load_state_dict(torch.load(model_path))
model.to(device)
model.eval()
return model
def generate_summary(model, text, vocab):
"""Generates a summary for a single input text."""
device = torch.device(config.DEVICE)
# 1. Parse text to AMR graph
stog = amrlib.load_stog_model(device='cpu')
graph_string = stog.parse_sents([text])[0]
# 2. Transform graph (should match training)
# This logic should be imported from data_loader
from data.data_loader import _transform_graph
transformed_graph = _transform_graph(graph_string, config.GRAPH_TRANSFORMATION)
# 3. Generate with the model
summary_ids = []
with torch.no_grad():
# Generation logic is highly model-specific and simplified here
# This is a placeholder for a proper beam search decoder
decoder_input = torch.tensor([[vocab.word2idx["<sos>"]]], device=device)
for _ in range(config.MAX_SUMMARY_LEN):
# This is a dummy generation loop and will not work as-is
# It needs the full encoder-decoder pass
# For now, we just demonstrate the pipeline
# A real implementation of model.generate() would go here
# For example:
# output = model(transformed_graph, decoder_input)
# next_token_id = output.argmax(-1)[:, -1].item()
# Dummy next token
next_token_id = vocab.word2idx["<unk>"]
if next_token_id == vocab.word2idx["<eos>"]:
break
summary_ids.append(next_token_id)
decoder_input = torch.cat([decoder_input, torch.tensor([[next_token_id]], device=device)], dim=1)
summary_words = [vocab.idx2word.get(idx, "<unk>") for idx in summary_ids]
return " ".join(summary_words)
if __name__ == '__main__':
# Example Usage (requires a trained model and saved vocab)
# vocab = ... # load vocab
# model = load_model_for_generation('path/to/your/model.pt', vocab)
# text = "The boy wants to go to the store."
# summary = generate_summary(model, text, vocab)
# print(f"Input Text: {text}")
# print(f"Generated Summary: {summary}")
print("Generation script is ready. A full implementation requires saved models and vocab.")