Skip to content

Commit 43d6a6d

Browse files
committed
patch: bulk_upsert_mappings now take in multiple pk_fields arg for indexing elements
1 parent 639414d commit 43d6a6d

2 files changed

Lines changed: 8 additions & 4 deletions

File tree

sqlmodel_crud_utils/a_sync.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,12 +303,14 @@ async def bulk_upsert_mappings(
303303
payload: list,
304304
session_inst: AsyncSession,
305305
model: type[SQLModel],
306-
pk_field: str = "id",
306+
pk_fields: list[str] | None = None,
307307
):
308+
if not pk_fields:
309+
pk_fields = ["id"]
308310
try:
309311
stmnt = upsert(model).values(payload)
310312
stmnt = stmnt.on_conflict_do_update(
311-
index_elements=[getattr(model, pk_field)],
313+
index_elements=[getattr(model, x) for x in pk_fields],
312314
set_={k: getattr(stmnt.excluded, k) for k in payload[0].keys()},
313315
)
314316
await session_inst.execute(stmnt)

sqlmodel_crud_utils/sync.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,12 +300,14 @@ def bulk_upsert_mappings(
300300
payload: list,
301301
session_inst: Session,
302302
model: type[SQLModel],
303-
pk_field: str = "id",
303+
pk_fields: list[str] | None = None,
304304
):
305+
if not pk_fields:
306+
pk_fields = ["id"]
305307
try:
306308
stmnt = upsert(model).values(payload)
307309
stmnt = stmnt.on_conflict_do_update(
308-
index_elements=[getattr(model, pk_field)],
310+
index_elements=[getattr(model, x) for x in pk_fields],
309311
set_={k: getattr(stmnt.excluded, k) for k in payload[0].keys()},
310312
)
311313
session_inst.execute(stmnt)

0 commit comments

Comments
 (0)