Skip to content

Commit 435e235

Browse files
authored
Merge pull request #9 from tc-mateus/fix/request-pipeline
fix: change prediction to pipeline service
2 parents 78ce3c0 + 91c88ac commit 435e235

File tree

7 files changed

+127
-24
lines changed

7 files changed

+127
-24
lines changed

dhl_sdk/_constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
DATASETS_URL = "api/db/v2/datasets"
1616
MODELS_URL = "api/db/v2/pipelineJobs"
1717
TEMPLATES_URL = "api/db/v2/pipelineJobTemplates"
18-
PREDICT_URL = "api/pipeline/v1/predictors"
18+
PREDICT_URL = "api/pipeline/v1/pipeline"
1919

2020

2121
PROCESS_UNIT_MAP = {

dhl_sdk/_input_processing.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
_validate_spectra_format,
2121
)
2222
from dhl_sdk._utils import (
23-
PredictionRequest,
23+
Metadata,
24+
PipelineStage,
25+
PredictionPipelineRequest,
2426
PredictionRequestConfig,
2527
Predictions,
2628
PredictionResponse,
@@ -64,6 +66,8 @@ def variables(self) -> list[Variable]:
6466

6567

6668
class Model(Protocol):
69+
id: str
70+
6771
@property
6872
def dataset(self) -> Dataset:
6973
...
@@ -313,9 +317,21 @@ def format(self) -> list[dict]:
313317
else:
314318
instances[0].append(None)
315319

316-
json_data = PredictionRequest(
317-
instances=instances, config=self.prediction_config
318-
).model_dump(by_alias=True, exclude_none=True, exclude=["sampleId", "steps"])
320+
json_data = PredictionPipelineRequest(
321+
instances=instances,
322+
metadata=Metadata(
323+
variables=[{"id": var.id} for var in input_variables],
324+
),
325+
stages=[PipelineStage(config=self.prediction_config, id=self.model.id)],
326+
).model_dump(
327+
by_alias=True,
328+
exclude_none=True,
329+
include={
330+
"instances": {"__all__": {"__all__": {"timestamps", "values"}}},
331+
"metadata": True,
332+
"stages": True,
333+
},
334+
)
319335

320336
return [json_data]
321337

@@ -413,9 +429,23 @@ def format(self) -> list[dict]:
413429
else:
414430
instances[0].append(None)
415431

416-
json_data = PredictionRequest(
417-
instances=instances, config=self.prediction_config
418-
).model_dump(by_alias=True, exclude_none=True, exclude="sampleId")
432+
json_data = PredictionPipelineRequest(
433+
instances=instances,
434+
metadata=Metadata(
435+
variables=[{"id": var.id} for var in input_variables],
436+
),
437+
stages=[PipelineStage(config=self.prediction_config, id=self.model.id)],
438+
).model_dump(
439+
by_alias=True,
440+
exclude_none=True,
441+
include={
442+
"instances": {
443+
"__all__": {"__all__": {"timestamps", "values", "steps"}}
444+
},
445+
"metadata": True,
446+
"stages": True,
447+
},
448+
)
419449

420450
return [json_data]
421451

dhl_sdk/_spectra_utils.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@
66

77
import numpy as np
88

9-
from dhl_sdk._utils import Instance, PredictionRequest
9+
from dhl_sdk._utils import (
10+
Instance,
11+
Metadata,
12+
PipelineStage,
13+
PredictionPipelineRequest,
14+
SpectraPredictionConfig,
15+
)
1016
from dhl_sdk.exceptions import InvalidSpectraException
1117

1218
# Type Aliases
@@ -23,7 +29,7 @@ class Dataset(Protocol):
2329
def variables(self) -> list:
2430
...
2531

26-
def get_spectrum_index(self) -> int:
32+
def get_spectra_index(self) -> int:
2733
...
2834

2935

@@ -106,7 +112,7 @@ def _convert_to_request(
106112
# get number of vars in model from config
107113
variables = model.dataset.variables
108114
n_vars = len(variables)
109-
spectrum_index = model.dataset.get_spectrum_index()
115+
spectrum_index = model.dataset.get_spectra_index()
110116

111117
request_data = []
112118
# handle pagination
@@ -123,7 +129,22 @@ def _convert_to_request(
123129
)
124130
break
125131

126-
json_data = PredictionRequest(instances=[instance]).model_dump(by_alias=True)
132+
json_data = PredictionPipelineRequest(
133+
instances=[instance],
134+
metadata=Metadata(
135+
variables=[{"id": var.id} for var in model.dataset.variables],
136+
),
137+
stages=[PipelineStage(config=SpectraPredictionConfig(), id=model.id)],
138+
).model_dump(
139+
by_alias=True,
140+
exclude_none=True,
141+
include={
142+
"instances": True,
143+
"metadata": True,
144+
"stages": True,
145+
},
146+
)
147+
127148
request_data.append(json_data)
128149

129150
return request_data

dhl_sdk/_utils.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import urllib.parse as urlparse
55
from datetime import datetime
66
from functools import reduce
7-
from typing import Optional, Union
7+
from typing import Literal, Optional, Union
88

99
import numpy as np
1010
from pydantic import BaseModel, Field, model_validator
@@ -91,12 +91,50 @@ def new(
9191
)
9292

9393

94+
class SpectraPredictionConfig(BaseModel):
95+
"""Pydantic class representing Spectra Prediction Config"""
96+
97+
prediction_mode: Literal["classic", "onlySpectra"] = Field(
98+
default="classic", alias="predictionMode"
99+
)
100+
101+
102+
class OnlyId(BaseModel):
103+
"""Pydantic class representing a sctuc with only the id"""
104+
105+
id: str
106+
107+
94108
class PredictionRequest(BaseModel):
95109
"""Pydantic class representing the expected Predict Request"""
96110

97111
instances: list[list[Optional[Instance]]]
98112
metadata: Optional[dict] = None
99-
config: Optional[PredictionRequestConfig] = None
113+
config: Optional[Union[PredictionRequestConfig, SpectraPredictionConfig]] = None
114+
115+
116+
class Metadata(BaseModel):
117+
"""Pydantic class representing Metadata for Predict Request"""
118+
119+
experiments: list[Optional[OnlyId]] = [None]
120+
variables: list[OnlyId]
121+
122+
123+
class PipelineStage(BaseModel):
124+
"""Pydantic class representing the Prediction Pipeline Stage"""
125+
126+
config: Union[PredictionRequestConfig, SpectraPredictionConfig]
127+
id: str
128+
merge_strategy: str = Field(default="merge", alias="mergeStrategy")
129+
type: str = Field(default="predict")
130+
131+
132+
class PredictionPipelineRequest(BaseModel):
133+
"""Pydantic class representing the expected Predict Request"""
134+
135+
instances: list[list[Optional[Instance]]]
136+
metadata: Metadata
137+
stages: list[PipelineStage] = None
100138

101139

102140
class PredictionResponse(BaseModel):

dhl_sdk/entities.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,13 +119,20 @@ def __init__(self, **data):
119119
super().__init__(**data)
120120
self._client = data["client"]
121121

122-
def get_spectrum_index(self) -> int:
123-
"""Get the index of the spectrum variable"""
122+
def get_spectra_index(self) -> int:
123+
"""Get the index of the spectra variable"""
124124
for index, variable in enumerate(self.variables):
125125
if variable.variant == "spectrum":
126126
return index
127127
raise ValueError("No spectrum variable found in dataset")
128128

129+
def get_spectra_code(self) -> str:
130+
"""Get variable code of spectra variable"""
131+
for variable in self.variables:
132+
if variable.variant == "spectrum":
133+
return variable.code
134+
raise ValueError("No spectrum variable found in dataset")
135+
129136
@staticmethod
130137
def requests(client: Client) -> CRUDClient["SpectraDataset"]:
131138
# pylint: disable=missing-function-docstring
@@ -163,12 +170,10 @@ def get_predictions(self, preprocessor: Preprocessor) -> dict:
163170
"The provided inputs failed the validation step"
164171
)
165172

166-
predict_url = f"{PREDICT_URL}/{self.id}/predict"
167-
168173
predictions = []
169174
for prediction_data in json_data:
170175
try:
171-
response = self._client.post(predict_url, prediction_data)
176+
response = self._client.post(PREDICT_URL, prediction_data)
172177
response.raise_for_status()
173178

174179
# in case of an error in the response (not HTTP)
@@ -295,7 +300,13 @@ def predict(
295300
spectra=spectra, inputs=inputs, model=self
296301
)
297302

298-
return super().get_predictions(spectra_processing_strategy)
303+
predictions = super().get_predictions(spectra_processing_strategy)
304+
305+
spectra_code = self.dataset.get_spectra_code()
306+
if spectra_code in predictions:
307+
predictions.pop(spectra_code)
308+
309+
return predictions
299310

300311
@property
301312
def inputs(self) -> list[str]:
@@ -316,7 +327,7 @@ def spectra_size(self) -> int:
316327

317328
def _get_spectra_size(self) -> int:
318329
"""Get the size of the spectra from variable information in the API"""
319-
spectrum = self.dataset.variables[self.dataset.get_spectrum_index()]
330+
spectrum = self.dataset.variables[self.dataset.get_spectra_index()]
320331
return spectrum.size
321332

322333
@staticmethod

examples.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -807,7 +807,7 @@
807807
"\n",
808808
"prediction_config = PredictionConfig(model_confidence=50)\n",
809809
"\n",
810-
"result = model_hist.predict(timestamps, steps, inputs, timestamps_unit=\"s\", config = prediction_config)"
810+
"result = model_hist.predict(timestamps, steps, inputs, timestamps_unit=\"s\", config = prediction_config)"
811811
]
812812
},
813813
{

tests/test_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,11 @@ def setUp(self):
3333
"spectrum": {"xAxis": {"dimension": 4}},
3434
}
3535
self.model_no_inputs = Mock()
36+
self.model_no_inputs.id = "model-id-1"
3637
self.model_no_inputs.inputs = []
3738
self.model_no_inputs.dataset.variables = [Variable(**spectrum_var)]
3839
self.model_with_inputs = Mock()
40+
self.model_with_inputs.id = "model-id-2"
3941
self.model_with_inputs.dataset.variables = [
4042
Variable(**spectrum_var),
4143
Variable(id="id-123", code="var1", variant="numeric", name="variable 1"),
@@ -213,7 +215,7 @@ def test_validation_with_input(self):
213215
def test_convert_to_request(self):
214216
model = self.model_with_inputs
215217
model.spectra_size = 4
216-
model.dataset.get_spectrum_index.return_value = 0
218+
model.dataset.get_spectra_index.return_value = 0
217219

218220
spectra = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 6.0, 6.0], [7.0, 8.0, 9.0, 9.0]]
219221
inputs = {"var1": [0, 1, 0], "var2": [1, 1, 1]}
@@ -231,7 +233,7 @@ def test_convert_to_request(self):
231233
def test_convert_request_noinput(self):
232234
model = self.model_no_inputs
233235
model.spectra_size = 4
234-
model.dataset.get_spectrum_index.return_value = 0
236+
model.dataset.get_spectra_index.return_value = 0
235237

236238
spectra = [[1.0, 2.0, 3.0, 3.0], [4.0, 5.0, 6.0, 6.0], [7.0, 8.0, 9.0, 9.0]]
237239
processor = SpectraPreprocessor(spectra=spectra, model=model, inputs=None)
@@ -283,6 +285,7 @@ def setUp(self):
283285
"group": {"code": "X"},
284286
}
285287
self.model = Mock()
288+
self.model.id = "model-id-1"
286289
self.model.dataset.variables = [
287290
Variable(**var1),
288291
Variable(**var2),

0 commit comments

Comments
 (0)