Skip to content

Commit 9fe89dc

Browse files
committed
split functionality between async and sync applications, to have full compatibility with sync/async approaches to db querying
1 parent 4a273b9 commit 9fe89dc

3 files changed

Lines changed: 371 additions & 27 deletions

File tree

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import importlib
2-
import os
31
from typing import Type
42

53
from dateutil.parser import parse as date_parse
@@ -11,32 +9,9 @@
119
from sqlmodel.ext.asyncio.session import AsyncSession
1210
from sqlmodel.sql.expression import SelectOfScalar
1311

14-
load_dotenv() # take environment variables from .env.
15-
16-
17-
def get_val(val: str):
18-
"""
19-
Quick utility to pull environmental variable values after
20-
loading dot env. It does one thing: either return a value
21-
from a string representation of an environmental key or
22-
a Null value.
23-
24-
:param val: str
25-
:return: str | None
26-
"""
27-
return os.environ.get(val, None)
28-
29-
30-
def get_sql_dialect_import(dialect: str):
31-
"""
32-
A utility function to dynamically load the correct SQL Dialect from the
33-
SQLAlchemy package.
34-
:param dialect: str
35-
36-
:return: func
37-
"""
38-
return importlib.import_module(f"sqlalchemy.dialects" f".{dialect}").insert
12+
from sqlmodel_crud_utils.utils import get_sql_dialect_import, get_val
3913

14+
load_dotenv() # take environment variables from .env.
4015

4116
upsert = get_sql_dialect_import(dialect=get_val("SQL_DIALECT"))
4217

sqlmodel_crud_utils/sync.py

Lines changed: 343 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,343 @@
1+
from typing import Type
2+
3+
from dateutil.parser import parse as date_parse
4+
from dotenv import load_dotenv
5+
from loguru import logger
6+
from sqlalchemy.exc import MultipleResultsFound
7+
from sqlalchemy.orm import selectinload
8+
from sqlmodel import Session, SQLModel, select
9+
from sqlmodel.sql.expression import SelectOfScalar
10+
11+
from sqlmodel_crud_utils.utils import get_sql_dialect_import, get_val
12+
13+
load_dotenv() # take environment variables from .env.
14+
15+
upsert = get_sql_dialect_import(dialect=get_val("SQL_DIALECT"))
16+
17+
18+
def get_result_from_query(query: SelectOfScalar, session: Session):
19+
"""
20+
Processes an SQLModel query object and returns a singular result from the
21+
return payload. If more than one row is returned, then only the first row is
22+
returned. If no rows are available, then a null value is returned.
23+
24+
:param query: SelectOfScalar
25+
:param session: Session
26+
27+
:return: Row
28+
"""
29+
results = session.exec(query)
30+
try:
31+
results = results.one_or_none()
32+
except MultipleResultsFound:
33+
results = session.exec(query)
34+
results = results.first()
35+
36+
return results
37+
38+
39+
def get_one_or_create(
40+
session_inst: Session,
41+
model: type[SQLModel],
42+
create_method_kwargs: dict = None,
43+
selectin: bool = False,
44+
select_in_key: str | None = None,
45+
**kwargs,
46+
):
47+
"""
48+
This function either returns an existing data row from the database or
49+
creates a new instance and saves it to the DB.
50+
51+
:param session_inst: Session
52+
:param model: SQLModel ORM
53+
:param create_method_kwargs: dict
54+
:param selectin: bool
55+
:param select_in_key: str | None
56+
:param kwargs: keyword args
57+
:return: Tuple[Row, bool]
58+
"""
59+
60+
def _get_entry(sqlmodel, **key_args):
61+
stmnt = select(sqlmodel).filter_by(**key_args)
62+
results = get_result_from_query(query=stmnt, session=session_inst)
63+
64+
if results:
65+
if selectin and select_in_key:
66+
stmnt = stmnt.options(
67+
selectinload(getattr(sqlmodel, select_in_key))
68+
)
69+
results = get_result_from_query(
70+
query=stmnt, session=session_inst
71+
)
72+
return results, True
73+
else:
74+
return results, False
75+
76+
results, exists = _get_entry(model, **kwargs)
77+
if results:
78+
return results, exists
79+
else:
80+
kwargs.update(create_method_kwargs or {})
81+
created = model()
82+
[setattr(created, k, v) for k, v in kwargs.items()]
83+
session_inst.add(created)
84+
session_inst.commit()
85+
return created, False
86+
87+
88+
def write_row(data_row: Type[SQLModel], session_inst: Session):
89+
"""
90+
Writes a new instance of an SQLModel ORM model to the database, with an
91+
exception catch that rolls back the session in the event of failure.
92+
93+
:param data_row: Type[SQLModel]
94+
:param session_inst: Session
95+
:return: Tuple[bool, ScalarResult]
96+
"""
97+
try:
98+
session_inst.add(data_row)
99+
session_inst.commit()
100+
101+
return True, data_row
102+
except Exception as e:
103+
session_inst.rollback()
104+
logger.error(
105+
f"Writing data row to table failed. See error message: "
106+
f"{type(e), e, e.args}"
107+
)
108+
109+
return False, None
110+
111+
112+
def insert_data_rows(data_rows, session_inst: Session):
113+
try:
114+
session_inst.add_all(data_rows)
115+
session_inst.commit()
116+
117+
return True, data_rows
118+
119+
except Exception as e:
120+
logger.error(
121+
f"Writing data rows to table failed. See error message: "
122+
f"{type(e), e, e.args}"
123+
)
124+
logger.info(
125+
"Attempting to write individual entries. This can be a "
126+
"bit taxing, so please consider your payload to the DB"
127+
)
128+
129+
session_inst.rollback()
130+
processed_rows, failed_rows = [], []
131+
for row in data_rows:
132+
success, processed_row = write_row(row, session_inst=session_inst)
133+
if not success:
134+
failed_rows.append(row)
135+
else:
136+
processed_rows.append(row)
137+
138+
if processed_rows:
139+
status = True
140+
else:
141+
status = (False,)
142+
return status, {"success": processed_rows, "failed": failed_rows}
143+
144+
145+
def get_row(
146+
id_str: str or int,
147+
session_inst: Session,
148+
model: type[SQLModel],
149+
selectin: bool = False,
150+
select_in_keys: list[str] | None = None,
151+
pk_field: str = "id",
152+
):
153+
stmnt = select(model).where(getattr(model, pk_field) == id_str)
154+
if selectin and select_in_keys:
155+
if isinstance(select_in_keys, list) is False:
156+
select_in_keys = [select_in_keys]
157+
158+
for key in select_in_keys:
159+
stmnt = stmnt.options(selectinload(getattr(model, key)))
160+
results = session_inst.exec(stmnt)
161+
162+
row = results.one_or_none()
163+
164+
if not row:
165+
success = False
166+
else:
167+
success = True
168+
169+
return success, row
170+
171+
172+
def get_rows(
173+
session_inst: Session,
174+
model: type[SQLModel],
175+
selectin: bool = False,
176+
select_in_keys: list[str] | None = None,
177+
page_size: int = 100,
178+
page: int = 1,
179+
stmnt: SelectOfScalar | None = None,
180+
**kwargs,
181+
):
182+
# kwargs = {k: v for k, v in kwargs.items() if v}
183+
if stmnt is None:
184+
stmnt = select(model)
185+
if kwargs:
186+
if ["date" in x for x in kwargs] and any(
187+
x in y for y in kwargs for x in ("lte", "gte")
188+
):
189+
date_keys = [x for x in kwargs.keys() if "date" in x]
190+
for key in date_keys:
191+
if "lte" in key:
192+
model_key = key.replace("__lte", "")
193+
date_val = kwargs.pop(key)
194+
if isinstance(date_val, str):
195+
date_val = date_parse(date_val)
196+
stmnt = stmnt.where(
197+
getattr(model, model_key) < date_val
198+
)
199+
elif "gte" in key:
200+
model_key = key.replace("__gte", "")
201+
logger.info(model_key)
202+
date_val = kwargs.pop(key)
203+
if isinstance(date_val, str):
204+
date_val = date_parse(date_val)
205+
stmnt = stmnt.where(
206+
getattr(model, model_key) > date_val
207+
)
208+
else:
209+
date_val = kwargs.pop(key)
210+
if isinstance(date_val, str):
211+
date_val = date_parse(date_val)
212+
stmnt = stmnt.where(getattr(model, key) == date_val)
213+
elif "date" in kwargs:
214+
date_keys = [x for x in kwargs.keys() if "date" in x]
215+
for key in date_keys:
216+
stmnt = stmnt.where(getattr(model, key) == kwargs.pop(key))
217+
else:
218+
pass
219+
sort_desc, sort_field = (
220+
kwargs.pop(x, None) for x in ("sort_desc", "sort_field")
221+
)
222+
if sort_field and sort_desc:
223+
stmnt = stmnt.order_by(getattr(model, sort_field).desc())
224+
elif sort_field:
225+
stmnt = stmnt.order_by(getattr(model, sort_field))
226+
else:
227+
pass
228+
stmnt = stmnt.filter_by(**kwargs)
229+
230+
if selectin and select_in_keys:
231+
if isinstance(select_in_keys, list) is False:
232+
select_in_keys = [select_in_keys]
233+
for key in select_in_keys:
234+
stmnt = stmnt.options(selectinload(getattr(model, key)))
235+
236+
stmnt = stmnt.offset(page - 1).limit(page_size)
237+
_result = session_inst.exec(stmnt)
238+
results = _result.all()
239+
success = True if len(results) > 0 else False
240+
241+
return success, results
242+
243+
244+
def get_rows_within_id_list(
245+
id_str_list: list[str | int],
246+
session_inst: Session,
247+
model: type[SQLModel],
248+
pk_field: str = "id",
249+
):
250+
stmnt = select(model).where(getattr(model, pk_field).in_(id_str_list))
251+
results = session_inst.exec(stmnt)
252+
253+
if results:
254+
success = True
255+
else:
256+
success = False
257+
258+
return success, results
259+
260+
261+
def delete_row(
262+
id_str: str or int,
263+
session_inst: Session,
264+
model: type[SQLModel],
265+
pk_field: str = "id",
266+
):
267+
success = False
268+
stmnt = select(model).where(getattr(model, pk_field) == id_str)
269+
results = session_inst.exec(stmnt)
270+
271+
row = results.one_or_none()
272+
273+
if not row:
274+
pass
275+
else:
276+
try:
277+
session_inst.delete(row)
278+
session_inst.commit()
279+
success = True
280+
except Exception as e:
281+
logger.error(
282+
f"Failed to delete data row. Please see error messages here: "
283+
f"{type(e), e, e.args}"
284+
)
285+
session_inst.rollback()
286+
287+
return success
288+
289+
290+
def bulk_upsert_mappings(
291+
payload: list,
292+
session_inst: Session,
293+
model: type[SQLModel],
294+
pk_field: str = "id",
295+
):
296+
try:
297+
stmnt = upsert(model).values(payload)
298+
stmnt = stmnt.on_conflict_do_update(
299+
index_elements=[getattr(model, pk_field)],
300+
set_={k: getattr(stmnt.excluded, k) for k in payload[0].keys()},
301+
)
302+
session_inst.execute(stmnt)
303+
304+
session_inst.commit()
305+
306+
return True
307+
308+
except Exception as e:
309+
logger.error(
310+
f"Failed to upsert values to DB. Please see error: "
311+
f"{type(e), e, e.args}"
312+
)
313+
return False
314+
315+
316+
def update_row(
317+
id_str: int | str,
318+
data: dict,
319+
session_inst: Session,
320+
model: type[SQLModel],
321+
pk_field: str = "id",
322+
):
323+
success = False
324+
stmnt = select(model).where(getattr(model, pk_field) == id_str)
325+
results = session_inst.exec(stmnt)
326+
327+
row = results.one_or_none()
328+
329+
if row:
330+
[setattr(row, k, v) for k, v in data.items()]
331+
try:
332+
session_inst.add(row)
333+
session_inst.commit()
334+
success = True
335+
except Exception as e:
336+
session_inst.rollback()
337+
logger.error(
338+
f"Updating the data row failed. See error messages: "
339+
f"{type(e), e, e.args}"
340+
)
341+
return success, row
342+
else:
343+
return success, None

0 commit comments

Comments
 (0)