|
61 | 61 | PyArrowFileIO, |
62 | 62 | StatsAggregator, |
63 | 63 | _ConvertToArrowSchema, |
| 64 | + _determine_partitions, |
64 | 65 | _primitive_to_physical, |
65 | 66 | _read_deletes, |
66 | 67 | bin_pack_arrow_table, |
|
69 | 70 | schema_to_pyarrow, |
70 | 71 | ) |
71 | 72 | from pyiceberg.manifest import DataFile, DataFileContent, FileFormat |
72 | | -from pyiceberg.partitioning import PartitionSpec |
| 73 | +from pyiceberg.partitioning import PartitionField, PartitionSpec |
73 | 74 | from pyiceberg.schema import Schema, make_compatible_name, visit |
74 | 75 | from pyiceberg.table import FileScanTask, TableProperties |
75 | 76 | from pyiceberg.table.metadata import TableMetadataV2 |
76 | | -from pyiceberg.typedef import UTF8 |
| 77 | +from pyiceberg.transforms import IdentityTransform |
| 78 | +from pyiceberg.typedef import UTF8, Record |
77 | 79 | from pyiceberg.types import ( |
78 | 80 | BinaryType, |
79 | 81 | BooleanType, |
@@ -1718,3 +1720,81 @@ def test_bin_pack_arrow_table(arrow_table_with_null: pa.Table) -> None: |
1718 | 1720 | # and will produce half the number of files if we double the target size |
1719 | 1721 | bin_packed = bin_pack_arrow_table(bigger_arrow_tbl, target_file_size=arrow_table_with_null.nbytes * 2) |
1720 | 1722 | assert len(list(bin_packed)) == 5 |
| 1723 | + |
| 1724 | + |
| 1725 | +def test_partition_for_demo() -> None: |
| 1726 | + test_pa_schema = pa.schema([("year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())]) |
| 1727 | + test_schema = Schema( |
| 1728 | + NestedField(field_id=1, name="year", field_type=StringType(), required=False), |
| 1729 | + NestedField(field_id=2, name="n_legs", field_type=IntegerType(), required=True), |
| 1730 | + NestedField(field_id=3, name="animal", field_type=StringType(), required=False), |
| 1731 | + schema_id=1, |
| 1732 | + ) |
| 1733 | + test_data = { |
| 1734 | + "year": [2020, 2022, 2022, 2022, 2021, 2022, 2022, 2019, 2021], |
| 1735 | + "n_legs": [2, 2, 2, 4, 4, 4, 4, 5, 100], |
| 1736 | + "animal": ["Flamingo", "Parrot", "Parrot", "Horse", "Dog", "Horse", "Horse", "Brittle stars", "Centipede"], |
| 1737 | + } |
| 1738 | + arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema) |
| 1739 | + partition_spec = PartitionSpec( |
| 1740 | + PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"), |
| 1741 | + PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"), |
| 1742 | + ) |
| 1743 | + result = _determine_partitions(partition_spec, test_schema, arrow_table) |
| 1744 | + assert {table_partition.partition_key.partition for table_partition in result} == { |
| 1745 | + Record(n_legs_identity=2, year_identity=2020), |
| 1746 | + Record(n_legs_identity=100, year_identity=2021), |
| 1747 | + Record(n_legs_identity=4, year_identity=2021), |
| 1748 | + Record(n_legs_identity=4, year_identity=2022), |
| 1749 | + Record(n_legs_identity=2, year_identity=2022), |
| 1750 | + Record(n_legs_identity=5, year_identity=2019), |
| 1751 | + } |
| 1752 | + assert ( |
| 1753 | + pa.concat_tables([table_partition.arrow_table_partition for table_partition in result]).num_rows == arrow_table.num_rows |
| 1754 | + ) |
| 1755 | + |
| 1756 | + |
| 1757 | +def test_identity_partition_on_multi_columns() -> None: |
| 1758 | + test_pa_schema = pa.schema([("born_year", pa.int64()), ("n_legs", pa.int64()), ("animal", pa.string())]) |
| 1759 | + test_schema = Schema( |
| 1760 | + NestedField(field_id=1, name="born_year", field_type=StringType(), required=False), |
| 1761 | + NestedField(field_id=2, name="n_legs", field_type=IntegerType(), required=True), |
| 1762 | + NestedField(field_id=3, name="animal", field_type=StringType(), required=False), |
| 1763 | + schema_id=1, |
| 1764 | + ) |
| 1765 | + # 5 partitions, 6 unique row values, 12 rows |
| 1766 | + test_rows = [ |
| 1767 | + (2021, 4, "Dog"), |
| 1768 | + (2022, 4, "Horse"), |
| 1769 | + (2022, 4, "Another Horse"), |
| 1770 | + (2021, 100, "Centipede"), |
| 1771 | + (None, 4, "Kirin"), |
| 1772 | + (2021, None, "Fish"), |
| 1773 | + ] * 2 |
| 1774 | + expected = {Record(n_legs_identity=test_rows[i][1], year_identity=test_rows[i][0]) for i in range(len(test_rows))} |
| 1775 | + partition_spec = PartitionSpec( |
| 1776 | + PartitionField(source_id=2, field_id=1002, transform=IdentityTransform(), name="n_legs_identity"), |
| 1777 | + PartitionField(source_id=1, field_id=1001, transform=IdentityTransform(), name="year_identity"), |
| 1778 | + ) |
| 1779 | + import random |
| 1780 | + |
| 1781 | + # there are 12! / ((2!)^6) = 7,484,400 permutations, too many to pick all |
| 1782 | + for _ in range(1000): |
| 1783 | + random.shuffle(test_rows) |
| 1784 | + test_data = { |
| 1785 | + "born_year": [row[0] for row in test_rows], |
| 1786 | + "n_legs": [row[1] for row in test_rows], |
| 1787 | + "animal": [row[2] for row in test_rows], |
| 1788 | + } |
| 1789 | + arrow_table = pa.Table.from_pydict(test_data, schema=test_pa_schema) |
| 1790 | + |
| 1791 | + result = _determine_partitions(partition_spec, test_schema, arrow_table) |
| 1792 | + |
| 1793 | + assert {table_partition.partition_key.partition for table_partition in result} == expected |
| 1794 | + concatenated_arrow_table = pa.concat_tables([table_partition.arrow_table_partition for table_partition in result]) |
| 1795 | + assert concatenated_arrow_table.num_rows == arrow_table.num_rows |
| 1796 | + assert concatenated_arrow_table.sort_by([ |
| 1797 | + ("born_year", "ascending"), |
| 1798 | + ("n_legs", "ascending"), |
| 1799 | + ("animal", "ascending"), |
| 1800 | + ]) == arrow_table.sort_by([("born_year", "ascending"), ("n_legs", "ascending"), ("animal", "ascending")]) |
0 commit comments