diff --git a/.asf.yaml b/.asf.yaml index dd4975435cf0..36f01b88a724 100644 --- a/.asf.yaml +++ b/.asf.yaml @@ -46,6 +46,9 @@ github: strict: true # don't require any jobs to pass contexts: [] + pull_requests: + # enable updating head branches of pull requests + allow_update_branch: true # publishes the content of the `asf-site` branch to # https://arrow.apache.org/rust/ diff --git a/.github/actions/setup-builder/action.yaml b/.github/actions/setup-builder/action.yaml index 20da777ec0e5..209d58e2d86e 100644 --- a/.github/actions/setup-builder/action.yaml +++ b/.github/actions/setup-builder/action.yaml @@ -16,16 +16,7 @@ # under the License. name: Prepare Rust Builder -description: 'Prepare Rust Build Environment' -inputs: - rust-version: - description: 'version of rust to install (e.g. stable)' - required: false - default: 'stable' - target: - description: 'target architecture(s)' - required: false - default: 'x86_64-unknown-linux-gnu' +description: "Prepare Rust Build Environment" runs: using: "composite" steps: @@ -43,6 +34,9 @@ runs: /usr/local/cargo/git/db/ key: cargo-cache3-${{ hashFiles('**/Cargo.toml') }} restore-keys: cargo-cache3- + - name: Setup Rust toolchain + shell: bash + run: rustup install - name: Generate lockfile shell: bash run: cargo fetch @@ -51,12 +45,6 @@ runs: run: | apt-get update apt-get install -y protobuf-compiler - - name: Setup Rust toolchain - shell: bash - run: | - echo "Installing ${{ inputs.rust-version }}" - rustup toolchain install ${{ inputs.rust-version }} --target ${{ inputs.target }} - rustup default ${{ inputs.rust-version }} - name: Disable debuginfo generation # Disable full debug symbol generation to speed up CI build and keep memory down # "1" means line tables only, which is useful for panic tracebacks. @@ -65,6 +53,9 @@ runs: - name: Enable backtraces shell: bash run: echo "RUST_BACKTRACE=1" >> $GITHUB_ENV + - name: Disable incremental compilation + shell: bash + run: echo CARGO_INCREMENTAL=0 >> $GITHUB_ENV - name: Fixup git permissions # https://github.com/actions/checkout/issues/766 shell: bash diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 7ccf01fed2bd..2da398d7d861 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -14,6 +14,10 @@ updates: applies-to: version-updates patterns: - "prost*" + tonic: + applies-to: version-updates + patterns: + - "tonic*" - package-ecosystem: "github-actions" directory: "/" schedule: diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index e999f505bca1..c2d07f49ab88 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -1,20 +1,38 @@ # Which issue does this PR close? + -Closes #NNN. +- Closes #NNN. # Rationale for this change + # What changes are included in this PR? + + +# Are these changes tested? + + # Are there any user-facing changes? + diff --git a/.github/workflows/arrow.yml b/.github/workflows/arrow.yml index 0b90a78577e5..3a0b28d2d101 100644 --- a/.github/workflows/arrow.yml +++ b/.github/workflows/arrow.yml @@ -56,7 +56,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: submodules: true - name: Setup Rust toolchain @@ -68,7 +68,10 @@ jobs: - name: Test arrow-schema run: cargo test -p arrow-schema --all-features - name: Test arrow-array - run: cargo test -p arrow-array --all-features + run: | + cargo test -p arrow-array --all-features + # Disable feature `force_validate` + cargo test -p arrow-array --features=ffi - name: Test arrow-select run: cargo test -p arrow-select --all-features - name: Test arrow-cast @@ -112,7 +115,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: submodules: true - name: Setup Rust toolchain @@ -140,13 +143,15 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: submodules: true - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - with: - target: wasm32-unknown-unknown,wasm32-wasip1 + - name: Install wasm32 targets + run: | + rustup target add wasm32-unknown-unknown + rustup target add wasm32-wasip1 - name: Build wasm32-unknown-unknown run: cargo build -p arrow --no-default-features --features=json,csv,ipc,ffi --target wasm32-unknown-unknown - name: Build wasm32-wasip1 @@ -158,7 +163,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Setup Clippy diff --git a/.github/workflows/arrow_flight.yml b/.github/workflows/arrow_flight.yml index 2659a0d987b8..426255f0f3c3 100644 --- a/.github/workflows/arrow_flight.yml +++ b/.github/workflows/arrow_flight.yml @@ -47,7 +47,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: submodules: true - name: Setup Rust toolchain @@ -60,7 +60,7 @@ jobs: cargo test -p arrow-flight --all-features - name: Test --examples run: | - cargo test -p arrow-flight --features=flight-sql,tls --examples + cargo test -p arrow-flight --features=flight-sql,tls-ring --examples vendor: name: Verify Vendored Code @@ -68,7 +68,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Run gen @@ -82,7 +82,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Setup Clippy diff --git a/.github/workflows/audit.yml b/.github/workflows/audit.yml index e6254ea24a58..d568fcc0f069 100644 --- a/.github/workflows/audit.yml +++ b/.github/workflows/audit.yml @@ -36,7 +36,7 @@ jobs: name: Audit runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Install cargo-audit run: cargo install cargo-audit - name: Run audit check diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index b28e8c20cfe7..f20f0b143696 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -38,9 +38,9 @@ jobs: name: Release Audit Tool (RAT) runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: 3.8 - name: Audit licenses @@ -50,8 +50,8 @@ jobs: name: Markdown format runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - - uses: actions/setup-node@v4 + - uses: actions/checkout@v6 + - uses: actions/setup-node@v6 with: node-version: "14" - name: Prettier check diff --git a/.github/workflows/dev_pr.yml b/.github/workflows/dev_pr.yml index 0d60ae006796..7b0c2566a3bf 100644 --- a/.github/workflows/dev_pr.yml +++ b/.github/workflows/dev_pr.yml @@ -37,14 +37,14 @@ jobs: contents: read pull-requests: write steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Assign GitHub labels if: | github.event_name == 'pull_request_target' && (github.event.action == 'opened' || github.event.action == 'synchronize') - uses: actions/labeler@v5.0.0 + uses: actions/labeler@v6.0.1 with: repo-token: ${{ secrets.GITHUB_TOKEN }} configuration-path: .github/workflows/dev_pr/labeler.yml diff --git a/.github/workflows/dev_pr/labeler.yml b/.github/workflows/dev_pr/labeler.yml index 64299bd507d3..edb6d036174c 100644 --- a/.github/workflows/dev_pr/labeler.yml +++ b/.github/workflows/dev_pr/labeler.yml @@ -37,6 +37,11 @@ arrow: - 'arrow-string/**/*' - 'arrow/**/*' +arrow-avro: + - changed-files: + - any-glob-to-any-file: + - 'arrow-avro/**/*' + arrow-flight: - changed-files: - any-glob-to-any-file: @@ -46,7 +51,13 @@ parquet: - changed-files: - any-glob-to-any-file: - 'parquet/**/*' - - 'parquet-variant/**/*' + +parquet-variant: + - changed-files: + - any-glob-to-any-file: + - 'parquet-variant/**/*' + - 'parquet-variant-compute/**/*' + - 'parquet-variant-json/**/*' parquet-derive: - changed-files: diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index d6ec0622f6ed..12e22abce06d 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -34,28 +34,20 @@ jobs: docs: name: Rustdocs are clean runs-on: ubuntu-latest - strategy: - matrix: - arch: [ amd64 ] - rust: [ nightly ] container: - image: ${{ matrix.arch }}/rust + image: amd64/rust env: RUSTDOCFLAGS: "-Dwarnings --enable-index-page -Zunstable-options" steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: submodules: true - - name: Install python dev - run: | - apt update - apt install -y libpython3.11-dev - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - with: - rust-version: ${{ matrix.rust }} + - name: Install Nightly Rust + run: rustup install nightly - name: Run cargo doc - run: cargo doc --document-private-items --no-deps --workspace --all-features + run: cargo +nightly doc --document-private-items --no-deps --workspace --all-features - name: Fix file permissions shell: sh run: | @@ -64,7 +56,7 @@ jobs: echo "::warning title=Invalid file permissions automatically fixed::$line" done - name: Upload artifacts - uses: actions/upload-pages-artifact@v3 + uses: actions/upload-pages-artifact@v4 with: name: crate-docs path: target/doc @@ -77,9 +69,9 @@ jobs: contents: write runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Download crate docs - uses: actions/download-artifact@v4 + uses: actions/download-artifact@v7 with: name: crate-docs path: website/build diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 1b6eeb15dca4..cc74650812e9 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -63,6 +63,7 @@ jobs: ARROW_INTEGRATION_CPP: ON ARROW_INTEGRATION_CSHARP: ON ARCHERY_INTEGRATION_TARGET_IMPLEMENTATIONS: "rust" + ARCHERY_INTEGRATION_WITH_DOTNET: "1" ARCHERY_INTEGRATION_WITH_GO: "1" ARCHERY_INTEGRATION_WITH_JAVA: "1" ARCHERY_INTEGRATION_WITH_JS: "1" @@ -77,52 +78,112 @@ jobs: run: shell: bash steps: + - name: Monitor disk usage - Initial + run: | + echo "=== Initial Disk Usage ===" + df -h / + echo "" + + - name: Remove unnecessary preinstalled software + run: | + echo "=== Cleaning up host disk space ===" + echo "Disk space before cleanup:" + df -h / + + # Clean apt cache + apt-get clean || true + + # Remove GitHub Actions tool cache + rm -rf /__t/* || true + + # Remove large packages from host filesystem (mounted at /host/) + rm -rf /host/usr/share/dotnet || true + rm -rf /host/usr/local/lib/android || true + rm -rf /host/usr/local/.ghcup || true + rm -rf /host/opt/hostedtoolcache/CodeQL || true + + echo "" + echo "Disk space after cleanup:" + df -h / + echo "" + # This is necessary so that actions/checkout can find git - name: Export conda path run: echo "/opt/conda/envs/arrow/bin" >> $GITHUB_PATH # This is necessary so that Rust can find cargo - name: Export cargo path run: echo "/root/.cargo/bin" >> $GITHUB_PATH - - name: Check rustup - run: which rustup - - name: Check cmake - run: which cmake + + # Checkout repos (using shallow clones with fetch-depth: 1) - name: Checkout Arrow - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: repository: apache/arrow submodules: true - fetch-depth: 0 + fetch-depth: 1 - name: Checkout Arrow Rust - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: path: rust - fetch-depth: 0 + submodules: true + fetch-depth: 1 + - name: Checkout Arrow .NET + uses: actions/checkout@v6 + with: + repository: apache/arrow-dotnet + path: dotnet + fetch-depth: 1 - name: Checkout Arrow Go - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: repository: apache/arrow-go path: go + fetch-depth: 1 - name: Checkout Arrow Java - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: repository: apache/arrow-java path: java + fetch-depth: 1 - name: Checkout Arrow JavaScript - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: repository: apache/arrow-js path: js + fetch-depth: 1 - name: Checkout Arrow nanoarrow - uses: actions/checkout@v4 + uses: actions/checkout@v6 with: repository: apache/arrow-nanoarrow path: nanoarrow + fetch-depth: 1 + + - name: Monitor disk usage - After checkouts + run: | + echo "=== After Checkouts ===" + df -h / + echo "" + - name: Build run: conda run --no-capture-output ci/scripts/integration_arrow_build.sh $PWD /build + + - name: Monitor disk usage - After build + if: always() + run: | + echo "=== After Build ===" + df -h / + echo "" + - name: Run run: conda run --no-capture-output ci/scripts/integration_arrow.sh $PWD /build + - name: Monitor disk usage - After tests + if: always() + run: | + echo "=== After Tests ===" + df -h / + echo "" + # test FFI against the C-Data interface exposed by pyarrow pyarrow-integration-test: name: Pyarrow C Data Interface @@ -133,7 +194,7 @@ jobs: # PyArrow 15 was the first version to introduce StringView/BinaryView support pyarrow: ["15", "16", "17"] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: submodules: true - name: Setup Rust toolchain @@ -142,17 +203,17 @@ jobs: rustup default ${{ matrix.rust }} rustup component add rustfmt clippy - name: Cache Cargo - uses: actions/cache@v4 + uses: actions/cache@v5 with: path: /home/runner/.cargo key: cargo-maturin-cache- - name: Cache Rust dependencies - uses: actions/cache@v4 + uses: actions/cache@v5 with: path: /home/runner/target # this key is not equal because maturin uses different compilation flags. key: ${{ runner.os }}-${{ matrix.arch }}-target-maturin-cache-${{ matrix.rust }}- - - uses: actions/setup-python@v5 + - uses: actions/setup-python@v6 with: python-version: '3.8' - name: Upgrade pip and setuptools @@ -165,8 +226,9 @@ jobs: - name: Run Rust tests run: | source venv/bin/activate - cargo test -p arrow-pyarrow - - name: Run tests + cd arrow-pyarrow-testing + cargo test + - name: Run Python tests run: | source venv/bin/activate cd arrow-pyarrow-integration-testing diff --git a/.github/workflows/miri.yaml b/.github/workflows/miri.yaml index ce67546a104b..f7269f535249 100644 --- a/.github/workflows/miri.yaml +++ b/.github/workflows/miri.yaml @@ -47,7 +47,7 @@ jobs: name: MIRI runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: submodules: true - name: Setup Rust toolchain diff --git a/.github/workflows/parquet-geospatial.yml b/.github/workflows/parquet-geospatial.yml new file mode 100644 index 000000000000..77bd8f97b4f7 --- /dev/null +++ b/.github/workflows/parquet-geospatial.yml @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +--- +# tests for parquet-geospatial crate +name: "parquet-geospatial" + +concurrency: + group: ${{ github.repository }}-${{ github.head_ref || github.sha }}-${{ github.workflow }} + cancel-in-progress: true + +# trigger for all PRs that touch certain files and changes to main +on: + push: + branches: + - main + pull_request: + paths: + - parquet-geospatial/** + - .github/** + +jobs: + # test the crate + linux-test: + name: Test + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v6 + with: + submodules: true + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + - name: Test parquet-geospatial + run: cargo test -p parquet-geospatial + + # test compilation + linux-features: + name: Check Compilation + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v6 + with: + submodules: true + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + - name: Check compilation (parquet-geospatial) + run: cargo check -p parquet-geospatial + + clippy: + name: Clippy + runs-on: ubuntu-latest + container: + image: amd64/rust + steps: + - uses: actions/checkout@v6 + - name: Setup Rust toolchain + uses: ./.github/actions/setup-builder + - name: Setup Clippy + run: rustup component add clippy + - name: Run clippy (parquet-geospatial) + run: cargo clippy -p parquet-geospatial --all-targets --all-features -- -D warnings diff --git a/.github/workflows/parquet-variant.yml b/.github/workflows/parquet-variant.yml index 6fc5c3a8cd00..3e4563286b22 100644 --- a/.github/workflows/parquet-variant.yml +++ b/.github/workflows/parquet-variant.yml @@ -31,6 +31,8 @@ on: pull_request: paths: - parquet-variant/** + - parquet-variant-json/** + - parquet-variant-compute/** - .github/** jobs: @@ -41,13 +43,17 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: submodules: true - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - - name: Test + - name: Test parquet-variant run: cargo test -p parquet-variant + - name: Test parquet-variant-json + run: cargo test -p parquet-variant-json + - name: Test parquet-variant-compute + run: cargo test -p parquet-variant-compute # test compilation linux-features: @@ -56,13 +62,17 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: submodules: true - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - - name: Check compilation + - name: Check compilation (parquet-variant) run: cargo check -p parquet-variant + - name: Check compilation (parquet-variant-json) + run: cargo check -p parquet-variant-json + - name: Check compilation (parquet-variant-compute) + run: cargo check -p parquet-variant-compute clippy: name: Clippy @@ -70,10 +80,14 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Setup Clippy run: rustup component add clippy - - name: Run clippy + - name: Run clippy (parquet-variant) run: cargo clippy -p parquet-variant --all-targets --all-features -- -D warnings + - name: Run clippy (parquet-variant-json) + run: cargo clippy -p parquet-variant-json --all-targets --all-features -- -D warnings + - name: Run clippy (parquet-variant-compute) + run: cargo clippy -p parquet-variant-compute --all-targets --all-features -- -D warnings diff --git a/.github/workflows/parquet.yml b/.github/workflows/parquet.yml index 96c7ab8f4e3a..8b94efd91f90 100644 --- a/.github/workflows/parquet.yml +++ b/.github/workflows/parquet.yml @@ -42,6 +42,9 @@ on: - arrow-json/** - arrow-avro/** - parquet/** + - parquet-variant/** + - parquet-variant-compute/** + - parquet-variant-json/** - .github/** jobs: @@ -52,7 +55,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: submodules: true - name: Setup Rust toolchain @@ -75,7 +78,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: submodules: true - name: Setup Rust toolchain @@ -113,6 +116,15 @@ jobs: run: cargo check -p parquet --all-targets --no-default-features --features json - name: Check compilation --no-default-features --features encryption --features async run: cargo check -p parquet --no-default-features --features encryption --features async + - name: Check compilation --no-default-features --features flate2, this is expected to fail + run: if `cargo check -p parquet --no-default-features --features flate2 2>/dev/null`; then false; else true; fi + - name: Check compilation --no-default-features --features flate2 --features flate2-rust_backened + run: cargo check -p parquet --no-default-features --features flate2 --features flate2-rust_backened + - name: Check compilation --no-default-features --features flate2 --features flate2-zlib-rs + run: cargo check -p parquet --no-default-features --features flate2 --features flate2-zlib-rs + - name: Check compilation --no-default-features --features variant_experimental + run: cargo check -p parquet --no-default-features --features variant_experimental + # test the parquet crate builds against wasm32 in stable rust wasm32-build: @@ -121,13 +133,15 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: submodules: true - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - with: - target: wasm32-unknown-unknown,wasm32-wasip1 + - name: Install wasm32 targets + run: | + rustup target add wasm32-unknown-unknown + rustup target add wasm32-wasip1 - name: Install clang # Needed for zlib compilation run: apt-get update && apt-get install -y clang gcc-multilib - name: Build wasm32-unknown-unknown @@ -142,9 +156,9 @@ jobs: matrix: rust: [ stable ] steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Python - uses: actions/setup-python@v5 + uses: actions/setup-python@v6 with: python-version: "3.10" cache: "pip" @@ -175,7 +189,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Setup Clippy diff --git a/.github/workflows/parquet_derive.yml b/.github/workflows/parquet_derive.yml index 17aec724a820..b1541b5dfb0b 100644 --- a/.github/workflows/parquet_derive.yml +++ b/.github/workflows/parquet_derive.yml @@ -43,7 +43,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: submodules: true - name: Setup Rust toolchain @@ -57,7 +57,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Setup Clippy diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8f87c50649d3..6e0d10106cbe 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -33,7 +33,7 @@ jobs: runs-on: ubuntu-latest timeout-minutes: 5 steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Create GitHub Releases run: | version=${GITHUB_REF_NAME} diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index a20575391b48..77fccdbebc46 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -30,14 +30,13 @@ on: pull_request: jobs: - # Check workspace wide compile and test with default features for # mac macos: name: Test on Mac runs-on: macos-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: submodules: true - name: Install protoc with brew @@ -52,8 +51,7 @@ jobs: # do not produce debug symbols to keep memory usage down export RUSTFLAGS="-C debuginfo=0" # PyArrow tests happen in integration.yml. - cargo test --workspace --exclude arrow-pyarrow - + cargo test --workspace # Check workspace wide compile and test with default features for # windows @@ -61,7 +59,7 @@ jobs: name: Test on Windows runs-on: windows-latest steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 with: submodules: true - name: Install protobuf compiler in /d/protoc @@ -84,9 +82,7 @@ jobs: # do not produce debug symbols to keep memory usage down export RUSTFLAGS="-C debuginfo=0" export PATH=$PATH:/d/protoc/bin - # PyArrow tests happen in integration.yml. - cargo test --workspace --exclude arrow-pyarrow - + cargo test --workspace # Run cargo fmt for all crates lint: @@ -95,7 +91,7 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - name: Setup rustfmt @@ -117,20 +113,12 @@ jobs: container: image: amd64/rust steps: - - uses: actions/checkout@v4 + - uses: actions/checkout@v6 - name: Setup Rust toolchain uses: ./.github/actions/setup-builder - - name: Install cargo-msrv - run: cargo install cargo-msrv - - name: Downgrade arrow-pyarrow-integration-testing dependencies - working-directory: arrow-pyarrow-integration-testing - # Necessary because half 2.5 requires rust 1.81 or newer - run: | - cargo update -p half --precise 2.4.0 - - name: Downgrade workspace dependencies - # Necessary because half 2.5 requires rust 1.81 or newer - run: | - cargo update -p half --precise 2.4.0 + - name: Install cargo-msrv (if needed) + # cargo-msrv binary may be cached by the cargo cache step in setup-builder, and cargo install will error if it is already installed + run: if which cargo-msrv ; then echo "using existing cargo-msrv binary" ; else cargo install cargo-msrv ; fi - name: Check all packages run: | # run `cargo msrv verify --manifest-path "path/to/Cargo.toml"` to see problematic dependencies diff --git a/.github/workflows/take.yml b/.github/workflows/take.yml index dd21c794960e..94a95f6e31a2 100644 --- a/.github/workflows/take.yml +++ b/.github/workflows/take.yml @@ -28,7 +28,7 @@ jobs: if: (!github.event.issue.pull_request) && github.event.comment.body == 'take' runs-on: ubuntu-latest steps: - - uses: actions/github-script@v7 + - uses: actions/github-script@v8 with: script: | github.rest.issues.addAssignees({ diff --git a/.gitignore b/.gitignore index 05091a4e975d..127182a8f99e 100644 --- a/.gitignore +++ b/.gitignore @@ -25,6 +25,9 @@ __blobstorage__ *.bak2 # OS-specific .gitignores +# cargo insta temp files +*.pending-snap + # Mac .gitignore # General .DS_Store @@ -99,4 +102,4 @@ parquet/pytest/venv/ __pycache__/ # Parquet file from arrow_reader_clickbench -hits_1.parquet \ No newline at end of file +hits_1.parquet diff --git a/CHANGELOG-old.md b/CHANGELOG-old.md index 941c9f26382c..a651a860f893 100644 --- a/CHANGELOG-old.md +++ b/CHANGELOG-old.md @@ -19,6 +19,1318 @@ # Historical Changelog + +## [57.1.0](https://github.com/apache/arrow-rs/tree/57.1.0) (2025-11-20) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/57.0.0...57.1.0) + +**Implemented enhancements:** + +- Eliminate bound checks in filter kernels [\#8865](https://github.com/apache/arrow-rs/issues/8865) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Respect page index policy option for ParquetObjectReader when it's not skip [\#8856](https://github.com/apache/arrow-rs/issues/8856) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Speed up collect\_bool and remove `unsafe` [\#8848](https://github.com/apache/arrow-rs/issues/8848) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Error reading parquet FileMetaData with empty lists encoded as element-type=0 [\#8826](https://github.com/apache/arrow-rs/issues/8826) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- ValueStatistics methods can't be used from generic context in external crate [\#8823](https://github.com/apache/arrow-rs/issues/8823) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Custom Pretty-Printing Implementation for Column when Formatting Record Batches [\#8821](https://github.com/apache/arrow-rs/issues/8821) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Parquet-concat: supports bloom filter and page index [\#8804](https://github.com/apache/arrow-rs/issues/8804) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Parquet\] virtual row number support [\#7299](https://github.com/apache/arrow-rs/issues/7299) +- \[Variant\] Enforce shredded-type validation in `shred_variant` [\#8795](https://github.com/apache/arrow-rs/issues/8795) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Simplify decision logic to call `FilterBuilder::optimize` or not [\#8781](https://github.com/apache/arrow-rs/issues/8781) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Add variant to arrow for DataType::{Binary, LargeBinary, BinaryView} [\#8767](https://github.com/apache/arrow-rs/issues/8767) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Provide algorithm that allows zipping arrays whose values are not prealigned [\#8752](https://github.com/apache/arrow-rs/issues/8752) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Parquet\] ParquetMetadataReader decodes too much metadata under point-get scenerio [\#8751](https://github.com/apache/arrow-rs/issues/8751) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `arrow-json` supports encoding binary arrays, but not decoding [\#8736](https://github.com/apache/arrow-rs/issues/8736) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Allow `FilterPredicate` instances to be reused for RecordBatches [\#8692](https://github.com/apache/arrow-rs/issues/8692) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- ArrowJsonBatch::from\_batch is incomplete [\#8684](https://github.com/apache/arrow-rs/issues/8684) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- parquet-layout: More info about layout including footer size, page index, bloom filter? [\#8682](https://github.com/apache/arrow-rs/issues/8682) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Rewrite `ParquetRecordBatchStream` \(async API\) in terms of the PushDecoder [\#8677](https://github.com/apache/arrow-rs/issues/8677) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[JSON\] Add encoding for binary view [\#8674](https://github.com/apache/arrow-rs/issues/8674) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Refactor arrow-cast decimal casting to unify the rescale logic used in Parquet variant casts [\#8670](https://github.com/apache/arrow-rs/issues/8670) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Support Uuid/`FixedSizeBinary(16)` shredding [\#8665](https://github.com/apache/arrow-rs/issues/8665) +- \[Parquet\]There should be an encoding counter to know how many encodings the repo supports in total [\#8662](https://github.com/apache/arrow-rs/issues/8662) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Improve `parse_data_type` for `List`, `ListView`, `LargeList`, `LargeListView`, `FixedSizeList`, `Union`, `Map`, `RunEndCoded`. [\#8648](https://github.com/apache/arrow-rs/issues/8648) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Support variant to arrow primitive support null/time/decimal\_\* [\#8637](https://github.com/apache/arrow-rs/issues/8637) +- Return error from `RleDecoder::reset` rather than panic [\#8632](https://github.com/apache/arrow-rs/issues/8632) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Add bitwise ops on `BooleanBufferBuilder` and `MutableBuffer` that mutate directly the buffer [\#8618](https://github.com/apache/arrow-rs/issues/8618) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Add variant\_to\_arrow Utf-8, LargeUtf8, Utf8View types support [\#8567](https://github.com/apache/arrow-rs/issues/8567) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Fixed bugs:** + +- Regression: Parsing `List(Int64)` results in nullable list in 57.0.0 and a non-nullable list in 57.1.0 [\#8883](https://github.com/apache/arrow-rs/issues/8883) +- Regression: FixedSlizeList data type parsing fails on 57.1.0 [\#8880](https://github.com/apache/arrow-rs/issues/8880) +- \(dyn ArrayFormatterFactory + 'static\) can't be safely shared between threads [\#8875](https://github.com/apache/arrow-rs/issues/8875) +- RowNumber reader has wrong row group ordering [\#8864](https://github.com/apache/arrow-rs/issues/8864) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `ThriftMetadataWriter::write_column_indexes` cannot handle a `ColumnIndexMetaData::NONE` [\#8815](https://github.com/apache/arrow-rs/issues/8815) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- "Archery test With other arrows" Integration test failing on main: [\#8813](https://github.com/apache/arrow-rs/issues/8813) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Parquet\] Writing in 57.0.0 seems 10% slower than 56.0.0 [\#8783](https://github.com/apache/arrow-rs/issues/8783) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Parquet reader cannot handle files with unknown logical types [\#8776](https://github.com/apache/arrow-rs/issues/8776) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- zip now treats nulls as false in provided mask regardless of the underlying bit value [\#8721](https://github.com/apache/arrow-rs/issues/8721) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[avro\] Incorrect version in crate.io landing page [\#8691](https://github.com/apache/arrow-rs/issues/8691) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Array: ViewType gc\(\) has bug when array sum length exceed i32::MAX [\#8681](https://github.com/apache/arrow-rs/issues/8681) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Parquet 56: encounter `error: item_reader def levels are None` when reading nested field with row filter [\#8657](https://github.com/apache/arrow-rs/issues/8657) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Degnerate and non-nullable `FixedSizeListArray`s are not handled [\#8623](https://github.com/apache/arrow-rs/issues/8623) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Parquet\]Performance Degradation with RowFilter on Unsorted Columns due to Fragmented ReadPlan [\#8565](https://github.com/apache/arrow-rs/issues/8565) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Documentation updates:** + +- docs: Add example for creating a `MutableBuffer` from `Buffer` [\#8853](https://github.com/apache/arrow-rs/pull/8853) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- docs: Add examples for creating MutableBuffer from Vec [\#8852](https://github.com/apache/arrow-rs/pull/8852) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Improve ParquetDecoder docs [\#8802](https://github.com/apache/arrow-rs/pull/8802) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Update docs for zero copy conversion of ScalarBuffer [\#8772](https://github.com/apache/arrow-rs/pull/8772) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add example to convert `PrimitiveArray` to a `Vec` [\#8771](https://github.com/apache/arrow-rs/pull/8771) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- docs: Add links for arrow-avro [\#8770](https://github.com/apache/arrow-rs/pull/8770) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- \[Parquet\] Minor: Update comments in page decompressor [\#8764](https://github.com/apache/arrow-rs/pull/8764) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Document limitations of the `arrow_integration_test` crate [\#8738](https://github.com/apache/arrow-rs/pull/8738) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([phil-opp](https://github.com/phil-opp)) +- docs: Add link to the Arrow implementation status page [\#8732](https://github.com/apache/arrow-rs/pull/8732) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- docs: Update Parquet readme implementation status [\#8731](https://github.com/apache/arrow-rs/pull/8731) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) + +**Performance improvements:** + +- `RowConverter::from_binary` should opportunistically take ownership of the buffer [\#8685](https://github.com/apache/arrow-rs/issues/8685) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Speed up filter some more \(up to 2x\) [\#8868](https://github.com/apache/arrow-rs/pull/8868) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Speed up `collect_bool` and remove `unsafe`, optimize `take_bits`, `take_native` for null values [\#8849](https://github.com/apache/arrow-rs/pull/8849) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Change `BooleanBuffer::append_packed_range` to use `apply_bitwise_binary_op` [\#8812](https://github.com/apache/arrow-rs/pull/8812) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- \[Parquet\] Avoid copying `LogicalType` in `ColumnOrder::get_sort_order`, deprecate `get_logical_type` [\#8789](https://github.com/apache/arrow-rs/pull/8789) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- perf: Speed up Parquet file writing \(10%, back to speed of 56\) [\#8786](https://github.com/apache/arrow-rs/pull/8786) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- perf: override `ArrayIter` default impl for `nth`, `nth_back`, `last` and `count` [\#8785](https://github.com/apache/arrow-rs/pull/8785) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- \[Parquet\] Reduce one copy in `SerializedPageReader` [\#8745](https://github.com/apache/arrow-rs/pull/8745) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Small optimization in Parquet varint decoder [\#8742](https://github.com/apache/arrow-rs/pull/8742) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- perf: override `count`, `nth`, `nth_back`, `last` and `max` for BitIterator [\#8696](https://github.com/apache/arrow-rs/pull/8696) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- Add `FilterPredicate::filter_record_batch` [\#8693](https://github.com/apache/arrow-rs/pull/8693) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([pepijnve](https://github.com/pepijnve)) +- perf: zero-copy path in `RowConverter::from_binary` [\#8686](https://github.com/apache/arrow-rs/pull/8686) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mzabaluev](https://github.com/mzabaluev)) +- perf: add optimized zip implementation for scalars [\#8653](https://github.com/apache/arrow-rs/pull/8653) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- feat: add `apply_unary_op` and `apply_binary_op` bitwise operations [\#8619](https://github.com/apache/arrow-rs/pull/8619) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- \[Parquet\]Optimize the performance in record reader [\#8607](https://github.com/apache/arrow-rs/pull/8607) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([hhhizzz](https://github.com/hhhizzz)) + +**Closed issues:** + +- Variant to NullType conversion ignores strict casting [\#8810](https://github.com/apache/arrow-rs/issues/8810) +- Unify display representation for `Field` [\#8784](https://github.com/apache/arrow-rs/issues/8784) +- Misleading configuration name: skip\_arrow\_metadata [\#8780](https://github.com/apache/arrow-rs/issues/8780) +- Inconsistent display for types with Metadata [\#8761](https://github.com/apache/arrow-rs/issues/8761) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Internal `arrow-integration-test` crate is linked from `arrow` docs [\#8739](https://github.com/apache/arrow-rs/issues/8739) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add benchmark for RunEndEncoded casting [\#8709](https://github.com/apache/arrow-rs/issues/8709) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Varaint\] Support `VariantArray::value` to return a `Result` [\#8672](https://github.com/apache/arrow-rs/issues/8672) + +**Merged pull requests:** + +- Fix regression caused by changes in Display for DataType - display \(`List(non-null Int64)` instead of `List(nullable Int64)` [\#8890](https://github.com/apache/arrow-rs/pull/8890) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([etseidl](https://github.com/etseidl)) +- Support parsing for old style FixedSizeList [\#8882](https://github.com/apache/arrow-rs/pull/8882) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Make ArrayFormatterFactory Send + Sync and add a test [\#8878](https://github.com/apache/arrow-rs/pull/8878) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tobixdev](https://github.com/tobixdev)) +- Make `ArrowReaderOptions::with_virtual_columns` error rather than panic on invalid input [\#8867](https://github.com/apache/arrow-rs/pull/8867) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Fix errors when reading nested Lists with pushdown predicates. [\#8866](https://github.com/apache/arrow-rs/pull/8866) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Fix `RowNumberReader` when not all row groups are selected [\#8863](https://github.com/apache/arrow-rs/pull/8863) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([vustef](https://github.com/vustef)) +- Respect page index policy option for ParquetObjectReader when it's not skip [\#8857](https://github.com/apache/arrow-rs/pull/8857) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- build\(deps\): update apache-avro requirement from 0.20.0 to 0.21.0 [\#8832](https://github.com/apache/arrow-rs/pull/8832) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Allow Users to Provide Custom `ArrayFormatter`s when Pretty-Printing Record Batches [\#8829](https://github.com/apache/arrow-rs/pull/8829) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tobixdev](https://github.com/tobixdev)) +- Allow reading of improperly constructed empty lists in Parquet metadata [\#8827](https://github.com/apache/arrow-rs/pull/8827) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- \[Variant\] Fix cast logic for Variant to Arrow for DataType::Null [\#8825](https://github.com/apache/arrow-rs/pull/8825) ([klion26](https://github.com/klion26)) +- remove T: ParquetValueType bound on ValueStatistics [\#8824](https://github.com/apache/arrow-rs/pull/8824) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([pmarks](https://github.com/pmarks)) +- build\(deps\): update lz4\_flex requirement from 0.11 to 0.12 [\#8820](https://github.com/apache/arrow-rs/pull/8820) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Fix bug in handling of empty Parquet page index structures [\#8817](https://github.com/apache/arrow-rs/pull/8817) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Parquet-concat: supports page index and bloom filter [\#8811](https://github.com/apache/arrow-rs/pull/8811) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([mapleFU](https://github.com/mapleFU)) +- \[Doc\] Correct `ListArray` documentation [\#8803](https://github.com/apache/arrow-rs/pull/8803) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liamzwbao](https://github.com/liamzwbao)) +- \[Parquet\] Add additional docs for `ArrowReaderOptions` and `ArrowReaderMetadata` [\#8798](https://github.com/apache/arrow-rs/pull/8798) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- \[Variant\] Enforce shredded-type validation in `shred_variant` [\#8796](https://github.com/apache/arrow-rs/pull/8796) ([liamzwbao](https://github.com/liamzwbao)) +- Add `VariantPath::is_empty` [\#8791](https://github.com/apache/arrow-rs/pull/8791) ([friendlymatthew](https://github.com/friendlymatthew)) +- Add FilterBuilder::is\_optimize\_beneficial [\#8782](https://github.com/apache/arrow-rs/pull/8782) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([pepijnve](https://github.com/pepijnve)) +- \[Parquet\] Allow reading of files with unknown logical types [\#8777](https://github.com/apache/arrow-rs/pull/8777) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- bench: add `ArrayIter` benchmarks [\#8774](https://github.com/apache/arrow-rs/pull/8774) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- Update Rust toolchain to 1.91 [\#8769](https://github.com/apache/arrow-rs/pull/8769) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- \[Variant\] Add variant to arrow for `DataType::{Binary/LargeBinary/BinaryView}` [\#8768](https://github.com/apache/arrow-rs/pull/8768) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([klion26](https://github.com/klion26)) +- feat: parse `DataType::Union`, `DataType::Map`, `DataType::RunEndEncoded` [\#8765](https://github.com/apache/arrow-rs/pull/8765) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dqkqd](https://github.com/dqkqd)) +- Add options to control various aspects of Parquet metadata decoding [\#8763](https://github.com/apache/arrow-rs/pull/8763) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- feat: Ensure consistent metadata display for data types [\#8760](https://github.com/apache/arrow-rs/pull/8760) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mhilton](https://github.com/mhilton)) +- Clean up predicate\_cache tests [\#8755](https://github.com/apache/arrow-rs/pull/8755) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- refactor `test_cache_projection_excludes_nested_columns` to use high level APIs [\#8754](https://github.com/apache/arrow-rs/pull/8754) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Add `merge` and `merge_n` kernels [\#8753](https://github.com/apache/arrow-rs/pull/8753) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([pepijnve](https://github.com/pepijnve)) +- Fix lint in arrow-flight by updating assert\_cmd after it upgraded [\#8741](https://github.com/apache/arrow-rs/pull/8741) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([vegarsti](https://github.com/vegarsti)) +- Remove link to internal `arrow-integration-test` crate from main `arrow` crate [\#8740](https://github.com/apache/arrow-rs/pull/8740) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([phil-opp](https://github.com/phil-opp)) +- Implement hex decoding of JSON strings to binary arrays [\#8737](https://github.com/apache/arrow-rs/pull/8737) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([phil-opp](https://github.com/phil-opp)) +- \[Parquet\] Adaptive Parquet Predicate Pushdown [\#8733](https://github.com/apache/arrow-rs/pull/8733) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([hhhizzz](https://github.com/hhhizzz)) +- \[Parquet\] Return error from `RleDecoder::reload` rather than panic [\#8729](https://github.com/apache/arrow-rs/pull/8729) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([liamzwbao](https://github.com/liamzwbao)) +- fix: `ArrayIter` does not report size hint correctly after advancing from the iterator back [\#8728](https://github.com/apache/arrow-rs/pull/8728) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- perf: Use Vec::with\_capacity in cast\_to\_run\_end\_encoded [\#8726](https://github.com/apache/arrow-rs/pull/8726) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([vegarsti](https://github.com/vegarsti)) +- \[Variant\] Fix the index of an item in VariantArray in a unit test [\#8725](https://github.com/apache/arrow-rs/pull/8725) ([martin-g](https://github.com/martin-g)) +- build\(deps\): bump actions/download-artifact from 5 to 6 [\#8720](https://github.com/apache/arrow-rs/pull/8720) ([dependabot[bot]](https://github.com/apps/dependabot)) +- \[Variant\] Add try\_value/value for VariantArray [\#8719](https://github.com/apache/arrow-rs/pull/8719) ([klion26](https://github.com/klion26)) +- General virtual columns support + row numbers as a first use-case [\#8715](https://github.com/apache/arrow-rs/pull/8715) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([vustef](https://github.com/vustef)) +- feat: Parquet-layout add Index and Footer info [\#8712](https://github.com/apache/arrow-rs/pull/8712) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([mapleFU](https://github.com/mapleFU)) +- fix: `zip` now treats nulls as false in provided mask regardless of the underlying bit value [\#8711](https://github.com/apache/arrow-rs/pull/8711) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- Add benchmark for casting to RunEndEncoded \(REE\) [\#8710](https://github.com/apache/arrow-rs/pull/8710) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([vegarsti](https://github.com/vegarsti)) +- \[Minor\]: Document visibility for enums produced by Thrift macros [\#8706](https://github.com/apache/arrow-rs/pull/8706) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Update `arrow-avro` `README.md` version to 57 [\#8695](https://github.com/apache/arrow-rs/pull/8695) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Fix: ViewType gc on huge batch would produce bad output [\#8694](https://github.com/apache/arrow-rs/pull/8694) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mapleFU](https://github.com/mapleFU)) +- Refactor arrow-cast decimal casting to unify the rescale logic used in Parquet variant casts [\#8689](https://github.com/apache/arrow-rs/pull/8689) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liamzwbao](https://github.com/liamzwbao)) +- check bit width to avoid panic in DeltaBitPackDecoder [\#8688](https://github.com/apache/arrow-rs/pull/8688) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([rambleraptor](https://github.com/rambleraptor)) +- \[thrift-remodel\] Use `thrift_enum` macro for `ConvertedType` [\#8680](https://github.com/apache/arrow-rs/pull/8680) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- \[JSON\] Map key supports utf8 view [\#8679](https://github.com/apache/arrow-rs/pull/8679) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mapleFU](https://github.com/mapleFU)) +- \[JSON\] Add encoding for binary view [\#8675](https://github.com/apache/arrow-rs/pull/8675) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mapleFU](https://github.com/mapleFU)) +- \[Parquet\] Account for FileDecryptor in ParquetMetaData heap size calculation [\#8671](https://github.com/apache/arrow-rs/pull/8671) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([adamreeve](https://github.com/adamreeve)) +- chore: update `OffsetBuffer::from_lengths(std::iter::repeat_n(, ));` with `OffsetBuffer::from_repeated_length(, );` [\#8669](https://github.com/apache/arrow-rs/pull/8669) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- \[Variant\] Support `shred_variant` for Uuids [\#8666](https://github.com/apache/arrow-rs/pull/8666) ([friendlymatthew](https://github.com/friendlymatthew)) +- \[Variant\] Remove `create_test_variant_array` helper method [\#8664](https://github.com/apache/arrow-rs/pull/8664) ([friendlymatthew](https://github.com/friendlymatthew)) +- \[parquet\] Adding counting method in thrift\_enum macro to support ENCODING\_SLOTS [\#8663](https://github.com/apache/arrow-rs/pull/8663) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([hhhizzz](https://github.com/hhhizzz)) +- chore: add test case of RowSelection::trim [\#8660](https://github.com/apache/arrow-rs/pull/8660) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([lichuang](https://github.com/lichuang)) +- feat: add `new_repeated` to `ByteArray` [\#8659](https://github.com/apache/arrow-rs/pull/8659) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- perf: add `repeat_slice_n_times` to `MutableBuffer` [\#8658](https://github.com/apache/arrow-rs/pull/8658) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- perf: add optimized function to create offset with same length [\#8656](https://github.com/apache/arrow-rs/pull/8656) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- \[Variant\] `rescale_decimal` followup [\#8655](https://github.com/apache/arrow-rs/pull/8655) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liamzwbao](https://github.com/liamzwbao)) +- feat: parse DataType `List`, `ListView`, `LargeList`, `LargeListView`, `FixedSizeList` [\#8649](https://github.com/apache/arrow-rs/pull/8649) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dqkqd](https://github.com/dqkqd)) +- Support more operations on ListView [\#8645](https://github.com/apache/arrow-rs/pull/8645) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([a10y](https://github.com/a10y)) +- \[Variant\] Implement primitive type access for null/time/decimal\* [\#8638](https://github.com/apache/arrow-rs/pull/8638) ([klion26](https://github.com/klion26)) +- \[Variant\] refactor: Split builder.rs into several smaller files [\#8635](https://github.com/apache/arrow-rs/pull/8635) ([Weijun-H](https://github.com/Weijun-H)) +- add `try_new_with_length` constructor to `FixedSizeList` [\#8624](https://github.com/apache/arrow-rs/pull/8624) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([connortsui20](https://github.com/connortsui20)) +- Change some panics to errors in parquet decoder [\#8602](https://github.com/apache/arrow-rs/pull/8602) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([rambleraptor](https://github.com/rambleraptor)) +- Support `variant_to_arrow` for utf8 [\#8600](https://github.com/apache/arrow-rs/pull/8600) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([sdf-jkl](https://github.com/sdf-jkl)) +- Cast support for RunEndEncoded arrays [\#8589](https://github.com/apache/arrow-rs/pull/8589) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([vegarsti](https://github.com/vegarsti)) + + + +## [57.0.0](https://github.com/apache/arrow-rs/tree/57.0.0) (2025-10-19) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/56.2.0...57.0.0) + +**Breaking changes:** + +- Use `Arc` everywhere to be be consistent with `FileDecryptionProperties` [\#8626](https://github.com/apache/arrow-rs/pull/8626) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- feat: Improve DataType display for `RunEndEncoded` [\#8596](https://github.com/apache/arrow-rs/pull/8596) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Add `ArrowError::AvroError`, remaining types and roundtrip tests to `arrow-avro`, [\#8595](https://github.com/apache/arrow-rs/pull/8595) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- \[thrift-remodel\] Refactor Thrift encryption and store encodings as bitmask [\#8587](https://github.com/apache/arrow-rs/pull/8587) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- feat: Enhance `Map` display formatting in DataType [\#8570](https://github.com/apache/arrow-rs/pull/8570) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- feat: Enhance DataType display formatting for `ListView` and `LargeListView` variants [\#8569](https://github.com/apache/arrow-rs/pull/8569) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Use custom thrift parser for parquet metadata \(phase 1 of Thrift remodel\) [\#8530](https://github.com/apache/arrow-rs/pull/8530) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- refactor: improve display formatting for Union [\#8529](https://github.com/apache/arrow-rs/pull/8529) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Use `Arc` to reduce size of ParquetMetadata and avoid copying when `encryption` is enabled [\#8470](https://github.com/apache/arrow-rs/pull/8470) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Fix for column name based projection mask creation [\#8447](https://github.com/apache/arrow-rs/pull/8447) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Improve Display formatting of DataType::Timestamp [\#8425](https://github.com/apache/arrow-rs/pull/8425) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([emilk](https://github.com/emilk)) +- Use more compact Debug formatting of Field [\#8424](https://github.com/apache/arrow-rs/pull/8424) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([emilk](https://github.com/emilk)) +- Reuse zstd compression context when writing IPC [\#8405](https://github.com/apache/arrow-rs/pull/8405) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([albertlockett](https://github.com/albertlockett)) +- \[Decimal\] Add scale argument to validation functions to ensure accurate error logging [\#8396](https://github.com/apache/arrow-rs/pull/8396) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Quote `DataType::Struct` field names in `Display` formatting [\#8291](https://github.com/apache/arrow-rs/pull/8291) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([emilk](https://github.com/emilk)) +- Improve `Display` for `DataType` and `Field` [\#8290](https://github.com/apache/arrow-rs/pull/8290) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([emilk](https://github.com/emilk)) +- Bump pyo3 to 0.26.0 [\#8286](https://github.com/apache/arrow-rs/pull/8286) ([mbrobbel](https://github.com/mbrobbel)) + +**Implemented enhancements:** + +- Added Avro support (new `arrow-avro` crate) [\#4886](https://github.com/apache/arrow-rs/issues/4886) +- parquet-rewrite: supports compression level and write batch size [\#8639](https://github.com/apache/arrow-rs/issues/8639) +- Error not panic when int96 stastistics aren't size 12 [\#8614](https://github.com/apache/arrow-rs/issues/8614) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Make `VariantArray` iterable [\#8612](https://github.com/apache/arrow-rs/issues/8612) +- \[Variant\] impl `PartialEq` for `VariantArray` [\#8610](https://github.com/apache/arrow-rs/issues/8610) +- \[Variant\] Remove potential panics when probing `VariantArray` [\#8609](https://github.com/apache/arrow-rs/issues/8609) +- \[Variant\] Remove ceremony of going from list of `Variant` to `VariantArray` [\#8606](https://github.com/apache/arrow-rs/issues/8606) +- Eliminate redundant validation in `RecordBatch::project` [\#8591](https://github.com/apache/arrow-rs/issues/8591) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[PARQUET\]\[BENCH\] Arrow writer bench with compression and/or page v2 [\#8559](https://github.com/apache/arrow-rs/issues/8559) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] casting functions are confusingly named [\#8531](https://github.com/apache/arrow-rs/issues/8531) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support writing GeospatialStatistics in Parquet writer [\#8523](https://github.com/apache/arrow-rs/issues/8523) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[thrift-remodel\] Optimize `convert_row_groups` [\#8517](https://github.com/apache/arrow-rs/issues/8517) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Add variant to arrow primitive support for boolean/timestamp/time [\#8515](https://github.com/apache/arrow-rs/issues/8515) +- Test `thrift-remodel` branch with DataFusion [\#8513](https://github.com/apache/arrow-rs/issues/8513) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Make `UnionArray::is_dense` Method Public [\#8503](https://github.com/apache/arrow-rs/issues/8503) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add `append_n` method to `FixedSizeBinaryDictionaryBuilder` [\#8497](https://github.com/apache/arrow-rs/issues/8497) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Parquet\] Reduce size of ParquetMetadata when encryption feature is enabled [\#8469](https://github.com/apache/arrow-rs/issues/8469) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Parquet\] Remove useless mut requirements in geting bloom filter function [\#8461](https://github.com/apache/arrow-rs/issues/8461) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Change `serde` dependency to `serde_core` where applicable [\#8451](https://github.com/apache/arrow-rs/issues/8451) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Parquet\] Split `ParquetMetadataReader` into IO/decoder state machine and thrift parsing [\#8439](https://github.com/apache/arrow-rs/issues/8439) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Remove compiler warning for redundant config enablement [\#8412](https://github.com/apache/arrow-rs/issues/8412) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add geospatial statistics creation support for GEOMETRY/GEOGRAPHY Parquet logical types [\#8411](https://github.com/apache/arrow-rs/issues/8411) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `arrow_json` lacks `with_timestamp_format` functions like `arrow_csv` had offered [\#8398](https://github.com/apache/arrow-rs/issues/8398) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Unify API for writing column chunks / row groups in parallel [\#8389](https://github.com/apache/arrow-rs/issues/8389) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Reuse zstd context in arrow IPC writer [\#8386](https://github.com/apache/arrow-rs/issues/8386) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- \[Variant\] Support reading/writing Parquet Variant LogicalType [\#8370](https://github.com/apache/arrow-rs/issues/8370) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Implement a `shred_variant` function [\#8361](https://github.com/apache/arrow-rs/issues/8361) +- \[Parquet\] Expose ReadPlan and ReadPlanBuilder [\#8347](https://github.com/apache/arrow-rs/issues/8347) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] \[Shredding\] Support typed\_access for `List` [\#8337](https://github.com/apache/arrow-rs/issues/8337) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] \[Shredding\] Support typed\_access for `Struct` [\#8336](https://github.com/apache/arrow-rs/issues/8336) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] \[Shredding\] Support typed\_access for `Time64(Microsecond)` [\#8334](https://github.com/apache/arrow-rs/issues/8334) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] \[Shredding\] Support typed\_access for `Decimal128` [\#8332](https://github.com/apache/arrow-rs/issues/8332) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] \[Shredding\] Support typed\_access for `Timestamp(Microsecond, _)` and `Timestamp(Nanosecond, _)` [\#8331](https://github.com/apache/arrow-rs/issues/8331) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] \[Shredding\] Support typed\_access for `Date32` [\#8330](https://github.com/apache/arrow-rs/issues/8330) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Support strict casting for all data types [\#8303](https://github.com/apache/arrow-rs/issues/8303) +- \[Variant\] Support typed access for string types in variant\_get [\#8285](https://github.com/apache/arrow-rs/issues/8285) +- \[Variant\]: Implement `DataType::FixedSizeList` support for `cast_to_variant` kernel [\#8281](https://github.com/apache/arrow-rs/issues/8281) + +**Fixed bugs:** + +- Fix arrow-avro Writer Documentation related to AvroBinaryFormat [\#8631](https://github.com/apache/arrow-rs/issues/8631) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Decimal -\> Decimal cast wrongly fails for large scale reduction [\#8579](https://github.com/apache/arrow-rs/issues/8579) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Parquet\] Avoid fetching multiple pages when `max_predicate_cache_size`is 0 [\#8542](https://github.com/apache/arrow-rs/issues/8542) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- DataType parsing no longer works correctly for old formatted timestamps [\#8539](https://github.com/apache/arrow-rs/issues/8539) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Parquet\] ArrowWriter flush does not work [\#8534](https://github.com/apache/arrow-rs/issues/8534) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `arrow::compute::interleave` fails with struct arrays with no fields [\#8533](https://github.com/apache/arrow-rs/issues/8533) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Parquet\] Over memory consumation for writer page v1 compressed [\#8526](https://github.com/apache/arrow-rs/issues/8526) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Incorrect Behavior of Collecting a filtered iterator to a BooleanArray [\#8505](https://github.com/apache/arrow-rs/issues/8505) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Parquet\] ProjectionMask::columns name handling is bug prone [\#8443](https://github.com/apache/arrow-rs/issues/8443) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Shredded typed\_value columns must have valid variant types [\#8435](https://github.com/apache/arrow-rs/issues/8435) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- cargo test -p parquet fails with default `ulimit` [\#8406](https://github.com/apache/arrow-rs/issues/8406) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Column with List\(Struct\) causes failed to decode level data for struct array [\#8404](https://github.com/apache/arrow-rs/issues/8404) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Binaryview Utf8 Cast Issue [\#8403](https://github.com/apache/arrow-rs/issues/8403) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Decimal precision validation displays value without accounting for scale [\#8382](https://github.com/apache/arrow-rs/issues/8382) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] `VariantArray::data_type` returns `StructType`, causing `Array::as_struct` to panic [\#8319](https://github.com/apache/arrow-rs/issues/8319) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] writing a VariantArray to parquet panics [\#8296](https://github.com/apache/arrow-rs/issues/8296) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Documentation updates:** + +- Docs: Add more comments to the Parquet writer code [\#8383](https://github.com/apache/arrow-rs/pull/8383) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) + +**Performance improvements:** + +- \[parquet\] Improve encoding mask API \(wrap bare i32 in a struct w/ docs\) [\#8588](https://github.com/apache/arrow-rs/issues/8588) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- bench: create `zip` kernel benchmarks [\#8654](https://github.com/apache/arrow-rs/pull/8654) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- Skip redundant validation checks in RecordBatch\#project [\#8583](https://github.com/apache/arrow-rs/pull/8583) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([pepijnve](https://github.com/pepijnve)) +- \[thrift-remodel\] Remove conversion functions for row group and column metadata [\#8574](https://github.com/apache/arrow-rs/pull/8574) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- \[PARQUET\] Improve memory efficency for compressed writer parquet 1.0 [\#8527](https://github.com/apache/arrow-rs/pull/8527) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([lilianm](https://github.com/lilianm)) +- perf: improve `GenericByteBuilder::append_array` to use SIMD for extending the offsets [\#8388](https://github.com/apache/arrow-rs/pull/8388) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) + +**Closed issues:** + +- Utf-8, LargeUtf8, Utf8View [\#8601](https://github.com/apache/arrow-rs/issues/8601) +- \[Variant\] Improve the get type logic for DataType in variant to arrow row builder [\#8538](https://github.com/apache/arrow-rs/issues/8538) +- Add a README.md for arrow-avro [\#8504](https://github.com/apache/arrow-rs/issues/8504) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Fix UnionArray references to "positive" values [\#8418](https://github.com/apache/arrow-rs/issues/8418) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] `metadata` field should be marked is non-nullable [\#8410](https://github.com/apache/arrow-rs/issues/8410) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Avro\] Example read\_with\_utf8view.rs fails to run with error "Error: ParseError\("Unexpected EOF while reading Avro header"\)" [\#8380](https://github.com/apache/arrow-rs/issues/8380) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Geospatial\]: Add CI checks for `parquet-geospatial` crate [\#8377](https://github.com/apache/arrow-rs/issues/8377) +- \[Geospatial\] Create new `parquet-geometry` crate [\#8374](https://github.com/apache/arrow-rs/issues/8374) + +**Merged pull requests:** + +- parquet-rewrite: add write\_batch\_size and compression\_level config [\#8642](https://github.com/apache/arrow-rs/pull/8642) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([mapleFU](https://github.com/mapleFU)) +- Introduce a ThriftProtocolError to avoid allocating and formattings strings for error messages [\#8636](https://github.com/apache/arrow-rs/pull/8636) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jhorstmann](https://github.com/jhorstmann)) +- \[thrift-remodel\] Add macro to reduce boilerplate necessary to implement Thrift serialization [\#8634](https://github.com/apache/arrow-rs/pull/8634) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Fix Writer docs and rename `AvroBinaryFormat` to `AvroSoeFormat` [\#8633](https://github.com/apache/arrow-rs/pull/8633) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- \[Variant\] Bulk insert elements into List and Object Builders [\#8629](https://github.com/apache/arrow-rs/pull/8629) ([friendlymatthew](https://github.com/friendlymatthew)) +- \[Variant\] impl `PartialEq` and `FromIterator>` for `VariantArray` [\#8627](https://github.com/apache/arrow-rs/pull/8627) ([friendlymatthew](https://github.com/friendlymatthew)) +- \[Variant\] Remove ceremony from iterator of variants into VariantArray [\#8625](https://github.com/apache/arrow-rs/pull/8625) ([friendlymatthew](https://github.com/friendlymatthew)) +- Undeprecate `ArrowWriter::into_serialized_writer` and add docs [\#8621](https://github.com/apache/arrow-rs/pull/8621) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- fix: incorrect assertion in `BitChunks::new` [\#8620](https://github.com/apache/arrow-rs/pull/8620) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- \[Variant\] Clean up redundant `get_type_name` [\#8617](https://github.com/apache/arrow-rs/pull/8617) ([liamzwbao](https://github.com/liamzwbao)) +- \[Minor\] Hide thrift macros [\#8616](https://github.com/apache/arrow-rs/pull/8616) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Deprecate `parquet::format` module [\#8615](https://github.com/apache/arrow-rs/pull/8615) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- \[Variant\] Make `VariantArray` iterable [\#8613](https://github.com/apache/arrow-rs/pull/8613) ([friendlymatthew](https://github.com/friendlymatthew)) +- \[Variant\] Impl `Extend` for `VariantArrayBuilder` [\#8611](https://github.com/apache/arrow-rs/pull/8611) ([friendlymatthew](https://github.com/friendlymatthew)) +- build\(deps\): bump actions/setup-node from 5 to 6 [\#8604](https://github.com/apache/arrow-rs/pull/8604) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Check int96 min/max instead of panicking [\#8603](https://github.com/apache/arrow-rs/pull/8603) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([rambleraptor](https://github.com/rambleraptor)) +- \[thrift-remodel\] Refactor Parquet Thrift code into new `thrift` module [\#8599](https://github.com/apache/arrow-rs/pull/8599) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- \[Parquet\] Remove use of `parquet::format` in metadata bench code [\#8598](https://github.com/apache/arrow-rs/pull/8598) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([lichuang](https://github.com/lichuang)) +- Remove experimental warning from `extension` module [\#8597](https://github.com/apache/arrow-rs/pull/8597) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Adding `try_append_value` implementation to `ByteViewBuilder` [\#8594](https://github.com/apache/arrow-rs/pull/8594) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([samueleresca](https://github.com/samueleresca)) +- Add RecordBatch::project microbenchmark [\#8592](https://github.com/apache/arrow-rs/pull/8592) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([pepijnve](https://github.com/pepijnve)) +- \[parquet\] Add a sync fn to ArrowWriter that flushes Writer [\#8586](https://github.com/apache/arrow-rs/pull/8586) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([PiotrSrebrny](https://github.com/PiotrSrebrny)) +- chore: use magic number`FOOTER_SIZE` instead of hard code number [\#8585](https://github.com/apache/arrow-rs/pull/8585) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([lichuang](https://github.com/lichuang)) +- Add support for run-end encoded \(REE\) arrays in arrow-avro [\#8584](https://github.com/apache/arrow-rs/pull/8584) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Unify API for writing column chunks / row groups in parallel [\#8582](https://github.com/apache/arrow-rs/pull/8582) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([adamreeve](https://github.com/adamreeve)) +- Fix linting issues missed by \#8506 [\#8581](https://github.com/apache/arrow-rs/pull/8581) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Fix broken decimal-\>decimal casting with large scale reduction [\#8580](https://github.com/apache/arrow-rs/pull/8580) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([scovich](https://github.com/scovich)) +- Migrate `arrow` and workspace to Rust 2024 [\#8578](https://github.com/apache/arrow-rs/pull/8578) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([mbrobbel](https://github.com/mbrobbel)) +- Fix doctests of parquet push decoded without default features [\#8577](https://github.com/apache/arrow-rs/pull/8577) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([mbrobbel](https://github.com/mbrobbel)) +- Avoid panics and warnings when building avro without default features [\#8576](https://github.com/apache/arrow-rs/pull/8576) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Add support for 64-bit Schema Registry IDs \(Id64\) in arrow-avro [\#8575](https://github.com/apache/arrow-rs/pull/8575) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- fix: bug when struct nullability determined from `Dict<_, ByteArray>>` column [\#8573](https://github.com/apache/arrow-rs/pull/8573) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([albertlockett](https://github.com/albertlockett)) +- fix: Support `interleave_struct` to handle empty fields [\#8563](https://github.com/apache/arrow-rs/pull/8563) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- \[Variant\] Define and use VariantDecimalType trait [\#8562](https://github.com/apache/arrow-rs/pull/8562) ([scovich](https://github.com/scovich)) +- \[PARQUET\] Update parquet writer bench with compression and pagev2 [\#8560](https://github.com/apache/arrow-rs/pull/8560) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([lilianm](https://github.com/lilianm)) +- Replace serde with `serde_core` when possible [\#8558](https://github.com/apache/arrow-rs/pull/8558) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([AdamGS](https://github.com/AdamGS)) +- fix: use default field name when name is None in Field conversion [\#8557](https://github.com/apache/arrow-rs/pull/8557) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Add arrow-avro README.md file [\#8556](https://github.com/apache/arrow-rs/pull/8556) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- minor\(parquet\): Fix test\_not\_found on Windows [\#8555](https://github.com/apache/arrow-rs/pull/8555) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([nuno-faria](https://github.com/nuno-faria)) +- \[Parquet\] Avoid fetching multiple pages when the predicate cache is disabled [\#8554](https://github.com/apache/arrow-rs/pull/8554) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([nuno-faria](https://github.com/nuno-faria)) +- \[Variant\] Support variant to `Decimal32/64/128/256` [\#8552](https://github.com/apache/arrow-rs/pull/8552) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liamzwbao](https://github.com/liamzwbao)) +- Arrow-avro Writer Dense Union support [\#8550](https://github.com/apache/arrow-rs/pull/8550) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nathaniel-d-ef](https://github.com/nathaniel-d-ef)) +- Arrow-Avro: Resolve named field discrepancies [\#8546](https://github.com/apache/arrow-rs/pull/8546) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nathaniel-d-ef](https://github.com/nathaniel-d-ef)) +- Migrate `arrow-avro` to Rust 2024 [\#8545](https://github.com/apache/arrow-rs/pull/8545) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- feat: Export `is_dense` public [\#8544](https://github.com/apache/arrow-rs/pull/8544) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Fix "Incorrect Behavior of Collecting a filtered iterator to a BooleanArray" [\#8543](https://github.com/apache/arrow-rs/pull/8543) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tobixdev](https://github.com/tobixdev)) +- Support old syntax for DataType parsing [\#8541](https://github.com/apache/arrow-rs/pull/8541) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- \[Variant\] Decimal unshredding support [\#8540](https://github.com/apache/arrow-rs/pull/8540) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- \[Variant\] Improve documentation and make kernels consistent [\#8536](https://github.com/apache/arrow-rs/pull/8536) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- feat: support casting from null to float16 [\#8535](https://github.com/apache/arrow-rs/pull/8535) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([chenkovsky](https://github.com/chenkovsky)) +- Add benchmarks for FromIter \(PrimitiveArray and BooleanArray\) [\#8525](https://github.com/apache/arrow-rs/pull/8525) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tobixdev](https://github.com/tobixdev)) +- Support writing GeospatialStatistics in Parquet writer [\#8524](https://github.com/apache/arrow-rs/pull/8524) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([paleolimbot](https://github.com/paleolimbot)) +- Fix some new rustdoc warnings [\#8522](https://github.com/apache/arrow-rs/pull/8522) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- \[Variant\] Reverse VariantAsPrimitive trait to PrimitiveFromVariant [\#8519](https://github.com/apache/arrow-rs/pull/8519) ([scovich](https://github.com/scovich)) +- \[Variant\] Add variant to arrow primitive support for boolean/timestamp/time [\#8516](https://github.com/apache/arrow-rs/pull/8516) ([klion26](https://github.com/klion26)) +- \[Variant\] Add list support to unshred\_variant [\#8514](https://github.com/apache/arrow-rs/pull/8514) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- Migrate `parquet-variant-json` to Rust 2024 [\#8512](https://github.com/apache/arrow-rs/pull/8512) ([mbrobbel](https://github.com/mbrobbel)) +- Migrate `parquet-variant-compute` to Rust 2024 [\#8511](https://github.com/apache/arrow-rs/pull/8511) ([mbrobbel](https://github.com/mbrobbel)) +- Migrate `parquet-variant` to Rust 2024 [\#8510](https://github.com/apache/arrow-rs/pull/8510) ([mbrobbel](https://github.com/mbrobbel)) +- Migrate `parquet-geospatial` to Rust 2024 [\#8509](https://github.com/apache/arrow-rs/pull/8509) ([mbrobbel](https://github.com/mbrobbel)) +- Migrate `parquet_derive_test` to Rust 2024 [\#8508](https://github.com/apache/arrow-rs/pull/8508) ([mbrobbel](https://github.com/mbrobbel)) +- Migrate `parquet_derive` to Rust 2024 [\#8507](https://github.com/apache/arrow-rs/pull/8507) ([mbrobbel](https://github.com/mbrobbel)) +- Migrate `parquet` to Rust 2024 [\#8506](https://github.com/apache/arrow-rs/pull/8506) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([mbrobbel](https://github.com/mbrobbel)) +- \[Variant\] ReadOnlyMetadataBuilder borrows its underlying VariantMetadata [\#8502](https://github.com/apache/arrow-rs/pull/8502) ([scovich](https://github.com/scovich)) +- \[Variant\] Add a VariantBuilderExt impl for VariantValueArrayBuilder [\#8501](https://github.com/apache/arrow-rs/pull/8501) ([scovich](https://github.com/scovich)) +- build\(deps\): update sysinfo requirement from 0.36.0 to 0.37.1 [\#8500](https://github.com/apache/arrow-rs/pull/8500) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- \[Variant\] Introduce new BorrowedShreddingState concept [\#8499](https://github.com/apache/arrow-rs/pull/8499) ([scovich](https://github.com/scovich)) +- Add `append_n` method to `FixedSizeBinaryDictionaryBuilder` [\#8498](https://github.com/apache/arrow-rs/pull/8498) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([albertlockett](https://github.com/albertlockett)) +- Fix docs.rs build: Use `doc_cfg` instead of removed `doc_auto_cfg` [\#8494](https://github.com/apache/arrow-rs/pull/8494) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([mbrobbel](https://github.com/mbrobbel)) +- Remove allow unused from arrow-avro lib.rs file [\#8493](https://github.com/apache/arrow-rs/pull/8493) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Regression Testing, Bug Fixes, and Public API Tightening for arrow-avro [\#8492](https://github.com/apache/arrow-rs/pull/8492) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Migrate `arrow-string` to Rust 2024 [\#8491](https://github.com/apache/arrow-rs/pull/8491) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Migrate `arrow-select` to Rust 2024 [\#8490](https://github.com/apache/arrow-rs/pull/8490) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Migrate `arrow-schema` to Rust 2024 [\#8489](https://github.com/apache/arrow-rs/pull/8489) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Migrate `arrow-row` to Rust 2024 [\#8488](https://github.com/apache/arrow-rs/pull/8488) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Migrate `arrow-pyarrow-testing` to Rust 2024 [\#8487](https://github.com/apache/arrow-rs/pull/8487) ([mbrobbel](https://github.com/mbrobbel)) +- Migrate `arrow-pyarrow-integration-testing` to Rust 2024 [\#8486](https://github.com/apache/arrow-rs/pull/8486) ([mbrobbel](https://github.com/mbrobbel)) +- Migrate `arrow-pyarrow` to Rust 2024 [\#8485](https://github.com/apache/arrow-rs/pull/8485) ([mbrobbel](https://github.com/mbrobbel)) +- Migrate `arrow-ord` to Rust 2024 [\#8484](https://github.com/apache/arrow-rs/pull/8484) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- \[Variant\] Support strict casting for Decimals [\#8483](https://github.com/apache/arrow-rs/pull/8483) ([liamzwbao](https://github.com/liamzwbao)) +- feat\(json\): Add temporal formatting options when write to JSON [\#8482](https://github.com/apache/arrow-rs/pull/8482) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([linyihai](https://github.com/linyihai)) +- \[Variant\] Define and use unshred\_variant function [\#8481](https://github.com/apache/arrow-rs/pull/8481) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- \[Minor\] Remove private APIs from Parquet metadata benchmark [\#8478](https://github.com/apache/arrow-rs/pull/8478) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Add examples of using `Field::try_extension_type` [\#8475](https://github.com/apache/arrow-rs/pull/8475) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Fix Rustfmt in arrow-cast [\#8473](https://github.com/apache/arrow-rs/pull/8473) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Disable incremental builds in CI [\#8471](https://github.com/apache/arrow-rs/pull/8471) ([mbrobbel](https://github.com/mbrobbel)) +- Update Rust toolchain to 1.90 [\#8468](https://github.com/apache/arrow-rs/pull/8468) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- \[Parquet\] Minor: Remove mut ref for getting row-group bloom filter [\#8462](https://github.com/apache/arrow-rs/pull/8462) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([mapleFU](https://github.com/mapleFU)) +- refactor: split `num` dependency [\#8459](https://github.com/apache/arrow-rs/pull/8459) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([crepererum](https://github.com/crepererum)) +- Migrate `arrow-json` to Rust 2024 [\#8458](https://github.com/apache/arrow-rs/pull/8458) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Migrate `arrow-ipc` to Rust 2024 [\#8457](https://github.com/apache/arrow-rs/pull/8457) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Migrate `arrow-flight` to Rust 2024 [\#8456](https://github.com/apache/arrow-rs/pull/8456) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([mbrobbel](https://github.com/mbrobbel)) +- Migrate `arrow-data` to Rust 2024 [\#8455](https://github.com/apache/arrow-rs/pull/8455) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Migrate `arrow-csv` to Rust 2024 [\#8454](https://github.com/apache/arrow-rs/pull/8454) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Migrate `arrow-cast` to Rust 2024 [\#8453](https://github.com/apache/arrow-rs/pull/8453) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Migrate `arrow-buffer` to Rust 2024 [\#8452](https://github.com/apache/arrow-rs/pull/8452) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Migrate `arrow-array` to Rust 2024 [\#8450](https://github.com/apache/arrow-rs/pull/8450) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Migrate `arrow-arith` to Rust 2024 [\#8449](https://github.com/apache/arrow-rs/pull/8449) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Expose `fields` in `StructBuilder` [\#8448](https://github.com/apache/arrow-rs/pull/8448) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([lewiszlw](https://github.com/lewiszlw)) +- \[Variant\] Simpler shredding state [\#8444](https://github.com/apache/arrow-rs/pull/8444) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- Unpin comfytable [\#8440](https://github.com/apache/arrow-rs/pull/8440) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Variant integration fixes [\#8438](https://github.com/apache/arrow-rs/pull/8438) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- Refactor: extract FooterTail from ParquetMetadataReader [\#8437](https://github.com/apache/arrow-rs/pull/8437) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Refactor: Move parquet metadata parsing code into its own module [\#8436](https://github.com/apache/arrow-rs/pull/8436) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Update `UnionArray` wording to 'non-negative' [\#8434](https://github.com/apache/arrow-rs/pull/8434) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jdockerty](https://github.com/jdockerty)) +- Adds Duration\(TimeUnit\) support to arrow-avro reader and writer [\#8433](https://github.com/apache/arrow-rs/pull/8433) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nathaniel-d-ef](https://github.com/nathaniel-d-ef)) +- Update release schedule [\#8432](https://github.com/apache/arrow-rs/pull/8432) ([mbrobbel](https://github.com/mbrobbel)) +- expose read plan and plan builder via mod [\#8431](https://github.com/apache/arrow-rs/pull/8431) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([yeya24](https://github.com/yeya24)) +- Bump MSRV to 1.85 [\#8429](https://github.com/apache/arrow-rs/pull/8429) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Fix clippy [\#8426](https://github.com/apache/arrow-rs/pull/8426) ([alamb](https://github.com/alamb)) +- Fix red main by updating test [\#8421](https://github.com/apache/arrow-rs/pull/8421) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([emilk](https://github.com/emilk)) +- Implement AsRef for Schema and Field [\#8417](https://github.com/apache/arrow-rs/pull/8417) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- \[Variant\] mark metadata field as non-nullable [\#8416](https://github.com/apache/arrow-rs/pull/8416) ([ding-young](https://github.com/ding-young)) +- Respect `CastOptions.safe` when casting `BinaryView` → `Utf8View` \(return `null` for invalid UTF‑8\) [\#8415](https://github.com/apache/arrow-rs/pull/8415) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kosiew](https://github.com/kosiew)) +- Add Parquet geospatial statistics utility [\#8414](https://github.com/apache/arrow-rs/pull/8414) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([paleolimbot](https://github.com/paleolimbot)) +- Remove explicit default cfg option [\#8413](https://github.com/apache/arrow-rs/pull/8413) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([abacef](https://github.com/abacef)) +- Support parquet canonical extension type roundtrip [\#8409](https://github.com/apache/arrow-rs/pull/8409) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Support reading/writing `VariantArray` to parquet with Variant LogicalType [\#8408](https://github.com/apache/arrow-rs/pull/8408) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Follow-up on arrow-avro Documentation [\#8402](https://github.com/apache/arrow-rs/pull/8402) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- \[Variant\]\[Shredding\] Support typed\_access for timestamp\_micro/timestamp\_nano [\#8401](https://github.com/apache/arrow-rs/pull/8401) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([klion26](https://github.com/klion26)) +- Expose ReadPlan and ReadPlanBuilder [\#8399](https://github.com/apache/arrow-rs/pull/8399) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([yeya24](https://github.com/yeya24)) +- Propagate errors instead of panics: Replace usages of `new` with `try_new` for Array types [\#8397](https://github.com/apache/arrow-rs/pull/8397) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- \[Variant\] Fix NULL handling for shredded object fields [\#8395](https://github.com/apache/arrow-rs/pull/8395) ([scovich](https://github.com/scovich)) +- Add Arrow Variant Extension Type, remove `Array` impl for `VariantArray` and `ShreddedVariantFieldArray` [\#8392](https://github.com/apache/arrow-rs/pull/8392) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Minor cleanup creating Schema [\#8391](https://github.com/apache/arrow-rs/pull/8391) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- \[Geospatial\]: Add CI checks for `parquet-geospatial` crate [\#8390](https://github.com/apache/arrow-rs/pull/8390) ([kylebarron](https://github.com/kylebarron)) +- Follow-up Improvements to Avro union handling [\#8385](https://github.com/apache/arrow-rs/pull/8385) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- fix: reset the offset of 'file\_for\_view' [\#8381](https://github.com/apache/arrow-rs/pull/8381) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([TrevorADHD](https://github.com/TrevorADHD)) +- \[Variant\] \[Shredding\] feat: Support typed\_access for Date32 [\#8379](https://github.com/apache/arrow-rs/pull/8379) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([PinkCrow007](https://github.com/PinkCrow007)) +- \[Geospatial\]: Scaffolding for new `parquet-geospatial` crate [\#8375](https://github.com/apache/arrow-rs/pull/8375) ([kylebarron](https://github.com/kylebarron)) +- Avro writer prefix support [\#8371](https://github.com/apache/arrow-rs/pull/8371) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nathaniel-d-ef](https://github.com/nathaniel-d-ef)) +- \[Variant\] Define new shred\_variant function [\#8366](https://github.com/apache/arrow-rs/pull/8366) ([scovich](https://github.com/scovich)) +- Add arrow-avro Reader support for Dense Union and Union resolution \(Part 2\) [\#8349](https://github.com/apache/arrow-rs/pull/8349) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Move ParquetMetadata decoder state machine into ParquetMetadataPushDecoder [\#8340](https://github.com/apache/arrow-rs/pull/8340) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- \[Variant\]: Implement `DataType::FixedSizeList` support for `cast_to_variant` kernel [\#8282](https://github.com/apache/arrow-rs/pull/8282) ([liamzwbao](https://github.com/liamzwbao)) + +## [56.2.0](https://github.com/apache/arrow-rs/tree/56.2.0) (2025-09-19) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/56.1.0...56.2.0) + +- \[Variant\] \[Shredding\] Support typed\_access for Utf8 and BinaryView [\#8364](https://github.com/apache/arrow-rs/pull/8364) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([petern48](https://github.com/petern48)) +- Fix casting floats to Decimal64 [\#8363](https://github.com/apache/arrow-rs/pull/8363) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([AdamGS](https://github.com/AdamGS)) +- \[Variant\] Implement new VariantValueArrayBuilder [\#8360](https://github.com/apache/arrow-rs/pull/8360) ([scovich](https://github.com/scovich)) +- \[Variant\] Add constants for empty variant metadata [\#8359](https://github.com/apache/arrow-rs/pull/8359) ([scovich](https://github.com/scovich)) +- \[Variant\] Allow lossless casting from integer to floating point [\#8357](https://github.com/apache/arrow-rs/pull/8357) ([scovich](https://github.com/scovich)) +- \[Variant\] Minor code cleanups [\#8356](https://github.com/apache/arrow-rs/pull/8356) ([scovich](https://github.com/scovich)) +- \[Variant\] Remove unused metadata from variant ShreddingState [\#8355](https://github.com/apache/arrow-rs/pull/8355) ([scovich](https://github.com/scovich)) +- Adds Map & Enum support, round-trip & benchmark tests [\#8353](https://github.com/apache/arrow-rs/pull/8353) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nathaniel-d-ef](https://github.com/nathaniel-d-ef)) +- \[Variant\] \[Shredding\] feat: Support typed\_access for FixedSizeBinary [\#8352](https://github.com/apache/arrow-rs/pull/8352) ([petern48](https://github.com/petern48)) +- Add arrow-avro Reader support for Dense Union and Union resolution \(Part 1\) [\#8348](https://github.com/apache/arrow-rs/pull/8348) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- \[Variant\] feat: Support typed\_access for Boolean [\#8346](https://github.com/apache/arrow-rs/pull/8346) ([Weijun-H](https://github.com/Weijun-H)) +- \[Variant\] Make VariantToArrowRowBuilder an enum [\#8345](https://github.com/apache/arrow-rs/pull/8345) ([scovich](https://github.com/scovich)) +- \[Variant\] Rename VariantShreddingRowBuilder to VariantToArrowRowBuilder [\#8344](https://github.com/apache/arrow-rs/pull/8344) ([scovich](https://github.com/scovich)) +- \[Variant\] Add tests for variant\_get requesting Some struct [\#8343](https://github.com/apache/arrow-rs/pull/8343) ([scovich](https://github.com/scovich)) +- \[Variant\] Add nullable arg to StructArrayBuilder::with\_field [\#8342](https://github.com/apache/arrow-rs/pull/8342) ([scovich](https://github.com/scovich)) +- Minor: avoid an `Arc::clone` in CacheOptions for Parquet PredicateCache [\#8338](https://github.com/apache/arrow-rs/pull/8338) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Fix `can_cast_types` for temporal to `Utf8View` [\#8328](https://github.com/apache/arrow-rs/pull/8328) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- Update `variant_integration` test to use final approved `parquet-testing` data [\#8325](https://github.com/apache/arrow-rs/pull/8325) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- \[Variant\] ParentState tracks builder-specific state in a uniform way [\#8324](https://github.com/apache/arrow-rs/pull/8324) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- \[Variant\] Remove boilerplate from make\_shredding\_row\_builder [\#8322](https://github.com/apache/arrow-rs/pull/8322) ([scovich](https://github.com/scovich)) +- \[Variant\] Move VariantAsPrimitive to type\_conversions.rs [\#8321](https://github.com/apache/arrow-rs/pull/8321) ([scovich](https://github.com/scovich)) +- \[Variant\] Remove unused output builder files [\#8320](https://github.com/apache/arrow-rs/pull/8320) ([scovich](https://github.com/scovich)) +- Add arrow-avro examples and Reader documentation [\#8316](https://github.com/apache/arrow-rs/pull/8316) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Expose predicates from RowFilter [\#8315](https://github.com/apache/arrow-rs/pull/8315) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([yeya24](https://github.com/yeya24)) +- \[Variant\] Implement row builders for cast\_to\_variant [\#8299](https://github.com/apache/arrow-rs/pull/8299) ([scovich](https://github.com/scovich)) +- Adds additional type support to arrow-avro writer [\#8298](https://github.com/apache/arrow-rs/pull/8298) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nathaniel-d-ef](https://github.com/nathaniel-d-ef)) +- Use apache/arrow-dotnet for integration test [\#8295](https://github.com/apache/arrow-rs/pull/8295) ([kou](https://github.com/kou)) +- Add projection with default values support to `RecordDecoder` [\#8293](https://github.com/apache/arrow-rs/pull/8293) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Add array/map/fixed schema resolution and default value support to arrow-avro codec [\#8292](https://github.com/apache/arrow-rs/pull/8292) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Bump actions/labeler from 6.0.0 to 6.0.1 [\#8288](https://github.com/apache/arrow-rs/pull/8288) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Bump actions/github-script from 7 to 8 [\#8287](https://github.com/apache/arrow-rs/pull/8287) ([dependabot[bot]](https://github.com/apps/dependabot)) +- \[Variant\] Add as\_u\* for Variant [\#8284](https://github.com/apache/arrow-rs/pull/8284) ([klion26](https://github.com/klion26)) +- \[Variant\] Support Shredded Objects in variant\_get \(take 2\) [\#8280](https://github.com/apache/arrow-rs/pull/8280) ([scovich](https://github.com/scovich)) +- Bump actions/setup-node from 4 to 5 [\#8279](https://github.com/apache/arrow-rs/pull/8279) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Bump actions/setup-python from 5 to 6 [\#8278](https://github.com/apache/arrow-rs/pull/8278) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Bump actions/labeler from 5.0.0 to 6.0.0 [\#8276](https://github.com/apache/arrow-rs/pull/8276) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Impl `Display` for `Tz` [\#8275](https://github.com/apache/arrow-rs/pull/8275) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kylebarron](https://github.com/kylebarron)) +- Added List and Struct Encoding to arrow-avro Writer [\#8274](https://github.com/apache/arrow-rs/pull/8274) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Add into\_builder method for WriterProperties [\#8272](https://github.com/apache/arrow-rs/pull/8272) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([corwinjoy](https://github.com/corwinjoy)) +- chore\(parquet/record/field\): dont truncate timestamps on display [\#8266](https://github.com/apache/arrow-rs/pull/8266) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Erigara](https://github.com/Erigara)) +- \[Parquet\] Write row group with async writer [\#8262](https://github.com/apache/arrow-rs/pull/8262) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([lilianm](https://github.com/lilianm)) +- Parquet: Do not compress v2 data page when compress is bad quality [\#8257](https://github.com/apache/arrow-rs/pull/8257) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([mapleFU](https://github.com/mapleFU)) +- Add Decimal32 and Decimal64 support to arrow-avro Reader [\#8255](https://github.com/apache/arrow-rs/pull/8255) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- \[Minor\] Backport changes to metadata benchmark [\#8251](https://github.com/apache/arrow-rs/pull/8251) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Update hashbrown requirement from 0.15.1 to 0.16.0 [\#8248](https://github.com/apache/arrow-rs/pull/8248) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Sort: Change lexsort comment from stable to unstable [\#8245](https://github.com/apache/arrow-rs/pull/8245) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mapleFU](https://github.com/mapleFU)) +- pin comfy-table to 7.1.2 [\#8244](https://github.com/apache/arrow-rs/pull/8244) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zachschuermann](https://github.com/zachschuermann)) +- Adds Confluent wire format handling to arrow-avro crate [\#8242](https://github.com/apache/arrow-rs/pull/8242) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nathaniel-d-ef](https://github.com/nathaniel-d-ef)) +- feat: gRPC compression support for flight CLI [\#8240](https://github.com/apache/arrow-rs/pull/8240) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- feat: `SSLKEYLOGFILE` support for flight CLI [\#8239](https://github.com/apache/arrow-rs/pull/8239) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- \[Variant\] Refactor `cast_to_variant` [\#8235](https://github.com/apache/arrow-rs/pull/8235) ([liamzwbao](https://github.com/liamzwbao)) +- \[Variant\] add strict mode to cast\_to\_variant [\#8233](https://github.com/apache/arrow-rs/pull/8233) ([codephage2020](https://github.com/codephage2020)) +- \[Variant\] Add Variant::as\_f16 [\#8232](https://github.com/apache/arrow-rs/pull/8232) ([klion26](https://github.com/klion26)) +- Unpin nightly rust version \(MIRI job\) [\#8229](https://github.com/apache/arrow-rs/pull/8229) ([mbrobbel](https://github.com/mbrobbel)) +- Update apache-avro requirement from 0.14.0 to 0.20.0 [\#8226](https://github.com/apache/arrow-rs/pull/8226) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Bump actions/upload-pages-artifact from 3 to 4 [\#8224](https://github.com/apache/arrow-rs/pull/8224) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Added arrow-avro enum mapping support for schema resolution [\#8223](https://github.com/apache/arrow-rs/pull/8223) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Added arrow-avro schema resolution value skipping [\#8220](https://github.com/apache/arrow-rs/pull/8220) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Fix error condition in doc comment of `Field::try_canonical_extension_type` [\#8216](https://github.com/apache/arrow-rs/pull/8216) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- \[Variant\]: Implement `DataType::Duration` support for `cast_to_variant` kernel [\#8215](https://github.com/apache/arrow-rs/pull/8215) ([liamzwbao](https://github.com/liamzwbao)) +- \[Variant\] feat: remove unnecessary unwraps in `Object::finish` [\#8214](https://github.com/apache/arrow-rs/pull/8214) ([Weijun-H](https://github.com/Weijun-H)) +- \[avro\] Fix Avro decoder bitmap corruption when nullable field decoding fails [\#8213](https://github.com/apache/arrow-rs/pull/8213) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([yongkyunlee](https://github.com/yongkyunlee)) +- Restore accidentally removed method Block::to\_ne\_bytes [\#8211](https://github.com/apache/arrow-rs/pull/8211) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jhorstmann](https://github.com/jhorstmann)) +- \[avro\] Support all default types for avro schema's record field [\#8210](https://github.com/apache/arrow-rs/pull/8210) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([yongkyunlee](https://github.com/yongkyunlee)) +- \[Variant\] Support read-only metadata builders [\#8208](https://github.com/apache/arrow-rs/pull/8208) ([scovich](https://github.com/scovich)) +- Avro to arrow schema conversion fails when a field has a default type that is not string [\#8209](https://github.com/apache/arrow-rs/issues/8209) +- parquet: No method named `to_ne_bytes` found for struct `bloom_filter::Block` for target `s390x-unknown-linux-gnu` [\#8207](https://github.com/apache/arrow-rs/issues/8207) +- \[Variant\] cast\_to\_variant will panic on certain `Date64` or Timestamp Values values [\#8155](https://github.com/apache/arrow-rs/issues/8155) +- Parquet: Avoid page-size overflows i32 [\#8264](https://github.com/apache/arrow-rs/pull/8264) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([mapleFU](https://github.com/mapleFU)) + +**Documentation updates:** + +- Update docstring comment for Writer::write\(\) in writer.rs [\#8267](https://github.com/apache/arrow-rs/pull/8267) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([YKoustubhRao](https://github.com/YKoustubhRao)) + +**Closed issues:** + +- comfy-table release 7.2.0 breaks MSRV [\#8243](https://github.com/apache/arrow-rs/issues/8243) +- \[Variant\] Add `Variant::as_f16` [\#8228](https://github.com/apache/arrow-rs/issues/8228) +- Support appending raw bytes to variant objects and lists [\#8217](https://github.com/apache/arrow-rs/issues/8217) +- `VariantArrayBuilder` uses `ParentState` for simpler rollbacks [\#8205](https://github.com/apache/arrow-rs/issues/8205) +- Make `ObjectBuilder::finish` signature infallible [\#8184](https://github.com/apache/arrow-rs/issues/8184) +- Improve performance of `i256` to `f64` [\#8013](https://github.com/apache/arrow-rs/issues/8013) + +**Merged pull requests:** + +- \[Variant\] Support Variant to PrimitiveArrow for unsigned integer [\#8369](https://github.com/apache/arrow-rs/pull/8369) ([klion26](https://github.com/klion26)) +- \[Variant\] \[Shredding\] Support typed\_access for Utf8 and BinaryView [\#8364](https://github.com/apache/arrow-rs/pull/8364) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([petern48](https://github.com/petern48)) +- Fix casting floats to Decimal64 [\#8363](https://github.com/apache/arrow-rs/pull/8363) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([AdamGS](https://github.com/AdamGS)) +- \[Variant\] Implement new VariantValueArrayBuilder [\#8360](https://github.com/apache/arrow-rs/pull/8360) ([scovich](https://github.com/scovich)) +- \[Variant\] Add constants for empty variant metadata [\#8359](https://github.com/apache/arrow-rs/pull/8359) ([scovich](https://github.com/scovich)) +- \[Variant\] Allow lossless casting from integer to floating point [\#8357](https://github.com/apache/arrow-rs/pull/8357) ([scovich](https://github.com/scovich)) +- \[Variant\] Minor code cleanups [\#8356](https://github.com/apache/arrow-rs/pull/8356) ([scovich](https://github.com/scovich)) +- \[Variant\] Remove unused metadata from variant ShreddingState [\#8355](https://github.com/apache/arrow-rs/pull/8355) ([scovich](https://github.com/scovich)) +- Adds Map & Enum support, round-trip & benchmark tests [\#8353](https://github.com/apache/arrow-rs/pull/8353) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nathaniel-d-ef](https://github.com/nathaniel-d-ef)) +- \[Variant\] \[Shredding\] feat: Support typed\_access for FixedSizeBinary [\#8352](https://github.com/apache/arrow-rs/pull/8352) ([petern48](https://github.com/petern48)) +- Add arrow-avro Reader support for Dense Union and Union resolution \(Part 1\) [\#8348](https://github.com/apache/arrow-rs/pull/8348) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- \[Variant\] feat: Support typed\_access for Boolean [\#8346](https://github.com/apache/arrow-rs/pull/8346) ([Weijun-H](https://github.com/Weijun-H)) +- \[Variant\] Make VariantToArrowRowBuilder an enum [\#8345](https://github.com/apache/arrow-rs/pull/8345) ([scovich](https://github.com/scovich)) +- \[Variant\] Rename VariantShreddingRowBuilder to VariantToArrowRowBuilder [\#8344](https://github.com/apache/arrow-rs/pull/8344) ([scovich](https://github.com/scovich)) +- \[Variant\] Add tests for variant\_get requesting Some struct [\#8343](https://github.com/apache/arrow-rs/pull/8343) ([scovich](https://github.com/scovich)) +- \[Variant\] Add nullable arg to StructArrayBuilder::with\_field [\#8342](https://github.com/apache/arrow-rs/pull/8342) ([scovich](https://github.com/scovich)) +- Minor: avoid an `Arc::clone` in CacheOptions for Parquet PredicateCache [\#8338](https://github.com/apache/arrow-rs/pull/8338) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Fix `can_cast_types` for temporal to `Utf8View` [\#8328](https://github.com/apache/arrow-rs/pull/8328) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- Update `variant_integration` test to use final approved `parquet-testing` data [\#8325](https://github.com/apache/arrow-rs/pull/8325) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- \[Variant\] ParentState tracks builder-specific state in a uniform way [\#8324](https://github.com/apache/arrow-rs/pull/8324) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- \[Variant\] Remove boilerplate from make\_shredding\_row\_builder [\#8322](https://github.com/apache/arrow-rs/pull/8322) ([scovich](https://github.com/scovich)) +- \[Variant\] Move VariantAsPrimitive to type\_conversions.rs [\#8321](https://github.com/apache/arrow-rs/pull/8321) ([scovich](https://github.com/scovich)) +- \[Variant\] Remove unused output builder files [\#8320](https://github.com/apache/arrow-rs/pull/8320) ([scovich](https://github.com/scovich)) +- Add arrow-avro examples and Reader documentation [\#8316](https://github.com/apache/arrow-rs/pull/8316) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Expose predicates from RowFilter [\#8315](https://github.com/apache/arrow-rs/pull/8315) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([yeya24](https://github.com/yeya24)) +- \[Variant\] Implement row builders for cast\_to\_variant [\#8299](https://github.com/apache/arrow-rs/pull/8299) ([scovich](https://github.com/scovich)) +- Adds additional type support to arrow-avro writer [\#8298](https://github.com/apache/arrow-rs/pull/8298) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nathaniel-d-ef](https://github.com/nathaniel-d-ef)) +- Use apache/arrow-dotnet for integration test [\#8295](https://github.com/apache/arrow-rs/pull/8295) ([kou](https://github.com/kou)) +- Add projection with default values support to `RecordDecoder` [\#8293](https://github.com/apache/arrow-rs/pull/8293) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Add array/map/fixed schema resolution and default value support to arrow-avro codec [\#8292](https://github.com/apache/arrow-rs/pull/8292) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Bump actions/labeler from 6.0.0 to 6.0.1 [\#8288](https://github.com/apache/arrow-rs/pull/8288) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Bump actions/github-script from 7 to 8 [\#8287](https://github.com/apache/arrow-rs/pull/8287) ([dependabot[bot]](https://github.com/apps/dependabot)) +- \[Variant\] Add as\_u\* for Variant [\#8284](https://github.com/apache/arrow-rs/pull/8284) ([klion26](https://github.com/klion26)) +- \[Variant\] Support Shredded Objects in variant\_get \(take 2\) [\#8280](https://github.com/apache/arrow-rs/pull/8280) ([scovich](https://github.com/scovich)) +- Bump actions/setup-node from 4 to 5 [\#8279](https://github.com/apache/arrow-rs/pull/8279) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Bump actions/setup-python from 5 to 6 [\#8278](https://github.com/apache/arrow-rs/pull/8278) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Bump actions/labeler from 5.0.0 to 6.0.0 [\#8276](https://github.com/apache/arrow-rs/pull/8276) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Impl `Display` for `Tz` [\#8275](https://github.com/apache/arrow-rs/pull/8275) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kylebarron](https://github.com/kylebarron)) +- Added List and Struct Encoding to arrow-avro Writer [\#8274](https://github.com/apache/arrow-rs/pull/8274) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Add into\_builder method for WriterProperties [\#8272](https://github.com/apache/arrow-rs/pull/8272) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([corwinjoy](https://github.com/corwinjoy)) +- chore\(parquet/record/field\): dont truncate timestamps on display [\#8266](https://github.com/apache/arrow-rs/pull/8266) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Erigara](https://github.com/Erigara)) +- \[Parquet\] Write row group with async writer [\#8262](https://github.com/apache/arrow-rs/pull/8262) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([lilianm](https://github.com/lilianm)) +- Parquet: Do not compress v2 data page when compress is bad quality [\#8257](https://github.com/apache/arrow-rs/pull/8257) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([mapleFU](https://github.com/mapleFU)) +- Add Decimal32 and Decimal64 support to arrow-avro Reader [\#8255](https://github.com/apache/arrow-rs/pull/8255) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- \[Minor\] Backport changes to metadata benchmark [\#8251](https://github.com/apache/arrow-rs/pull/8251) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Update hashbrown requirement from 0.15.1 to 0.16.0 [\#8248](https://github.com/apache/arrow-rs/pull/8248) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Sort: Change lexsort comment from stable to unstable [\#8245](https://github.com/apache/arrow-rs/pull/8245) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mapleFU](https://github.com/mapleFU)) +- pin comfy-table to 7.1.2 [\#8244](https://github.com/apache/arrow-rs/pull/8244) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zachschuermann](https://github.com/zachschuermann)) +- Adds Confluent wire format handling to arrow-avro crate [\#8242](https://github.com/apache/arrow-rs/pull/8242) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nathaniel-d-ef](https://github.com/nathaniel-d-ef)) +- feat: gRPC compression support for flight CLI [\#8240](https://github.com/apache/arrow-rs/pull/8240) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- feat: `SSLKEYLOGFILE` support for flight CLI [\#8239](https://github.com/apache/arrow-rs/pull/8239) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([crepererum](https://github.com/crepererum)) +- \[Variant\] Refactor `cast_to_variant` [\#8235](https://github.com/apache/arrow-rs/pull/8235) ([liamzwbao](https://github.com/liamzwbao)) +- \[Variant\] add strict mode to cast\_to\_variant [\#8233](https://github.com/apache/arrow-rs/pull/8233) ([codephage2020](https://github.com/codephage2020)) +- \[Variant\] Add Variant::as\_f16 [\#8232](https://github.com/apache/arrow-rs/pull/8232) ([klion26](https://github.com/klion26)) +- Unpin nightly rust version \(MIRI job\) [\#8229](https://github.com/apache/arrow-rs/pull/8229) ([mbrobbel](https://github.com/mbrobbel)) +- Update apache-avro requirement from 0.14.0 to 0.20.0 [\#8226](https://github.com/apache/arrow-rs/pull/8226) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Bump actions/upload-pages-artifact from 3 to 4 [\#8224](https://github.com/apache/arrow-rs/pull/8224) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Added arrow-avro enum mapping support for schema resolution [\#8223](https://github.com/apache/arrow-rs/pull/8223) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Added arrow-avro schema resolution value skipping [\#8220](https://github.com/apache/arrow-rs/pull/8220) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Fix error condition in doc comment of `Field::try_canonical_extension_type` [\#8216](https://github.com/apache/arrow-rs/pull/8216) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- \[Variant\]: Implement `DataType::Duration` support for `cast_to_variant` kernel [\#8215](https://github.com/apache/arrow-rs/pull/8215) ([liamzwbao](https://github.com/liamzwbao)) +- \[Variant\] feat: remove unnecessary unwraps in `Object::finish` [\#8214](https://github.com/apache/arrow-rs/pull/8214) ([Weijun-H](https://github.com/Weijun-H)) +- \[avro\] Fix Avro decoder bitmap corruption when nullable field decoding fails [\#8213](https://github.com/apache/arrow-rs/pull/8213) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([yongkyunlee](https://github.com/yongkyunlee)) +- Restore accidentally removed method Block::to\_ne\_bytes [\#8211](https://github.com/apache/arrow-rs/pull/8211) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jhorstmann](https://github.com/jhorstmann)) +- \[avro\] Support all default types for avro schema's record field [\#8210](https://github.com/apache/arrow-rs/pull/8210) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([yongkyunlee](https://github.com/yongkyunlee)) +- \[Variant\] Support read-only metadata builders [\#8208](https://github.com/apache/arrow-rs/pull/8208) ([scovich](https://github.com/scovich)) +- \[Variant\] VariantArrayBuilder uses MetadataBuilder and ValueBuilder [\#8206](https://github.com/apache/arrow-rs/pull/8206) ([scovich](https://github.com/scovich)) +- \[Variant\]: Implement DataType::List/LargeList support for cast\_to\_variant kernel [\#8201](https://github.com/apache/arrow-rs/pull/8201) ([sdf-jkl](https://github.com/sdf-jkl)) +- \[Variant\]: Implement `DataType::Union` support for `cast_to_variant` kernel [\#8196](https://github.com/apache/arrow-rs/pull/8196) ([liamzwbao](https://github.com/liamzwbao)) +- \[Variant\] Support typed access for numeric types in variant\_get [\#8179](https://github.com/apache/arrow-rs/pull/8179) ([superserious-dev](https://github.com/superserious-dev)) +- \[Variant\]: Implement `DataType::Union` support for `cast_to_variant` kernel [\#8195](https://github.com/apache/arrow-rs/issues/8195) +- \[Variant\]: Implement `DataType::Duration` support for `cast_to_variant` kernel [\#8194](https://github.com/apache/arrow-rs/issues/8194) +- \[Variant\] Support typed access for numeric types in variant\_get [\#8178](https://github.com/apache/arrow-rs/issues/8178) +- \[Parquet\] Implement a "push style" API for decoding Parquet Metadata [\#8164](https://github.com/apache/arrow-rs/issues/8164) +- \[Variant\] Support creating Variants with pre-existing Metadata [\#8152](https://github.com/apache/arrow-rs/issues/8152) +- \[Variant\] Support Shredded Objects in `variant_get`: typed path access \(STEP 1\) [\#8150](https://github.com/apache/arrow-rs/issues/8150) +- \[Variant\] Add `variant` feature to `parquet` crate [\#8132](https://github.com/apache/arrow-rs/issues/8132) +- \[Parquet\] Concurrent writes with ArrowWriter.get\_column\_writers should parallelize across row groups [\#8115](https://github.com/apache/arrow-rs/issues/8115) +- \[Variant\] Implement `VariantArray::value` for shredded variants [\#8091](https://github.com/apache/arrow-rs/issues/8091) +- \[Variant\] Integration tests for reading parquet w/ Variants [\#8084](https://github.com/apache/arrow-rs/issues/8084) +- \[Variant\]: Implement `DataType::Map` support for `cast_to_variant` kernel [\#8063](https://github.com/apache/arrow-rs/issues/8063) +- \[Variant\]: Implement `DataType::List/LargeList` support for `cast_to_variant` kernel [\#8060](https://github.com/apache/arrow-rs/issues/8060) + +## [56.1.0](https://github.com/apache/arrow-rs/tree/56.1.0) (2025-08-21) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/56.0.0...56.1.0) + +**Implemented enhancements:** + +- Implement cast and other operations on decimal32 and decimal64 \#7815 [\#8204](https://github.com/apache/arrow-rs/issues/8204) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Speed up Parquet filter pushdown with predicate cache [\#8203](https://github.com/apache/arrow-rs/issues/8203) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Optionally read parquet page indexes [\#8070](https://github.com/apache/arrow-rs/issues/8070) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Parquet reader: add method for sync reader read bloom filter [\#8023](https://github.com/apache/arrow-rs/issues/8023) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[parquet\] Support writing logically equivalent types to `ArrowWriter` [\#8012](https://github.com/apache/arrow-rs/issues/8012) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Improve StringArray\(Utf8\) sort performance [\#7847](https://github.com/apache/arrow-rs/issues/7847) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- feat: arrow-ipc delta dictionary support [\#8001](https://github.com/apache/arrow-rs/pull/8001) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([JakeDern](https://github.com/JakeDern)) + +**Fixed bugs:** + +- The Rustdocs are clean CI job is failing [\#8175](https://github.com/apache/arrow-rs/issues/8175) +- \[avro\] Bug in resolving avro schema with named type [\#8045](https://github.com/apache/arrow-rs/issues/8045) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Doc test failure \(test arrow-avro/src/lib.rs - reader\) when verifying avro 56.0.0 RC1 release [\#8018](https://github.com/apache/arrow-rs/issues/8018) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Documentation updates:** + +- arrow-row: Document dictionary handling [\#8168](https://github.com/apache/arrow-rs/pull/8168) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Docs: Clarify that Array::value does not check for nulls [\#8065](https://github.com/apache/arrow-rs/pull/8065) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- docs: Fix a typo in README [\#8036](https://github.com/apache/arrow-rs/pull/8036) ([EricccTaiwan](https://github.com/EricccTaiwan)) +- Add more comments to the internal parquet reader [\#7932](https://github.com/apache/arrow-rs/pull/7932) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) + +**Performance improvements:** + +- perf\(arrow-ipc\): avoid counting nulls in `RecordBatchDecoder` [\#8127](https://github.com/apache/arrow-rs/pull/8127) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- Use `Vec` directly in builders [\#7984](https://github.com/apache/arrow-rs/pull/7984) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([liamzwbao](https://github.com/liamzwbao)) +- Improve StringArray\(Utf8\) sort performance \(~2-4x faster\) [\#7860](https://github.com/apache/arrow-rs/pull/7860) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) + +**Closed issues:** + +- \[Variant\] Improve fuzz test for Variant [\#8199](https://github.com/apache/arrow-rs/issues/8199) +- \[Variant\] Improve fuzz test for Variant [\#8198](https://github.com/apache/arrow-rs/issues/8198) +- `VariantArrayBuilder` tracks starting offsets instead of \(offset, len\) pairs [\#8192](https://github.com/apache/arrow-rs/issues/8192) +- Rework `ValueBuilder` API to work with `ParentState` for reliable nested rollbacks [\#8188](https://github.com/apache/arrow-rs/issues/8188) +- \[Variant\] Rename `ValueBuffer` as `ValueBuilder` [\#8186](https://github.com/apache/arrow-rs/issues/8186) +- \[Variant\] Refactor `ParentState` to track and rollback state on behalf of its owning builder [\#8182](https://github.com/apache/arrow-rs/issues/8182) +- \[Variant\] `ObjectBuilder` should detect duplicates at insertion time, not at finish [\#8180](https://github.com/apache/arrow-rs/issues/8180) +- \[Variant\] ObjectBuilder does not reliably check for duplicates [\#8170](https://github.com/apache/arrow-rs/issues/8170) +- [Variant] Support `StringView` and `LargeString` in ´batch_json_string_to_variant` [\#8145](https://github.com/apache/arrow-rs/issues/8145) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Rename `batch_json_string_to_variant` and `batch_variant_to_json_string` json\_to\_variant [\#8144](https://github.com/apache/arrow-rs/issues/8144) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[avro\] Use `tempfile` crate rather than custom temporary file generator in tests [\#8143](https://github.com/apache/arrow-rs/issues/8143) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Avro\] Use `Write` rather `dyn Write` in Decoder [\#8142](https://github.com/apache/arrow-rs/issues/8142) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Nested builder rollback is broken [\#8136](https://github.com/apache/arrow-rs/issues/8136) +- \[Variant\] Add support the remaing primitive type\(timestamp\_nanos/timestampntz\_nanos/uuid\) for parquet variant [\#8126](https://github.com/apache/arrow-rs/issues/8126) +- Meta: Implement missing Arrow 56.0 lint rules - Sequential workflow [\#8121](https://github.com/apache/arrow-rs/issues/8121) +- ARROW-012-015: Add linter rules for remaining Arrow 56.0 breaking changes [\#8120](https://github.com/apache/arrow-rs/issues/8120) +- ARROW-010 & ARROW-011: Add linter rules for Parquet Statistics and Metadata API removals [\#8119](https://github.com/apache/arrow-rs/issues/8119) +- ARROW-009: Add linter rules for IPC Dictionary API removals in Arrow 56.0 [\#8118](https://github.com/apache/arrow-rs/issues/8118) +- ARROW-008: Add linter rule for SerializedPageReaderState usize→u64 breaking change [\#8117](https://github.com/apache/arrow-rs/issues/8117) +- ARROW-007: Add linter rule for Schema.all\_fields\(\) removal in Arrow 56.0 [\#8116](https://github.com/apache/arrow-rs/issues/8116) +- \[Variant\] Implement `ShreddingState::AllNull` variant [\#8088](https://github.com/apache/arrow-rs/issues/8088) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Support Shredded Objects in `variant_get` [\#8083](https://github.com/apache/arrow-rs/issues/8083) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::RunEndEncoded` support for `cast_to_variant` kernel [\#8064](https://github.com/apache/arrow-rs/issues/8064) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::Dictionary` support for `cast_to_variant` kernel [\#8062](https://github.com/apache/arrow-rs/issues/8062) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::Struct` support for `cast_to_variant` kernel [\#8061](https://github.com/apache/arrow-rs/issues/8061) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::Decimal32/Decimal64/Decimal128/Decimal256` support for `cast_to_variant` kernel [\#8059](https://github.com/apache/arrow-rs/issues/8059) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::Timestamp(..)` support for `cast_to_variant` kernel [\#8058](https://github.com/apache/arrow-rs/issues/8058) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::Float16` support for `cast_to_variant` kernel [\#8057](https://github.com/apache/arrow-rs/issues/8057) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::Interval` support for `cast_to_variant` kernel [\#8056](https://github.com/apache/arrow-rs/issues/8056) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::Time32/Time64` support for `cast_to_variant` kernel [\#8055](https://github.com/apache/arrow-rs/issues/8055) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::Date32 / DataType::Date64` support for `cast_to_variant` kernel [\#8054](https://github.com/apache/arrow-rs/issues/8054) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::Null` support for `cast_to_variant` kernel [\#8053](https://github.com/apache/arrow-rs/issues/8053) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::Boolean` support for `cast_to_variant` kernel [\#8052](https://github.com/apache/arrow-rs/issues/8052) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::FixedSizeBinary` support for `cast_to_variant` kernel [\#8051](https://github.com/apache/arrow-rs/issues/8051) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::Binary/LargeBinary/BinaryView` support for `cast_to_variant` kernel [\#8050](https://github.com/apache/arrow-rs/issues/8050) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]: Implement `DataType::Utf8/LargeUtf8/Utf8View` support for `cast_to_variant` kernel [\#8049](https://github.com/apache/arrow-rs/issues/8049) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Implement `cast_to_variant` kernel [\#8043](https://github.com/apache/arrow-rs/issues/8043) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Support `variant_get` kernel for shredded variants [\#7941](https://github.com/apache/arrow-rs/issues/7941) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Add test for casting `Decimal128` \(`i128::MIN` and `i128::MAX`\) to `f64` with overflow handling [\#7939](https://github.com/apache/arrow-rs/issues/7939) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] + +**Merged pull requests:** + +- \[Variant\] Enhance the variant fuz test to cover time/timestamp/uuid primitive type [\#8200](https://github.com/apache/arrow-rs/pull/8200) ([klion26](https://github.com/klion26)) +- \[Variant\] VariantArrayBuilder tracks only offsets [\#8193](https://github.com/apache/arrow-rs/pull/8193) ([scovich](https://github.com/scovich)) +- \[Variant\] Caller provides ParentState to ValueBuilder methods [\#8189](https://github.com/apache/arrow-rs/pull/8189) ([scovich](https://github.com/scovich)) +- \[Variant\] Rename ValueBuffer as ValueBuilder [\#8187](https://github.com/apache/arrow-rs/pull/8187) ([scovich](https://github.com/scovich)) +- \[Variant\] ParentState handles finish/rollback for builders [\#8185](https://github.com/apache/arrow-rs/pull/8185) ([scovich](https://github.com/scovich)) +- \[Variant\]: Implement `DataType::RunEndEncoded` support for `cast_to_variant` kernel [\#8174](https://github.com/apache/arrow-rs/pull/8174) ([liamzwbao](https://github.com/liamzwbao)) +- \[Variant\]: Implement `DataType::Dictionary` support for `cast_to_variant` kernel [\#8173](https://github.com/apache/arrow-rs/pull/8173) ([liamzwbao](https://github.com/liamzwbao)) +- Implement `ArrayBuilder` for `UnionBuilder` [\#8169](https://github.com/apache/arrow-rs/pull/8169) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([grtlr](https://github.com/grtlr)) +- \[Variant\] Support `LargeString` and `StringView` in `batch_json_string_to_variant` [\#8163](https://github.com/apache/arrow-rs/pull/8163) ([liamzwbao](https://github.com/liamzwbao)) +- \[Variant\] Rename `batch_json_string_to_variant` and `batch_variant_to_json_string` [\#8161](https://github.com/apache/arrow-rs/pull/8161) ([liamzwbao](https://github.com/liamzwbao)) +- \[Variant\] Add primitive type timestamp\_nanos\(with&without timezone\) and uuid [\#8149](https://github.com/apache/arrow-rs/pull/8149) ([klion26](https://github.com/klion26)) +- refactor\(avro\): Use impl Write instead of dyn Write in encoder [\#8148](https://github.com/apache/arrow-rs/pull/8148) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Xuanwo](https://github.com/Xuanwo)) +- chore: Use tempfile to replace hand-written utils functions [\#8147](https://github.com/apache/arrow-rs/pull/8147) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Xuanwo](https://github.com/Xuanwo)) +- feat: support push batch direct to completed and add biggest coalesce batch support [\#8146](https://github.com/apache/arrow-rs/pull/8146) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- \[Variant\] Add human-readable impl Debug for Variant [\#8140](https://github.com/apache/arrow-rs/pull/8140) ([scovich](https://github.com/scovich)) +- \[Variant\] Fix broken metadata builder rollback [\#8135](https://github.com/apache/arrow-rs/pull/8135) ([scovich](https://github.com/scovich)) +- \[Variant\]: Implement DataType::Interval support for cast\_to\_variant kernel [\#8125](https://github.com/apache/arrow-rs/pull/8125) ([codephage2020](https://github.com/codephage2020)) +- Add schema resolution and type promotion support to arrow-avro Decoder [\#8124](https://github.com/apache/arrow-rs/pull/8124) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Add Initial `arrow-avro` writer implementation with basic type support [\#8123](https://github.com/apache/arrow-rs/pull/8123) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- \[Variant\] Add Variant::Time primitive and cast logic [\#8114](https://github.com/apache/arrow-rs/pull/8114) ([klion26](https://github.com/klion26)) +- \[Variant\] Support Timestamp to variant for `cast_to_variant` kernel [\#8113](https://github.com/apache/arrow-rs/pull/8113) ([abacef](https://github.com/abacef)) +- Bump actions/checkout from 4 to 5 [\#8110](https://github.com/apache/arrow-rs/pull/8110) ([dependabot[bot]](https://github.com/apps/dependabot)) +- \[Varaint\]: add `DataType::Null` support to cast\_to\_variant [\#8107](https://github.com/apache/arrow-rs/pull/8107) ([feniljain](https://github.com/feniljain)) +- \[Variant\] Adding fixed size byte array to variant and test [\#8106](https://github.com/apache/arrow-rs/pull/8106) ([abacef](https://github.com/abacef)) +- \[VARIANT\] Initial integration tests for variant reads [\#8104](https://github.com/apache/arrow-rs/pull/8104) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([carpecodeum](https://github.com/carpecodeum)) +- \[Variant\]: Implement `DataType::Decimal32/Decimal64/Decimal128/Decimal256` support for `cast_to_variant` kernel [\#8101](https://github.com/apache/arrow-rs/pull/8101) ([liamzwbao](https://github.com/liamzwbao)) +- Refactor arrow-avro `Decoder` to support partial decoding [\#8100](https://github.com/apache/arrow-rs/pull/8100) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- fix: Validate metadata len in IPC reader [\#8097](https://github.com/apache/arrow-rs/pull/8097) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([JakeDern](https://github.com/JakeDern)) +- \[parquet\] further improve logical type compatibility in ArrowWriter [\#8095](https://github.com/apache/arrow-rs/pull/8095) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([albertlockett](https://github.com/albertlockett)) +- \[Varint\] Implement ShreddingState::AllNull variant [\#8093](https://github.com/apache/arrow-rs/pull/8093) ([codephage2020](https://github.com/codephage2020)) +- \[Variant\] Minor: Add comments to tickets for follow on items [\#8092](https://github.com/apache/arrow-rs/pull/8092) ([alamb](https://github.com/alamb)) +- \[VARIANT\] Add support for DataType::Struct for cast\_to\_variant [\#8090](https://github.com/apache/arrow-rs/pull/8090) ([carpecodeum](https://github.com/carpecodeum)) +- \[VARIANT\] Add support for DataType::Utf8/LargeUtf8/Utf8View for cast\_to\_variant [\#8089](https://github.com/apache/arrow-rs/pull/8089) ([carpecodeum](https://github.com/carpecodeum)) +- \[Variant\] Implement `DataType::Boolean` support for `cast_to_variant` kernel [\#8085](https://github.com/apache/arrow-rs/pull/8085) ([sdf-jkl](https://github.com/sdf-jkl)) +- \[Variant\] Implement `DataType::{Date32,Date64}` =\> `Variant::Date` [\#8081](https://github.com/apache/arrow-rs/pull/8081) ([superserious-dev](https://github.com/superserious-dev)) +- Fix new clippy lints from Rust 1.89 [\#8078](https://github.com/apache/arrow-rs/pull/8078) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Implement ArrowSchema to AvroSchema conversion logic in arrow-avro [\#8075](https://github.com/apache/arrow-rs/pull/8075) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Implement `DataType::{Binary, LargeBinary, BinaryView}` =\> `Variant::Binary` [\#8074](https://github.com/apache/arrow-rs/pull/8074) ([superserious-dev](https://github.com/superserious-dev)) +- \[Variant\] Implement `DataType::Float16` =\> `Variant::Float` [\#8073](https://github.com/apache/arrow-rs/pull/8073) ([superserious-dev](https://github.com/superserious-dev)) +- create PageIndexPolicy to allow optional indexes [\#8071](https://github.com/apache/arrow-rs/pull/8071) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([kczimm](https://github.com/kczimm)) +- \[Variant\] Minor: use From impl to make conversion infallable [\#8068](https://github.com/apache/arrow-rs/pull/8068) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Bump actions/download-artifact from 4 to 5 [\#8066](https://github.com/apache/arrow-rs/pull/8066) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Added arrow-avro schema resolution foundations and type promotion [\#8047](https://github.com/apache/arrow-rs/pull/8047) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Fix arrow-avro type resolver register bug [\#8046](https://github.com/apache/arrow-rs/pull/8046) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([yongkyunlee](https://github.com/yongkyunlee)) +- implement `cast_to_variant` kernel to cast native types to `VariantArray` [\#8044](https://github.com/apache/arrow-rs/pull/8044) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Add arrow-avro `SchemaStore` and fingerprinting [\#8039](https://github.com/apache/arrow-rs/pull/8039) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Add more benchmarks for Parquet thrift decoding [\#8037](https://github.com/apache/arrow-rs/pull/8037) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Support multi-threaded writing of Parquet files with modular encryption [\#8029](https://github.com/apache/arrow-rs/pull/8029) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([rok](https://github.com/rok)) +- Add arrow-avro Decoder Benchmarks [\#8025](https://github.com/apache/arrow-rs/pull/8025) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- feat: add method for sync Parquet reader read bloom filter [\#8024](https://github.com/apache/arrow-rs/pull/8024) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([mapleFU](https://github.com/mapleFU)) +- \[Variant\] Add `variant_get` and Shredded `VariantArray` [\#8021](https://github.com/apache/arrow-rs/pull/8021) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Implement arrow-avro SchemaStore and Fingerprinting To Enable Schema Resolution [\#8006](https://github.com/apache/arrow-rs/pull/8006) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- \[Parquet\] Add tests for IO/CPU access in parquet reader [\#7971](https://github.com/apache/arrow-rs/pull/7971) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Speed up Parquet filter pushdown v4 \(Predicate evaluation cache for async\_reader\) [\#7850](https://github.com/apache/arrow-rs/pull/7850) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Implement cast and other operations on decimal32 and decimal64 [\#7815](https://github.com/apache/arrow-rs/pull/7815) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([CurtHagenlocher](https://github.com/CurtHagenlocher)) +## [56.0.0](https://github.com/apache/arrow-rs/tree/56.0.0) (2025-07-29) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/55.2.0...56.0.0) + +**Breaking changes:** + +- arrow-schema: Remove dict\_id from being required equal for merging [\#7968](https://github.com/apache/arrow-rs/pull/7968) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) +- \[Parquet\] Use `u64` for `SerializedPageReaderState.offset` & `remaining_bytes`, instead of `usize` [\#7918](https://github.com/apache/arrow-rs/pull/7918) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([JigaoLuo](https://github.com/JigaoLuo)) +- Upgrade tonic dependencies to 0.13.0 version \(try 2\) [\#7839](https://github.com/apache/arrow-rs/pull/7839) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Remove deprecated Arrow functions [\#7830](https://github.com/apache/arrow-rs/pull/7830) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([etseidl](https://github.com/etseidl)) +- Remove deprecated temporal functions [\#7813](https://github.com/apache/arrow-rs/pull/7813) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([etseidl](https://github.com/etseidl)) +- Remove functions from parquet crate deprecated in or before 54.0.0 [\#7811](https://github.com/apache/arrow-rs/pull/7811) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- GH-7686: \[Parquet\] Fix int96 min/max stats [\#7687](https://github.com/apache/arrow-rs/pull/7687) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([rahulketch](https://github.com/rahulketch)) + +**Implemented enhancements:** + +- \[parquet\] Relax type restriction to allow writing dictionary/native batches for same column [\#8004](https://github.com/apache/arrow-rs/issues/8004) +- Support casting int64 to interval [\#7988](https://github.com/apache/arrow-rs/issues/7988) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Add `ListBuilder::with_value` for convenience [\#7951](https://github.com/apache/arrow-rs/issues/7951) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Add `ObjectBuilder::with_field` for convenience [\#7949](https://github.com/apache/arrow-rs/issues/7949) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Impl PartialEq for VariantObject \#7943 [\#7948](https://github.com/apache/arrow-rs/issues/7948) +- \[Variant\] Offer `simdutf8` as an optional dependency when validating metadata [\#7902](https://github.com/apache/arrow-rs/issues/7902) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Avoid collecting offset iterator [\#7901](https://github.com/apache/arrow-rs/issues/7901) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Remove superfluous check when validating monotonic offsets [\#7900](https://github.com/apache/arrow-rs/issues/7900) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Avoid extra allocation in `ObjectBuilder` [\#7899](https://github.com/apache/arrow-rs/issues/7899) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]\[Compute\] `variant_get` kernel [\#7893](https://github.com/apache/arrow-rs/issues/7893) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\]\[Compute\] Add batch processing for Variant-JSON String conversion [\#7883](https://github.com/apache/arrow-rs/issues/7883) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support `MapArray` in lexsort [\#7881](https://github.com/apache/arrow-rs/issues/7881) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Add testing for invalid variants \(fuzz testing??\) [\#7842](https://github.com/apache/arrow-rs/issues/7842) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] VariantMetadata, VariantList and VariantObject are too big for Copy [\#7831](https://github.com/apache/arrow-rs/issues/7831) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Allow choosing flate2 backend [\#7826](https://github.com/apache/arrow-rs/issues/7826) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Tests for creating "large" `VariantObjects`s [\#7821](https://github.com/apache/arrow-rs/issues/7821) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Tests for creating "large" `VariantList`s [\#7820](https://github.com/apache/arrow-rs/issues/7820) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Support VariantBuilder to write to buffers owned by the caller [\#7805](https://github.com/apache/arrow-rs/issues/7805) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Move JSON related functionality to different crate. [\#7800](https://github.com/apache/arrow-rs/issues/7800) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Add flag in `ObjectBuilder` to control validation behavior on duplicate field write [\#7777](https://github.com/apache/arrow-rs/issues/7777) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] make `serde_json` an optional dependency of `parquet-variant` [\#7775](https://github.com/apache/arrow-rs/issues/7775) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[coalesce\] Implement specialized `BatchCoalescer::push_batch` for `PrimitiveArray` [\#7763](https://github.com/apache/arrow-rs/issues/7763) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add sort\_kernel benchmark for StringViewArray case [\#7758](https://github.com/apache/arrow-rs/issues/7758) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Improved API for accessing Variant Objects and lists [\#7756](https://github.com/apache/arrow-rs/issues/7756) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Buildable reproducible release builds [\#7751](https://github.com/apache/arrow-rs/issues/7751) +- Allow per-column parquet dictionary page size limit [\#7723](https://github.com/apache/arrow-rs/issues/7723) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Test and implement efficient building for "large" Arrays [\#7699](https://github.com/apache/arrow-rs/issues/7699) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Improve VariantBuilder when creating field name dictionaries / sorted dictionaries [\#7698](https://github.com/apache/arrow-rs/issues/7698) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Add input validation in `VariantBuilder` [\#7697](https://github.com/apache/arrow-rs/issues/7697) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Support Nested Data in `VariantBuilder` [\#7696](https://github.com/apache/arrow-rs/issues/7696) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Parquet: Incorrect min/max stats for int96 columns [\#7686](https://github.com/apache/arrow-rs/issues/7686) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Add `DictionaryArray::gc` method [\#7683](https://github.com/apache/arrow-rs/issues/7683) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Add negative tests for reading invalid primitive variant values [\#7645](https://github.com/apache/arrow-rs/issues/7645) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Fixed bugs:** + +- \[Variant\] Panic when appending nested objects to VariantBuilder [\#7907](https://github.com/apache/arrow-rs/issues/7907) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Panic when casting large Decimal256 to f64 due to unchecked `unwrap()` [\#7886](https://github.com/apache/arrow-rs/issues/7886) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Incorrect inlined string view comparison after " Add prefix compare for inlined" [\#7874](https://github.com/apache/arrow-rs/issues/7874) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] `test_json_to_variant_object_very_large` takes over 20s [\#7872](https://github.com/apache/arrow-rs/issues/7872) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] If `ObjectBuilder::finalize` is not called, the resulting Variant object is malformed. [\#7863](https://github.com/apache/arrow-rs/issues/7863) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- CSV error message has values transposed [\#7848](https://github.com/apache/arrow-rs/issues/7848) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Concating struct arrays with no fields unnecessarily errors [\#7828](https://github.com/apache/arrow-rs/issues/7828) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Clippy CI is failing on main after Rust `1.88` upgrade [\#7796](https://github.com/apache/arrow-rs/issues/7796) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- \[Variant\] Field lookup with out of bounds index causes unwanted behavior [\#7784](https://github.com/apache/arrow-rs/issues/7784) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Error verifying `parquet-variant` crate on 55.2.0 with `verify-release-candidate.sh` [\#7746](https://github.com/apache/arrow-rs/issues/7746) +- `test_to_pyarrow` tests fail during release verification [\#7736](https://github.com/apache/arrow-rs/issues/7736) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[parquet\_derive\] Example for ParquetRecordWriter is broken. [\#7732](https://github.com/apache/arrow-rs/issues/7732) +- \[Variant\] `Variant::Object` can contain two fields with the same field name [\#7730](https://github.com/apache/arrow-rs/issues/7730) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Panic when appending Object or List to VariantBuilder [\#7701](https://github.com/apache/arrow-rs/issues/7701) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Slicing a single-field dense union array creates an array with incorrect `logical_nulls` length [\#7647](https://github.com/apache/arrow-rs/issues/7647) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Ensure page encoding statistics are written to Parquet file [\#7643](https://github.com/apache/arrow-rs/pull/7643) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) + +**Documentation updates:** + +- Minor: Upate `cast_with_options` docs about casting integers --\> intervals [\#8002](https://github.com/apache/arrow-rs/pull/8002) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- docs: More docs to `BatchCoalescer` [\#7891](https://github.com/apache/arrow-rs/pull/7891) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([2010YOUY01](https://github.com/2010YOUY01)) +- chore: fix a typo in `ExtensionType::supports_data_type` docs [\#7682](https://github.com/apache/arrow-rs/pull/7682) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- \[Variant\] Add variant docs and examples [\#7661](https://github.com/apache/arrow-rs/pull/7661) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Minor: Add version to deprecation notice for `ParquetMetaDataReader::decode_footer` [\#7639](https://github.com/apache/arrow-rs/pull/7639) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) + +**Performance improvements:** + +- `RowConverter` on list should only encode the sliced list values and not the entire data [\#7993](https://github.com/apache/arrow-rs/issues/7993) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Avoid extra allocation in list builder [\#7977](https://github.com/apache/arrow-rs/issues/7977) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Convert JSON to Variant with fewer copies [\#7964](https://github.com/apache/arrow-rs/issues/7964) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Optimize sort kernels partition\_validity method [\#7936](https://github.com/apache/arrow-rs/issues/7936) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Speedup sorting for inline views [\#7857](https://github.com/apache/arrow-rs/issues/7857) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Perf: Investigate and improve parquet writing performance [\#7822](https://github.com/apache/arrow-rs/issues/7822) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Perf: optimize sort string\_view performance [\#7790](https://github.com/apache/arrow-rs/issues/7790) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Clickbench microbenchmark spends significant time in memcmp for not\_empty predicate [\#7766](https://github.com/apache/arrow-rs/issues/7766) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Use prefix first for comparisons, resort to data buffer for remaining data on equal values [\#7744](https://github.com/apache/arrow-rs/issues/7744) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Change use of `inline_value` to inline it to a u128 [\#7743](https://github.com/apache/arrow-rs/issues/7743) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add efficient way to upgrade keys for additional dictionary builders [\#7654](https://github.com/apache/arrow-rs/issues/7654) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Perf: Make sort string view fast\(1.5X ~ 3X faster\) [\#7792](https://github.com/apache/arrow-rs/pull/7792) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- Add specialized coalesce path for PrimitiveArrays [\#7772](https://github.com/apache/arrow-rs/pull/7772) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) + +**Closed issues:** + +- Implement full-range `i256::to_f64` to replace current ±∞ saturation for Decimal256 → Float64 [\#7985](https://github.com/apache/arrow-rs/issues/7985) +- \[Variant\] `impl FromIterator` fpr `VariantPath` [\#7955](https://github.com/apache/arrow-rs/issues/7955) +- `validated` and `is_fully_validated` flags doesn't need to be part of PartialEq [\#7952](https://github.com/apache/arrow-rs/issues/7952) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] remove VariantMetadata::dictionary\_size [\#7947](https://github.com/apache/arrow-rs/issues/7947) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Improve `VariantArray` performance by storing the index of the metadata and value arrays [\#7920](https://github.com/apache/arrow-rs/issues/7920) +- \[Variant\] Converting variant to JSON string seems slow [\#7869](https://github.com/apache/arrow-rs/issues/7869) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Present Variant at Iceberg Summit NYC July 10, 2025 [\#7858](https://github.com/apache/arrow-rs/issues/7858) +- \[Variant\] Avoid second copy of field name in MetadataBuilder [\#7814](https://github.com/apache/arrow-rs/issues/7814) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Remove APIs deprecated in or before 54.0.0 [\#7810](https://github.com/apache/arrow-rs/issues/7810) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- \[Variant\] Make it harder to forget to finish a pending parent i n ObjectBuilder [\#7798](https://github.com/apache/arrow-rs/issues/7798) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Remove explicit ObjectBuilder::finish\(\) and ListBuilder::finish and move to `Drop` impl [\#7780](https://github.com/apache/arrow-rs/issues/7780) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Reduce repetition in tests for arrow-row/src/run.rs [\#7692](https://github.com/apache/arrow-rs/issues/7692) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Add tests for invalid variant values \(aka verify invalid inputs\) [\#7681](https://github.com/apache/arrow-rs/issues/7681) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Introduce structs for Variant::Decimal types [\#7660](https://github.com/apache/arrow-rs/issues/7660) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Merged pull requests:** + +- Add benchmark for converting StringViewArray with mixed short and long strings [\#8015](https://github.com/apache/arrow-rs/pull/8015) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ding-young](https://github.com/ding-young)) +- \[Variant\] impl FromIterator for VariantPath [\#8011](https://github.com/apache/arrow-rs/pull/8011) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([sdf-jkl](https://github.com/sdf-jkl)) +- Create empty buffer for a buffer specified in the C Data Interface with length zero [\#8009](https://github.com/apache/arrow-rs/pull/8009) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- bench: add benchmark for converting list and sliced list to row format [\#8008](https://github.com/apache/arrow-rs/pull/8008) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- bench: benchmark interleave structs [\#8007](https://github.com/apache/arrow-rs/pull/8007) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- \[Parquet\] Allow writing compatible DictionaryArrays to parquet writer [\#8005](https://github.com/apache/arrow-rs/pull/8005) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([albertlockett](https://github.com/albertlockett)) +- doc: remove outdated info from CONTRIBUTING doc in project root dir. [\#7998](https://github.com/apache/arrow-rs/pull/7998) ([sonhmai](https://github.com/sonhmai)) +- perf: only encode actual list values in `RowConverter` \(16-26 times faster for small sliced list\) [\#7996](https://github.com/apache/arrow-rs/pull/7996) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- test: add tests for converting sliced list to row based [\#7994](https://github.com/apache/arrow-rs/pull/7994) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- perf: Improve `interleave` performance for struct \(3-6 times faster\) [\#7991](https://github.com/apache/arrow-rs/pull/7991) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- \[Variant\] Avoid extra buffer allocation in ListBuilder [\#7987](https://github.com/apache/arrow-rs/pull/7987) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([klion26](https://github.com/klion26)) +- Implement full-range `i256::to_f64` to eliminate ±∞ saturation for Decimal256 → Float64 casts [\#7986](https://github.com/apache/arrow-rs/pull/7986) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kosiew](https://github.com/kosiew)) +- Minor: Restore warning comment on Int96 statistics read [\#7975](https://github.com/apache/arrow-rs/pull/7975) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Add additional integration tests to arrow-avro [\#7974](https://github.com/apache/arrow-rs/pull/7974) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nathaniel-d-ef](https://github.com/nathaniel-d-ef)) +- Perf: optimize actual\_buffer\_size to use only data buffer capacity for coalesce [\#7967](https://github.com/apache/arrow-rs/pull/7967) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- Implement Improved arrow-avro Reader Zero-Byte Record Handling [\#7966](https://github.com/apache/arrow-rs/pull/7966) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Perf: improve sort via `partition_validity` to use fast path for bit map scan \(up to 30% faster\) [\#7962](https://github.com/apache/arrow-rs/pull/7962) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- \[Variant\] Revisit VariantMetadata and Object equality [\#7961](https://github.com/apache/arrow-rs/pull/7961) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- \[Variant\] Add ListBuilder::with\_value for convenience [\#7959](https://github.com/apache/arrow-rs/pull/7959) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([codephage2020](https://github.com/codephage2020)) +- \[Variant\] remove VariantMetadata::dictionary\_size [\#7958](https://github.com/apache/arrow-rs/pull/7958) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([codephage2020](https://github.com/codephage2020)) +- \[Variant\] VariantMetadata is allowed to contain the empty string [\#7956](https://github.com/apache/arrow-rs/pull/7956) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- Add arrow-avro support for Impala Nullability [\#7954](https://github.com/apache/arrow-rs/pull/7954) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([veronica-m-ef](https://github.com/veronica-m-ef)) +- \[Test\] Add tests for VariantList equality [\#7953](https://github.com/apache/arrow-rs/pull/7953) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- \[Variant\] Add ObjectBuilder::with\_field for convenience [\#7950](https://github.com/apache/arrow-rs/pull/7950) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- \[Variant\] Adding code to store metadata and value references in VariantArray [\#7945](https://github.com/apache/arrow-rs/pull/7945) ([abacef](https://github.com/abacef)) +- \[Variant\] Add `variant_kernels` benchmark [\#7944](https://github.com/apache/arrow-rs/pull/7944) ([alamb](https://github.com/alamb)) +- \[Variant\] Impl `PartialEq` for VariantObject [\#7943](https://github.com/apache/arrow-rs/pull/7943) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- \[Variant\] Add documentation, tests and cleaner api for Variant::get\_path [\#7942](https://github.com/apache/arrow-rs/pull/7942) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- arrow-ipc: Remove all abilities to preserve dict IDs [\#7940](https://github.com/apache/arrow-rs/pull/7940) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([brancz](https://github.com/brancz)) +- Optimize partition\_validity function used in sort kernels [\#7937](https://github.com/apache/arrow-rs/pull/7937) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- \[Variant\] Avoid extra allocation in object builder [\#7935](https://github.com/apache/arrow-rs/pull/7935) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([klion26](https://github.com/klion26)) +- \[Variant\] Avoid collecting offset iterator [\#7934](https://github.com/apache/arrow-rs/pull/7934) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([codephage2020](https://github.com/codephage2020)) +- Minor: Support BinaryView and StringView builders in `make_builder` [\#7931](https://github.com/apache/arrow-rs/pull/7931) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kylebarron](https://github.com/kylebarron)) +- chore: bump MSRV to 1.84 [\#7926](https://github.com/apache/arrow-rs/pull/7926) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([mbrobbel](https://github.com/mbrobbel)) +- Update bzip2 requirement from 0.4.4 to 0.6.0 [\#7924](https://github.com/apache/arrow-rs/pull/7924) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- \[Variant\] Reserve capacity beforehand during large object building [\#7922](https://github.com/apache/arrow-rs/pull/7922) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- \[Variant\] Add `variant_get` compute kernel [\#7919](https://github.com/apache/arrow-rs/pull/7919) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Samyak2](https://github.com/Samyak2)) +- Improve memory usage for `arrow-row -> String/BinaryView` when utf8 validation disabled [\#7917](https://github.com/apache/arrow-rs/pull/7917) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ding-young](https://github.com/ding-young)) +- Restructure compare\_greater function used in parquet statistics for better performance [\#7916](https://github.com/apache/arrow-rs/pull/7916) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jhorstmann](https://github.com/jhorstmann)) +- \[Variant\] Support appending complex variants in `VariantBuilder` [\#7914](https://github.com/apache/arrow-rs/pull/7914) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- \[Variant\] Add `VariantBuilder::new_with_buffers` to write to existing buffers [\#7912](https://github.com/apache/arrow-rs/pull/7912) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Convert JSON to VariantArray without copying \(8 - 32% faster\) [\#7911](https://github.com/apache/arrow-rs/pull/7911) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- \[Variant\] Use simdutf8 for UTF-8 validation [\#7908](https://github.com/apache/arrow-rs/pull/7908) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([codephage2020](https://github.com/codephage2020)) +- \[Variant\] Avoid superflous validation checks [\#7906](https://github.com/apache/arrow-rs/pull/7906) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- Add `VariantArray` and `VariantArrayBuilder` for constructing Arrow Arrays of Variants [\#7905](https://github.com/apache/arrow-rs/pull/7905) ([alamb](https://github.com/alamb)) +- Update sysinfo requirement from 0.35.0 to 0.36.0 [\#7904](https://github.com/apache/arrow-rs/pull/7904) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Fix current CI failure [\#7898](https://github.com/apache/arrow-rs/pull/7898) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([viirya](https://github.com/viirya)) +- Remove redundant is\_err checks in Variant tests [\#7897](https://github.com/apache/arrow-rs/pull/7897) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) +- \[Variant\] test: add variant object tests with different sizes [\#7896](https://github.com/apache/arrow-rs/pull/7896) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([odysa](https://github.com/odysa)) +- \[Variant\] Define basic convenience methods for variant pathing [\#7894](https://github.com/apache/arrow-rs/pull/7894) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- fix: `view_types` benchmark slice should follow by correct len array [\#7892](https://github.com/apache/arrow-rs/pull/7892) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- Add arrow-avro support for bzip2 and xz compression [\#7890](https://github.com/apache/arrow-rs/pull/7890) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Add arrow-avro support for Duration type and minor fixes for UUID decoding [\#7889](https://github.com/apache/arrow-rs/pull/7889) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- \[Variant\] Reduce variant-related struct sizes [\#7888](https://github.com/apache/arrow-rs/pull/7888) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- Fix panic on lossy decimal to float casting: round to saturation for overflows [\#7887](https://github.com/apache/arrow-rs/pull/7887) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kosiew](https://github.com/kosiew)) +- Add tests for invalid variant metadata and value [\#7885](https://github.com/apache/arrow-rs/pull/7885) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) +- \[Variant\] Introduce parquet-variant-compute crate to transform batches of JSON strings to and from Variants [\#7884](https://github.com/apache/arrow-rs/pull/7884) ([harshmotw-db](https://github.com/harshmotw-db)) +- feat: support `MapArray` in lexsort [\#7882](https://github.com/apache/arrow-rs/pull/7882) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- fix: mark `DataType::Map` as unsupported in `RowConverter` [\#7880](https://github.com/apache/arrow-rs/pull/7880) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- \[Variant\] Speedup validation [\#7878](https://github.com/apache/arrow-rs/pull/7878) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- benchmark: Add StringViewArray gc benchmark with not null cases [\#7877](https://github.com/apache/arrow-rs/pull/7877) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- \[ARROW-RS-7820\]\[Variant\] Add tests for large variant lists [\#7876](https://github.com/apache/arrow-rs/pull/7876) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([klion26](https://github.com/klion26)) +- fix: Incorrect inlined string view comparison after Add prefix compar… [\#7875](https://github.com/apache/arrow-rs/pull/7875) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- perf: speed up StringViewArray gc 1.4 ~5.x faster [\#7873](https://github.com/apache/arrow-rs/pull/7873) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- \[Variant\] Remove superflous validate call and rename methods [\#7871](https://github.com/apache/arrow-rs/pull/7871) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- Benchmark: Add rich testing cases for sort string\(utf8\) [\#7867](https://github.com/apache/arrow-rs/pull/7867) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- chore: update link for `row_filter.rs` [\#7866](https://github.com/apache/arrow-rs/pull/7866) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([haohuaijin](https://github.com/haohuaijin)) +- \[Variant\] List and object builders have no effect until finalized [\#7865](https://github.com/apache/arrow-rs/pull/7865) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- Added number to string benches for json\_writer [\#7864](https://github.com/apache/arrow-rs/pull/7864) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([abacef](https://github.com/abacef)) +- \[Variant\] Introduce `parquet-variant-json` crate [\#7862](https://github.com/apache/arrow-rs/pull/7862) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- \[Variant\] Remove dead code, add comments [\#7861](https://github.com/apache/arrow-rs/pull/7861) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Speedup sorting for inline views: 1.4x - 1.7x improvement [\#7856](https://github.com/apache/arrow-rs/pull/7856) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Fix union slice logical\_nulls length [\#7855](https://github.com/apache/arrow-rs/pull/7855) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([codephage2020](https://github.com/codephage2020)) +- Add `get_ref/get_mut` to JSON Writer [\#7854](https://github.com/apache/arrow-rs/pull/7854) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([cetra3](https://github.com/cetra3)) +- \[Minor\] Add Benchmark for RowConverter::append [\#7853](https://github.com/apache/arrow-rs/pull/7853) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Add Enum type support to arrow-avro and Minor Decimal type fix [\#7852](https://github.com/apache/arrow-rs/pull/7852) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- CSV error message has values transposed [\#7851](https://github.com/apache/arrow-rs/pull/7851) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Omega359](https://github.com/Omega359)) +- \[Variant\] Fuzz testing and benchmarks for vaildation [\#7849](https://github.com/apache/arrow-rs/pull/7849) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([carpecodeum](https://github.com/carpecodeum)) +- \[Variant\] Follow up nits and uncomment test cases [\#7846](https://github.com/apache/arrow-rs/pull/7846) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- \[Variant\] Make sure ObjectBuilder and ListBuilder to be finalized before its parent builder [\#7843](https://github.com/apache/arrow-rs/pull/7843) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([viirya](https://github.com/viirya)) +- Add decimal32 and decimal64 support to Parquet, JSON and CSV readers and writers [\#7841](https://github.com/apache/arrow-rs/pull/7841) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([CurtHagenlocher](https://github.com/CurtHagenlocher)) +- Implement arrow-avro Reader and ReaderBuilder [\#7834](https://github.com/apache/arrow-rs/pull/7834) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- \[Variant\] Support creating sorted dictionaries [\#7833](https://github.com/apache/arrow-rs/pull/7833) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- Add Decimal type support to arrow-avro [\#7832](https://github.com/apache/arrow-rs/pull/7832) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Allow concating struct arrays with no fields [\#7829](https://github.com/apache/arrow-rs/pull/7829) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([AdamGS](https://github.com/AdamGS)) +- Add features to configure flate2 [\#7827](https://github.com/apache/arrow-rs/pull/7827) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([zeevm](https://github.com/zeevm)) +- make builder public under experimental [\#7825](https://github.com/apache/arrow-rs/pull/7825) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([XiangpengHao](https://github.com/XiangpengHao)) +- Improvements for parquet writing performance \(25%-44%\) [\#7824](https://github.com/apache/arrow-rs/pull/7824) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- Use in-memory buffer for arrow\_writer benchmark [\#7823](https://github.com/apache/arrow-rs/pull/7823) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jhorstmann](https://github.com/jhorstmann)) +- \[Variant\] impl \[Try\]From for VariantDecimalXX types [\#7809](https://github.com/apache/arrow-rs/pull/7809) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- \[Variant\] Speedup `ObjectBuilder` \(62x faster\) [\#7808](https://github.com/apache/arrow-rs/pull/7808) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- \[VARIANT\] Support both fallible and infallible access to variants [\#7807](https://github.com/apache/arrow-rs/pull/7807) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- Minor: fix clippy in parquet-variant after logical conflict [\#7803](https://github.com/apache/arrow-rs/pull/7803) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- \[Variant\] Add flag in `ObjectBuilder` to control validation behavior on duplicate field write [\#7801](https://github.com/apache/arrow-rs/pull/7801) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([micoo227](https://github.com/micoo227)) +- Fix clippy for Rust 1.88 release [\#7797](https://github.com/apache/arrow-rs/pull/7797) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- \[Variant\] Simplify `Builder` buffer operations [\#7795](https://github.com/apache/arrow-rs/pull/7795) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- fix: Change panic to error in`take` kernel for StringArrary/BinaryArray on overflow [\#7793](https://github.com/apache/arrow-rs/pull/7793) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([chenkovsky](https://github.com/chenkovsky)) +- Update base64 requirement from 0.21 to 0.22 [\#7791](https://github.com/apache/arrow-rs/pull/7791) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([dependabot[bot]](https://github.com/apps/dependabot)) +- Fix RowConverter when FixedSizeList is not the last [\#7789](https://github.com/apache/arrow-rs/pull/7789) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- Add schema with only primitive arrays to `coalesce_kernel` benchmark [\#7788](https://github.com/apache/arrow-rs/pull/7788) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add sort\_kernel benchmark for StringViewArray case [\#7787](https://github.com/apache/arrow-rs/pull/7787) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- \[Variant\] Check pending before `VariantObject::insert` [\#7786](https://github.com/apache/arrow-rs/pull/7786) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- \[VARIANT\] impl Display for VariantDecimalXX [\#7785](https://github.com/apache/arrow-rs/pull/7785) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([scovich](https://github.com/scovich)) +- \[VARIANT\] Add support for the json\_to\_variant API [\#7783](https://github.com/apache/arrow-rs/pull/7783) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([harshmotw-db](https://github.com/harshmotw-db)) +- \[Variant\] Consolidate examples for json writing [\#7782](https://github.com/apache/arrow-rs/pull/7782) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Add benchmark for about view array slice [\#7781](https://github.com/apache/arrow-rs/pull/7781) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ctsk](https://github.com/ctsk)) +- \[Variant\] Add negative tests for reading invalid primitive variant values [\#7779](https://github.com/apache/arrow-rs/pull/7779) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([superserious-dev](https://github.com/superserious-dev)) +- \[Variant\] Support creating nested objects and object with lists [\#7778](https://github.com/apache/arrow-rs/pull/7778) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- \[VARIANT\] Validate precision in VariantDecimalXX structs and add missing tests [\#7776](https://github.com/apache/arrow-rs/pull/7776) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- Add tests for `BatchCoalescer::push_batch_with_filter`, fix bug [\#7774](https://github.com/apache/arrow-rs/pull/7774) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- \[Variant\] Minor: make fields in `VariantDecimal*` private, add examples [\#7770](https://github.com/apache/arrow-rs/pull/7770) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Extend the fast path in GenericByteViewArray::is\_eq for comparing against empty strings [\#7767](https://github.com/apache/arrow-rs/pull/7767) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jhorstmann](https://github.com/jhorstmann)) +- \[Variant\] Improve getter API for `VariantList` and `VariantObject` [\#7757](https://github.com/apache/arrow-rs/pull/7757) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- \[Variant\] Add Variant::as\_object and Variant::as\_list [\#7755](https://github.com/apache/arrow-rs/pull/7755) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- \[Variant\] Fix several overflow panic risks for 32-bit arch [\#7752](https://github.com/apache/arrow-rs/pull/7752) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- Add testing section to pull request template [\#7749](https://github.com/apache/arrow-rs/pull/7749) ([alamb](https://github.com/alamb)) +- Perf: Add prefix compare for inlined compare and change use of inline\_value to inline it to a u128 [\#7748](https://github.com/apache/arrow-rs/pull/7748) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- Move arrow-pyarrow tests that require `pyarrow` to be installed into `arrow-pyarrow-testing` crate [\#7742](https://github.com/apache/arrow-rs/pull/7742) ([alamb](https://github.com/alamb)) +- \[Variant\] Improve write API in `Variant::Object` [\#7741](https://github.com/apache/arrow-rs/pull/7741) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- \[Variant\] Support nested lists and object lists [\#7740](https://github.com/apache/arrow-rs/pull/7740) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- feat: \[Variant\] Add Validation for Variant Deciaml [\#7738](https://github.com/apache/arrow-rs/pull/7738) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Weijun-H](https://github.com/Weijun-H)) +- Add fallible versions of temporal functions that may panic [\#7737](https://github.com/apache/arrow-rs/pull/7737) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([adriangb](https://github.com/adriangb)) +- fix: Implement support for appending Object and List variants in VariantBuilder [\#7735](https://github.com/apache/arrow-rs/pull/7735) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([Weijun-H](https://github.com/Weijun-H)) +- parquet\_derive: update in working example for ParquetRecordWriter [\#7733](https://github.com/apache/arrow-rs/pull/7733) ([LanHikari22](https://github.com/LanHikari22)) +- Perf: Optimize comparison kernels for inlined views [\#7731](https://github.com/apache/arrow-rs/pull/7731) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- arrow-row: Refactor arrow-row REE roundtrip tests [\#7729](https://github.com/apache/arrow-rs/pull/7729) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) +- arrow-array: Implement PartialEq for RunArray [\#7727](https://github.com/apache/arrow-rs/pull/7727) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) +- fix: Do not add null buffer for `NullArray` in MutableArrayData [\#7726](https://github.com/apache/arrow-rs/pull/7726) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([comphead](https://github.com/comphead)) +- Allow per-column parquet dictionary page size limit [\#7724](https://github.com/apache/arrow-rs/pull/7724) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([XiangpengHao](https://github.com/XiangpengHao)) +- fix JSON decoder error checking for UTF16 / surrogate parsing panic [\#7721](https://github.com/apache/arrow-rs/pull/7721) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nicklan](https://github.com/nicklan)) +- \[Variant\] Use `BTreeMap` for `VariantBuilder.dict` and `ObjectBuilder.fields` to maintain invariants upon entry writes [\#7720](https://github.com/apache/arrow-rs/pull/7720) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- Introduce `MAX_INLINE_VIEW_LEN` constant for string/byte views [\#7719](https://github.com/apache/arrow-rs/pull/7719) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- \[Variant\] Introduce new type over &str for ShortString [\#7718](https://github.com/apache/arrow-rs/pull/7718) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- Split out variant code into several new sub-modules [\#7717](https://github.com/apache/arrow-rs/pull/7717) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- add `garbage_collect_dictionary` to `arrow-select` [\#7716](https://github.com/apache/arrow-rs/pull/7716) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([davidhewitt](https://github.com/davidhewitt)) +- Support write to buffer api for SerializedFileWriter [\#7714](https://github.com/apache/arrow-rs/pull/7714) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- Support `FixedSizeList` RowConverter [\#7705](https://github.com/apache/arrow-rs/pull/7705) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([findepi](https://github.com/findepi)) +- Make variant iterators safely infallible [\#7704](https://github.com/apache/arrow-rs/pull/7704) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- Speedup `interleave_views` \(4-7x faster\) [\#7695](https://github.com/apache/arrow-rs/pull/7695) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Define a "arrow-pyrarrow" crate to implement the "pyarrow" feature. [\#7694](https://github.com/apache/arrow-rs/pull/7694) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brunal](https://github.com/brunal)) +- feat: add constructor to efficiently upgrade dict key type to remaining builders [\#7689](https://github.com/apache/arrow-rs/pull/7689) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([albertlockett](https://github.com/albertlockett)) +- Document REE row format and add some more tests [\#7680](https://github.com/apache/arrow-rs/pull/7680) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- feat: add min max aggregate support for FixedSizeBinary [\#7675](https://github.com/apache/arrow-rs/pull/7675) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alexwilcoxson-rel](https://github.com/alexwilcoxson-rel)) +- arrow-data: Add REE support for `build_extend` and `build_extend_nulls` [\#7671](https://github.com/apache/arrow-rs/pull/7671) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) +- Variant: Write Variant Values as JSON [\#7670](https://github.com/apache/arrow-rs/pull/7670) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([carpecodeum](https://github.com/carpecodeum)) +- Remove `lazy_static` dependency [\#7669](https://github.com/apache/arrow-rs/pull/7669) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Expyron](https://github.com/Expyron)) +- Finish implementing Variant::Object and Variant::List [\#7666](https://github.com/apache/arrow-rs/pull/7666) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- Add `RecordBatch::schema_metadata_mut` and `Field::metadata_mut` [\#7664](https://github.com/apache/arrow-rs/pull/7664) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([emilk](https://github.com/emilk)) +- \[Variant\] Simplify creation of Variants from metadata and value [\#7663](https://github.com/apache/arrow-rs/pull/7663) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- chore: group prost dependabot updates [\#7659](https://github.com/apache/arrow-rs/pull/7659) ([mbrobbel](https://github.com/mbrobbel)) +- Initial Builder API for Creating Variant Values [\#7653](https://github.com/apache/arrow-rs/pull/7653) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([PinkCrow007](https://github.com/PinkCrow007)) +- Add `BatchCoalescer::push_filtered_batch` and docs [\#7652](https://github.com/apache/arrow-rs/pull/7652) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Optimize coalesce kernel for StringView \(10-50% faster\) [\#7650](https://github.com/apache/arrow-rs/pull/7650) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- arrow-row: Add support for REE [\#7649](https://github.com/apache/arrow-rs/pull/7649) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) +- Use approximate comparisons for pow tests [\#7646](https://github.com/apache/arrow-rs/pull/7646) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([adamreeve](https://github.com/adamreeve)) +- \[Variant\] Implement read support for remaining primitive types [\#7644](https://github.com/apache/arrow-rs/pull/7644) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([superserious-dev](https://github.com/superserious-dev)) +- Add `pretty_format_batches_with_schema` function [\#7642](https://github.com/apache/arrow-rs/pull/7642) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([lewiszlw](https://github.com/lewiszlw)) +- Deprecate old Parquet page index parsing functions [\#7640](https://github.com/apache/arrow-rs/pull/7640) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Update FlightSQL `GetDbSchemas` and `GetTables` schemas to fully match the protocol [\#7638](https://github.com/apache/arrow-rs/pull/7638) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([sgrebnov](https://github.com/sgrebnov)) +- Minor: Remove outdated FIXME from `ParquetMetaDataReader` [\#7635](https://github.com/apache/arrow-rs/pull/7635) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Fix the error info of `StructArray::try_new` [\#7634](https://github.com/apache/arrow-rs/pull/7634) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([xudong963](https://github.com/xudong963)) +- Fix reading encrypted Parquet pages when using the page index [\#7633](https://github.com/apache/arrow-rs/pull/7633) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([adamreeve](https://github.com/adamreeve)) +- \[Variant\] Add commented out primitive test casees [\#7631](https://github.com/apache/arrow-rs/pull/7631) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +## [55.2.0](https://github.com/apache/arrow-rs/tree/55.2.0) (2025-06-22) + +- Add a `strong_count` method to `Buffer` [\#7568](https://github.com/apache/arrow-rs/issues/7568) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Create version of LexicographicalComparator that compares fixed number of columns [\#7531](https://github.com/apache/arrow-rs/issues/7531) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- parquet-show-bloom-filter should work with integer typed columns [\#7528](https://github.com/apache/arrow-rs/issues/7528) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Allow merging primitive dictionary values in concat and interleave kernels [\#7518](https://github.com/apache/arrow-rs/issues/7518) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add efficient concatenation of StructArrays [\#7516](https://github.com/apache/arrow-rs/issues/7516) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Rename `flight-sql-experimental` to `flight-sql` [\#7498](https://github.com/apache/arrow-rs/issues/7498) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Consider moving from ryu to lexical-core for string formatting / casting floats to string. [\#7496](https://github.com/apache/arrow-rs/issues/7496) +- Arithmetic kernels can be safer and faster [\#7494](https://github.com/apache/arrow-rs/issues/7494) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Speedup `filter_bytes` by precalculating capacity [\#7465](https://github.com/apache/arrow-rs/issues/7465) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\]: Rust API to Create Variant Values [\#7424](https://github.com/apache/arrow-rs/issues/7424) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Rust API to Read Variant Values [\#7423](https://github.com/apache/arrow-rs/issues/7423) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Release arrow-rs / parquet Minor version `55.1.0` \(May 2025\) [\#7393](https://github.com/apache/arrow-rs/issues/7393) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support create\_random\_array for Decimal data types [\#7343](https://github.com/apache/arrow-rs/issues/7343) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Truncate Parquet page data page statistics [\#7555](https://github.com/apache/arrow-rs/pull/7555) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) + +**Fixed bugs:** + +- In arrow\_json, Decoder::decode can panic if it encounters two high surrogates in a row. [\#7712](https://github.com/apache/arrow-rs/issues/7712) +- FlightSQL "GetDbSchemas" and "GetTables" schemas do not fully match the protocol [\#7637](https://github.com/apache/arrow-rs/issues/7637) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Cannot read encrypted Parquet file if page index reading is enabled [\#7629](https://github.com/apache/arrow-rs/issues/7629) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `encoding_stats` not present in Parquet generated by `parquet-rewrite` [\#7616](https://github.com/apache/arrow-rs/issues/7616) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- When writing parquet plaintext footer files `footer_signing_key_metadata` is not included, encryption alghoritm is always written in footer [\#7599](https://github.com/apache/arrow-rs/issues/7599) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `new_null_array` panics when constructing a struct of a dictionary [\#7571](https://github.com/apache/arrow-rs/issues/7571) +- Parquet derive fails to build when Result is aliased [\#7547](https://github.com/apache/arrow-rs/issues/7547) +- Unable to read `Dictionary(u8, FixedSizeBinary(_))` using datafusion. [\#7545](https://github.com/apache/arrow-rs/issues/7545) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- filter\_record\_batch panics with empty struct array. [\#7538](https://github.com/apache/arrow-rs/issues/7538) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Panic in `pretty_format` function when displaying DurationSecondsArray with `i64::MIN` / `i64::MAX` [\#7533](https://github.com/apache/arrow-rs/issues/7533) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Record API unable to parse TIME\_MILLIS when encoded as INT32 [\#7510](https://github.com/apache/arrow-rs/issues/7510) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- The `read_record_batch` func of the `RecordBatchDecoder` does not respect the `skip_validation` property [\#7508](https://github.com/apache/arrow-rs/issues/7508) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `arrow-55.1.0` breaks `filter_record_batch` [\#7500](https://github.com/apache/arrow-rs/issues/7500) +- Files containing binary data with \>=8\_388\_855 bytes per row written with `arrow-rs` can't be read with `pyarrow` [\#7489](https://github.com/apache/arrow-rs/issues/7489) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Bug\] Ingestion with Arrow Flight Sql panic when the input stream is empty or fallible [\#7329](https://github.com/apache/arrow-rs/issues/7329) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Ensure page encoding statistics are written to Parquet file [\#7643](https://github.com/apache/arrow-rs/pull/7643) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) + +**Documentation updates:** + +- arrow\_reader\_row\_filter benchmark doesn't capture page cache improvements [\#7460](https://github.com/apache/arrow-rs/issues/7460) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- chore: fix a typo in `ExtensionType::supports_data_type` docs [\#7682](https://github.com/apache/arrow-rs/pull/7682) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- \[Variant\] Add variant docs and examples [\#7661](https://github.com/apache/arrow-rs/pull/7661) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Minor: Add version to deprecation notice for `ParquetMetaDataReader::decode_footer` [\#7639](https://github.com/apache/arrow-rs/pull/7639) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Add references for defaults in `WriterPropertiesBuilder` [\#7558](https://github.com/apache/arrow-rs/pull/7558) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Clarify Docs: NullBuffer::len is in bits [\#7556](https://github.com/apache/arrow-rs/pull/7556) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- docs: fix typo for `Decimal128Array` [\#7525](https://github.com/apache/arrow-rs/pull/7525) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([burmecia](https://github.com/burmecia)) +- Minor: Add examples to ProjectionMask documentation [\#7523](https://github.com/apache/arrow-rs/pull/7523) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Improve documentation for Parquet `WriterProperties` [\#7491](https://github.com/apache/arrow-rs/pull/7491) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) + +**Closed issues:** + +- \[Variant\] More efficient determination of String vs ShortString [\#7700](https://github.com/apache/arrow-rs/issues/7700) +- \[Variant\] Improve API for iterating over values of a VariantList [\#7685](https://github.com/apache/arrow-rs/issues/7685) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Consider validating variants on creation \(rather than read\) [\#7684](https://github.com/apache/arrow-rs/issues/7684) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Miri test\_native\_type\_pow test failing [\#7641](https://github.com/apache/arrow-rs/issues/7641) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Improve performance of `coalesce` and `concat` for views [\#7615](https://github.com/apache/arrow-rs/issues/7615) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Bad min value in row group statistics in some special cases [\#7593](https://github.com/apache/arrow-rs/issues/7593) +- Feature Request: BloomFilter Position Flexibility in `parquet-rewrite` [\#7552](https://github.com/apache/arrow-rs/issues/7552) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Merged pull requests:** + +- arrow-array: Implement PartialEq for RunArray [\#7727](https://github.com/apache/arrow-rs/pull/7727) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) +- fix: Do not add null buffer for `NullArray` in MutableArrayData [\#7726](https://github.com/apache/arrow-rs/pull/7726) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([comphead](https://github.com/comphead)) +- fix JSON decoder error checking for UTF16 / surrogate parsing panic [\#7721](https://github.com/apache/arrow-rs/pull/7721) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nicklan](https://github.com/nicklan)) +- \[Variant\] Introduce new type over &str for ShortString [\#7718](https://github.com/apache/arrow-rs/pull/7718) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- Split out variant code into several new sub-modules [\#7717](https://github.com/apache/arrow-rs/pull/7717) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- Support write to buffer api for SerializedFileWriter [\#7714](https://github.com/apache/arrow-rs/pull/7714) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- Make variant iterators safely infallible [\#7704](https://github.com/apache/arrow-rs/pull/7704) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- Speedup `interleave_views` \(4-7x faster\) [\#7695](https://github.com/apache/arrow-rs/pull/7695) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Define a "arrow-pyrarrow" crate to implement the "pyarrow" feature. [\#7694](https://github.com/apache/arrow-rs/pull/7694) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brunal](https://github.com/brunal)) +- Document REE row format and add some more tests [\#7680](https://github.com/apache/arrow-rs/pull/7680) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- feat: add min max aggregate support for FixedSizeBinary [\#7675](https://github.com/apache/arrow-rs/pull/7675) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alexwilcoxson-rel](https://github.com/alexwilcoxson-rel)) +- arrow-data: Add REE support for `build_extend` and `build_extend_nulls` [\#7671](https://github.com/apache/arrow-rs/pull/7671) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) +- Remove `lazy_static` dependency [\#7669](https://github.com/apache/arrow-rs/pull/7669) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Expyron](https://github.com/Expyron)) +- Finish implementing Variant::Object and Variant::List [\#7666](https://github.com/apache/arrow-rs/pull/7666) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- Add `RecordBatch::schema_metadata_mut` and `Field::metadata_mut` [\#7664](https://github.com/apache/arrow-rs/pull/7664) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([emilk](https://github.com/emilk)) +- \[Variant\] Simplify creation of Variants from metadata and value [\#7663](https://github.com/apache/arrow-rs/pull/7663) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- chore: group prost dependabot updates [\#7659](https://github.com/apache/arrow-rs/pull/7659) ([mbrobbel](https://github.com/mbrobbel)) +- Initial Builder API for Creating Variant Values [\#7653](https://github.com/apache/arrow-rs/pull/7653) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([PinkCrow007](https://github.com/PinkCrow007)) +- Add `BatchCoalescer::push_filtered_batch` and docs [\#7652](https://github.com/apache/arrow-rs/pull/7652) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Optimize coalesce kernel for StringView \(10-50% faster\) [\#7650](https://github.com/apache/arrow-rs/pull/7650) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- arrow-row: Add support for REE [\#7649](https://github.com/apache/arrow-rs/pull/7649) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) +- Use approximate comparisons for pow tests [\#7646](https://github.com/apache/arrow-rs/pull/7646) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([adamreeve](https://github.com/adamreeve)) +- \[Variant\] Implement read support for remaining primitive types [\#7644](https://github.com/apache/arrow-rs/pull/7644) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([superserious-dev](https://github.com/superserious-dev)) +- Add `pretty_format_batches_with_schema` function [\#7642](https://github.com/apache/arrow-rs/pull/7642) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([lewiszlw](https://github.com/lewiszlw)) +- Deprecate old Parquet page index parsing functions [\#7640](https://github.com/apache/arrow-rs/pull/7640) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Update FlightSQL `GetDbSchemas` and `GetTables` schemas to fully match the protocol [\#7638](https://github.com/apache/arrow-rs/pull/7638) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([sgrebnov](https://github.com/sgrebnov)) +- Minor: Remove outdated FIXME from `ParquetMetaDataReader` [\#7635](https://github.com/apache/arrow-rs/pull/7635) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Fix the error info of `StructArray::try_new` [\#7634](https://github.com/apache/arrow-rs/pull/7634) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([xudong963](https://github.com/xudong963)) +- Fix reading encrypted Parquet pages when using the page index [\#7633](https://github.com/apache/arrow-rs/pull/7633) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([adamreeve](https://github.com/adamreeve)) +- \[Variant\] Add commented out primitive test casees [\#7631](https://github.com/apache/arrow-rs/pull/7631) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Improve `coalesce` kernel tests [\#7626](https://github.com/apache/arrow-rs/pull/7626) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Revert "Revert "Improve `coalesce` and `concat` performance for views… [\#7625](https://github.com/apache/arrow-rs/pull/7625) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Revert "Improve `coalesce` and `concat` performance for views \(\#7614\)" [\#7623](https://github.com/apache/arrow-rs/pull/7623) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Improve coalesce\_kernel benchmark to capture inline vs non inline views [\#7619](https://github.com/apache/arrow-rs/pull/7619) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Improve `coalesce` and `concat` performance for views [\#7614](https://github.com/apache/arrow-rs/pull/7614) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- feat: add constructor to help efficiently upgrade key for GenericBytesDictionaryBuilder [\#7611](https://github.com/apache/arrow-rs/pull/7611) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([albertlockett](https://github.com/albertlockett)) +- feat: support append\_nulls on additional builders [\#7606](https://github.com/apache/arrow-rs/pull/7606) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([albertlockett](https://github.com/albertlockett)) +- feat: add AsyncArrowWriter::into\_inner [\#7604](https://github.com/apache/arrow-rs/pull/7604) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jpopesculian](https://github.com/jpopesculian)) +- Move variant interop test to Rust integration test [\#7602](https://github.com/apache/arrow-rs/pull/7602) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Include footer key metadata when writing encrypted Parquet with a plaintext footer [\#7600](https://github.com/apache/arrow-rs/pull/7600) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([rok](https://github.com/rok)) +- Add `coalesce` kernel and`BatchCoalescer` for statefully combining selected b…atches: [\#7597](https://github.com/apache/arrow-rs/pull/7597) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add FixedSizeBinary to `take_kernel` benchmark [\#7592](https://github.com/apache/arrow-rs/pull/7592) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Fix GenericBinaryArray docstring. [\#7588](https://github.com/apache/arrow-rs/pull/7588) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brunal](https://github.com/brunal)) +- fix: error reading multiple batches of `Dict(_, FixedSizeBinary(_))` [\#7585](https://github.com/apache/arrow-rs/pull/7585) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([albertlockett](https://github.com/albertlockett)) +- Revert "Minor: remove filter code deprecated in 2023 \(\#7554\)" [\#7583](https://github.com/apache/arrow-rs/pull/7583) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Fixed a warning build build: function never used. [\#7577](https://github.com/apache/arrow-rs/pull/7577) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([JigaoLuo](https://github.com/JigaoLuo)) +- Adding Encoding argument in `parquet-rewrite` [\#7576](https://github.com/apache/arrow-rs/pull/7576) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([JigaoLuo](https://github.com/JigaoLuo)) +- feat: add `row_group_is_[max/min]_value_exact` to StatisticsConverter [\#7574](https://github.com/apache/arrow-rs/pull/7574) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([CookiePieWw](https://github.com/CookiePieWw)) +- \[array\] Remove unwrap checks from GenericByteArray::value\_unchecked [\#7573](https://github.com/apache/arrow-rs/pull/7573) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ctsk](https://github.com/ctsk)) +- \[benches/row\_format\] fix typo in array lengths [\#7572](https://github.com/apache/arrow-rs/pull/7572) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ctsk](https://github.com/ctsk)) +- Add a strong\_count method to Buffer [\#7569](https://github.com/apache/arrow-rs/pull/7569) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([westonpace](https://github.com/westonpace)) +- Minor: Enable byte view for clickbench benchmark [\#7565](https://github.com/apache/arrow-rs/pull/7565) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- Optimize length calculation in row encoding for fixed-length columns [\#7564](https://github.com/apache/arrow-rs/pull/7564) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ctsk](https://github.com/ctsk)) +- Use PR title and description for commit message [\#7563](https://github.com/apache/arrow-rs/pull/7563) ([kou](https://github.com/kou)) +- Use apache/arrow-{go,java,js} in integration test [\#7561](https://github.com/apache/arrow-rs/pull/7561) ([kou](https://github.com/kou)) +- Implement Array Decoding in arrow-avro [\#7559](https://github.com/apache/arrow-rs/pull/7559) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Minor: remove filter code deprecated in 2023 [\#7554](https://github.com/apache/arrow-rs/pull/7554) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- fix: Correct docs for `WriterPropertiesBuilder::set_column_index_truncate_length` [\#7553](https://github.com/apache/arrow-rs/pull/7553) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Adding Bloom Filter Position argument in parquet-rewrite [\#7550](https://github.com/apache/arrow-rs/pull/7550) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([JigaoLuo](https://github.com/JigaoLuo)) +- Fix `Result` name collision in parquet\_derive [\#7548](https://github.com/apache/arrow-rs/pull/7548) ([jspaezp](https://github.com/jspaezp)) +- Fix: Converted feature flight-sql-experimental to flight-sql [\#7546](https://github.com/apache/arrow-rs/pull/7546) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([kunalsinghdadhwal](https://github.com/kunalsinghdadhwal)) +- Fix CI on main due to logical conflict [\#7542](https://github.com/apache/arrow-rs/pull/7542) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Fix `filter_record_batch` panics with empty struct array [\#7539](https://github.com/apache/arrow-rs/pull/7539) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([thorfour](https://github.com/thorfour)) +- \[Variant\] Initial API for reading Variant data and metadata [\#7535](https://github.com/apache/arrow-rs/pull/7535) ([mkarbo](https://github.com/mkarbo)) +- fix: Panic in pretty\_format function when displaying DurationSecondsA… [\#7534](https://github.com/apache/arrow-rs/pull/7534) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- Create version of LexicographicalComparator that compares fixed number of columns \(~ -15%\) [\#7530](https://github.com/apache/arrow-rs/pull/7530) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Make parquet-show-bloom-filter work with integer typed columns [\#7529](https://github.com/apache/arrow-rs/pull/7529) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([adamreeve](https://github.com/adamreeve)) +- chore\(deps\): update criterion requirement from 0.5 to 0.6 [\#7527](https://github.com/apache/arrow-rs/pull/7527) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Minor: Add a parquet row\_filter test, reduce some test boiler plate [\#7522](https://github.com/apache/arrow-rs/pull/7522) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Refactor `build_array_reader` into a struct [\#7521](https://github.com/apache/arrow-rs/pull/7521) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- arrow: add concat structs benchmark [\#7520](https://github.com/apache/arrow-rs/pull/7520) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([asubiotto](https://github.com/asubiotto)) +- arrow-select: add support for merging primitive dictionary values [\#7519](https://github.com/apache/arrow-rs/pull/7519) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([asubiotto](https://github.com/asubiotto)) +- arrow-select: add support for optimized concatenation of struct arrays [\#7517](https://github.com/apache/arrow-rs/pull/7517) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([asubiotto](https://github.com/asubiotto)) +- Fix Clippy in CI for Rust 1.87 release [\#7514](https://github.com/apache/arrow-rs/pull/7514) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Simplify `ParquetRecordBatchReader::next` control logic [\#7512](https://github.com/apache/arrow-rs/pull/7512) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Fix record API support for reading INT32 encoded TIME\_MILLIS [\#7511](https://github.com/apache/arrow-rs/pull/7511) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([njaremko](https://github.com/njaremko)) +- RecordBatchDecoder: skip RecordBatch validation when `skip_validation` property is enabled [\#7509](https://github.com/apache/arrow-rs/pull/7509) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nilskch](https://github.com/nilskch)) +- Introduce `ReadPlan` to encapsulate the calculation of what parquet rows to decode [\#7502](https://github.com/apache/arrow-rs/pull/7502) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Update documentation for ParquetReader [\#7501](https://github.com/apache/arrow-rs/pull/7501) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Improve `Field` docs, add missing `Field::set_*` methods [\#7497](https://github.com/apache/arrow-rs/pull/7497) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Speed up arithmetic kernels, reduce `unsafe` usage [\#7493](https://github.com/apache/arrow-rs/pull/7493) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Prevent FlightSQL server panics for `do_put` when stream is empty or 1st stream element is an Err [\#7492](https://github.com/apache/arrow-rs/pull/7492) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([superserious-dev](https://github.com/superserious-dev)) +- arrow-ipc: add `StreamDecoder::schema` [\#7488](https://github.com/apache/arrow-rs/pull/7488) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([lidavidm](https://github.com/lidavidm)) +- arrow-select: Implement concat for `RunArray`s [\#7487](https://github.com/apache/arrow-rs/pull/7487) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) +- \[Variant\] Add \(empty\) `parquet-variant` crate, update `parquet-testing` pin [\#7485](https://github.com/apache/arrow-rs/pull/7485) ([alamb](https://github.com/alamb)) +- Improve error messages if schema hint mismatches with parquet schema [\#7481](https://github.com/apache/arrow-rs/pull/7481) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add `arrow_reader_clickbench` benchmark [\#7470](https://github.com/apache/arrow-rs/pull/7470) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Speedup `filter_bytes` ~-20-40%, `filter_native` low selectivity \(~-37%\) [\#7463](https://github.com/apache/arrow-rs/pull/7463) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +## [55.2.0](https://github.com/apache/arrow-rs/tree/55.2.0) (2025-06-22) + +[Full Changelog](https://github.com/apache/arrow-rs/compare/55.1.0...55.2.0) + +**Implemented enhancements:** + +- Do not populate nulls for `NullArray` for `MutableArrayData` [\#7725](https://github.com/apache/arrow-rs/issues/7725) +- Implement `PartialEq` for RunArray [\#7691](https://github.com/apache/arrow-rs/issues/7691) +- `interleave_views` is really slow [\#7688](https://github.com/apache/arrow-rs/issues/7688) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add min max aggregates for FixedSizeBinary [\#7674](https://github.com/apache/arrow-rs/issues/7674) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Deliver pyarrow as a standalone crate [\#7668](https://github.com/apache/arrow-rs/issues/7668) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Implement `VariantObject::field` and `VariantObject::fields` [\#7665](https://github.com/apache/arrow-rs/issues/7665) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Implement read support for remaining primitive types [\#7630](https://github.com/apache/arrow-rs/issues/7630) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Fast and ergonomic method to add metadata to a `RecordBatch` [\#7628](https://github.com/apache/arrow-rs/issues/7628) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add efficient way to change the keys of string dictionary builder [\#7610](https://github.com/apache/arrow-rs/issues/7610) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support `add_nulls` on additional builder types [\#7605](https://github.com/apache/arrow-rs/issues/7605) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add `into_inner` for `AsyncArrowWriter` [\#7603](https://github.com/apache/arrow-rs/issues/7603) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Optimize `PrimitiveBuilder::append_trusted_len_iter` [\#7591](https://github.com/apache/arrow-rs/issues/7591) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Benchmark for filter+concat and take+concat into even sized record batches [\#7589](https://github.com/apache/arrow-rs/issues/7589) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `max_statistics_truncate_length` is ignored when writing statistics to data page headers [\#7579](https://github.com/apache/arrow-rs/issues/7579) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Feature Request: Encoding in `parquet-rewrite` [\#7575](https://github.com/apache/arrow-rs/issues/7575) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Add a `strong_count` method to `Buffer` [\#7568](https://github.com/apache/arrow-rs/issues/7568) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Create version of LexicographicalComparator that compares fixed number of columns [\#7531](https://github.com/apache/arrow-rs/issues/7531) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- parquet-show-bloom-filter should work with integer typed columns [\#7528](https://github.com/apache/arrow-rs/issues/7528) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Allow merging primitive dictionary values in concat and interleave kernels [\#7518](https://github.com/apache/arrow-rs/issues/7518) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add efficient concatenation of StructArrays [\#7516](https://github.com/apache/arrow-rs/issues/7516) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Rename `flight-sql-experimental` to `flight-sql` [\#7498](https://github.com/apache/arrow-rs/issues/7498) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Consider moving from ryu to lexical-core for string formatting / casting floats to string. [\#7496](https://github.com/apache/arrow-rs/issues/7496) +- Arithmetic kernels can be safer and faster [\#7494](https://github.com/apache/arrow-rs/issues/7494) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Speedup `filter_bytes` by precalculating capacity [\#7465](https://github.com/apache/arrow-rs/issues/7465) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\]: Rust API to Create Variant Values [\#7424](https://github.com/apache/arrow-rs/issues/7424) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Rust API to Read Variant Values [\#7423](https://github.com/apache/arrow-rs/issues/7423) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Release arrow-rs / parquet Minor version `55.1.0` \(May 2025\) [\#7393](https://github.com/apache/arrow-rs/issues/7393) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support create\_random\_array for Decimal data types [\#7343](https://github.com/apache/arrow-rs/issues/7343) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Truncate Parquet page data page statistics [\#7555](https://github.com/apache/arrow-rs/pull/7555) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) + +**Fixed bugs:** + +- In arrow\_json, Decoder::decode can panic if it encounters two high surrogates in a row. [\#7712](https://github.com/apache/arrow-rs/issues/7712) +- FlightSQL "GetDbSchemas" and "GetTables" schemas do not fully match the protocol [\#7637](https://github.com/apache/arrow-rs/issues/7637) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Cannot read encrypted Parquet file if page index reading is enabled [\#7629](https://github.com/apache/arrow-rs/issues/7629) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `encoding_stats` not present in Parquet generated by `parquet-rewrite` [\#7616](https://github.com/apache/arrow-rs/issues/7616) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- When writing parquet plaintext footer files `footer_signing_key_metadata` is not included, encryption alghoritm is always written in footer [\#7599](https://github.com/apache/arrow-rs/issues/7599) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- `new_null_array` panics when constructing a struct of a dictionary [\#7571](https://github.com/apache/arrow-rs/issues/7571) +- Parquet derive fails to build when Result is aliased [\#7547](https://github.com/apache/arrow-rs/issues/7547) +- Unable to read `Dictionary(u8, FixedSizeBinary(_))` using datafusion. [\#7545](https://github.com/apache/arrow-rs/issues/7545) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- filter\_record\_batch panics with empty struct array. [\#7538](https://github.com/apache/arrow-rs/issues/7538) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Panic in `pretty_format` function when displaying DurationSecondsArray with `i64::MIN` / `i64::MAX` [\#7533](https://github.com/apache/arrow-rs/issues/7533) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Record API unable to parse TIME\_MILLIS when encoded as INT32 [\#7510](https://github.com/apache/arrow-rs/issues/7510) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- The `read_record_batch` func of the `RecordBatchDecoder` does not respect the `skip_validation` property [\#7508](https://github.com/apache/arrow-rs/issues/7508) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `arrow-55.1.0` breaks `filter_record_batch` [\#7500](https://github.com/apache/arrow-rs/issues/7500) +- Files containing binary data with \>=8\_388\_855 bytes per row written with `arrow-rs` can't be read with `pyarrow` [\#7489](https://github.com/apache/arrow-rs/issues/7489) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Bug\] Ingestion with Arrow Flight Sql panic when the input stream is empty or fallible [\#7329](https://github.com/apache/arrow-rs/issues/7329) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] +- Ensure page encoding statistics are written to Parquet file [\#7643](https://github.com/apache/arrow-rs/pull/7643) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) + +**Documentation updates:** + +- arrow\_reader\_row\_filter benchmark doesn't capture page cache improvements [\#7460](https://github.com/apache/arrow-rs/issues/7460) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- chore: fix a typo in `ExtensionType::supports_data_type` docs [\#7682](https://github.com/apache/arrow-rs/pull/7682) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- \[Variant\] Add variant docs and examples [\#7661](https://github.com/apache/arrow-rs/pull/7661) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Minor: Add version to deprecation notice for `ParquetMetaDataReader::decode_footer` [\#7639](https://github.com/apache/arrow-rs/pull/7639) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Add references for defaults in `WriterPropertiesBuilder` [\#7558](https://github.com/apache/arrow-rs/pull/7558) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Clarify Docs: NullBuffer::len is in bits [\#7556](https://github.com/apache/arrow-rs/pull/7556) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- docs: fix typo for `Decimal128Array` [\#7525](https://github.com/apache/arrow-rs/pull/7525) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([burmecia](https://github.com/burmecia)) +- Minor: Add examples to ProjectionMask documentation [\#7523](https://github.com/apache/arrow-rs/pull/7523) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Improve documentation for Parquet `WriterProperties` [\#7491](https://github.com/apache/arrow-rs/pull/7491) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) + +**Closed issues:** + +- \[Variant\] More efficient determination of String vs ShortString [\#7700](https://github.com/apache/arrow-rs/issues/7700) +- \[Variant\] Improve API for iterating over values of a VariantList [\#7685](https://github.com/apache/arrow-rs/issues/7685) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- \[Variant\] Consider validating variants on creation \(rather than read\) [\#7684](https://github.com/apache/arrow-rs/issues/7684) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Miri test\_native\_type\_pow test failing [\#7641](https://github.com/apache/arrow-rs/issues/7641) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Improve performance of `coalesce` and `concat` for views [\#7615](https://github.com/apache/arrow-rs/issues/7615) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Bad min value in row group statistics in some special cases [\#7593](https://github.com/apache/arrow-rs/issues/7593) +- Feature Request: BloomFilter Position Flexibility in `parquet-rewrite` [\#7552](https://github.com/apache/arrow-rs/issues/7552) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] + +**Merged pull requests:** + +- arrow-array: Implement PartialEq for RunArray [\#7727](https://github.com/apache/arrow-rs/pull/7727) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) +- fix: Do not add null buffer for `NullArray` in MutableArrayData [\#7726](https://github.com/apache/arrow-rs/pull/7726) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([comphead](https://github.com/comphead)) +- fix JSON decoder error checking for UTF16 / surrogate parsing panic [\#7721](https://github.com/apache/arrow-rs/pull/7721) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nicklan](https://github.com/nicklan)) +- \[Variant\] Introduce new type over &str for ShortString [\#7718](https://github.com/apache/arrow-rs/pull/7718) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) +- Split out variant code into several new sub-modules [\#7717](https://github.com/apache/arrow-rs/pull/7717) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- Support write to buffer api for SerializedFileWriter [\#7714](https://github.com/apache/arrow-rs/pull/7714) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- Make variant iterators safely infallible [\#7704](https://github.com/apache/arrow-rs/pull/7704) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- Speedup `interleave_views` \(4-7x faster\) [\#7695](https://github.com/apache/arrow-rs/pull/7695) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Define a "arrow-pyrarrow" crate to implement the "pyarrow" feature. [\#7694](https://github.com/apache/arrow-rs/pull/7694) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brunal](https://github.com/brunal)) +- Document REE row format and add some more tests [\#7680](https://github.com/apache/arrow-rs/pull/7680) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- feat: add min max aggregate support for FixedSizeBinary [\#7675](https://github.com/apache/arrow-rs/pull/7675) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alexwilcoxson-rel](https://github.com/alexwilcoxson-rel)) +- arrow-data: Add REE support for `build_extend` and `build_extend_nulls` [\#7671](https://github.com/apache/arrow-rs/pull/7671) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) +- Remove `lazy_static` dependency [\#7669](https://github.com/apache/arrow-rs/pull/7669) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Expyron](https://github.com/Expyron)) +- Finish implementing Variant::Object and Variant::List [\#7666](https://github.com/apache/arrow-rs/pull/7666) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) +- Add `RecordBatch::schema_metadata_mut` and `Field::metadata_mut` [\#7664](https://github.com/apache/arrow-rs/pull/7664) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([emilk](https://github.com/emilk)) +- \[Variant\] Simplify creation of Variants from metadata and value [\#7663](https://github.com/apache/arrow-rs/pull/7663) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- chore: group prost dependabot updates [\#7659](https://github.com/apache/arrow-rs/pull/7659) ([mbrobbel](https://github.com/mbrobbel)) +- Initial Builder API for Creating Variant Values [\#7653](https://github.com/apache/arrow-rs/pull/7653) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([PinkCrow007](https://github.com/PinkCrow007)) +- Add `BatchCoalescer::push_filtered_batch` and docs [\#7652](https://github.com/apache/arrow-rs/pull/7652) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Optimize coalesce kernel for StringView \(10-50% faster\) [\#7650](https://github.com/apache/arrow-rs/pull/7650) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- arrow-row: Add support for REE [\#7649](https://github.com/apache/arrow-rs/pull/7649) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) +- Use approximate comparisons for pow tests [\#7646](https://github.com/apache/arrow-rs/pull/7646) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([adamreeve](https://github.com/adamreeve)) +- \[Variant\] Implement read support for remaining primitive types [\#7644](https://github.com/apache/arrow-rs/pull/7644) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([superserious-dev](https://github.com/superserious-dev)) +- Add `pretty_format_batches_with_schema` function [\#7642](https://github.com/apache/arrow-rs/pull/7642) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([lewiszlw](https://github.com/lewiszlw)) +- Deprecate old Parquet page index parsing functions [\#7640](https://github.com/apache/arrow-rs/pull/7640) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Update FlightSQL `GetDbSchemas` and `GetTables` schemas to fully match the protocol [\#7638](https://github.com/apache/arrow-rs/pull/7638) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([sgrebnov](https://github.com/sgrebnov)) +- Minor: Remove outdated FIXME from `ParquetMetaDataReader` [\#7635](https://github.com/apache/arrow-rs/pull/7635) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Fix the error info of `StructArray::try_new` [\#7634](https://github.com/apache/arrow-rs/pull/7634) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([xudong963](https://github.com/xudong963)) +- Fix reading encrypted Parquet pages when using the page index [\#7633](https://github.com/apache/arrow-rs/pull/7633) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([adamreeve](https://github.com/adamreeve)) +- \[Variant\] Add commented out primitive test casees [\#7631](https://github.com/apache/arrow-rs/pull/7631) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Improve `coalesce` kernel tests [\#7626](https://github.com/apache/arrow-rs/pull/7626) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Revert "Revert "Improve `coalesce` and `concat` performance for views… [\#7625](https://github.com/apache/arrow-rs/pull/7625) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Revert "Improve `coalesce` and `concat` performance for views \(\#7614\)" [\#7623](https://github.com/apache/arrow-rs/pull/7623) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Improve coalesce\_kernel benchmark to capture inline vs non inline views [\#7619](https://github.com/apache/arrow-rs/pull/7619) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Improve `coalesce` and `concat` performance for views [\#7614](https://github.com/apache/arrow-rs/pull/7614) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- feat: add constructor to help efficiently upgrade key for GenericBytesDictionaryBuilder [\#7611](https://github.com/apache/arrow-rs/pull/7611) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([albertlockett](https://github.com/albertlockett)) +- feat: support append\_nulls on additional builders [\#7606](https://github.com/apache/arrow-rs/pull/7606) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([albertlockett](https://github.com/albertlockett)) +- feat: add AsyncArrowWriter::into\_inner [\#7604](https://github.com/apache/arrow-rs/pull/7604) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jpopesculian](https://github.com/jpopesculian)) +- Move variant interop test to Rust integration test [\#7602](https://github.com/apache/arrow-rs/pull/7602) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Include footer key metadata when writing encrypted Parquet with a plaintext footer [\#7600](https://github.com/apache/arrow-rs/pull/7600) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([rok](https://github.com/rok)) +- Add `coalesce` kernel and`BatchCoalescer` for statefully combining selected b…atches: [\#7597](https://github.com/apache/arrow-rs/pull/7597) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add FixedSizeBinary to `take_kernel` benchmark [\#7592](https://github.com/apache/arrow-rs/pull/7592) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Fix GenericBinaryArray docstring. [\#7588](https://github.com/apache/arrow-rs/pull/7588) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brunal](https://github.com/brunal)) +- fix: error reading multiple batches of `Dict(_, FixedSizeBinary(_))` [\#7585](https://github.com/apache/arrow-rs/pull/7585) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([albertlockett](https://github.com/albertlockett)) +- Revert "Minor: remove filter code deprecated in 2023 \(\#7554\)" [\#7583](https://github.com/apache/arrow-rs/pull/7583) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Fixed a warning build build: function never used. [\#7577](https://github.com/apache/arrow-rs/pull/7577) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([JigaoLuo](https://github.com/JigaoLuo)) +- Adding Encoding argument in `parquet-rewrite` [\#7576](https://github.com/apache/arrow-rs/pull/7576) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([JigaoLuo](https://github.com/JigaoLuo)) +- feat: add `row_group_is_[max/min]_value_exact` to StatisticsConverter [\#7574](https://github.com/apache/arrow-rs/pull/7574) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([CookiePieWw](https://github.com/CookiePieWw)) +- \[array\] Remove unwrap checks from GenericByteArray::value\_unchecked [\#7573](https://github.com/apache/arrow-rs/pull/7573) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ctsk](https://github.com/ctsk)) +- \[benches/row\_format\] fix typo in array lengths [\#7572](https://github.com/apache/arrow-rs/pull/7572) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ctsk](https://github.com/ctsk)) +- Add a strong\_count method to Buffer [\#7569](https://github.com/apache/arrow-rs/pull/7569) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([westonpace](https://github.com/westonpace)) +- Minor: Enable byte view for clickbench benchmark [\#7565](https://github.com/apache/arrow-rs/pull/7565) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- Optimize length calculation in row encoding for fixed-length columns [\#7564](https://github.com/apache/arrow-rs/pull/7564) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ctsk](https://github.com/ctsk)) +- Use PR title and description for commit message [\#7563](https://github.com/apache/arrow-rs/pull/7563) ([kou](https://github.com/kou)) +- Use apache/arrow-{go,java,js} in integration test [\#7561](https://github.com/apache/arrow-rs/pull/7561) ([kou](https://github.com/kou)) +- Implement Array Decoding in arrow-avro [\#7559](https://github.com/apache/arrow-rs/pull/7559) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Minor: remove filter code deprecated in 2023 [\#7554](https://github.com/apache/arrow-rs/pull/7554) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- fix: Correct docs for `WriterPropertiesBuilder::set_column_index_truncate_length` [\#7553](https://github.com/apache/arrow-rs/pull/7553) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Adding Bloom Filter Position argument in parquet-rewrite [\#7550](https://github.com/apache/arrow-rs/pull/7550) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([JigaoLuo](https://github.com/JigaoLuo)) +- Fix `Result` name collision in parquet\_derive [\#7548](https://github.com/apache/arrow-rs/pull/7548) ([jspaezp](https://github.com/jspaezp)) +- Fix: Converted feature flight-sql-experimental to flight-sql [\#7546](https://github.com/apache/arrow-rs/pull/7546) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([kunalsinghdadhwal](https://github.com/kunalsinghdadhwal)) +- Fix CI on main due to logical conflict [\#7542](https://github.com/apache/arrow-rs/pull/7542) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Fix `filter_record_batch` panics with empty struct array [\#7539](https://github.com/apache/arrow-rs/pull/7539) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([thorfour](https://github.com/thorfour)) +- \[Variant\] Initial API for reading Variant data and metadata [\#7535](https://github.com/apache/arrow-rs/pull/7535) ([mkarbo](https://github.com/mkarbo)) +- fix: Panic in pretty\_format function when displaying DurationSecondsA… [\#7534](https://github.com/apache/arrow-rs/pull/7534) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) +- Create version of LexicographicalComparator that compares fixed number of columns \(~ -15%\) [\#7530](https://github.com/apache/arrow-rs/pull/7530) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Make parquet-show-bloom-filter work with integer typed columns [\#7529](https://github.com/apache/arrow-rs/pull/7529) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([adamreeve](https://github.com/adamreeve)) +- chore\(deps\): update criterion requirement from 0.5 to 0.6 [\#7527](https://github.com/apache/arrow-rs/pull/7527) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) +- Minor: Add a parquet row\_filter test, reduce some test boiler plate [\#7522](https://github.com/apache/arrow-rs/pull/7522) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Refactor `build_array_reader` into a struct [\#7521](https://github.com/apache/arrow-rs/pull/7521) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- arrow: add concat structs benchmark [\#7520](https://github.com/apache/arrow-rs/pull/7520) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([asubiotto](https://github.com/asubiotto)) +- arrow-select: add support for merging primitive dictionary values [\#7519](https://github.com/apache/arrow-rs/pull/7519) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([asubiotto](https://github.com/asubiotto)) +- arrow-select: add support for optimized concatenation of struct arrays [\#7517](https://github.com/apache/arrow-rs/pull/7517) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([asubiotto](https://github.com/asubiotto)) +- Fix Clippy in CI for Rust 1.87 release [\#7514](https://github.com/apache/arrow-rs/pull/7514) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) +- Simplify `ParquetRecordBatchReader::next` control logic [\#7512](https://github.com/apache/arrow-rs/pull/7512) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Fix record API support for reading INT32 encoded TIME\_MILLIS [\#7511](https://github.com/apache/arrow-rs/pull/7511) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([njaremko](https://github.com/njaremko)) +- RecordBatchDecoder: skip RecordBatch validation when `skip_validation` property is enabled [\#7509](https://github.com/apache/arrow-rs/pull/7509) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nilskch](https://github.com/nilskch)) +- Introduce `ReadPlan` to encapsulate the calculation of what parquet rows to decode [\#7502](https://github.com/apache/arrow-rs/pull/7502) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Update documentation for ParquetReader [\#7501](https://github.com/apache/arrow-rs/pull/7501) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Improve `Field` docs, add missing `Field::set_*` methods [\#7497](https://github.com/apache/arrow-rs/pull/7497) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Speed up arithmetic kernels, reduce `unsafe` usage [\#7493](https://github.com/apache/arrow-rs/pull/7493) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Prevent FlightSQL server panics for `do_put` when stream is empty or 1st stream element is an Err [\#7492](https://github.com/apache/arrow-rs/pull/7492) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([superserious-dev](https://github.com/superserious-dev)) +- arrow-ipc: add `StreamDecoder::schema` [\#7488](https://github.com/apache/arrow-rs/pull/7488) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([lidavidm](https://github.com/lidavidm)) +- arrow-select: Implement concat for `RunArray`s [\#7487](https://github.com/apache/arrow-rs/pull/7487) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) +- \[Variant\] Add \(empty\) `parquet-variant` crate, update `parquet-testing` pin [\#7485](https://github.com/apache/arrow-rs/pull/7485) ([alamb](https://github.com/alamb)) +- Improve error messages if schema hint mismatches with parquet schema [\#7481](https://github.com/apache/arrow-rs/pull/7481) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add `arrow_reader_clickbench` benchmark [\#7470](https://github.com/apache/arrow-rs/pull/7470) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Speedup `filter_bytes` ~-20-40%, `filter_native` low selectivity \(~-37%\) [\#7463](https://github.com/apache/arrow-rs/pull/7463) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) +- Update arrow\_reader\_row\_filter benchmark to reflect ClickBench distribution [\#7461](https://github.com/apache/arrow-rs/pull/7461) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Add Map support to arrow-avro [\#7451](https://github.com/apache/arrow-rs/pull/7451) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Support Utf8View for Avro [\#7434](https://github.com/apache/arrow-rs/pull/7434) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kumarlokesh](https://github.com/kumarlokesh)) +- Add support for creating random Decimal128 and Decimal256 arrays [\#7427](https://github.com/apache/arrow-rs/pull/7427) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) + ## [55.1.0](https://github.com/apache/arrow-rs/tree/55.1.0) (2025-05-09) [Full Changelog](https://github.com/apache/arrow-rs/compare/55.0.0...55.1.0) diff --git a/CHANGELOG.md b/CHANGELOG.md index 03c5f6436fd5..fbbdba7d36ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,177 +19,172 @@ # Changelog -## [55.2.0](https://github.com/apache/arrow-rs/tree/55.2.0) (2025-06-22) +## [57.2.0](https://github.com/apache/arrow-rs/tree/57.2.0) (2026-01-07) -[Full Changelog](https://github.com/apache/arrow-rs/compare/55.1.0...55.2.0) +[Full Changelog](https://github.com/apache/arrow-rs/compare/57.1.0...57.2.0) + +**Breaking changes:** + +- Seal Array trait [\#9092](https://github.com/apache/arrow-rs/pull/9092) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tustvold](https://github.com/tustvold)) +- \[Variant\] Unify the CastOptions usage in parquet-variant-compute [\#8984](https://github.com/apache/arrow-rs/pull/8984) ([klion26](https://github.com/klion26)) **Implemented enhancements:** -- Do not populate nulls for `NullArray` for `MutableArrayData` [\#7725](https://github.com/apache/arrow-rs/issues/7725) -- Implement `PartialEq` for RunArray [\#7691](https://github.com/apache/arrow-rs/issues/7691) -- `interleave_views` is really slow [\#7688](https://github.com/apache/arrow-rs/issues/7688) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Add min max aggregates for FixedSizeBinary [\#7674](https://github.com/apache/arrow-rs/issues/7674) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Deliver pyarrow as a standalone crate [\#7668](https://github.com/apache/arrow-rs/issues/7668) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- \[Variant\] Implement `VariantObject::field` and `VariantObject::fields` [\#7665](https://github.com/apache/arrow-rs/issues/7665) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\] Implement read support for remaining primitive types [\#7630](https://github.com/apache/arrow-rs/issues/7630) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Fast and ergonomic method to add metadata to a `RecordBatch` [\#7628](https://github.com/apache/arrow-rs/issues/7628) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Add efficient way to change the keys of string dictionary builder [\#7610](https://github.com/apache/arrow-rs/issues/7610) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Support `add_nulls` on additional builder types [\#7605](https://github.com/apache/arrow-rs/issues/7605) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Add `into_inner` for `AsyncArrowWriter` [\#7603](https://github.com/apache/arrow-rs/issues/7603) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Optimize `PrimitiveBuilder::append_trusted_len_iter` [\#7591](https://github.com/apache/arrow-rs/issues/7591) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Benchmark for filter+concat and take+concat into even sized record batches [\#7589](https://github.com/apache/arrow-rs/issues/7589) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- `max_statistics_truncate_length` is ignored when writing statistics to data page headers [\#7579](https://github.com/apache/arrow-rs/issues/7579) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Feature Request: Encoding in `parquet-rewrite` [\#7575](https://github.com/apache/arrow-rs/issues/7575) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Add a `strong_count` method to `Buffer` [\#7568](https://github.com/apache/arrow-rs/issues/7568) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Create version of LexicographicalComparator that compares fixed number of columns [\#7531](https://github.com/apache/arrow-rs/issues/7531) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- parquet-show-bloom-filter should work with integer typed columns [\#7528](https://github.com/apache/arrow-rs/issues/7528) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Allow merging primitive dictionary values in concat and interleave kernels [\#7518](https://github.com/apache/arrow-rs/issues/7518) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Add efficient concatenation of StructArrays [\#7516](https://github.com/apache/arrow-rs/issues/7516) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Rename `flight-sql-experimental` to `flight-sql` [\#7498](https://github.com/apache/arrow-rs/issues/7498) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] -- Consider moving from ryu to lexical-core for string formatting / casting floats to string. [\#7496](https://github.com/apache/arrow-rs/issues/7496) -- Arithmetic kernels can be safer and faster [\#7494](https://github.com/apache/arrow-rs/issues/7494) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Speedup `filter_bytes` by precalculating capacity [\#7465](https://github.com/apache/arrow-rs/issues/7465) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- \[Variant\]: Rust API to Create Variant Values [\#7424](https://github.com/apache/arrow-rs/issues/7424) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- \[Variant\] Rust API to Read Variant Values [\#7423](https://github.com/apache/arrow-rs/issues/7423) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Release arrow-rs / parquet Minor version `55.1.0` \(May 2025\) [\#7393](https://github.com/apache/arrow-rs/issues/7393) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Support create\_random\_array for Decimal data types [\#7343](https://github.com/apache/arrow-rs/issues/7343) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Truncate Parquet page data page statistics [\#7555](https://github.com/apache/arrow-rs/pull/7555) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- \[parquet\] further relax `LevelInfoBuilder::types_compatible` for `ArrowWriter` [\#9098](https://github.com/apache/arrow-rs/issues/9098) +- Update arrow-row documentation with Union encoding [\#9084](https://github.com/apache/arrow-rs/issues/9084) +- Add code examples for min and max compute functions [\#9055](https://github.com/apache/arrow-rs/issues/9055) +- Add `append_n` to bytes view builder API [\#9034](https://github.com/apache/arrow-rs/issues/9034) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Move `RunArray::get_physical_indices` to `RunEndBuffer` [\#9025](https://github.com/apache/arrow-rs/issues/9025) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Allow quote style in csv writer [\#9003](https://github.com/apache/arrow-rs/issues/9003) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- IPC support for ListView [\#9002](https://github.com/apache/arrow-rs/issues/9002) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Implement `BinaryArrayType` for `&FixedSizeBinaryArray`s [\#8992](https://github.com/apache/arrow-rs/issues/8992) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- arrow-buffer: implement num-traits for i256 [\#8976](https://github.com/apache/arrow-rs/issues/8976) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Support for `Arc` in `ParquetRecordWriter` derive macro [\#8972](https://github.com/apache/arrow-rs/issues/8972) +- \[arrow-avro\] suggest switching from xz to liblzma [\#8970](https://github.com/apache/arrow-rs/issues/8970) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- arrow-buffer: add i256::trailing\_zeros [\#8968](https://github.com/apache/arrow-rs/issues/8968) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- arrow-buffer: make i256::leading\_zeros public [\#8965](https://github.com/apache/arrow-rs/issues/8965) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add spark like `ignoreLeadingWhiteSpace` and `ignoreTrailingWhiteSpace` options to the csv writer [\#8961](https://github.com/apache/arrow-rs/issues/8961) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add round trip benchmark for Parquet writer/reader [\#8955](https://github.com/apache/arrow-rs/issues/8955) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Support performant `interleave` for List/LargeList [\#8952](https://github.com/apache/arrow-rs/issues/8952) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Support array access when parsing `VariantPath` [\#8946](https://github.com/apache/arrow-rs/issues/8946) +- Some panic!s could be represented as unimplemented!s [\#8932](https://github.com/apache/arrow-rs/issues/8932) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] easier way to construct a shredded schema [\#8922](https://github.com/apache/arrow-rs/issues/8922) +- Support `DataType::ListView` and `DataType::LargeListView` in `ArrayData::new_null` [\#8908](https://github.com/apache/arrow-rs/issues/8908) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Add `GenericListViewArray::from_iter_primitive` [\#8906](https://github.com/apache/arrow-rs/issues/8906) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Unify the cast option usage in ParquentVariant [\#8873](https://github.com/apache/arrow-rs/issues/8873) +- Blog post about efficient filter representation in Parquet filter pushdown [\#8843](https://github.com/apache/arrow-rs/issues/8843) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- Add comparison support for Union arrays in the `cmp` kernel [\#8837](https://github.com/apache/arrow-rs/issues/8837) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Support array shredding into `List/LargeList/ListView/LargeListView` [\#8830](https://github.com/apache/arrow-rs/issues/8830) +- Support `Union` data types for row format [\#8828](https://github.com/apache/arrow-rs/issues/8828) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- FFI support for ListView [\#8819](https://github.com/apache/arrow-rs/issues/8819) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[Variant\] Support more Arrow Datatypes from Variant primitive types [\#8805](https://github.com/apache/arrow-rs/issues/8805) +- `FixedSizeBinaryBuilder` supports `append_array` [\#8750](https://github.com/apache/arrow-rs/issues/8750) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Implement special case `zip` with scalar for Utf8View [\#8724](https://github.com/apache/arrow-rs/issues/8724) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[geometry\] Wire up arrow reader/writer for `GEOMETRY` and `GEOGRAPHY` [\#8717](https://github.com/apache/arrow-rs/issues/8717) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] **Fixed bugs:** -- In arrow\_json, Decoder::decode can panic if it encounters two high surrogates in a row. [\#7712](https://github.com/apache/arrow-rs/issues/7712) -- FlightSQL "GetDbSchemas" and "GetTables" schemas do not fully match the protocol [\#7637](https://github.com/apache/arrow-rs/issues/7637) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] -- Cannot read encrypted Parquet file if page index reading is enabled [\#7629](https://github.com/apache/arrow-rs/issues/7629) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- `encoding_stats` not present in Parquet generated by `parquet-rewrite` [\#7616](https://github.com/apache/arrow-rs/issues/7616) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- When writing parquet plaintext footer files `footer_signing_key_metadata` is not included, encryption alghoritm is always written in footer [\#7599](https://github.com/apache/arrow-rs/issues/7599) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- `new_null_array` panics when constructing a struct of a dictionary [\#7571](https://github.com/apache/arrow-rs/issues/7571) -- Parquet derive fails to build when Result is aliased [\#7547](https://github.com/apache/arrow-rs/issues/7547) -- Unable to read `Dictionary(u8, FixedSizeBinary(_))` using datafusion. [\#7545](https://github.com/apache/arrow-rs/issues/7545) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- filter\_record\_batch panics with empty struct array. [\#7538](https://github.com/apache/arrow-rs/issues/7538) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Panic in `pretty_format` function when displaying DurationSecondsArray with `i64::MIN` / `i64::MAX` [\#7533](https://github.com/apache/arrow-rs/issues/7533) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Record API unable to parse TIME\_MILLIS when encoded as INT32 [\#7510](https://github.com/apache/arrow-rs/issues/7510) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- The `read_record_batch` func of the `RecordBatchDecoder` does not respect the `skip_validation` property [\#7508](https://github.com/apache/arrow-rs/issues/7508) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- `arrow-55.1.0` breaks `filter_record_batch` [\#7500](https://github.com/apache/arrow-rs/issues/7500) -- Files containing binary data with \>=8\_388\_855 bytes per row written with `arrow-rs` can't be read with `pyarrow` [\#7489](https://github.com/apache/arrow-rs/issues/7489) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Bug\] Ingestion with Arrow Flight Sql panic when the input stream is empty or fallible [\#7329](https://github.com/apache/arrow-rs/issues/7329) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] -- Ensure page encoding statistics are written to Parquet file [\#7643](https://github.com/apache/arrow-rs/pull/7643) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Soundness Bug in `try_binary` when `Array` is implemented incorrectly in external crate [\#9106](https://github.com/apache/arrow-rs/issues/9106) +- casting `Dict(_, LargeUtf8)` to `Utf8View` \(`StringViewArray`\) panics [\#9101](https://github.com/apache/arrow-rs/issues/9101) +- wrong results for null count of `nullif` kernel [\#9085](https://github.com/apache/arrow-rs/issues/9085) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Empty first line in some code examples [\#9063](https://github.com/apache/arrow-rs/issues/9063) +- GenericByteViewArray::slice is not zero-copy but ought to be [\#9014](https://github.com/apache/arrow-rs/issues/9014) +- Regression in struct casting in 57.2.0 \(not yet released\) [\#9005](https://github.com/apache/arrow-rs/issues/9005) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Fix panic when decoding multiple Union columns in RowConverter [\#8999](https://github.com/apache/arrow-rs/issues/8999) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `take_fixed_size_binary` Does Not Consider NULL Indices [\#8947](https://github.com/apache/arrow-rs/issues/8947) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- \[arrow-avro\] RecordEncoder Bugs [\#8934](https://github.com/apache/arrow-rs/issues/8934) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `FixedSizeBinaryArray::try_new(...)` Panics with Item Length of Zero [\#8926](https://github.com/apache/arrow-rs/issues/8926) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `cargo test -p arrow-cast` fails on main [\#8910](https://github.com/apache/arrow-rs/issues/8910) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `GenericListViewArray::new_null` ignores `len` and returns an empty array [\#8904](https://github.com/apache/arrow-rs/issues/8904) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- `FixedSizeBinaryArray::new_null` Does Not Properly Set the Length of the Values Buffer [\#8900](https://github.com/apache/arrow-rs/issues/8900) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Struct casting requires same order of fields [\#8870](https://github.com/apache/arrow-rs/issues/8870) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] +- Cannot cast string dictionary to binary view [\#8841](https://github.com/apache/arrow-rs/issues/8841) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] **Documentation updates:** -- arrow\_reader\_row\_filter benchmark doesn't capture page cache improvements [\#7460](https://github.com/apache/arrow-rs/issues/7460) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- chore: fix a typo in `ExtensionType::supports_data_type` docs [\#7682](https://github.com/apache/arrow-rs/pull/7682) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) -- \[Variant\] Add variant docs and examples [\#7661](https://github.com/apache/arrow-rs/pull/7661) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- Minor: Add version to deprecation notice for `ParquetMetaDataReader::decode_footer` [\#7639](https://github.com/apache/arrow-rs/pull/7639) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) -- Add references for defaults in `WriterPropertiesBuilder` [\#7558](https://github.com/apache/arrow-rs/pull/7558) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) -- Clarify Docs: NullBuffer::len is in bits [\#7556](https://github.com/apache/arrow-rs/pull/7556) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- docs: fix typo for `Decimal128Array` [\#7525](https://github.com/apache/arrow-rs/pull/7525) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([burmecia](https://github.com/burmecia)) -- Minor: Add examples to ProjectionMask documentation [\#7523](https://github.com/apache/arrow-rs/pull/7523) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- Improve documentation for Parquet `WriterProperties` [\#7491](https://github.com/apache/arrow-rs/pull/7491) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Add Union encoding documentation [\#9102](https://github.com/apache/arrow-rs/pull/9102) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([EduardAkhmetshin](https://github.com/EduardAkhmetshin)) +- docs: fix misleading reserve documentation [\#9076](https://github.com/apache/arrow-rs/pull/9076) ([WaterWhisperer](https://github.com/WaterWhisperer)) +- Fix headers and empty lines in code examples [\#9064](https://github.com/apache/arrow-rs/pull/9064) ([EduardAkhmetshin](https://github.com/EduardAkhmetshin)) +- Add examples for min and max functions [\#9062](https://github.com/apache/arrow-rs/pull/9062) ([EduardAkhmetshin](https://github.com/EduardAkhmetshin)) +- Improve arrow-buffer documentation [\#9020](https://github.com/apache/arrow-rs/pull/9020) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Move examples in arrow-csv to docstrings, polish up docs [\#9001](https://github.com/apache/arrow-rs/pull/9001) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add example of parsing field names as VariantPath [\#8945](https://github.com/apache/arrow-rs/pull/8945) ([alamb](https://github.com/alamb)) +- Improve documentation for `prep\_null\_mask\_flter [\#8722](https://github.com/apache/arrow-rs/pull/8722) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) + +**Performance improvements:** + +- \[parquet\] Avoid a clone while resolving the read strategy [\#9056](https://github.com/apache/arrow-rs/pull/9056) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- perf: improve performance of encoding `GenericByteArray` by 8% [\#9054](https://github.com/apache/arrow-rs/pull/9054) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- Speed up unary `not` kernel by 50%, add `BooleanBuffer::from_bitwise_unary` [\#8996](https://github.com/apache/arrow-rs/pull/8996) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- arrow-select: improve dictionary interleave fallback performance [\#8978](https://github.com/apache/arrow-rs/pull/8978) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([asubiotto](https://github.com/asubiotto)) +- Add special implementation for zip for Utf8View/BinaryView scalars [\#8963](https://github.com/apache/arrow-rs/pull/8963) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mkleen](https://github.com/mkleen)) +- arrow-select: implement specialized interleave\_list [\#8953](https://github.com/apache/arrow-rs/pull/8953) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([asubiotto](https://github.com/asubiotto)) **Closed issues:** -- \[Variant\] More efficient determination of String vs ShortString [\#7700](https://github.com/apache/arrow-rs/issues/7700) -- \[Variant\] Improve API for iterating over values of a VariantList [\#7685](https://github.com/apache/arrow-rs/issues/7685) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- \[Variant\] Consider validating variants on creation \(rather than read\) [\#7684](https://github.com/apache/arrow-rs/issues/7684) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] -- Miri test\_native\_type\_pow test failing [\#7641](https://github.com/apache/arrow-rs/issues/7641) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Improve performance of `coalesce` and `concat` for views [\#7615](https://github.com/apache/arrow-rs/issues/7615) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] -- Bad min value in row group statistics in some special cases [\#7593](https://github.com/apache/arrow-rs/issues/7593) -- Feature Request: BloomFilter Position Flexibility in `parquet-rewrite` [\#7552](https://github.com/apache/arrow-rs/issues/7552) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] +- impl `Index` for `UnionFields` [\#8958](https://github.com/apache/arrow-rs/issues/8958) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] **Merged pull requests:** -- arrow-array: Implement PartialEq for RunArray [\#7727](https://github.com/apache/arrow-rs/pull/7727) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) -- fix: Do not add null buffer for `NullArray` in MutableArrayData [\#7726](https://github.com/apache/arrow-rs/pull/7726) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([comphead](https://github.com/comphead)) -- fix JSON decoder error checking for UTF16 / surrogate parsing panic [\#7721](https://github.com/apache/arrow-rs/pull/7721) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nicklan](https://github.com/nicklan)) -- \[Variant\] Introduce new type over &str for ShortString [\#7718](https://github.com/apache/arrow-rs/pull/7718) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([friendlymatthew](https://github.com/friendlymatthew)) -- Split out variant code into several new sub-modules [\#7717](https://github.com/apache/arrow-rs/pull/7717) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) -- Support write to buffer api for SerializedFileWriter [\#7714](https://github.com/apache/arrow-rs/pull/7714) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) -- Make variant iterators safely infallible [\#7704](https://github.com/apache/arrow-rs/pull/7704) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) -- Speedup `interleave_views` \(4-7x faster\) [\#7695](https://github.com/apache/arrow-rs/pull/7695) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) -- Define a "arrow-pyrarrow" crate to implement the "pyarrow" feature. [\#7694](https://github.com/apache/arrow-rs/pull/7694) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brunal](https://github.com/brunal)) -- Document REE row format and add some more tests [\#7680](https://github.com/apache/arrow-rs/pull/7680) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- feat: add min max aggregate support for FixedSizeBinary [\#7675](https://github.com/apache/arrow-rs/pull/7675) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alexwilcoxson-rel](https://github.com/alexwilcoxson-rel)) -- arrow-data: Add REE support for `build_extend` and `build_extend_nulls` [\#7671](https://github.com/apache/arrow-rs/pull/7671) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) -- Remove `lazy_static` dependency [\#7669](https://github.com/apache/arrow-rs/pull/7669) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Expyron](https://github.com/Expyron)) -- Finish implementing Variant::Object and Variant::List [\#7666](https://github.com/apache/arrow-rs/pull/7666) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([scovich](https://github.com/scovich)) -- Add `RecordBatch::schema_metadata_mut` and `Field::metadata_mut` [\#7664](https://github.com/apache/arrow-rs/pull/7664) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([emilk](https://github.com/emilk)) -- \[Variant\] Simplify creation of Variants from metadata and value [\#7663](https://github.com/apache/arrow-rs/pull/7663) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- chore: group prost dependabot updates [\#7659](https://github.com/apache/arrow-rs/pull/7659) ([mbrobbel](https://github.com/mbrobbel)) -- Initial Builder API for Creating Variant Values [\#7653](https://github.com/apache/arrow-rs/pull/7653) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([PinkCrow007](https://github.com/PinkCrow007)) -- Add `BatchCoalescer::push_filtered_batch` and docs [\#7652](https://github.com/apache/arrow-rs/pull/7652) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- Optimize coalesce kernel for StringView \(10-50% faster\) [\#7650](https://github.com/apache/arrow-rs/pull/7650) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- arrow-row: Add support for REE [\#7649](https://github.com/apache/arrow-rs/pull/7649) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) -- Use approximate comparisons for pow tests [\#7646](https://github.com/apache/arrow-rs/pull/7646) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([adamreeve](https://github.com/adamreeve)) -- \[Variant\] Implement read support for remaining primitive types [\#7644](https://github.com/apache/arrow-rs/pull/7644) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([superserious-dev](https://github.com/superserious-dev)) -- Add `pretty_format_batches_with_schema` function [\#7642](https://github.com/apache/arrow-rs/pull/7642) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([lewiszlw](https://github.com/lewiszlw)) -- Deprecate old Parquet page index parsing functions [\#7640](https://github.com/apache/arrow-rs/pull/7640) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) -- Update FlightSQL `GetDbSchemas` and `GetTables` schemas to fully match the protocol [\#7638](https://github.com/apache/arrow-rs/pull/7638) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([sgrebnov](https://github.com/sgrebnov)) -- Minor: Remove outdated FIXME from `ParquetMetaDataReader` [\#7635](https://github.com/apache/arrow-rs/pull/7635) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) -- Fix the error info of `StructArray::try_new` [\#7634](https://github.com/apache/arrow-rs/pull/7634) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([xudong963](https://github.com/xudong963)) -- Fix reading encrypted Parquet pages when using the page index [\#7633](https://github.com/apache/arrow-rs/pull/7633) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([adamreeve](https://github.com/adamreeve)) -- \[Variant\] Add commented out primitive test casees [\#7631](https://github.com/apache/arrow-rs/pull/7631) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- Improve `coalesce` kernel tests [\#7626](https://github.com/apache/arrow-rs/pull/7626) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- Revert "Revert "Improve `coalesce` and `concat` performance for views… [\#7625](https://github.com/apache/arrow-rs/pull/7625) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) -- Revert "Improve `coalesce` and `concat` performance for views \(\#7614\)" [\#7623](https://github.com/apache/arrow-rs/pull/7623) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) -- Improve coalesce\_kernel benchmark to capture inline vs non inline views [\#7619](https://github.com/apache/arrow-rs/pull/7619) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- Improve `coalesce` and `concat` performance for views [\#7614](https://github.com/apache/arrow-rs/pull/7614) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) -- feat: add constructor to help efficiently upgrade key for GenericBytesDictionaryBuilder [\#7611](https://github.com/apache/arrow-rs/pull/7611) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([albertlockett](https://github.com/albertlockett)) -- feat: support append\_nulls on additional builders [\#7606](https://github.com/apache/arrow-rs/pull/7606) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([albertlockett](https://github.com/albertlockett)) -- feat: add AsyncArrowWriter::into\_inner [\#7604](https://github.com/apache/arrow-rs/pull/7604) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([jpopesculian](https://github.com/jpopesculian)) -- Move variant interop test to Rust integration test [\#7602](https://github.com/apache/arrow-rs/pull/7602) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- Include footer key metadata when writing encrypted Parquet with a plaintext footer [\#7600](https://github.com/apache/arrow-rs/pull/7600) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([rok](https://github.com/rok)) -- Add `coalesce` kernel and`BatchCoalescer` for statefully combining selected b…atches: [\#7597](https://github.com/apache/arrow-rs/pull/7597) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- Add FixedSizeBinary to `take_kernel` benchmark [\#7592](https://github.com/apache/arrow-rs/pull/7592) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- Fix GenericBinaryArray docstring. [\#7588](https://github.com/apache/arrow-rs/pull/7588) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brunal](https://github.com/brunal)) -- fix: error reading multiple batches of `Dict(_, FixedSizeBinary(_))` [\#7585](https://github.com/apache/arrow-rs/pull/7585) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([albertlockett](https://github.com/albertlockett)) -- Revert "Minor: remove filter code deprecated in 2023 \(\#7554\)" [\#7583](https://github.com/apache/arrow-rs/pull/7583) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- Fixed a warning build build: function never used. [\#7577](https://github.com/apache/arrow-rs/pull/7577) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([JigaoLuo](https://github.com/JigaoLuo)) -- Adding Encoding argument in `parquet-rewrite` [\#7576](https://github.com/apache/arrow-rs/pull/7576) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([JigaoLuo](https://github.com/JigaoLuo)) -- feat: add `row_group_is_[max/min]_value_exact` to StatisticsConverter [\#7574](https://github.com/apache/arrow-rs/pull/7574) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([CookiePieWw](https://github.com/CookiePieWw)) -- \[array\] Remove unwrap checks from GenericByteArray::value\_unchecked [\#7573](https://github.com/apache/arrow-rs/pull/7573) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ctsk](https://github.com/ctsk)) -- \[benches/row\_format\] fix typo in array lengths [\#7572](https://github.com/apache/arrow-rs/pull/7572) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ctsk](https://github.com/ctsk)) -- Add a strong\_count method to Buffer [\#7569](https://github.com/apache/arrow-rs/pull/7569) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([westonpace](https://github.com/westonpace)) -- Minor: Enable byte view for clickbench benchmark [\#7565](https://github.com/apache/arrow-rs/pull/7565) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) -- Optimize length calculation in row encoding for fixed-length columns [\#7564](https://github.com/apache/arrow-rs/pull/7564) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ctsk](https://github.com/ctsk)) -- Use PR title and description for commit message [\#7563](https://github.com/apache/arrow-rs/pull/7563) ([kou](https://github.com/kou)) -- Use apache/arrow-{go,java,js} in integration test [\#7561](https://github.com/apache/arrow-rs/pull/7561) ([kou](https://github.com/kou)) -- Implement Array Decoding in arrow-avro [\#7559](https://github.com/apache/arrow-rs/pull/7559) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) -- Minor: remove filter code deprecated in 2023 [\#7554](https://github.com/apache/arrow-rs/pull/7554) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- fix: Correct docs for `WriterPropertiesBuilder::set_column_index_truncate_length` [\#7553](https://github.com/apache/arrow-rs/pull/7553) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) -- Adding Bloom Filter Position argument in parquet-rewrite [\#7550](https://github.com/apache/arrow-rs/pull/7550) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([JigaoLuo](https://github.com/JigaoLuo)) -- Fix `Result` name collision in parquet\_derive [\#7548](https://github.com/apache/arrow-rs/pull/7548) ([jspaezp](https://github.com/jspaezp)) -- Fix: Converted feature flight-sql-experimental to flight-sql [\#7546](https://github.com/apache/arrow-rs/pull/7546) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([kunalsinghdadhwal](https://github.com/kunalsinghdadhwal)) -- Fix CI on main due to logical conflict [\#7542](https://github.com/apache/arrow-rs/pull/7542) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- Fix `filter_record_batch` panics with empty struct array [\#7539](https://github.com/apache/arrow-rs/pull/7539) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([thorfour](https://github.com/thorfour)) -- \[Variant\] Initial API for reading Variant data and metadata [\#7535](https://github.com/apache/arrow-rs/pull/7535) ([mkarbo](https://github.com/mkarbo)) -- fix: Panic in pretty\_format function when displaying DurationSecondsA… [\#7534](https://github.com/apache/arrow-rs/pull/7534) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([zhuqi-lucas](https://github.com/zhuqi-lucas)) -- Create version of LexicographicalComparator that compares fixed number of columns \(~ -15%\) [\#7530](https://github.com/apache/arrow-rs/pull/7530) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) -- Make parquet-show-bloom-filter work with integer typed columns [\#7529](https://github.com/apache/arrow-rs/pull/7529) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([adamreeve](https://github.com/adamreeve)) -- chore\(deps\): update criterion requirement from 0.5 to 0.6 [\#7527](https://github.com/apache/arrow-rs/pull/7527) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mbrobbel](https://github.com/mbrobbel)) -- Minor: Add a parquet row\_filter test, reduce some test boiler plate [\#7522](https://github.com/apache/arrow-rs/pull/7522) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- Refactor `build_array_reader` into a struct [\#7521](https://github.com/apache/arrow-rs/pull/7521) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- arrow: add concat structs benchmark [\#7520](https://github.com/apache/arrow-rs/pull/7520) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([asubiotto](https://github.com/asubiotto)) -- arrow-select: add support for merging primitive dictionary values [\#7519](https://github.com/apache/arrow-rs/pull/7519) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([asubiotto](https://github.com/asubiotto)) -- arrow-select: add support for optimized concatenation of struct arrays [\#7517](https://github.com/apache/arrow-rs/pull/7517) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([asubiotto](https://github.com/asubiotto)) -- Fix Clippy in CI for Rust 1.87 release [\#7514](https://github.com/apache/arrow-rs/pull/7514) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([alamb](https://github.com/alamb)) -- Simplify `ParquetRecordBatchReader::next` control logic [\#7512](https://github.com/apache/arrow-rs/pull/7512) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- Fix record API support for reading INT32 encoded TIME\_MILLIS [\#7511](https://github.com/apache/arrow-rs/pull/7511) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([njaremko](https://github.com/njaremko)) -- RecordBatchDecoder: skip RecordBatch validation when `skip_validation` property is enabled [\#7509](https://github.com/apache/arrow-rs/pull/7509) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([nilskch](https://github.com/nilskch)) -- Introduce `ReadPlan` to encapsulate the calculation of what parquet rows to decode [\#7502](https://github.com/apache/arrow-rs/pull/7502) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- Update documentation for ParquetReader [\#7501](https://github.com/apache/arrow-rs/pull/7501) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- Improve `Field` docs, add missing `Field::set_*` methods [\#7497](https://github.com/apache/arrow-rs/pull/7497) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- Speed up arithmetic kernels, reduce `unsafe` usage [\#7493](https://github.com/apache/arrow-rs/pull/7493) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) -- Prevent FlightSQL server panics for `do_put` when stream is empty or 1st stream element is an Err [\#7492](https://github.com/apache/arrow-rs/pull/7492) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([superserious-dev](https://github.com/superserious-dev)) -- arrow-ipc: add `StreamDecoder::schema` [\#7488](https://github.com/apache/arrow-rs/pull/7488) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([lidavidm](https://github.com/lidavidm)) -- arrow-select: Implement concat for `RunArray`s [\#7487](https://github.com/apache/arrow-rs/pull/7487) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) -- \[Variant\] Add \(empty\) `parquet-variant` crate, update `parquet-testing` pin [\#7485](https://github.com/apache/arrow-rs/pull/7485) ([alamb](https://github.com/alamb)) -- Improve error messages if schema hint mismatches with parquet schema [\#7481](https://github.com/apache/arrow-rs/pull/7481) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) -- Add `arrow_reader_clickbench` benchmark [\#7470](https://github.com/apache/arrow-rs/pull/7470) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- Speedup `filter_bytes` ~-20-40%, `filter_native` low selectivity \(~-37%\) [\#7463](https://github.com/apache/arrow-rs/pull/7463) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Dandandan](https://github.com/Dandandan)) -- Update arrow\_reader\_row\_filter benchmark to reflect ClickBench distribution [\#7461](https://github.com/apache/arrow-rs/pull/7461) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) -- Add Map support to arrow-avro [\#7451](https://github.com/apache/arrow-rs/pull/7451) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) -- Support Utf8View for Avro [\#7434](https://github.com/apache/arrow-rs/pull/7434) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([kumarlokesh](https://github.com/kumarlokesh)) -- Add support for creating random Decimal128 and Decimal256 arrays [\#7427](https://github.com/apache/arrow-rs/pull/7427) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Add `DataType::is_decimal` [\#9100](https://github.com/apache/arrow-rs/pull/9100) ([AdamGS](https://github.com/AdamGS)) +- feat\(parquet\): relax type compatility check in parquet ArrowWriter [\#9099](https://github.com/apache/arrow-rs/pull/9099) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([gruuya](https://github.com/gruuya)) +- \[Variant\] Move `ArrayVariantToArrowRowBuilder` to `variant_to_arrow` [\#9094](https://github.com/apache/arrow-rs/pull/9094) ([liamzwbao](https://github.com/liamzwbao)) +- chore: increase row count and batch size for more deterministic tests [\#9088](https://github.com/apache/arrow-rs/pull/9088) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Weijun-H](https://github.com/Weijun-H)) +- Fix `nullif` kernel [\#9087](https://github.com/apache/arrow-rs/pull/9087) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add `FlightInfo::with_endpoints` method [\#9075](https://github.com/apache/arrow-rs/pull/9075) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([lewiszlw](https://github.com/lewiszlw)) +- chore: run validation when debug assertion enabled and not only for test [\#9073](https://github.com/apache/arrow-rs/pull/9073) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- Minor: make it clear cache array reader is not cloning arrays [\#9057](https://github.com/apache/arrow-rs/pull/9057) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- Minor: avoid clone in RunArray row decoding via buffer stealing [\#9052](https://github.com/apache/arrow-rs/pull/9052) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([lyang24](https://github.com/lyang24)) +- Minor: avoid some clones when reading parquet [\#9048](https://github.com/apache/arrow-rs/pull/9048) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([alamb](https://github.com/alamb)) +- fix: don't generate nulls for `Decimal128` and `Decimal256` when field is non-nullable and have non-zero `null_density` [\#9046](https://github.com/apache/arrow-rs/pull/9046) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- fix: `Rows` `size` should use `capacity` and not `len` [\#9044](https://github.com/apache/arrow-rs/pull/9044) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([rluvaton](https://github.com/rluvaton)) +- fix: integration / Archery test With other arrows container ran out of space [\#9043](https://github.com/apache/arrow-rs/pull/9043) ([lyang24](https://github.com/lyang24)) +- feat: add new `try_append_value_n()` function to `GenericByteViewBuilder` [\#9040](https://github.com/apache/arrow-rs/pull/9040) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([lyang24](https://github.com/lyang24)) +- Rename fields in BooleanBuffer for clarity [\#9039](https://github.com/apache/arrow-rs/pull/9039) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Allocate buffers before work in `boolean_kernels` benchmark [\#9035](https://github.com/apache/arrow-rs/pull/9035) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Move RunArray::get\_physical\_indices to RunEndBuffer [\#9027](https://github.com/apache/arrow-rs/pull/9027) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([lyang24](https://github.com/lyang24)) +- Improve `RunArray` documentation [\#9019](https://github.com/apache/arrow-rs/pull/9019) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- Add BooleanArray tests for null and slice behavior [\#9013](https://github.com/apache/arrow-rs/pull/9013) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([UtkarshSahay123](https://github.com/UtkarshSahay123)) +- feat: support array indices in VariantPath dot notation [\#9012](https://github.com/apache/arrow-rs/pull/9012) ([foskey51](https://github.com/foskey51)) +- arrow-cast: Bring back in-order field casting for `StructArray` [\#9007](https://github.com/apache/arrow-rs/pull/9007) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) +- arrow-ipc: Add ListView support [\#9006](https://github.com/apache/arrow-rs/pull/9006) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) +- Add quote style to csv writer [\#9004](https://github.com/apache/arrow-rs/pull/9004) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([xanderbailey](https://github.com/xanderbailey)) +- Fix row slice bug in Union column decoding with many columns [\#9000](https://github.com/apache/arrow-rs/pull/9000) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([friendlymatthew](https://github.com/friendlymatthew)) +- build\(deps\): bump actions/download-artifact from 6 to 7 [\#8995](https://github.com/apache/arrow-rs/pull/8995) ([dependabot[bot]](https://github.com/apps/dependabot)) +- minor: Add comment blocks to PR template [\#8994](https://github.com/apache/arrow-rs/pull/8994) ([Jefffrey](https://github.com/Jefffrey)) +- Implement `BinaryArrayType` for `&FixedSizeBinaryArray`s [\#8993](https://github.com/apache/arrow-rs/pull/8993) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- feat: impl BatchCoalescer::push\_batch\_with\_indices [\#8991](https://github.com/apache/arrow-rs/pull/8991) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ClSlaid](https://github.com/ClSlaid)) +- \[Arrow\]Configure max deduplication length for `StringView` [\#8990](https://github.com/apache/arrow-rs/pull/8990) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([lichuang](https://github.com/lichuang)) +- feat: implement append\_array for FixedSizeBinaryBuilder [\#8989](https://github.com/apache/arrow-rs/pull/8989) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([ClSlaid](https://github.com/ClSlaid)) +- Add benchmarks for Utf8View scalars for zip [\#8988](https://github.com/apache/arrow-rs/pull/8988) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([mkleen](https://github.com/mkleen)) +- build\(deps\): bump actions/cache from 4 to 5 [\#8986](https://github.com/apache/arrow-rs/pull/8986) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Take fsb null indices [\#8981](https://github.com/apache/arrow-rs/pull/8981) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Add List to `interleave_kernels` benchmark [\#8980](https://github.com/apache/arrow-rs/pull/8980) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([alamb](https://github.com/alamb)) +- Fix ipc errors for `LargeList` containing sliced `StringViews` [\#8979](https://github.com/apache/arrow-rs/pull/8979) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([fabianmurariu](https://github.com/fabianmurariu)) +- arrow-buffer: implement num-traits numeric operations [\#8977](https://github.com/apache/arrow-rs/pull/8977) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([theirix](https://github.com/theirix)) +- Update `xz` crate dependency to use `liblzma` in arrow-avro [\#8975](https://github.com/apache/arrow-rs/pull/8975) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- arrow-data: avoid allocating in get\_last\_run\_end [\#8974](https://github.com/apache/arrow-rs/pull/8974) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([asubiotto](https://github.com/asubiotto)) +- Support for `Arc` in `ParquetRecordWriter` derive macro [\#8973](https://github.com/apache/arrow-rs/pull/8973) ([heilhead](https://github.com/heilhead)) +- feat: support casting `Time32` to `Int64` [\#8971](https://github.com/apache/arrow-rs/pull/8971) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tshauck](https://github.com/tshauck)) +- arrow-buffer: add i256::trailing\_zeros [\#8969](https://github.com/apache/arrow-rs/pull/8969) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([theirix](https://github.com/theirix)) +- Perf: Vectorize check\_bounds\(2x speedup\) [\#8966](https://github.com/apache/arrow-rs/pull/8966) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([gstvg](https://github.com/gstvg)) +- arrow-buffer: make i256::leading\_zeros public and tested [\#8964](https://github.com/apache/arrow-rs/pull/8964) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([theirix](https://github.com/theirix)) +- Add ignore leading and trailing white space to csv parser [\#8960](https://github.com/apache/arrow-rs/pull/8960) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([xanderbailey](https://github.com/xanderbailey)) +- Access `UnionFields` elements by index [\#8959](https://github.com/apache/arrow-rs/pull/8959) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([friendlymatthew](https://github.com/friendlymatthew)) +- Add Parquet roundtrip benchmarks [\#8956](https://github.com/apache/arrow-rs/pull/8956) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- \[Variant\] Add variant to arrow for Date64/Timestamp\(Second/Millisecond\)/Time32/Time64 [\#8950](https://github.com/apache/arrow-rs/pull/8950) ([klion26](https://github.com/klion26)) +- Let `ArrowArrayStreamReader` handle schema with attached metadata + do schema checking [\#8944](https://github.com/apache/arrow-rs/pull/8944) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jonded94](https://github.com/jonded94)) +- Adds ExtensionType for Parquet geospatial WKB arrays [\#8943](https://github.com/apache/arrow-rs/pull/8943) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([BlakeOrth](https://github.com/BlakeOrth)) +- Add builder to help create Schemas for shredding \(`ShreddedSchemaBuilder`\) [\#8940](https://github.com/apache/arrow-rs/pull/8940) ([XiangpengHao](https://github.com/XiangpengHao)) +- build\(deps\): update criterion requirement from 0.7.0 to 0.8.0 [\#8939](https://github.com/apache/arrow-rs/pull/8939) ([dependabot[bot]](https://github.com/apps/dependabot)) +- fix: Resolve Avro RecordEncoder bugs related to nullable Struct fields and Union type ids [\#8935](https://github.com/apache/arrow-rs/pull/8935) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([jecsand838](https://github.com/jecsand838)) +- Some panic!s could more semantically be unimplemented! [\#8933](https://github.com/apache/arrow-rs/pull/8933) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([abacef](https://github.com/abacef)) +- fix: ipc decode panic with invalid data [\#8931](https://github.com/apache/arrow-rs/pull/8931) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([leiysky](https://github.com/leiysky)) +- Allow creating zero-sized FixedSizeBinary arrays [\#8927](https://github.com/apache/arrow-rs/pull/8927) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tobixdev](https://github.com/tobixdev)) +- Update `test_variant_get_error_when_cast_failure...` tests to uses a valid `VariantArray` [\#8921](https://github.com/apache/arrow-rs/pull/8921) ([alamb](https://github.com/alamb)) +- Make flight sql client generic [\#8915](https://github.com/apache/arrow-rs/pull/8915) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([lewiszlw](https://github.com/lewiszlw)) +- \[minor\] Name Magic Number "8" in `FixedSizeBinaryArray::new_null` [\#8914](https://github.com/apache/arrow-rs/pull/8914) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tobixdev](https://github.com/tobixdev)) +- fix: cast Binary/String dictionary to view [\#8912](https://github.com/apache/arrow-rs/pull/8912) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([Jefffrey](https://github.com/Jefffrey)) +- \[8910\]Fixed doc test with feature prettyprint [\#8911](https://github.com/apache/arrow-rs/pull/8911) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([manishkr](https://github.com/manishkr)) +- feat: `ArrayData::new_null` for `ListView` / `LargeListView` [\#8909](https://github.com/apache/arrow-rs/pull/8909) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dqkqd](https://github.com/dqkqd)) +- fead: add `GenericListViewArray::from_iter_primitive` [\#8907](https://github.com/apache/arrow-rs/pull/8907) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dqkqd](https://github.com/dqkqd)) +- fix: `GenericListViewArray::new_null` returns empty array [\#8905](https://github.com/apache/arrow-rs/pull/8905) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([dqkqd](https://github.com/dqkqd)) +- Allocate a zeroed buffer for FixedSizeBinaryArray::null [\#8901](https://github.com/apache/arrow-rs/pull/8901) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([tobixdev](https://github.com/tobixdev)) +- build\(deps\): bump actions/checkout from 5 to 6 [\#8899](https://github.com/apache/arrow-rs/pull/8899) ([dependabot[bot]](https://github.com/apps/dependabot)) +- Add getters to `UnionFields` [\#8895](https://github.com/apache/arrow-rs/pull/8895) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([friendlymatthew](https://github.com/friendlymatthew)) +- Add validated constructors for UnionFields [\#8891](https://github.com/apache/arrow-rs/pull/8891) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] [[arrow-flight](https://github.com/apache/arrow-rs/labels/arrow-flight)] ([friendlymatthew](https://github.com/friendlymatthew)) +- Add bit width check [\#8888](https://github.com/apache/arrow-rs/pull/8888) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([rambleraptor](https://github.com/rambleraptor)) +- \[Variant\] Improve `variant_get` performance on a perfect shredding [\#8887](https://github.com/apache/arrow-rs/pull/8887) ([XiangpengHao](https://github.com/XiangpengHao)) +- Add UnionArray::fields [\#8884](https://github.com/apache/arrow-rs/pull/8884) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([friendlymatthew](https://github.com/friendlymatthew)) +- Struct casting field order [\#8871](https://github.com/apache/arrow-rs/pull/8871) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([brancz](https://github.com/brancz)) +- Add support for `Union` types in `RowConverter` [\#8839](https://github.com/apache/arrow-rs/pull/8839) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([friendlymatthew](https://github.com/friendlymatthew)) +- Add comparison support for Union arrays [\#8838](https://github.com/apache/arrow-rs/pull/8838) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([friendlymatthew](https://github.com/friendlymatthew)) +- \[Variant\] Support array shredding into `List/LargeList/ListView/LargeListView` [\#8831](https://github.com/apache/arrow-rs/pull/8831) ([liamzwbao](https://github.com/liamzwbao)) +- Add support for using ListView arrays and types through FFI [\#8822](https://github.com/apache/arrow-rs/pull/8822) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([AdamGS](https://github.com/AdamGS)) +- Add ability to skip or transform page encoding statistics in Parquet metadata [\#8797](https://github.com/apache/arrow-rs/pull/8797) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([etseidl](https://github.com/etseidl)) +- Implement a `Vec` wrapper for `pyarrow.Table` convenience [\#8790](https://github.com/apache/arrow-rs/pull/8790) ([jonded94](https://github.com/jonded94)) +- Make Parquet SBBF serialize/deserialize helpers public for external reuse [\#8762](https://github.com/apache/arrow-rs/pull/8762) [[parquet](https://github.com/apache/arrow-rs/labels/parquet)] ([RoseZhang123](https://github.com/RoseZhang123)) +- Add cast support for \(Large\)ListView \<-\> \(Large\)List [\#8735](https://github.com/apache/arrow-rs/pull/8735) [[arrow](https://github.com/apache/arrow-rs/labels/arrow)] ([vegarsti](https://github.com/vegarsti)) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 07ed5e010c40..a375917e3a3b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -89,8 +89,7 @@ You can also use rust's official docker image: docker run --rm -v $(pwd):/arrow-rs -it rust /bin/bash -c "cd /arrow-rs && rustup component add rustfmt && cargo build" ``` -The command above assumes that are in the root directory of the project, not in the same -directory as this README.md. +The command above assumes that are in the root directory of the project. You can also compile specific workspaces: diff --git a/Cargo.toml b/Cargo.toml index a9b00f9537dc..e4f1780d2914 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -39,7 +39,10 @@ members = [ "arrow-select", "arrow-string", "parquet", + "parquet-geospatial", "parquet-variant", + "parquet-variant-compute", + "parquet-variant-json", "parquet_derive", "parquet_derive_test", ] @@ -55,6 +58,9 @@ members = [ resolver = "2" exclude = [ + # arrow-pyarrow-testing is excluded because it requires a Python interpreter with the pyarrow package installed, + # which makes running `cargo test --all` fail if the appropriate Python environment is not set up. + "arrow-pyarrow-testing", # arrow-pyarrow-integration-testing is excluded because it requires different compilation flags, thereby # significantly changing how it is compiled within the workspace, causing the whole workspace to be compiled from # scratch this way, this is a stand-alone package that compiles independently of the others. @@ -62,7 +68,7 @@ exclude = [ ] [workspace.package] -version = "55.2.0" +version = "57.2.0" homepage = "https://github.com/apache/arrow-rs" repository = "https://github.com/apache/arrow-rs" authors = ["Apache Arrow "] @@ -75,29 +81,37 @@ include = [ "LICENSE.txt", "NOTICE.txt", ] -edition = "2021" -rust-version = "1.81" +edition = "2024" +rust-version = "1.85" [workspace.dependencies] -arrow = { version = "55.2.0", path = "./arrow", default-features = false } -arrow-arith = { version = "55.2.0", path = "./arrow-arith" } -arrow-array = { version = "55.2.0", path = "./arrow-array" } -arrow-buffer = { version = "55.2.0", path = "./arrow-buffer" } -arrow-cast = { version = "55.2.0", path = "./arrow-cast" } -arrow-csv = { version = "55.2.0", path = "./arrow-csv" } -arrow-data = { version = "55.2.0", path = "./arrow-data" } -arrow-ipc = { version = "55.2.0", path = "./arrow-ipc" } -arrow-json = { version = "55.2.0", path = "./arrow-json" } -arrow-ord = { version = "55.2.0", path = "./arrow-ord" } -arrow-pyarrow = { version = "55.2.0", path = "./arrow-pyarrow" } -arrow-row = { version = "55.2.0", path = "./arrow-row" } -arrow-schema = { version = "55.2.0", path = "./arrow-schema" } -arrow-select = { version = "55.2.0", path = "./arrow-select" } -arrow-string = { version = "55.2.0", path = "./arrow-string" } -parquet = { version = "55.2.0", path = "./parquet", default-features = false } +arrow = { version = "57.2.0", path = "./arrow", default-features = false } +arrow-arith = { version = "57.2.0", path = "./arrow-arith" } +arrow-array = { version = "57.2.0", path = "./arrow-array" } +arrow-buffer = { version = "57.2.0", path = "./arrow-buffer" } +arrow-cast = { version = "57.2.0", path = "./arrow-cast" } +arrow-csv = { version = "57.2.0", path = "./arrow-csv" } +arrow-data = { version = "57.2.0", path = "./arrow-data" } +arrow-ipc = { version = "57.2.0", path = "./arrow-ipc" } +arrow-json = { version = "57.2.0", path = "./arrow-json" } +arrow-ord = { version = "57.2.0", path = "./arrow-ord" } +arrow-pyarrow = { version = "57.2.0", path = "./arrow-pyarrow" } +arrow-row = { version = "57.2.0", path = "./arrow-row" } +arrow-schema = { version = "57.2.0", path = "./arrow-schema" } +arrow-select = { version = "57.2.0", path = "./arrow-select" } +arrow-string = { version = "57.2.0", path = "./arrow-string" } +parquet = { version = "57.2.0", path = "./parquet", default-features = false } +parquet-geospatial = { version = "57.2.0", path = "./parquet-geospatial" } +parquet-variant = { version = "57.2.0", path = "./parquet-variant" } +parquet-variant-json = { version = "57.2.0", path = "./parquet-variant-json" } +parquet-variant-compute = { version = "57.2.0", path = "./parquet-variant-compute" } chrono = { version = "0.4.40", default-features = false, features = ["clock"] } +simdutf8 = { version = "0.1.5", default-features = false } + +criterion = { version = "0.8.0", default-features = false } + # release inherited profile keeping debug information and symbols # for mem/cpu profiling [profile.profiling] diff --git a/NOTICE.txt b/NOTICE.txt index a609791374c2..68538ffbdb4c 100644 --- a/NOTICE.txt +++ b/NOTICE.txt @@ -1,5 +1,5 @@ Apache Arrow -Copyright 2016-2019 The Apache Software Foundation +Copyright 2016-2026 The Apache Software Foundation This product includes software developed at The Apache Software Foundation (http://www.apache.org/). diff --git a/README.md b/README.md index 6140f9e902ea..901448eb6a92 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,32 @@ # Native Rust implementation of Apache Arrow and Apache Parquet -Welcome to the [Rust][rust] implementation of [Apache Arrow], the popular in-memory columnar format. +Welcome to the [Rust][rust] implementation of [Apache Arrow], a popular +in-memory columnar format and [Apache Parquet], a popular columnar file +format. + +## Community + +We welcome participation from everyone and encourage you to join us, ask +questions, help others, and get involved. All participation in the Apache Arrow +project is governed by the Apache Software Foundation's [code of +conduct](https://www.apache.org/foundation/policies/conduct.html). + +We use GitHub [issues] and [pull requests] for all technical discussions, reviews, +new features, bug fixes and release coordination. This ensures that all communication +is public and archived for future reference. + +The `dev@arrow.apache.org` mailing list is the communication channel for the overall Apache Arrow community. +Instructions for signing up and links to the archives can be found on the [Arrow Community](https://arrow.apache.org/community/) page. + +Some community members also use the [Arrow Rust Discord Server](https://discord.gg/YAb2TdazKQ) and the official [ASF Slack](https://s.apache.org/slack-invite) server for informal discussions and coordination. +This is a great place to meet other contributors and get guidance on where to contribute. +However, all technical designs should also be recorded and formalized in GitHub issues, so that they are accessible to everyone. +In Slack, find us in the `#arrow-rust` channel and feel free to ask for an invite via Discord, GitHub issues, or other means. + +There is more information in the [contributing] guide. + +## Repository Structure This repository contains the following crates: @@ -27,15 +52,16 @@ This repository contains the following crates: | ------------------ | ---------------------------------------------------------------------------- | ------------------------------------------------ | --------------------------------- | | [`arrow`] | Core functionality (memory layout, arrays, low level computations) | [docs.rs](https://docs.rs/arrow/latest) | [(README)][arrow-readme] | | [`arrow-flight`] | Support for Arrow-Flight IPC protocol | [docs.rs](https://docs.rs/arrow-flight/latest) | [(README)][flight-readme] | -| [`parquet`] | Support for Parquet columnar file format | [docs.rs](https://docs.rs/parquet/latest) | [(README)][parquet-readme] | +| [`parquet`] | Support for the [Apache Parquet] columnar file format | [docs.rs](https://docs.rs/parquet/latest) | [(README)][parquet-readme] | | [`parquet_derive`] | A crate for deriving RecordWriter/RecordReader for arbitrary, simple structs | [docs.rs](https://docs.rs/parquet-derive/latest) | [(README)][parquet-derive-readme] | -The current development version the API documentation in this repo can be found [here](https://arrow.apache.org/rust). +The current development version the API documentation can be found [here](https://arrow.apache.org/rust). Note: previously the [`object_store`] crate was also part of this repository, but it has been moved to the [arrow-rs-object-store repository] [apache arrow]: https://arrow.apache.org/ +[apache parquet]: https://parquet.apache.org/ [`arrow`]: https://crates.io/crates/arrow [`parquet`]: https://crates.io/crates/parquet [`parquet_derive`]: https://crates.io/crates/parquet-derive @@ -49,7 +75,7 @@ Versioning]. Due to available maintainer and testing bandwidth, [`arrow`] crates ([`arrow`], [`arrow-flight`], etc.) are released on the same schedule with the same versions -as the [`parquet`] and [`parquet-derive`] crates. +as the [`parquet`] and [`parquet_derive`] crates. This crate releases every month. We release new major versions (with potentially breaking API changes) at most once a quarter, and release incremental minor @@ -65,28 +91,26 @@ Planned Release Schedule | Approximate Date | Version | Notes | | ---------------- | ---------- | --------------------------------------- | -| Apr 2025 | [`55.0.0`] | Major, potentially breaking API changes | -| May 2025 | [`55.1.0`] | Minor, NO breaking API changes | -| June 2025 | [`55.2.0`] | Minor, NO breaking API changes | -| July 2025 | [`56.0.0`] | Major, potentially breaking API changes | - -[`55.0.0`]: https://github.com/apache/arrow-rs/issues/7084 -[`55.1.0`]: https://github.com/apache/arrow-rs/issues/7393 -[`55.2.0`]: https://github.com/apache/arrow-rs/issues/7394 -[`56.0.0`]: https://github.com/apache/arrow-rs/issues/7395 +| December 2025 | [`57.2.0`] | Minor, NO breaking API changes | +| January 2026 | [`58.0.0`] | Major, potentially breaking API changes | +| February 2026 | [`58.1.0`] | Minor, NO breaking API changes | +| March 2026 | [`58.2.0`] | Minor, NO breaking API changes | +| April 2026 | [`59.0.0`] | Major, potentially breaking API changes | + +[`57.2.0`]: https://github.com/apache/arrow-rs/milestone/5 +[`58.0.0`]: https://github.com/apache/arrow-rs/milestone/6 +[`58.1.0`]: https://github.com/apache/arrow-rs/issues/9108 +[`58.2.0`]: https://github.com/apache/arrow-rs/issues/9109 +[`59.0.0`]: https://github.com/apache/arrow-rs/issues/9110 [ticket #5368]: https://github.com/apache/arrow-rs/issues/5368 [semantic versioning]: https://semver.org/ ### Rust Version Compatibility Policy -arrow-rs, parquet and object_store are built and tested with stable Rust, and will keep a rolling MSRV (minimum supported Rust version) that can only be updated in major releases on a need by basis (e.g. project dependencies bump their MSRV or a particular Rust feature is useful for us etc.). The new MSRV if selected will be at least 6 months old. The minor releases are guaranteed to have the same MSRV. +arrow-rs and parquet are built and tested with stable Rust, and will keep a rolling MSRV (minimum supported Rust version) that can only be updated in major releases on an as needed basis (e.g. project dependencies bump their MSRV or a particular Rust feature is useful for us etc.). The new MSRV if selected will be at least 6 months old. The minor releases are guaranteed to have the same MSRV. Note: If a Rust hotfix is released for the current MSRV, the MSRV will be updated to the specific minor version that includes all applicable hotfixes preceding other policies. -E.g. - -in Apr 2025 we will release version 55.0.0 which might have a version bump. But the Rust version selected in this case will be at most version 1.81. - ### Guidelines for `panic` vs `Result` In general, use panics for bad states that are unreachable, unrecoverable or harmful. @@ -112,7 +136,7 @@ The deprecated version is the next version which will be released (please consult the list above). To mark the API as deprecated, use the `#[deprecated(since = "...", note = "...")]` attribute. -Foe example +For example ```rust #[deprecated(since = "51.0.0", note = "Use `date_part` instead")] @@ -154,24 +178,6 @@ including `join`s and window functions. You can find more details about each crate in their respective READMEs. -## Arrow Rust Community - -The `dev@arrow.apache.org` mailing list serves as the core communication channel for the Arrow community. Instructions for signing up and links to the archives can be found on the [Arrow Community](https://arrow.apache.org/community/) page. All major announcements and communications happen there. - -The Rust Arrow community also uses the official [ASF Slack](https://s.apache.org/slack-invite) for informal discussions and coordination. This is -a great place to meet other contributors and get guidance on where to contribute. Join us in the `#arrow-rust` channel and feel free to ask for an invite via: - -1. the `dev@arrow.apache.org` mailing list -2. the [GitHub Discussions][discussions] -3. the [Discord channel](https://discord.gg/YAb2TdazKQ) - -The Rust implementation uses [GitHub issues][issues] as the system of record for new features and bug fixes and -this plays a critical role in the release process. - -For design discussions we generally use GitHub issues. - -There is more information in the [contributing] guide. - [rust]: https://www.rust-lang.org/ [`object_store`]: https://crates.io/crates/object-store [arrow-readme]: arrow/README.md @@ -182,4 +188,5 @@ There is more information in the [contributing] guide. [ballista-readme]: https://github.com/apache/datafusion-ballista/blob/main/README.md [parquet-derive-readme]: parquet_derive/README.md [issues]: https://github.com/apache/arrow-rs/issues +[pull requests]: https://github.com/apache/arrow-rs/pulls [discussions]: https://github.com/apache/arrow-rs/discussions diff --git a/arrow-arith/Cargo.toml b/arrow-arith/Cargo.toml index a3fdafa823a2..f2a4604c116e 100644 --- a/arrow-arith/Cargo.toml +++ b/arrow-arith/Cargo.toml @@ -41,4 +41,4 @@ arrow-buffer = { workspace = true } arrow-data = { workspace = true } arrow-schema = { workspace = true } chrono = { workspace = true } -num = { version = "0.4", default-features = false, features = ["std"] } +num-traits = { version = "0.2.19", default-features = false, features = ["std"] } diff --git a/arrow-arith/src/aggregate.rs b/arrow-arith/src/aggregate.rs index 9a19b5d8a1f1..a043259694c1 100644 --- a/arrow-arith/src/aggregate.rs +++ b/arrow-arith/src/aggregate.rs @@ -45,11 +45,7 @@ trait NumericAccumulator: Copy + Default { /// After verifying the generated assembly this can be a simple `if`. #[inline(always)] fn select(m: bool, a: T, b: T) -> T { - if m { - a - } else { - b - } + if m { a } else { b } } #[derive(Clone, Copy)] @@ -336,10 +332,10 @@ fn aggregate, A: Numeric /// Returns the minimum value in the boolean array. /// +/// # Example /// ``` /// # use arrow_array::BooleanArray; /// # use arrow_arith::aggregate::min_boolean; -/// /// let a = BooleanArray::from(vec![Some(true), None, Some(false)]); /// assert_eq!(min_boolean(&a), Some(false)) /// ``` @@ -394,10 +390,10 @@ pub fn min_boolean(array: &BooleanArray) -> Option { /// Returns the maximum value in the boolean array /// +/// # Example /// ``` /// # use arrow_array::BooleanArray; /// # use arrow_arith::aggregate::max_boolean; -/// /// let a = BooleanArray::from(vec![Some(true), None, Some(false)]); /// assert_eq!(max_boolean(&a), Some(true)) /// ``` @@ -451,11 +447,7 @@ where let idx = nulls.valid_indices().reduce(|acc_idx, idx| { let acc = array.value_unchecked(acc_idx); let item = array.value_unchecked(idx); - if cmp(&acc, &item) { - idx - } else { - acc_idx - } + if cmp(&acc, &item) { idx } else { acc_idx } }); idx.map(|idx| array.value_unchecked(idx)) } @@ -477,11 +469,7 @@ fn min_max_view_helper( let target_idx = (0..array.len()).reduce(|acc, item| { // SAFETY: array's length is correct so item is within bounds let cmp = unsafe { GenericByteViewArray::compare_unchecked(array, item, array, acc) }; - if cmp == swap_cond { - item - } else { - acc - } + if cmp == swap_cond { item } else { acc } }); // SAFETY: idx came from valid range `0..array.len()` unsafe { target_idx.map(|idx| array.value_unchecked(idx)) } @@ -491,11 +479,7 @@ fn min_max_view_helper( let target_idx = nulls.valid_indices().reduce(|acc_idx, idx| { let cmp = unsafe { GenericByteViewArray::compare_unchecked(array, idx, array, acc_idx) }; - if cmp == swap_cond { - idx - } else { - acc_idx - } + if cmp == swap_cond { idx } else { acc_idx } }); // SAFETY: idx came from valid range `0..array.len()` @@ -825,6 +809,15 @@ where /// Returns the minimum value in the array, according to the natural order. /// For floating point arrays any NaN values are considered to be greater than any other non-null value +/// +/// # Example +/// ```rust +/// # use arrow_array::Int32Array; +/// # use arrow_arith::aggregate::min; +/// let array = Int32Array::from(vec![8, 2, 4]); +/// let result = min(&array); +/// assert_eq!(result, Some(2)); +/// ``` pub fn min(array: &PrimitiveArray) -> Option where T::Native: PartialOrd, @@ -834,6 +827,15 @@ where /// Returns the maximum value in the array, according to the natural order. /// For floating point arrays any NaN values are considered to be greater than any other non-null value +/// +/// # Example +/// ```rust +/// # use arrow_array::Int32Array; +/// # use arrow_arith::aggregate::max; +/// let array = Int32Array::from(vec![4, 8, 2]); +/// let result = max(&array); +/// assert_eq!(result, Some(8)); +/// ``` pub fn max(array: &PrimitiveArray) -> Option where T::Native: PartialOrd, diff --git a/arrow-arith/src/arithmetic.rs b/arrow-arith/src/arithmetic.rs index febf5ceabdd9..27efed6fcdb4 100644 --- a/arrow-arith/src/arithmetic.rs +++ b/arrow-arith/src/arithmetic.rs @@ -25,8 +25,8 @@ use crate::arity::*; use arrow_array::types::*; use arrow_array::*; -use arrow_buffer::i256; use arrow_buffer::ArrowNativeType; +use arrow_buffer::i256; use arrow_schema::*; use std::cmp::min; use std::sync::Arc; @@ -43,8 +43,7 @@ fn get_fixed_point_info( if required_scale > product_scale { return Err(ArrowError::ComputeError(format!( - "Required scale {} is greater than product scale {}", - required_scale, product_scale + "Required scale {required_scale} is greater than product scale {product_scale}", ))); } @@ -122,7 +121,7 @@ pub fn multiply_fixed_point_checked( let mut mul = a.wrapping_mul(b); mul = divide_and_round::(mul, divisor); mul.to_i128().ok_or_else(|| { - ArrowError::ArithmeticOverflow(format!("Overflow happened on: {:?} * {:?}", a, b)) + ArrowError::ArithmeticOverflow(format!("Overflow happened on: {a:?} * {b:?}")) }) }) .and_then(|a| a.with_precision_and_scale(precision, required_scale)) @@ -209,9 +208,11 @@ mod tests { .unwrap(); let err = mul(&a, &b).unwrap_err(); - assert!(err - .to_string() - .contains("Overflow happened on: 123456789000000000000000000 * 10000000000000000000")); + assert!( + err.to_string().contains( + "Overflow happened on: 123456789000000000000000000 * 10000000000000000000" + ) + ); // Allow precision loss. let result = multiply_fixed_point_checked(&a, &b, 28).unwrap(); @@ -279,9 +280,11 @@ mod tests { // Required scale cannot be larger than the product of the input scales. let result = multiply_fixed_point_checked(&a, &b, 5).unwrap_err(); - assert!(result - .to_string() - .contains("Required scale 5 is greater than product scale 4")); + assert!( + result + .to_string() + .contains("Required scale 5 is greater than product scale 4") + ); } #[test] @@ -323,7 +326,10 @@ mod tests { // `multiply` overflows on this case. let err = mul(&a, &b).unwrap_err(); - assert_eq!(err.to_string(), "Arithmetic overflow: Overflow happened on: 123456789000000000000000000 * 10000000000000000000"); + assert_eq!( + err.to_string(), + "Arithmetic overflow: Overflow happened on: 123456789000000000000000000 * 10000000000000000000" + ); // Avoid overflow by reducing the scale. let result = multiply_fixed_point(&a, &b, 28).unwrap(); diff --git a/arrow-arith/src/arity.rs b/arrow-arith/src/arity.rs index d1bf1abcb269..b9f7a82963c7 100644 --- a/arrow-arith/src/arity.rs +++ b/arrow-arith/src/arity.rs @@ -19,9 +19,9 @@ use arrow_array::builder::BufferBuilder; use arrow_array::*; -use arrow_buffer::buffer::NullBuffer; use arrow_buffer::ArrowNativeType; use arrow_buffer::MutableBuffer; +use arrow_buffer::buffer::NullBuffer; use arrow_data::ArrayData; use arrow_schema::ArrowError; diff --git a/arrow-arith/src/bitwise.rs b/arrow-arith/src/bitwise.rs index a3c18136c5eb..aedeecd5b835 100644 --- a/arrow-arith/src/bitwise.rs +++ b/arrow-arith/src/bitwise.rs @@ -21,7 +21,7 @@ use crate::arity::{binary, unary}; use arrow_array::*; use arrow_buffer::ArrowNativeType; use arrow_schema::ArrowError; -use num::traits::{WrappingShl, WrappingShr}; +use num_traits::{WrappingShl, WrappingShr}; use std::ops::{BitAnd, BitOr, BitXor, Not}; /// The helper function for bitwise operation with two array diff --git a/arrow-arith/src/boolean.rs b/arrow-arith/src/boolean.rs index d8c7cc19323e..6bf438e64618 100644 --- a/arrow-arith/src/boolean.rs +++ b/arrow-arith/src/boolean.rs @@ -23,8 +23,8 @@ //! [here](https://doc.rust-lang.org/stable/core/arch/) for more information. use arrow_array::*; -use arrow_buffer::buffer::{bitwise_bin_op_helper, bitwise_quaternary_op_helper}; -use arrow_buffer::{buffer_bin_and_not, BooleanBuffer, NullBuffer}; +use arrow_buffer::buffer::bitwise_quaternary_op_helper; +use arrow_buffer::{BooleanBuffer, NullBuffer, buffer_bin_and_not}; use arrow_schema::ArrowError; /// Logical 'and' boolean values with Kleene logic @@ -74,7 +74,7 @@ pub fn and_kleene(left: &BooleanArray, right: &BooleanArray) -> Result Result { // Same as above - Some(bitwise_bin_op_helper( + Some(BooleanBuffer::from_bitwise_binary_op( right_null_buffer.buffer(), right_null_buffer.offset(), left_values.inner(), @@ -100,7 +100,7 @@ pub fn and_kleene(left: &BooleanArray, right: &BooleanArray) -> Result Result Result Result { // Same as above - Some(bitwise_bin_op_helper( + Some(BooleanBuffer::from_bitwise_binary_op( right_nulls.buffer(), right_nulls.offset(), left_values.inner(), @@ -195,7 +196,7 @@ pub fn or_kleene(left: &BooleanArray, right: &BooleanArray) -> Result Result Result { Float16 => neg_wrapping!(Float16Type, array), Float32 => neg_wrapping!(Float32Type, array), Float64 => neg_wrapping!(Float64Type, array), + Decimal32(p, s) => { + let a = array + .as_primitive::() + .try_unary::<_, Decimal32Type, _>(|x| x.neg_checked())?; + + Ok(Arc::new(a.with_precision_and_scale(*p, *s)?)) + } + Decimal64(p, s) => { + let a = array + .as_primitive::() + .try_unary::<_, Decimal64Type, _>(|x| x.neg_checked())?; + + Ok(Arc::new(a.with_precision_and_scale(*p, *s)?)) + } Decimal128(p, s) => { let a = array .as_primitive::() @@ -234,6 +248,8 @@ fn arithmetic_op(op: Op, lhs: &dyn Datum, rhs: &dyn Datum) -> Result interval_op::(op, l, l_scalar, r, r_scalar), (Date32, _) => date_op::(op, l, l_scalar, r, r_scalar), (Date64, _) => date_op::(op, l, l_scalar, r, r_scalar), + (Decimal32(_, _), Decimal32(_, _)) => decimal_op::(op, l, l_scalar, r, r_scalar), + (Decimal64(_, _), Decimal64(_, _)) => decimal_op::(op, l, l_scalar, r, r_scalar), (Decimal128(_, _), Decimal128(_, _)) => decimal_op::(op, l, l_scalar, r, r_scalar), (Decimal256(_, _), Decimal256(_, _)) => decimal_op::(op, l, l_scalar, r, r_scalar), (l_t, r_t) => match (l_t, r_t) { @@ -503,56 +519,123 @@ fn timestamp_op( "Invalid timestamp arithmetic operation: {} {op} {}", l.data_type(), r.data_type() - ))) + ))); } }; Ok(Arc::new(array.with_timezone_opt(l.timezone()))) } /// Arithmetic trait for date arrays -/// -/// Note: these should be fallible (#4456) trait DateOp: ArrowTemporalType { - fn add_year_month(timestamp: Self::Native, delta: i32) -> Self::Native; - fn add_day_time(timestamp: Self::Native, delta: IntervalDayTime) -> Self::Native; - fn add_month_day_nano(timestamp: Self::Native, delta: IntervalMonthDayNano) -> Self::Native; + fn add_year_month(timestamp: Self::Native, delta: i32) -> Result; + fn add_day_time( + timestamp: Self::Native, + delta: IntervalDayTime, + ) -> Result; + fn add_month_day_nano( + timestamp: Self::Native, + delta: IntervalMonthDayNano, + ) -> Result; + + fn sub_year_month(timestamp: Self::Native, delta: i32) -> Result; + fn sub_day_time( + timestamp: Self::Native, + delta: IntervalDayTime, + ) -> Result; + fn sub_month_day_nano( + timestamp: Self::Native, + delta: IntervalMonthDayNano, + ) -> Result; +} + +impl DateOp for Date32Type { + fn add_year_month(left: Self::Native, right: i32) -> Result { + // Date32Type functions don't have _opt variants and should be safe + Ok(Self::add_year_months(left, right)) + } + + fn add_day_time( + left: Self::Native, + right: IntervalDayTime, + ) -> Result { + Ok(Self::add_day_time(left, right)) + } + + fn add_month_day_nano( + left: Self::Native, + right: IntervalMonthDayNano, + ) -> Result { + Ok(Self::add_month_day_nano(left, right)) + } + + fn sub_year_month(left: Self::Native, right: i32) -> Result { + Ok(Self::subtract_year_months(left, right)) + } - fn sub_year_month(timestamp: Self::Native, delta: i32) -> Self::Native; - fn sub_day_time(timestamp: Self::Native, delta: IntervalDayTime) -> Self::Native; - fn sub_month_day_nano(timestamp: Self::Native, delta: IntervalMonthDayNano) -> Self::Native; + fn sub_day_time( + left: Self::Native, + right: IntervalDayTime, + ) -> Result { + Ok(Self::subtract_day_time(left, right)) + } + + fn sub_month_day_nano( + left: Self::Native, + right: IntervalMonthDayNano, + ) -> Result { + Ok(Self::subtract_month_day_nano(left, right)) + } } -macro_rules! date { - ($t:ty) => { - impl DateOp for $t { - fn add_year_month(left: Self::Native, right: i32) -> Self::Native { - Self::add_year_months(left, right) - } +impl DateOp for Date64Type { + fn add_year_month(left: Self::Native, right: i32) -> Result { + Self::add_year_months_opt(left, right).ok_or_else(|| { + ArrowError::ComputeError(format!("Date arithmetic overflow: {left} + {right} months",)) + }) + } - fn add_day_time(left: Self::Native, right: IntervalDayTime) -> Self::Native { - Self::add_day_time(left, right) - } + fn add_day_time( + left: Self::Native, + right: IntervalDayTime, + ) -> Result { + Self::add_day_time_opt(left, right).ok_or_else(|| { + ArrowError::ComputeError(format!("Date arithmetic overflow: {left} + {right:?}")) + }) + } - fn add_month_day_nano(left: Self::Native, right: IntervalMonthDayNano) -> Self::Native { - Self::add_month_day_nano(left, right) - } + fn add_month_day_nano( + left: Self::Native, + right: IntervalMonthDayNano, + ) -> Result { + Self::add_month_day_nano_opt(left, right).ok_or_else(|| { + ArrowError::ComputeError(format!("Date arithmetic overflow: {left} + {right:?}")) + }) + } - fn sub_year_month(left: Self::Native, right: i32) -> Self::Native { - Self::subtract_year_months(left, right) - } + fn sub_year_month(left: Self::Native, right: i32) -> Result { + Self::subtract_year_months_opt(left, right).ok_or_else(|| { + ArrowError::ComputeError(format!("Date arithmetic overflow: {left} - {right} months",)) + }) + } - fn sub_day_time(left: Self::Native, right: IntervalDayTime) -> Self::Native { - Self::subtract_day_time(left, right) - } + fn sub_day_time( + left: Self::Native, + right: IntervalDayTime, + ) -> Result { + Self::subtract_day_time_opt(left, right).ok_or_else(|| { + ArrowError::ComputeError(format!("Date arithmetic overflow: {left} - {right:?}")) + }) + } - fn sub_month_day_nano(left: Self::Native, right: IntervalMonthDayNano) -> Self::Native { - Self::subtract_month_day_nano(left, right) - } - } - }; + fn sub_month_day_nano( + left: Self::Native, + right: IntervalMonthDayNano, + ) -> Result { + Self::subtract_month_day_nano_opt(left, right).ok_or_else(|| { + ArrowError::ComputeError(format!("Date arithmetic overflow: {left} - {right:?}")) + }) + } } -date!(Date32Type); -date!(Date64Type); /// Arithmetic trait for interval arrays trait IntervalOp: ArrowPrimitiveType { @@ -689,29 +772,29 @@ fn date_op( match (op, r_t) { (Op::Add | Op::AddWrapping, Interval(YearMonth)) => { let r = r.as_primitive::(); - Ok(op_ref!(T, l, l_s, r, r_s, T::add_year_month(l, r))) + Ok(try_op_ref!(T, l, l_s, r, r_s, T::add_year_month(l, r))) } (Op::Sub | Op::SubWrapping, Interval(YearMonth)) => { let r = r.as_primitive::(); - Ok(op_ref!(T, l, l_s, r, r_s, T::sub_year_month(l, r))) + Ok(try_op_ref!(T, l, l_s, r, r_s, T::sub_year_month(l, r))) } (Op::Add | Op::AddWrapping, Interval(DayTime)) => { let r = r.as_primitive::(); - Ok(op_ref!(T, l, l_s, r, r_s, T::add_day_time(l, r))) + Ok(try_op_ref!(T, l, l_s, r, r_s, T::add_day_time(l, r))) } (Op::Sub | Op::SubWrapping, Interval(DayTime)) => { let r = r.as_primitive::(); - Ok(op_ref!(T, l, l_s, r, r_s, T::sub_day_time(l, r))) + Ok(try_op_ref!(T, l, l_s, r, r_s, T::sub_day_time(l, r))) } (Op::Add | Op::AddWrapping, Interval(MonthDayNano)) => { let r = r.as_primitive::(); - Ok(op_ref!(T, l, l_s, r, r_s, T::add_month_day_nano(l, r))) + Ok(try_op_ref!(T, l, l_s, r, r_s, T::add_month_day_nano(l, r))) } (Op::Sub | Op::SubWrapping, Interval(MonthDayNano)) => { let r = r.as_primitive::(); - Ok(op_ref!(T, l, l_s, r, r_s, T::sub_month_day_nano(l, r))) + Ok(try_op_ref!(T, l, l_s, r, r_s, T::sub_month_day_nano(l, r))) } _ => Err(ArrowError::InvalidArgumentError(format!( @@ -734,6 +817,8 @@ fn decimal_op( let r = r.as_primitive::(); let (p1, s1, p2, s2) = match (l.data_type(), r.data_type()) { + (DataType::Decimal32(p1, s1), DataType::Decimal32(p2, s2)) => (p1, s1, p2, s2), + (DataType::Decimal64(p1, s1), DataType::Decimal64(p2, s2)) => (p1, s1, p2, s2), (DataType::Decimal128(p1, s1), DataType::Decimal128(p2, s2)) => (p1, s1, p2, s2), (DataType::Decimal256(p1, s1), DataType::Decimal256(p2, s2)) => (p1, s1, p2, s2), _ => unreachable!(), @@ -856,7 +941,7 @@ fn decimal_op( mod tests { use super::*; use arrow_array::temporal_conversions::{as_date, as_datetime}; - use arrow_buffer::{i256, ScalarBuffer}; + use arrow_buffer::{ScalarBuffer, i256}; use chrono::{DateTime, NaiveDate}; fn test_neg_primitive( @@ -922,6 +1007,28 @@ mod tests { "Arithmetic overflow: Overflow happened on: - -9223372036854775808" ); + let a = Decimal32Array::from(vec![1, 3, -44, 2, 4]) + .with_precision_and_scale(9, 6) + .unwrap(); + + let r = neg(&a).unwrap(); + assert_eq!(r.data_type(), a.data_type()); + assert_eq!( + r.as_primitive::().values(), + &[-1, -3, 44, -2, -4] + ); + + let a = Decimal64Array::from(vec![1, 3, -44, 2, 4]) + .with_precision_and_scale(9, 6) + .unwrap(); + + let r = neg(&a).unwrap(); + assert_eq!(r.data_type(), a.data_type()); + assert_eq!( + r.as_primitive::().values(), + &[-1, -3, 44, -2, -4] + ); + let a = Decimal128Array::from(vec![1, 3, -44, 2, 4]) .with_precision_and_scale(9, 6) .unwrap(); @@ -1156,7 +1263,10 @@ mod tests { .with_precision_and_scale(37, 37) .unwrap(); let err = mul(&a, &b).unwrap_err().to_string(); - assert_eq!(err, "Invalid argument error: Output scale of Decimal128(3, 3) * Decimal128(37, 37) would exceed max scale of 38"); + assert_eq!( + err, + "Invalid argument error: Output scale of Decimal128(3, 3) * Decimal128(37, 37) would exceed max scale of 38" + ); let a = Decimal128Array::from(vec![1]) .with_precision_and_scale(3, -2) @@ -1533,4 +1643,536 @@ mod tests { "Arithmetic overflow: Overflow happened on: 9223372036854775807 - -1" ); } + + #[test] + fn test_date64_to_naive_date_opt_boundary_values() { + use arrow_array::types::Date64Type; + + // Date64Type::to_naive_date_opt has boundaries determined by NaiveDate's supported range. + // The valid date range is from January 1, -262143 to December 31, 262142 (Gregorian calendar). + + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + let ms_per_day = 24 * 60 * 60 * 1000i64; + + // Define the boundary dates using NaiveDate::from_ymd_opt + let max_valid_date = NaiveDate::from_ymd_opt(262142, 12, 31).unwrap(); + let min_valid_date = NaiveDate::from_ymd_opt(-262143, 1, 1).unwrap(); + + // Calculate their millisecond values from epoch + let max_valid_millis = (max_valid_date - epoch).num_milliseconds(); + let min_valid_millis = (min_valid_date - epoch).num_milliseconds(); + + // Verify these match the expected boundaries in milliseconds + assert_eq!( + max_valid_millis, 8210266790400000i64, + "December 31, 262142 should be 8210266790400000 ms from epoch" + ); + assert_eq!( + min_valid_millis, -8334601228800000i64, + "January 1, -262143 should be -8334601228800000 ms from epoch" + ); + + // Test that the boundary dates work + assert!( + Date64Type::to_naive_date_opt(max_valid_millis).is_some(), + "December 31, 262142 should return Some" + ); + assert!( + Date64Type::to_naive_date_opt(min_valid_millis).is_some(), + "January 1, -262143 should return Some" + ); + + // Test that one day beyond the boundaries fails + assert!( + Date64Type::to_naive_date_opt(max_valid_millis + ms_per_day).is_none(), + "January 1, 262143 should return None" + ); + assert!( + Date64Type::to_naive_date_opt(min_valid_millis - ms_per_day).is_none(), + "December 31, -262144 should return None" + ); + + // Test some values well within the valid range + assert!( + Date64Type::to_naive_date_opt(0).is_some(), + "Epoch (1970-01-01) should return Some" + ); + let year_2000 = NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(); + let year_2000_millis = (year_2000 - epoch).num_milliseconds(); + assert!( + Date64Type::to_naive_date_opt(year_2000_millis).is_some(), + "Year 2000 should return Some" + ); + + // Test extreme values that definitely fail due to Duration constraints + assert!( + Date64Type::to_naive_date_opt(i64::MAX).is_none(), + "i64::MAX should return None" + ); + assert!( + Date64Type::to_naive_date_opt(i64::MIN).is_none(), + "i64::MIN should return None" + ); + } + + #[test] + fn test_date64_add_year_months_opt_boundary_values() { + use arrow_array::types::Date64Type; + + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + + // Test normal case within valid range + let year_2000 = NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(); + let year_2000_millis = (year_2000 - epoch).num_milliseconds(); + assert!( + Date64Type::add_year_months_opt(year_2000_millis, 120).is_some(), + "Adding 10 years to year 2000 should succeed" + ); + + // Test with moderate years that are within chrono's safe range + let large_year = NaiveDate::from_ymd_opt(5000, 1, 1).unwrap(); + let large_year_millis = (large_year - epoch).num_milliseconds(); + assert!( + Date64Type::add_year_months_opt(large_year_millis, 12).is_some(), + "Adding 12 months to year 5000 should succeed" + ); + + let neg_year = NaiveDate::from_ymd_opt(-5000, 12, 31).unwrap(); + let neg_year_millis = (neg_year - epoch).num_milliseconds(); + assert!( + Date64Type::add_year_months_opt(neg_year_millis, -12).is_some(), + "Subtracting 12 months from year -5000 should succeed" + ); + + // Test with extreme input values that would cause overflow + assert!( + Date64Type::add_year_months_opt(i64::MAX, 1).is_none(), + "Adding months to i64::MAX should fail" + ); + assert!( + Date64Type::add_year_months_opt(i64::MIN, -1).is_none(), + "Subtracting months from i64::MIN should fail" + ); + + // Test edge case: adding zero should always work for valid dates + assert!( + Date64Type::add_year_months_opt(year_2000_millis, 0).is_some(), + "Adding zero months should always succeed for valid dates" + ); + } + + #[test] + fn test_date64_add_day_time_opt_boundary_values() { + use arrow_array::types::Date64Type; + use arrow_buffer::IntervalDayTime; + + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + + // Test with a date far from the boundary but still testing the function + let near_max_date = NaiveDate::from_ymd_opt(200000, 12, 1).unwrap(); + let near_max_millis = (near_max_date - epoch).num_milliseconds(); + + // Adding 30 days should succeed + let interval_30_days = IntervalDayTime::new(30, 0); + assert!( + Date64Type::add_day_time_opt(near_max_millis, interval_30_days).is_some(), + "Adding 30 days to large year should succeed" + ); + + // Adding a very large number of days should fail + let interval_large_days = IntervalDayTime::new(100000000, 0); + assert!( + Date64Type::add_day_time_opt(near_max_millis, interval_large_days).is_none(), + "Adding 100M days to large year should fail" + ); + + // Test with a date far from the boundary in the negative direction + let near_min_date = NaiveDate::from_ymd_opt(-200000, 2, 1).unwrap(); + let near_min_millis = (near_min_date - epoch).num_milliseconds(); + + // Subtracting 30 days should succeed + let interval_minus_30_days = IntervalDayTime::new(-30, 0); + assert!( + Date64Type::add_day_time_opt(near_min_millis, interval_minus_30_days).is_some(), + "Subtracting 30 days from large negative year should succeed" + ); + + // Subtracting a very large number of days should fail + let interval_minus_large_days = IntervalDayTime::new(-100000000, 0); + assert!( + Date64Type::add_day_time_opt(near_min_millis, interval_minus_large_days).is_none(), + "Subtracting 100M days from large negative year should fail" + ); + + // Test normal case within valid range + let year_2000 = NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(); + let year_2000_millis = (year_2000 - epoch).num_milliseconds(); + let interval_1000_days = IntervalDayTime::new(1000, 12345); + assert!( + Date64Type::add_day_time_opt(year_2000_millis, interval_1000_days).is_some(), + "Adding 1000 days and time to year 2000 should succeed" + ); + + // Test with extreme input values that would cause overflow + let interval_one_day = IntervalDayTime::new(1, 0); + assert!( + Date64Type::add_day_time_opt(i64::MAX, interval_one_day).is_none(), + "Adding interval to i64::MAX should fail" + ); + assert!( + Date64Type::add_day_time_opt(i64::MIN, IntervalDayTime::new(-1, 0)).is_none(), + "Subtracting interval from i64::MIN should fail" + ); + + // Test with extreme interval values + let max_interval = IntervalDayTime::new(i32::MAX, i32::MAX); + assert!( + Date64Type::add_day_time_opt(0, max_interval).is_none(), + "Adding extreme interval should fail" + ); + + let min_interval = IntervalDayTime::new(i32::MIN, i32::MIN); + assert!( + Date64Type::add_day_time_opt(0, min_interval).is_none(), + "Adding extreme negative interval should fail" + ); + + // Test millisecond overflow within a day + let large_ms_interval = IntervalDayTime::new(0, i32::MAX); + assert!( + Date64Type::add_day_time_opt(year_2000_millis, large_ms_interval).is_some(), + "Adding large milliseconds within valid range should succeed" + ); + } + + #[test] + fn test_date64_add_month_day_nano_opt_boundary_values() { + use arrow_array::types::Date64Type; + use arrow_buffer::IntervalMonthDayNano; + + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + + // Test with a large year that is still within chrono's safe range + let near_max_date = NaiveDate::from_ymd_opt(5000, 11, 1).unwrap(); + let near_max_millis = (near_max_date - epoch).num_milliseconds(); + + // Adding 1 month and 30 days should succeed + let interval_safe = IntervalMonthDayNano::new(1, 30, 0); + assert!( + Date64Type::add_month_day_nano_opt(near_max_millis, interval_safe).is_some(), + "Adding 1 month 30 days to large year should succeed" + ); + + // Test normal case within valid range + let year_2000 = NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(); + let year_2000_millis = (year_2000 - epoch).num_milliseconds(); + + // Test edge case: adding zero should always work for valid dates + let zero_interval = IntervalMonthDayNano::new(0, 0, 0); + assert!( + Date64Type::add_month_day_nano_opt(year_2000_millis, zero_interval).is_some(), + "Adding zero interval should always succeed for valid dates" + ); + + // Test with a negative year that is still within chrono's safe range + let near_min_date = NaiveDate::from_ymd_opt(-5000, 2, 28).unwrap(); + let near_min_millis = (near_min_date - epoch).num_milliseconds(); + + // Subtracting 1 month and 30 days should succeed + let interval_safe_neg = IntervalMonthDayNano::new(-1, -30, 0); + assert!( + Date64Type::add_month_day_nano_opt(near_min_millis, interval_safe_neg).is_some(), + "Subtracting 1 month 30 days from large negative year should succeed" + ); + + // Test with extreme input values that would cause overflow + assert!( + Date64Type::add_month_day_nano_opt(i64::MAX, IntervalMonthDayNano::new(1, 0, 0)) + .is_none(), + "Adding interval to i64::MAX should fail" + ); + + let interval_normal = IntervalMonthDayNano::new(2, 10, 123_456_789_000); + assert!( + Date64Type::add_month_day_nano_opt(year_2000_millis, interval_normal).is_some(), + "Adding 2 months, 10 days, and nanos to year 2000 should succeed" + ); + + // Test with extreme input values that would cause overflow + assert!( + Date64Type::add_month_day_nano_opt(i64::MAX, IntervalMonthDayNano::new(1, 0, 0)) + .is_none(), + "Adding interval to i64::MAX should fail" + ); + assert!( + Date64Type::add_month_day_nano_opt(i64::MIN, IntervalMonthDayNano::new(-1, 0, 0)) + .is_none(), + "Subtracting interval from i64::MIN should fail" + ); + + // Test with invalid timestamp input (the _opt function should handle these gracefully) + + // Test nanosecond precision (should not affect boundary since it's < 1ms) + let nano_interval = IntervalMonthDayNano::new(0, 0, 999_999_999); + assert!( + Date64Type::add_month_day_nano_opt(year_2000_millis, nano_interval).is_some(), + "Adding nanoseconds within valid range should succeed" + ); + + // Test large nanosecond values that convert to milliseconds + let large_nano_interval = IntervalMonthDayNano::new(0, 0, 86_400_000_000_000); // 1 day in nanos + assert!( + Date64Type::add_month_day_nano_opt(year_2000_millis, large_nano_interval).is_some(), + "Adding 1 day worth of nanoseconds should succeed" + ); + } + + #[test] + fn test_date64_subtract_year_months_opt_boundary_values() { + use arrow_array::types::Date64Type; + + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + + // Test with a negative year that is still within chrono's safe range + let near_min_date = NaiveDate::from_ymd_opt(-5000, 12, 31).unwrap(); + let near_min_millis = (near_min_date - epoch).num_milliseconds(); + + // Subtracting 12 months should succeed + assert!( + Date64Type::subtract_year_months_opt(near_min_millis, 12).is_some(), + "Subtracting 12 months from year -5000 should succeed" + ); + + // Test normal case within valid range + let year_2000 = NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(); + let year_2000_millis = (year_2000 - epoch).num_milliseconds(); + + // Test edge case: subtracting zero should always work for valid dates + assert!( + Date64Type::subtract_year_months_opt(year_2000_millis, 0).is_some(), + "Subtracting zero months should always succeed for valid dates" + ); + + // Test with a large year that is still within chrono's safe range + let near_max_date = NaiveDate::from_ymd_opt(5000, 1, 1).unwrap(); + let near_max_millis = (near_max_date - epoch).num_milliseconds(); + + // Adding 12 months (subtracting negative) should succeed + assert!( + Date64Type::subtract_year_months_opt(near_max_millis, -12).is_some(), + "Adding 12 months to year 5000 should succeed" + ); + + // Test with extreme input values that would cause overflow + assert!( + Date64Type::subtract_year_months_opt(i64::MAX, -1).is_none(), + "Adding months to i64::MAX should fail" + ); + + assert!( + Date64Type::subtract_year_months_opt(year_2000_millis, 12).is_some(), + "Subtracting 1 year from year 2000 should succeed" + ); + + // Test with extreme input values that would cause overflow + assert!( + Date64Type::subtract_year_months_opt(i64::MAX, -1).is_none(), + "Adding months to i64::MAX should fail" + ); + assert!( + Date64Type::subtract_year_months_opt(i64::MIN, 1).is_none(), + "Subtracting months from i64::MIN should fail" + ); + + // Test edge case: subtracting zero should always work for valid dates + let valid_date = NaiveDate::from_ymd_opt(2020, 6, 15).unwrap(); + let valid_millis = (valid_date - epoch).num_milliseconds(); + assert!( + Date64Type::subtract_year_months_opt(valid_millis, 0).is_some(), + "Subtracting zero months should always succeed for valid dates" + ); + } + + #[test] + fn test_date64_subtract_day_time_opt_boundary_values() { + use arrow_array::types::Date64Type; + use arrow_buffer::IntervalDayTime; + + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + + // Test with a date far from the boundary in the negative direction + let near_min_date = NaiveDate::from_ymd_opt(-200000, 2, 1).unwrap(); + let near_min_millis = (near_min_date - epoch).num_milliseconds(); + + // Subtracting 30 days should succeed + let interval_30_days = IntervalDayTime::new(30, 0); + assert!( + Date64Type::subtract_day_time_opt(near_min_millis, interval_30_days).is_some(), + "Subtracting 30 days from large negative year should succeed" + ); + + // Subtracting a very large number of days should fail + let interval_large_days = IntervalDayTime::new(100000000, 0); + assert!( + Date64Type::subtract_day_time_opt(near_min_millis, interval_large_days).is_none(), + "Subtracting 100M days from large negative year should fail" + ); + + // Test with a date far from the boundary but still testing the function + let near_max_date = NaiveDate::from_ymd_opt(200000, 12, 1).unwrap(); + let near_max_millis = (near_max_date - epoch).num_milliseconds(); + + // Adding 30 days (subtracting negative) should succeed + let interval_minus_30_days = IntervalDayTime::new(-30, 0); + assert!( + Date64Type::subtract_day_time_opt(near_max_millis, interval_minus_30_days).is_some(), + "Adding 30 days to large year should succeed" + ); + + // Adding a very large number of days should fail + let interval_minus_large_days = IntervalDayTime::new(-100000000, 0); + assert!( + Date64Type::subtract_day_time_opt(near_max_millis, interval_minus_large_days).is_none(), + "Adding 100M days to large year should fail" + ); + + // Test normal case within valid range + let year_2000 = NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(); + let year_2000_millis = (year_2000 - epoch).num_milliseconds(); + let interval_1000_days = IntervalDayTime::new(1000, 12345); + assert!( + Date64Type::subtract_day_time_opt(year_2000_millis, interval_1000_days).is_some(), + "Subtracting 1000 days and time from year 2000 should succeed" + ); + + // Test with extreme input values that would cause overflow + let interval_one_day = IntervalDayTime::new(1, 0); + assert!( + Date64Type::subtract_day_time_opt(i64::MIN, interval_one_day).is_none(), + "Subtracting interval from i64::MIN should fail" + ); + assert!( + Date64Type::subtract_day_time_opt(i64::MAX, IntervalDayTime::new(-1, 0)).is_none(), + "Adding interval to i64::MAX should fail" + ); + + // Test with extreme interval values + let max_interval = IntervalDayTime::new(i32::MAX, i32::MAX); + assert!( + Date64Type::subtract_day_time_opt(0, max_interval).is_none(), + "Subtracting extreme interval should fail" + ); + + let min_interval = IntervalDayTime::new(i32::MIN, i32::MIN); + assert!( + Date64Type::subtract_day_time_opt(0, min_interval).is_none(), + "Subtracting extreme negative interval should fail" + ); + + // Test millisecond precision + let large_ms_interval = IntervalDayTime::new(0, i32::MAX); + assert!( + Date64Type::subtract_day_time_opt(year_2000_millis, large_ms_interval).is_some(), + "Subtracting large milliseconds within valid range should succeed" + ); + + // Test edge case: subtracting zero should always work for valid dates + let zero_interval = IntervalDayTime::new(0, 0); + let valid_date = NaiveDate::from_ymd_opt(2020, 6, 15).unwrap(); + let valid_millis = (valid_date - epoch).num_milliseconds(); + assert!( + Date64Type::subtract_day_time_opt(valid_millis, zero_interval).is_some(), + "Subtracting zero interval should always succeed for valid dates" + ); + } + + #[test] + fn test_date64_subtract_month_day_nano_opt_boundary_values() { + use arrow_array::types::Date64Type; + use arrow_buffer::IntervalMonthDayNano; + + let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); + + // Test with a negative year that is still within chrono's safe range + let near_min_date = NaiveDate::from_ymd_opt(-5000, 2, 28).unwrap(); + let near_min_millis = (near_min_date - epoch).num_milliseconds(); + + // Subtracting 1 month and 30 days should succeed + let interval_safe = IntervalMonthDayNano::new(1, 30, 0); + assert!( + Date64Type::subtract_month_day_nano_opt(near_min_millis, interval_safe).is_some(), + "Subtracting 1 month 30 days from large negative year should succeed" + ); + + // Test normal case within valid range + let year_2000 = NaiveDate::from_ymd_opt(2000, 1, 1).unwrap(); + let year_2000_millis = (year_2000 - epoch).num_milliseconds(); + + // Test edge case: subtracting zero should always work for valid dates + let zero_interval = IntervalMonthDayNano::new(0, 0, 0); + assert!( + Date64Type::subtract_month_day_nano_opt(year_2000_millis, zero_interval).is_some(), + "Subtracting zero interval should always succeed for valid dates" + ); + + // Test with a large year that is still within chrono's safe range + let near_max_date = NaiveDate::from_ymd_opt(5000, 11, 1).unwrap(); + let near_max_millis = (near_max_date - epoch).num_milliseconds(); + + // Adding 1 month and 30 days (subtracting negative) should succeed + let interval_safe_neg = IntervalMonthDayNano::new(-1, -30, 0); + assert!( + Date64Type::subtract_month_day_nano_opt(near_max_millis, interval_safe_neg).is_some(), + "Adding 1 month 30 days to large year should succeed" + ); + + // Test with extreme input values that would cause overflow + assert!( + Date64Type::subtract_month_day_nano_opt(i64::MIN, IntervalMonthDayNano::new(1, 0, 0)) + .is_none(), + "Subtracting interval from i64::MIN should fail" + ); + + let interval_normal = IntervalMonthDayNano::new(2, 10, 123_456_789_000); + assert!( + Date64Type::subtract_month_day_nano_opt(year_2000_millis, interval_normal).is_some(), + "Subtracting 2 months, 10 days, and nanos from year 2000 should succeed" + ); + + // Test with extreme input values that would cause overflow + assert!( + Date64Type::subtract_month_day_nano_opt(i64::MIN, IntervalMonthDayNano::new(1, 0, 0)) + .is_none(), + "Subtracting interval from i64::MIN should fail" + ); + assert!( + Date64Type::subtract_month_day_nano_opt(i64::MAX, IntervalMonthDayNano::new(-1, 0, 0)) + .is_none(), + "Adding interval to i64::MAX should fail" + ); + + // Test nanosecond precision (should not affect boundary since it's < 1ms) + let nano_interval = IntervalMonthDayNano::new(0, 0, 999_999_999); + assert!( + Date64Type::subtract_month_day_nano_opt(year_2000_millis, nano_interval).is_some(), + "Subtracting nanoseconds within valid range should succeed" + ); + + // Test large nanosecond values that convert to milliseconds + let large_nano_interval = IntervalMonthDayNano::new(0, 0, 86_400_000_000_000); // 1 day in nanos + assert!( + Date64Type::subtract_month_day_nano_opt(year_2000_millis, large_nano_interval) + .is_some(), + "Subtracting 1 day worth of nanoseconds should succeed" + ); + + // Test edge case: subtracting zero should always work for valid dates + let zero_interval = IntervalMonthDayNano::new(0, 0, 0); + let valid_date = NaiveDate::from_ymd_opt(2020, 6, 15).unwrap(); + let valid_millis = (valid_date - epoch).num_milliseconds(); + assert!( + Date64Type::subtract_month_day_nano_opt(valid_millis, zero_interval).is_some(), + "Subtracting zero interval should always succeed for valid dates" + ); + } } diff --git a/arrow-arith/src/temporal.rs b/arrow-arith/src/temporal.rs index 0b2b98b67b93..faff59bc307d 100644 --- a/arrow-arith/src/temporal.rs +++ b/arrow-arith/src/temporal.rs @@ -24,14 +24,14 @@ use cast::as_primitive_array; use chrono::{Datelike, TimeZone, Timelike, Utc}; use arrow_array::temporal_conversions::{ - date32_to_datetime, date64_to_datetime, timestamp_ms_to_datetime, timestamp_ns_to_datetime, - timestamp_s_to_datetime, timestamp_us_to_datetime, MICROSECONDS, MICROSECONDS_IN_DAY, - MILLISECONDS, MILLISECONDS_IN_DAY, NANOSECONDS, NANOSECONDS_IN_DAY, SECONDS_IN_DAY, + MICROSECONDS, MICROSECONDS_IN_DAY, MILLISECONDS, MILLISECONDS_IN_DAY, NANOSECONDS, + NANOSECONDS_IN_DAY, SECONDS_IN_DAY, date32_to_datetime, date64_to_datetime, + timestamp_ms_to_datetime, timestamp_ns_to_datetime, timestamp_s_to_datetime, + timestamp_us_to_datetime, }; use arrow_array::timezone::Tz; use arrow_array::types::*; use arrow_array::*; -use arrow_buffer::ArrowNativeType; use arrow_schema::{ArrowError, DataType, IntervalUnit, TimeUnit}; /// Valid parts to extract from date/time/timestamp arrays. @@ -79,7 +79,7 @@ pub enum DatePart { impl std::fmt::Display for DatePart { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{:?}", self) + write!(f, "{self:?}") } } @@ -197,16 +197,6 @@ pub fn date_part(array: &dyn Array, part: DatePart) -> Result( - array: &PrimitiveArray, - part: DatePart, -) -> Result { - let array = date_part(array, part)?; - Ok(array.as_primitive::().to_owned()) -} - /// Extract optional [`Tz`] from timestamp data types, returning error /// if called with a non-timestamp type. fn get_tz(dt: &DataType) -> Result, ArrowError> { @@ -660,7 +650,7 @@ impl ExtractDatePartExt for PrimitiveArray { macro_rules! return_compute_error_with { ($msg:expr, $param:expr) => { - return { Err(ArrowError::ComputeError(format!("{}: {:?}", $msg, $param))) } + return { Err(ArrowError::ComputeError(format!("{}: {}", $msg, $param))) } }; } @@ -685,300 +675,26 @@ impl ChronoDateExt for T { } } -/// Extracts the hours of a given array as an array of integers within -/// the range of [0, 23]. If the given array isn't temporal primitive or dictionary array, -/// an `Err` will be returned. -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn hour_dyn(array: &dyn Array) -> Result { - date_part(array, DatePart::Hour) -} - -/// Extracts the hours of a given temporal primitive array as an array of integers within -/// the range of [0, 23]. -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn hour(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - date_part_primitive(array, DatePart::Hour) -} - -/// Extracts the years of a given temporal array as an array of integers. -/// If the given array isn't temporal primitive or dictionary array, -/// an `Err` will be returned. -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn year_dyn(array: &dyn Array) -> Result { - date_part(array, DatePart::Year) -} - -/// Extracts the years of a given temporal primitive array as an array of integers -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn year(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - date_part_primitive(array, DatePart::Year) -} - -/// Extracts the quarter of a given temporal array as an array of integersa within -/// the range of [1, 4]. If the given array isn't temporal primitive or dictionary array, -/// an `Err` will be returned. -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn quarter_dyn(array: &dyn Array) -> Result { - date_part(array, DatePart::Quarter) -} - -/// Extracts the quarter of a given temporal primitive array as an array of integers within -/// the range of [1, 4]. -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn quarter(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - date_part_primitive(array, DatePart::Quarter) -} - -/// Extracts the month of a given temporal array as an array of integers. -/// If the given array isn't temporal primitive or dictionary array, -/// an `Err` will be returned. -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn month_dyn(array: &dyn Array) -> Result { - date_part(array, DatePart::Month) -} - -/// Extracts the month of a given temporal primitive array as an array of integers within -/// the range of [1, 12]. -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn month(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - date_part_primitive(array, DatePart::Month) -} - -/// Extracts the day of week of a given temporal array as an array of -/// integers. -/// -/// Monday is encoded as `0`, Tuesday as `1`, etc. -/// -/// See also [`num_days_from_sunday`] which starts at Sunday. -/// -/// If the given array isn't temporal primitive or dictionary array, -/// an `Err` will be returned. -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn num_days_from_monday_dyn(array: &dyn Array) -> Result { - date_part(array, DatePart::DayOfWeekMonday0) -} - -/// Extracts the day of week of a given temporal primitive array as an array of -/// integers. -/// -/// Monday is encoded as `0`, Tuesday as `1`, etc. -/// -/// See also [`num_days_from_sunday`] which starts at Sunday. -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn num_days_from_monday(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - date_part_primitive(array, DatePart::DayOfWeekMonday0) -} - -/// Extracts the day of week of a given temporal array as an array of -/// integers, starting at Sunday. -/// -/// Sunday is encoded as `0`, Monday as `1`, etc. -/// -/// See also [`num_days_from_monday`] which starts at Monday. -/// -/// If the given array isn't temporal primitive or dictionary array, -/// an `Err` will be returned. -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn num_days_from_sunday_dyn(array: &dyn Array) -> Result { - date_part(array, DatePart::DayOfWeekSunday0) -} - -/// Extracts the day of week of a given temporal primitive array as an array of -/// integers, starting at Sunday. -/// -/// Sunday is encoded as `0`, Monday as `1`, etc. -/// -/// See also [`num_days_from_monday`] which starts at Monday. -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn num_days_from_sunday(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - date_part_primitive(array, DatePart::DayOfWeekSunday0) -} - -/// Extracts the day of a given temporal array as an array of integers. -/// -/// If the given array isn't temporal primitive or dictionary array, -/// an `Err` will be returned. -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn day_dyn(array: &dyn Array) -> Result { - date_part(array, DatePart::Day) -} - -/// Extracts the day of a given temporal primitive array as an array of integers -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn day(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - date_part_primitive(array, DatePart::Day) -} - -/// Extracts the day of year of a given temporal array as an array of integers. -/// -/// The day of year that ranges from 1 to 366. -/// If the given array isn't temporal primitive or dictionary array, -/// an `Err` will be returned. -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn doy_dyn(array: &dyn Array) -> Result { - date_part(array, DatePart::DayOfYear) -} - -/// Extracts the day of year of a given temporal primitive array as an array of integers. -/// -/// The day of year that ranges from 1 to 366 -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn doy(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - T::Native: ArrowNativeType, - i64: From, -{ - date_part_primitive(array, DatePart::DayOfYear) -} - -/// Extracts the minutes of a given temporal primitive array as an array of integers -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn minute(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - date_part_primitive(array, DatePart::Minute) -} - -/// Extracts the week of a given temporal array as an array of integers. -/// If the given array isn't temporal primitive or dictionary array, -/// an `Err` will be returned. -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn week_dyn(array: &dyn Array) -> Result { - date_part(array, DatePart::Week) -} - -/// Extracts the week of a given temporal primitive array as an array of integers -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn week(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - date_part_primitive(array, DatePart::Week) -} - -/// Extracts the seconds of a given temporal primitive array as an array of integers -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn second(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - date_part_primitive(array, DatePart::Second) -} - -/// Extracts the nanoseconds of a given temporal primitive array as an array of integers -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn nanosecond(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - date_part_primitive(array, DatePart::Nanosecond) -} - -/// Extracts the nanoseconds of a given temporal primitive array as an array of integers. -/// If the given array isn't temporal primitive or dictionary array, -/// an `Err` will be returned. -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn nanosecond_dyn(array: &dyn Array) -> Result { - date_part(array, DatePart::Nanosecond) -} - -/// Extracts the microseconds of a given temporal primitive array as an array of integers -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn microsecond(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - date_part_primitive(array, DatePart::Microsecond) -} - -/// Extracts the microseconds of a given temporal primitive array as an array of integers. -/// If the given array isn't temporal primitive or dictionary array, -/// an `Err` will be returned. -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn microsecond_dyn(array: &dyn Array) -> Result { - date_part(array, DatePart::Microsecond) -} - -/// Extracts the milliseconds of a given temporal primitive array as an array of integers -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn millisecond(array: &PrimitiveArray) -> Result -where - T: ArrowTemporalType + ArrowNumericType, - i64: From, -{ - date_part_primitive(array, DatePart::Millisecond) -} - -/// Extracts the milliseconds of a given temporal primitive array as an array of integers. -/// If the given array isn't temporal primitive or dictionary array, -/// an `Err` will be returned. -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn millisecond_dyn(array: &dyn Array) -> Result { - date_part(array, DatePart::Millisecond) -} - -/// Extracts the minutes of a given temporal array as an array of integers. -/// If the given array isn't temporal primitive or dictionary array, -/// an `Err` will be returned. -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn minute_dyn(array: &dyn Array) -> Result { - date_part(array, DatePart::Minute) -} - -/// Extracts the seconds of a given temporal array as an array of integers. -/// If the given array isn't temporal primitive or dictionary array, -/// an `Err` will be returned. -#[deprecated(since = "51.0.0", note = "Use `date_part` instead")] -pub fn second_dyn(array: &dyn Array) -> Result { - date_part(array, DatePart::Second) -} - #[cfg(test)] -#[allow(deprecated)] mod tests { use super::*; + /// Used to integrate new [`date_part()`] method with deprecated shims such as + /// [`hour()`] and [`week()`]. + fn date_part_primitive( + array: &PrimitiveArray, + part: DatePart, + ) -> Result { + let array = date_part(array, part)?; + Ok(array.as_primitive::().to_owned()) + } + #[test] fn test_temporal_array_date64_hour() { let a: PrimitiveArray = vec![Some(1514764800000), None, Some(1550636625000)].into(); - let b = hour(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Hour).unwrap(); assert_eq!(0, b.value(0)); assert!(!b.is_valid(1)); assert_eq!(4, b.value(2)); @@ -988,7 +704,7 @@ mod tests { fn test_temporal_array_date32_hour() { let a: PrimitiveArray = vec![Some(15147), None, Some(15148)].into(); - let b = hour(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Hour).unwrap(); assert_eq!(0, b.value(0)); assert!(!b.is_valid(1)); assert_eq!(0, b.value(2)); @@ -998,7 +714,7 @@ mod tests { fn test_temporal_array_time32_second_hour() { let a: PrimitiveArray = vec![37800, 86339].into(); - let b = hour(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Hour).unwrap(); assert_eq!(10, b.value(0)); assert_eq!(23, b.value(1)); } @@ -1007,7 +723,7 @@ mod tests { fn test_temporal_array_time64_micro_hour() { let a: PrimitiveArray = vec![37800000000, 86339000000].into(); - let b = hour(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Hour).unwrap(); assert_eq!(10, b.value(0)); assert_eq!(23, b.value(1)); } @@ -1016,7 +732,7 @@ mod tests { fn test_temporal_array_timestamp_micro_hour() { let a: TimestampMicrosecondArray = vec![37800000000, 86339000000].into(); - let b = hour(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Hour).unwrap(); assert_eq!(10, b.value(0)); assert_eq!(23, b.value(1)); } @@ -1026,7 +742,7 @@ mod tests { let a: PrimitiveArray = vec![Some(1514764800000), None, Some(1550636625000)].into(); - let b = year(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Year).unwrap(); assert_eq!(2018, b.value(0)); assert!(!b.is_valid(1)); assert_eq!(2019, b.value(2)); @@ -1036,7 +752,7 @@ mod tests { fn test_temporal_array_date32_year() { let a: PrimitiveArray = vec![Some(15147), None, Some(15448)].into(); - let b = year(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Year).unwrap(); assert_eq!(2011, b.value(0)); assert!(!b.is_valid(1)); assert_eq!(2012, b.value(2)); @@ -1049,7 +765,7 @@ mod tests { let a: PrimitiveArray = vec![Some(1514764800000), None, Some(1566275025000)].into(); - let b = quarter(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Quarter).unwrap(); assert_eq!(1, b.value(0)); assert!(!b.is_valid(1)); assert_eq!(3, b.value(2)); @@ -1059,7 +775,7 @@ mod tests { fn test_temporal_array_date32_quarter() { let a: PrimitiveArray = vec![Some(1), None, Some(300)].into(); - let b = quarter(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Quarter).unwrap(); assert_eq!(1, b.value(0)); assert!(!b.is_valid(1)); assert_eq!(4, b.value(2)); @@ -1069,10 +785,10 @@ mod tests { fn test_temporal_array_timestamp_quarter_with_timezone() { // 24 * 60 * 60 = 86400 let a = TimestampSecondArray::from(vec![86400 * 90]).with_timezone("+00:00".to_string()); - let b = quarter(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Quarter).unwrap(); assert_eq!(2, b.value(0)); let a = TimestampSecondArray::from(vec![86400 * 90]).with_timezone("-10:00".to_string()); - let b = quarter(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Quarter).unwrap(); assert_eq!(1, b.value(0)); } @@ -1083,7 +799,7 @@ mod tests { let a: PrimitiveArray = vec![Some(1514764800000), None, Some(1550636625000)].into(); - let b = month(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Month).unwrap(); assert_eq!(1, b.value(0)); assert!(!b.is_valid(1)); assert_eq!(2, b.value(2)); @@ -1093,7 +809,7 @@ mod tests { fn test_temporal_array_date32_month() { let a: PrimitiveArray = vec![Some(1), None, Some(31)].into(); - let b = month(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Month).unwrap(); assert_eq!(1, b.value(0)); assert!(!b.is_valid(1)); assert_eq!(2, b.value(2)); @@ -1103,10 +819,10 @@ mod tests { fn test_temporal_array_timestamp_month_with_timezone() { // 24 * 60 * 60 = 86400 let a = TimestampSecondArray::from(vec![86400 * 31]).with_timezone("+00:00".to_string()); - let b = month(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Month).unwrap(); assert_eq!(2, b.value(0)); let a = TimestampSecondArray::from(vec![86400 * 31]).with_timezone("-10:00".to_string()); - let b = month(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Month).unwrap(); assert_eq!(1, b.value(0)); } @@ -1114,10 +830,10 @@ mod tests { fn test_temporal_array_timestamp_day_with_timezone() { // 24 * 60 * 60 = 86400 let a = TimestampSecondArray::from(vec![86400]).with_timezone("+00:00".to_string()); - let b = day(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Day).unwrap(); assert_eq!(2, b.value(0)); let a = TimestampSecondArray::from(vec![86400]).with_timezone("-10:00".to_string()); - let b = day(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Day).unwrap(); assert_eq!(1, b.value(0)); } @@ -1128,7 +844,7 @@ mod tests { let a: PrimitiveArray = vec![Some(1514764800000), None, Some(1550636625000)].into(); - let b = num_days_from_monday(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::DayOfWeekMonday0).unwrap(); assert_eq!(0, b.value(0)); assert!(!b.is_valid(1)); assert_eq!(2, b.value(2)); @@ -1147,7 +863,7 @@ mod tests { ] .into(); - let b = num_days_from_sunday(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::DayOfWeekSunday0).unwrap(); assert_eq!(0, b.value(0)); assert!(!b.is_valid(1)); assert_eq!(1, b.value(2)); @@ -1161,7 +877,7 @@ mod tests { let a: PrimitiveArray = vec![Some(1514764800000), None, Some(1550636625000)].into(); - let b = day(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Day).unwrap(); assert_eq!(1, b.value(0)); assert!(!b.is_valid(1)); assert_eq!(20, b.value(2)); @@ -1171,7 +887,7 @@ mod tests { fn test_temporal_array_date32_day() { let a: PrimitiveArray = vec![Some(0), None, Some(31)].into(); - let b = day(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Day).unwrap(); assert_eq!(1, b.value(0)); assert!(!b.is_valid(1)); assert_eq!(1, b.value(2)); @@ -1190,7 +906,7 @@ mod tests { ] .into(); - let b = doy(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::DayOfYear).unwrap(); assert_eq!(1, b.value(0)); assert_eq!(1, b.value(1)); assert!(!b.is_valid(2)); @@ -1202,7 +918,7 @@ mod tests { let a: TimestampMicrosecondArray = vec![Some(1612025847000000), None, Some(1722015847000000)].into(); - let b = year(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Year).unwrap(); assert_eq!(2021, b.value(0)); assert!(!b.is_valid(1)); assert_eq!(2024, b.value(2)); @@ -1213,7 +929,7 @@ mod tests { let a: PrimitiveArray = vec![Some(1514764800000), None, Some(1550636625000)].into(); - let b = minute(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Minute).unwrap(); assert_eq!(0, b.value(0)); assert!(!b.is_valid(1)); assert_eq!(23, b.value(2)); @@ -1224,7 +940,7 @@ mod tests { let a: TimestampMicrosecondArray = vec![Some(1612025847000000), None, Some(1722015847000000)].into(); - let b = minute(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Minute).unwrap(); assert_eq!(57, b.value(0)); assert!(!b.is_valid(1)); assert_eq!(44, b.value(2)); @@ -1234,7 +950,7 @@ mod tests { fn test_temporal_array_date32_week() { let a: PrimitiveArray = vec![Some(0), None, Some(7)].into(); - let b = week(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Week).unwrap(); assert_eq!(1, b.value(0)); assert!(!b.is_valid(1)); assert_eq!(2, b.value(2)); @@ -1252,7 +968,7 @@ mod tests { ] .into(); - let b = week(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Week).unwrap(); assert_eq!(9, b.value(0)); assert!(!b.is_valid(1)); assert_eq!(1, b.value(2)); @@ -1266,7 +982,7 @@ mod tests { let a: TimestampMicrosecondArray = vec![Some(1612025847000000), None, Some(1722015847000000)].into(); - let b = week(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Week).unwrap(); assert_eq!(4, b.value(0)); assert!(!b.is_valid(1)); assert_eq!(30, b.value(2)); @@ -1277,7 +993,7 @@ mod tests { let a: PrimitiveArray = vec![Some(1514764800000), None, Some(1550636625000)].into(); - let b = second(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Second).unwrap(); assert_eq!(0, b.value(0)); assert!(!b.is_valid(1)); assert_eq!(45, b.value(2)); @@ -1288,7 +1004,7 @@ mod tests { let a: TimestampMicrosecondArray = vec![Some(1612025847000000), None, Some(1722015847000000)].into(); - let b = second(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Second).unwrap(); assert_eq!(27, b.value(0)); assert!(!b.is_valid(1)); assert_eq!(7, b.value(2)); @@ -1297,7 +1013,7 @@ mod tests { #[test] fn test_temporal_array_timestamp_second_with_timezone() { let a = TimestampSecondArray::from(vec![10, 20]).with_timezone("+00:00".to_string()); - let b = second(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Second).unwrap(); assert_eq!(10, b.value(0)); assert_eq!(20, b.value(1)); } @@ -1305,7 +1021,7 @@ mod tests { #[test] fn test_temporal_array_timestamp_minute_with_timezone() { let a = TimestampSecondArray::from(vec![0, 60]).with_timezone("+00:50".to_string()); - let b = minute(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Minute).unwrap(); assert_eq!(50, b.value(0)); assert_eq!(51, b.value(1)); } @@ -1313,42 +1029,46 @@ mod tests { #[test] fn test_temporal_array_timestamp_minute_with_negative_timezone() { let a = TimestampSecondArray::from(vec![60 * 55]).with_timezone("-00:50".to_string()); - let b = minute(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Minute).unwrap(); assert_eq!(5, b.value(0)); } #[test] fn test_temporal_array_timestamp_hour_with_timezone() { let a = TimestampSecondArray::from(vec![60 * 60 * 10]).with_timezone("+01:00".to_string()); - let b = hour(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Hour).unwrap(); assert_eq!(11, b.value(0)); } #[test] fn test_temporal_array_timestamp_hour_with_timezone_without_colon() { let a = TimestampSecondArray::from(vec![60 * 60 * 10]).with_timezone("+0100".to_string()); - let b = hour(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Hour).unwrap(); assert_eq!(11, b.value(0)); } #[test] fn test_temporal_array_timestamp_hour_with_timezone_without_minutes() { let a = TimestampSecondArray::from(vec![60 * 60 * 10]).with_timezone("+01".to_string()); - let b = hour(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Hour).unwrap(); assert_eq!(11, b.value(0)); } #[test] fn test_temporal_array_timestamp_hour_with_timezone_without_initial_sign() { let a = TimestampSecondArray::from(vec![60 * 60 * 10]).with_timezone("0100".to_string()); - let err = hour(&a).unwrap_err().to_string(); + let err = date_part_primitive(&a, DatePart::Hour) + .unwrap_err() + .to_string(); assert!(err.contains("Invalid timezone"), "{}", err); } #[test] fn test_temporal_array_timestamp_hour_with_timezone_with_only_colon() { let a = TimestampSecondArray::from(vec![60 * 60 * 10]).with_timezone("01:00".to_string()); - let err = hour(&a).unwrap_err().to_string(); + let err = date_part_primitive(&a, DatePart::Hour) + .unwrap_err() + .to_string(); assert!(err.contains("Invalid timezone"), "{}", err); } @@ -1358,7 +1078,7 @@ mod tests { // 1970-01-01T00:00:00 + 4 days -> 1970-01-05T00:00:00 Monday (week 2) // 1970-01-01T00:00:00 + 4 days - 1 second -> 1970-01-04T23:59:59 Sunday (week 1) let a = TimestampSecondArray::from(vec![0, 86400 * 4, 86400 * 4 - 1]); - let b = week(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Week).unwrap(); assert_eq!(1, b.value(0)); assert_eq!(2, b.value(1)); assert_eq!(1, b.value(2)); @@ -1371,7 +1091,7 @@ mod tests { // 1970-01-01T01:00:00+01:00 + 4 days - 1 second -> 1970-01-05T00:59:59+01:00 Monday (week 2) let a = TimestampSecondArray::from(vec![0, 86400 * 4, 86400 * 4 - 1]) .with_timezone("+01:00".to_string()); - let b = week(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Week).unwrap(); assert_eq!(1, b.value(0)); assert_eq!(2, b.value(1)); assert_eq!(2, b.value(2)); @@ -1389,7 +1109,7 @@ mod tests { let keys = Int8Array::from_iter_values([0_i8, 0, 1, 2, 1]); let dict = DictionaryArray::try_new(keys.clone(), Arc::new(a)).unwrap(); - let b = hour_dyn(&dict).unwrap(); + let b = date_part(&dict, DatePart::Hour).unwrap(); let expected_dict = DictionaryArray::new(keys.clone(), Arc::new(Int32Array::from(vec![11, 21, 7]))); @@ -1398,7 +1118,7 @@ mod tests { let b = date_part(&dict, DatePart::Minute).unwrap(); - let b_old = minute_dyn(&dict).unwrap(); + let b_old = date_part(&dict, DatePart::Minute).unwrap(); let expected_dict = DictionaryArray::new(keys.clone(), Arc::new(Int32Array::from(vec![1, 2, 3]))); @@ -1408,7 +1128,7 @@ mod tests { let b = date_part(&dict, DatePart::Second).unwrap(); - let b_old = second_dyn(&dict).unwrap(); + let b_old = date_part(&dict, DatePart::Second).unwrap(); let expected_dict = DictionaryArray::new(keys.clone(), Arc::new(Int32Array::from(vec![1, 2, 3]))); @@ -1431,7 +1151,7 @@ mod tests { let keys = Int8Array::from_iter_values([0_i8, 1, 1, 0]); let dict = DictionaryArray::new(keys.clone(), Arc::new(a)); - let b = year_dyn(&dict).unwrap(); + let b = date_part(&dict, DatePart::Year).unwrap(); let expected_dict = DictionaryArray::new( keys, @@ -1450,13 +1170,13 @@ mod tests { let keys = Int8Array::from_iter_values([0_i8, 1, 1, 0]); let dict = DictionaryArray::new(keys.clone(), Arc::new(a)); - let b = quarter_dyn(&dict).unwrap(); + let b = date_part(&dict, DatePart::Quarter).unwrap(); let expected = DictionaryArray::new(keys.clone(), Arc::new(Int32Array::from(vec![1, 3, 3, 1]))); assert_eq!(b.as_ref(), &expected); - let b = month_dyn(&dict).unwrap(); + let b = date_part(&dict, DatePart::Month).unwrap(); let expected = DictionaryArray::new(keys, Arc::new(Int32Array::from(vec![1, 8, 8, 1]))); assert_eq!(b.as_ref(), &expected); @@ -1471,31 +1191,31 @@ mod tests { let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(1), Some(0), None]); let dict = DictionaryArray::new(keys.clone(), Arc::new(a)); - let b = num_days_from_monday_dyn(&dict).unwrap(); + let b = date_part(&dict, DatePart::DayOfWeekMonday0).unwrap(); let a = Int32Array::from(vec![Some(0), Some(2), Some(2), Some(0), None]); let expected = DictionaryArray::new(keys.clone(), Arc::new(a)); assert_eq!(b.as_ref(), &expected); - let b = num_days_from_sunday_dyn(&dict).unwrap(); + let b = date_part(&dict, DatePart::DayOfWeekSunday0).unwrap(); let a = Int32Array::from(vec![Some(1), Some(3), Some(3), Some(1), None]); let expected = DictionaryArray::new(keys.clone(), Arc::new(a)); assert_eq!(b.as_ref(), &expected); - let b = day_dyn(&dict).unwrap(); + let b = date_part(&dict, DatePart::Day).unwrap(); let a = Int32Array::from(vec![Some(1), Some(20), Some(20), Some(1), None]); let expected = DictionaryArray::new(keys.clone(), Arc::new(a)); assert_eq!(b.as_ref(), &expected); - let b = doy_dyn(&dict).unwrap(); + let b = date_part(&dict, DatePart::DayOfYear).unwrap(); let a = Int32Array::from(vec![Some(1), Some(51), Some(51), Some(1), None]); let expected = DictionaryArray::new(keys.clone(), Arc::new(a)); assert_eq!(b.as_ref(), &expected); - let b = week_dyn(&dict).unwrap(); + let b = date_part(&dict, DatePart::Week).unwrap(); let a = Int32Array::from(vec![Some(1), Some(8), Some(8), Some(1), None]); let expected = DictionaryArray::new(keys, Arc::new(a)); @@ -1512,13 +1232,13 @@ mod tests { let a: PrimitiveArray = vec![None, Some(1667328721453)].into(); - let b = nanosecond(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Nanosecond).unwrap(); assert!(!b.is_valid(0)); assert_eq!(453_000_000, b.value(1)); let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(1)]); let dict = DictionaryArray::new(keys.clone(), Arc::new(a)); - let b = nanosecond_dyn(&dict).unwrap(); + let b = date_part(&dict, DatePart::Nanosecond).unwrap(); let a = Int32Array::from(vec![None, Some(453_000_000)]); let expected_dict = DictionaryArray::new(keys, Arc::new(a)); @@ -1530,13 +1250,13 @@ mod tests { fn test_temporal_array_date64_microsecond() { let a: PrimitiveArray = vec![None, Some(1667328721453)].into(); - let b = microsecond(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Microsecond).unwrap(); assert!(!b.is_valid(0)); assert_eq!(453_000, b.value(1)); let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(1)]); let dict = DictionaryArray::new(keys.clone(), Arc::new(a)); - let b = microsecond_dyn(&dict).unwrap(); + let b = date_part(&dict, DatePart::Microsecond).unwrap(); let a = Int32Array::from(vec![None, Some(453_000)]); let expected_dict = DictionaryArray::new(keys, Arc::new(a)); @@ -1548,13 +1268,13 @@ mod tests { fn test_temporal_array_date64_millisecond() { let a: PrimitiveArray = vec![None, Some(1667328721453)].into(); - let b = millisecond(&a).unwrap(); + let b = date_part_primitive(&a, DatePart::Millisecond).unwrap(); assert!(!b.is_valid(0)); assert_eq!(453, b.value(1)); let keys = Int8Array::from(vec![Some(0_i8), Some(1), Some(1)]); let dict = DictionaryArray::new(keys.clone(), Arc::new(a)); - let b = millisecond_dyn(&dict).unwrap(); + let b = date_part(&dict, DatePart::Millisecond).unwrap(); let a = Int32Array::from(vec![None, Some(453)]); let expected_dict = DictionaryArray::new(keys, Arc::new(a)); diff --git a/arrow-array/Cargo.toml b/arrow-array/Cargo.toml index a65c0c9ca8e6..8ab0bb290e96 100644 --- a/arrow-array/Cargo.toml +++ b/arrow-array/Cargo.toml @@ -44,9 +44,11 @@ arrow-schema = { workspace = true } arrow-data = { workspace = true } chrono = { workspace = true } chrono-tz = { version = "0.10", optional = true } -num = { version = "0.4.1", default-features = false, features = ["std"] } +num-complex = { version = "0.4.6", default-features = false, features = ["std"] } +num-integer = { version = "0.1.46", default-features = false, features = ["std"] } +num-traits = { version = "0.2.19", default-features = false, features = ["std"] } half = { version = "2.1", default-features = false, features = ["num-traits"] } -hashbrown = { version = "0.15.1", default-features = false } +hashbrown = { version = "0.16.0", default-features = false } [package.metadata.docs.rs] all-features = true @@ -57,14 +59,14 @@ force_validate = [] [dev-dependencies] rand = { version = "0.9", default-features = false, features = ["std", "std_rng", "thread_rng"] } -criterion = { version = "0.5", default-features = false } +criterion = { workspace = true, default-features = false } [[bench]] name = "occupancy" harness = false [[bench]] -name = "gc_view_types" +name = "view_types" harness = false [[bench]] @@ -78,3 +80,7 @@ harness = false [[bench]] name = "union_array" harness = false + +[[bench]] +name = "record_batch" +harness = false \ No newline at end of file diff --git a/arrow-array/benches/fixed_size_list_array.rs b/arrow-array/benches/fixed_size_list_array.rs index 2bdb0c252b8a..72319cdb9b3c 100644 --- a/arrow-array/benches/fixed_size_list_array.rs +++ b/arrow-array/benches/fixed_size_list_array.rs @@ -18,7 +18,7 @@ use arrow_array::{Array, FixedSizeListArray, Int32Array}; use arrow_schema::Field; use criterion::*; -use rand::{rng, Rng}; +use rand::{Rng, rng}; use std::{hint, sync::Arc}; fn gen_fsl(len: usize, value_len: usize) -> FixedSizeListArray { diff --git a/arrow-array/benches/occupancy.rs b/arrow-array/benches/occupancy.rs index 283020364199..c088577bc37b 100644 --- a/arrow-array/benches/occupancy.rs +++ b/arrow-array/benches/occupancy.rs @@ -19,7 +19,7 @@ use arrow_array::types::Int32Type; use arrow_array::{DictionaryArray, Int32Array}; use arrow_buffer::NullBuffer; use criterion::*; -use rand::{rng, Rng}; +use rand::{Rng, rng}; use std::{hint, sync::Arc}; fn gen_dict( diff --git a/arrow-array/benches/record_batch.rs b/arrow-array/benches/record_batch.rs new file mode 100644 index 000000000000..5f2ba5d3d7b5 --- /dev/null +++ b/arrow-array/benches/record_batch.rs @@ -0,0 +1,93 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::{ArrayRef, Int64Array, RecordBatch, RecordBatchOptions}; +use arrow_schema::{DataType, Field, Schema, SchemaRef}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; +use num_integer::Integer; +use std::hint::black_box; +use std::sync::Arc; + +fn make_record_batch(column_count: usize, row_count: usize) -> RecordBatch { + let fields = (0..column_count) + .map(|i| Field::new(format!("col_{}", i), DataType::Int64, i.is_even())) + .collect::>(); + + let columns = fields + .iter() + .map(|_| { + let array_ref: ArrayRef = Arc::new(Int64Array::from_value(0, row_count)); + array_ref + }) + .collect::>(); + + let schema = Schema::new(fields); + + let mut options = RecordBatchOptions::new(); + options.row_count = Some(row_count); + + RecordBatch::try_new_with_options(SchemaRef::new(schema), columns, &options).unwrap() +} + +fn project_benchmark( + c: &mut Criterion, + column_count: usize, + row_count: usize, + projection_size: usize, +) { + let input = make_input(column_count, row_count, projection_size); + + c.bench_with_input( + BenchmarkId::new( + "project", + format!( + "{:?}x{:?} -> {:?}x{:?}", + input.0.num_columns(), + input.0.num_rows(), + input.1.len(), + input.0.num_rows() + ), + ), + &input, + |b, (rb, projection)| { + b.iter(|| black_box(rb.project(projection).unwrap())); + }, + ); +} + +fn make_input( + column_count: usize, + row_count: usize, + projection_size: usize, +) -> (RecordBatch, Vec) { + let rb = make_record_batch(column_count, row_count); + let projection = (0..projection_size).collect::>(); + (rb, projection) +} + +fn criterion_benchmark(c: &mut Criterion) { + [10, 100, 1000].iter().for_each(|&column_count| { + [1, column_count / 2, column_count - 1] + .iter() + .for_each(|&projection_size| { + project_benchmark(c, column_count, 8192, projection_size); + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/arrow-array/benches/union_array.rs b/arrow-array/benches/union_array.rs index f3894e249f4c..414529882a29 100644 --- a/arrow-array/benches/union_array.rs +++ b/arrow-array/benches/union_array.rs @@ -15,17 +15,13 @@ // specific language governing permissions and limitations // under the License. -use std::{ - hint, - iter::{repeat, repeat_with}, - sync::Arc, -}; +use std::{hint, iter::repeat_with, sync::Arc}; use arrow_array::{Array, ArrayRef, Int32Array, UnionArray}; use arrow_buffer::{NullBuffer, ScalarBuffer}; use arrow_schema::{DataType, Field, UnionFields}; use criterion::*; -use rand::{rng, Rng}; +use rand::{Rng, rng}; fn array_with_nulls() -> ArrayRef { let mut rng = rng(); @@ -58,18 +54,17 @@ fn criterion_benchmark(c: &mut Criterion) { |b| { let type_ids = 0..with_nulls+without_nulls; - let fields = UnionFields::new( + let fields = UnionFields::try_new( type_ids.clone(), type_ids.clone().map(|i| Field::new(format!("f{i}"), DataType::Int32, true)), - ); + ).unwrap(); let array = UnionArray::try_new( fields, type_ids.cycle().take(4096).collect(), None, - repeat(array_with_nulls()) - .take(with_nulls as usize) - .chain(repeat(array_without_nulls()).take(without_nulls as usize)) + std::iter::repeat_n(array_with_nulls(), with_nulls as usize) + .chain(std::iter::repeat_n(array_without_nulls(), without_nulls as usize)) .collect(), ) .unwrap(); diff --git a/arrow-array/benches/view_types.rs b/arrow-array/benches/view_types.rs new file mode 100644 index 000000000000..e194c268c19d --- /dev/null +++ b/arrow-array/benches/view_types.rs @@ -0,0 +1,112 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::StringViewArray; +use criterion::{Criterion, criterion_group, criterion_main}; +use std::hint::black_box; + +fn gen_view_array(size: usize) -> StringViewArray { + StringViewArray::from_iter((0..size).map(|v| match v % 3 { + 0 => Some("small"), + 1 => Some("larger than 12 bytes array"), + 2 => None, + _ => unreachable!("unreachable"), + })) +} + +fn gen_view_array_without_nulls(size: usize) -> StringViewArray { + StringViewArray::from_iter((0..size).map(|v| { + let s = match v % 3 { + 0 => "small".to_string(), // < 12 bytes + 1 => "larger than 12 bytes array".to_string(), // >12 bytes + 2 => "x".repeat(300), // 300 bytes (>256) + _ => unreachable!(), + }; + Some(s) + })) +} + +fn criterion_benchmark(c: &mut Criterion) { + let array = gen_view_array(100_000); + + c.bench_function("view types slice", |b| { + b.iter(|| { + black_box(array.slice(0, 100_000 / 2)); + }); + }); + + c.bench_function("gc view types all[100000]", |b| { + b.iter(|| { + black_box(array.gc()); + }); + }); + + let sliced = array.slice(0, 100_000 / 2); + c.bench_function("gc view types slice half[100000]", |b| { + b.iter(|| { + black_box(sliced.gc()); + }); + }); + + let array = gen_view_array_without_nulls(100_000); + + c.bench_function("gc view types all without nulls[100000]", |b| { + b.iter(|| { + black_box(array.gc()); + }); + }); + + let sliced = array.slice(0, 100_000 / 2); + c.bench_function("gc view types slice half without nulls[100000]", |b| { + b.iter(|| { + black_box(sliced.gc()); + }); + }); + + let array = gen_view_array(8000); + + c.bench_function("gc view types all[8000]", |b| { + b.iter(|| { + black_box(array.gc()); + }); + }); + + let sliced = array.slice(0, 8000 / 2); + c.bench_function("gc view types slice half[8000]", |b| { + b.iter(|| { + black_box(sliced.gc()); + }); + }); + + let array = gen_view_array_without_nulls(8000); + + c.bench_function("gc view types all without nulls[8000]", |b| { + b.iter(|| { + black_box(array.gc()); + }); + }); + + let sliced = array.slice(0, 8000 / 2); + c.bench_function("gc view types slice half without nulls[8000]", |b| { + b.iter(|| { + black_box(sliced.gc()); + }); + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/arrow-array/src/arithmetic.rs b/arrow-array/src/arithmetic.rs index b5f4a106f5ad..52708da7810f 100644 --- a/arrow-array/src/arithmetic.rs +++ b/arrow-array/src/arithmetic.rs @@ -15,10 +15,10 @@ // specific language governing permissions and limitations // under the License. -use arrow_buffer::{i256, ArrowNativeType, IntervalDayTime, IntervalMonthDayNano}; +use arrow_buffer::{ArrowNativeType, IntervalDayTime, IntervalMonthDayNano, i256}; use arrow_schema::ArrowError; use half::f16; -use num::complex::ComplexFloat; +use num_complex::ComplexFloat; use std::cmp::Ordering; /// Trait for [`ArrowNativeType`] that adds checked and unchecked arithmetic operations, @@ -288,7 +288,7 @@ native_type_op!(u8); native_type_op!(u16); native_type_op!(u32); native_type_op!(u64); -native_type_op!(i256, i256::ZERO, i256::ONE, i256::MIN, i256::MAX); +native_type_op!(i256, i256::ZERO, i256::ONE); native_type_op!(IntervalDayTime, IntervalDayTime::ZERO, IntervalDayTime::ONE); native_type_op!( @@ -418,15 +418,35 @@ native_type_float_op!( f32, 0., 1., - unsafe { std::mem::transmute(-1_i32) }, - unsafe { std::mem::transmute(i32::MAX) } + unsafe { + // Need to allow in clippy because + // current MSRV (Minimum Supported Rust Version) is `1.85.0` but this item is stable since `1.87.0` + #[allow(unnecessary_transmutes)] + std::mem::transmute(-1_i32) + }, + unsafe { + // Need to allow in clippy because + // current MSRV (Minimum Supported Rust Version) is `1.85.0` but this item is stable since `1.87.0` + #[allow(unnecessary_transmutes)] + std::mem::transmute(i32::MAX) + } ); native_type_float_op!( f64, 0., 1., - unsafe { std::mem::transmute(-1_i64) }, - unsafe { std::mem::transmute(i64::MAX) } + unsafe { + // Need to allow in clippy because + // current MSRV (Minimum Supported Rust Version) is `1.85.0` but this item is stable since `1.87.0` + #[allow(unnecessary_transmutes)] + std::mem::transmute(-1_i64) + }, + unsafe { + // Need to allow in clippy because + // current MSRV (Minimum Supported Rust Version) is `1.85.0` but this item is stable since `1.87.0` + #[allow(unnecessary_transmutes)] + std::mem::transmute(i64::MAX) + } ); #[cfg(test)] @@ -434,9 +454,7 @@ mod tests { use super::*; macro_rules! assert_approx_eq { - ( $x: expr, $y: expr ) => {{ - assert_approx_eq!($x, $y, 1.0e-4) - }}; + ( $x: expr, $y: expr ) => {{ assert_approx_eq!($x, $y, 1.0e-4) }}; ( $x: expr, $y: expr, $tol: expr ) => {{ let x_val = $x; let y_val = $y; diff --git a/arrow-array/src/array/binary_array.rs b/arrow-array/src/array/binary_array.rs index 8e2158416f49..7cfa1b52728e 100644 --- a/arrow-array/src/array/binary_array.rs +++ b/arrow-array/src/array/binary_array.rs @@ -90,7 +90,7 @@ impl GenericBinaryArray { &'a self, indexes: impl Iterator> + 'a, ) -> impl Iterator> { - indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index))) + unsafe { indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index))) } } } diff --git a/arrow-array/src/array/boolean_array.rs b/arrow-array/src/array/boolean_array.rs index fcebf5a0f718..acea680ae374 100644 --- a/arrow-array/src/array/boolean_array.rs +++ b/arrow-array/src/array/boolean_array.rs @@ -19,7 +19,7 @@ use crate::array::print_long_array; use crate::builder::BooleanBuilder; use crate::iterator::BooleanIter; use crate::{Array, ArrayAccessor, ArrayRef, Scalar}; -use arrow_buffer::{bit_util, BooleanBuffer, Buffer, MutableBuffer, NullBuffer}; +use arrow_buffer::{BooleanBuffer, Buffer, MutableBuffer, NullBuffer, bit_util}; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::DataType; use std::any::Any; @@ -178,13 +178,20 @@ impl BooleanArray { /// Returns the boolean value at index `i`. /// + /// Note: This method does not check for nulls and the value is arbitrary + /// if [`is_null`](Self::is_null) returns true for the index. + /// /// # Safety /// This doesn't check bounds, the caller must ensure that index < self.len() pub unsafe fn value_unchecked(&self, i: usize) -> bool { - self.values.value_unchecked(i) + unsafe { self.values.value_unchecked(i) } } /// Returns the boolean value at index `i`. + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// if [`is_null`](Self::is_null) returns true for the index. + /// /// # Panics /// Panics if index `i` is out of bounds pub fn value(&self, i: usize) -> bool { @@ -215,7 +222,7 @@ impl BooleanArray { &'a self, indexes: impl Iterator> + 'a, ) -> impl Iterator> + 'a { - indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index))) + indexes.map(|opt_index| opt_index.map(|index| unsafe { self.value_unchecked(index) })) } /// Create a [`BooleanArray`] by evaluating the operation for @@ -279,6 +286,8 @@ impl BooleanArray { } } +impl super::private::Sealed for BooleanArray {} + impl Array for BooleanArray { fn as_any(&self) -> &dyn Any { self @@ -348,7 +357,7 @@ impl ArrayAccessor for &BooleanArray { } unsafe fn value_unchecked(&self, index: usize) -> Self::Item { - BooleanArray::value_unchecked(self, index) + unsafe { BooleanArray::value_unchecked(self, index) } } } @@ -429,11 +438,84 @@ impl<'a> BooleanArray { } } -impl>> FromIterator for BooleanArray { +/// An optional boolean value +/// +/// This struct is used as an adapter when creating `BooleanArray` from an iterator. +/// `FromIterator` for `BooleanArray` takes an iterator where the elements can be `into` +/// this struct. So once implementing `From` or `Into` trait for a type, an iterator of +/// the type can be collected to `BooleanArray`. +/// +/// See also [NativeAdapter](crate::array::NativeAdapter). +#[derive(Debug)] +struct BooleanAdapter { + /// Corresponding Rust native type if available + pub native: Option, +} + +impl From for BooleanAdapter { + fn from(value: bool) -> Self { + BooleanAdapter { + native: Some(value), + } + } +} + +impl From<&bool> for BooleanAdapter { + fn from(value: &bool) -> Self { + BooleanAdapter { + native: Some(*value), + } + } +} + +impl From> for BooleanAdapter { + fn from(value: Option) -> Self { + BooleanAdapter { native: value } + } +} + +impl From<&Option> for BooleanAdapter { + fn from(value: &Option) -> Self { + BooleanAdapter { native: *value } + } +} + +impl> FromIterator for BooleanArray { fn from_iter>(iter: I) -> Self { let iter = iter.into_iter(); - let (_, data_len) = iter.size_hint(); - let data_len = data_len.expect("Iterator must be sized"); // panic if no upper bound. + let capacity = match iter.size_hint() { + (lower, Some(upper)) if lower == upper => lower, + _ => 0, + }; + let mut builder = BooleanBuilder::with_capacity(capacity); + builder.extend(iter.map(|item| item.into().native)); + builder.finish() + } +} + +impl BooleanArray { + /// Creates a [`BooleanArray`] from an iterator of trusted length. + /// + /// # Safety + /// + /// The iterator must be [`TrustedLen`](https://doc.rust-lang.org/std/iter/trait.TrustedLen.html). + /// I.e. that `size_hint().1` correctly reports its length. Note that this is a stronger + /// guarantee that `ExactSizeIterator` provides which could still report a wrong length. + /// + /// # Panics + /// + /// Panics if the iterator does not report an upper bound on `size_hint()`. + #[inline] + #[allow( + private_bounds, + reason = "We will expose BooleanAdapter if there is a need" + )] + pub unsafe fn from_trusted_len_iter(iter: I) -> Self + where + P: Into, + I: ExactSizeIterator, + { + let data_len = iter.len(); let num_bytes = bit_util::ceil(data_len, 8); let mut null_builder = MutableBuffer::from_len_zeroed(num_bytes); @@ -443,10 +525,14 @@ impl>> FromIterator for BooleanArray let null_slice = null_builder.as_slice_mut(); iter.enumerate().for_each(|(i, item)| { - if let Some(a) = item.borrow() { - bit_util::set_bit(null_slice, i); - if *a { - bit_util::set_bit(data, i); + if let Some(a) = item.into().native { + unsafe { + // SAFETY: There will be enough space in the buffers due to the trusted len size + // hint + bit_util::set_bit_raw(null_slice.as_mut_ptr(), i); + if a { + bit_util::set_bit_raw(data.as_mut_ptr(), i); + } } } }); @@ -479,7 +565,7 @@ impl From for BooleanArray { mod tests { use super::*; use arrow_buffer::Buffer; - use rand::{rng, Rng}; + use rand::{Rng, rng}; #[test] fn test_boolean_fmt_debug() { @@ -592,6 +678,20 @@ mod tests { } } + #[test] + fn test_boolean_array_from_non_nullable_iter() { + let v = vec![true, false, true]; + let arr = v.into_iter().collect::(); + assert_eq!(3, arr.len()); + assert_eq!(0, arr.offset()); + assert_eq!(0, arr.null_count()); + assert!(arr.nulls().is_none()); + + assert!(arr.value(0)); + assert!(!arr.value(1)); + assert!(arr.value(2)); + } + #[test] fn test_boolean_array_from_nullable_iter() { let v = vec![Some(true), None, Some(false), None]; @@ -610,6 +710,29 @@ mod tests { assert!(!arr.value(2)); } + #[test] + fn test_boolean_array_from_nullable_trusted_len_iter() { + // Should exhibit the same behavior as `from_iter`, which is tested above. + let v = vec![Some(true), None, Some(false), None]; + let expected = v.clone().into_iter().collect::(); + let actual = unsafe { + // SAFETY: `v` has trusted length + BooleanArray::from_trusted_len_iter(v.into_iter()) + }; + assert_eq!(expected, actual); + } + + #[test] + fn test_boolean_array_from_iter_with_larger_upper_bound() { + // See https://github.com/apache/arrow-rs/issues/8505 + // This returns an upper size hint of 4 + let iterator = vec![Some(true), None, Some(false), None] + .into_iter() + .filter(Option::is_some); + let arr = iterator.collect::(); + assert_eq!(2, arr.len()); + } + #[test] fn test_boolean_array_builder() { // Test building a boolean array with ArrayData builder and offset @@ -708,4 +831,32 @@ mod tests { assert_eq!(values.values(), &[0b1000_0000]); assert!(nulls.is_none()); } + + #[test] + fn test_new_null_array() { + let arr = BooleanArray::new_null(5); + + assert_eq!(arr.len(), 5); + assert_eq!(arr.null_count(), 5); + assert_eq!(arr.true_count(), 0); + assert_eq!(arr.false_count(), 0); + + for i in 0..5 { + assert!(arr.is_null(i)); + assert!(!arr.is_valid(i)); + } + } + + #[test] + fn test_slice_with_nulls() { + let arr = BooleanArray::from(vec![Some(true), None, Some(false)]); + let sliced = arr.slice(1, 2); + + assert_eq!(sliced.len(), 2); + assert_eq!(sliced.null_count(), 1); + + assert!(sliced.is_null(0)); + assert!(sliced.is_valid(1)); + assert!(!sliced.value(1)); + } } diff --git a/arrow-array/src/array/byte_array.rs b/arrow-array/src/array/byte_array.rs index 192c9654b055..bd85bffcfe44 100644 --- a/arrow-array/src/array/byte_array.rs +++ b/arrow-array/src/array/byte_array.rs @@ -18,8 +18,8 @@ use crate::array::{get_offsets, print_long_array}; use crate::builder::GenericByteBuilder; use crate::iterator::ArrayIter; -use crate::types::bytes::ByteArrayNativeType; use crate::types::ByteArrayType; +use crate::types::bytes::ByteArrayNativeType; use crate::{Array, ArrayAccessor, ArrayRef, OffsetSizeTrait, Scalar}; use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer}; use arrow_buffer::{NullBuffer, OffsetBuffer}; @@ -190,6 +190,29 @@ impl GenericByteArray { Scalar::new(Self::from_iter_values(std::iter::once(value))) } + /// Create a new [`GenericByteArray`] where `value` is repeated `repeat_count` times. + /// + /// # Panics + /// This will panic if value's length multiplied by `repeat_count` overflows usize. + /// + pub fn new_repeated(value: impl AsRef, repeat_count: usize) -> Self { + let s: &[u8] = value.as_ref().as_ref(); + let value_offsets = OffsetBuffer::from_repeated_length(s.len(), repeat_count); + let bytes: Buffer = { + let mut mutable_buffer = MutableBuffer::with_capacity(0); + mutable_buffer.repeat_slice_n_times(s, repeat_count); + + mutable_buffer.into() + }; + + Self { + data_type: T::DATA_TYPE, + value_data: bytes, + value_offsets, + nulls: None, + } + } + /// Creates a [`GenericByteArray`] based on an iterator of values without nulls pub fn from_iter_values(iter: I) -> Self where @@ -276,11 +299,15 @@ impl GenericByteArray { } /// Returns the element at index `i` + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// if [`is_null`](Self::is_null) returns true for the index. + /// /// # Safety /// Caller is responsible for ensuring that the index is within the bounds of the array pub unsafe fn value_unchecked(&self, i: usize) -> &T::Native { - let end = *self.value_offsets().get_unchecked(i + 1); - let start = *self.value_offsets().get_unchecked(i); + let end = *unsafe { self.value_offsets().get_unchecked(i + 1) }; + let start = *unsafe { self.value_offsets().get_unchecked(i) }; // Soundness // pointer alignment & location is ensured by RawPtrBox @@ -291,19 +318,25 @@ impl GenericByteArray { // OffsetSizeTrait. Currently, only i32 and i64 implement OffsetSizeTrait, // both of which should cleanly cast to isize on an architecture that supports // 32/64-bit offsets - let b = std::slice::from_raw_parts( - self.value_data - .as_ptr() - .offset(start.to_isize().unwrap_unchecked()), - (end - start).to_usize().unwrap_unchecked(), - ); + let b = unsafe { + std::slice::from_raw_parts( + self.value_data + .as_ptr() + .offset(start.to_isize().unwrap_unchecked()), + (end - start).to_usize().unwrap_unchecked(), + ) + }; // SAFETY: // ArrayData is valid - T::Native::from_bytes_unchecked(b) + unsafe { T::Native::from_bytes_unchecked(b) } } /// Returns the element at index `i` + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// (but still well-defined) if [`is_null`](Self::is_null) returns true for the index. + /// /// # Panics /// Panics if index `i` is out of bounds. pub fn value(&self, i: usize) -> &T::Native { @@ -429,6 +462,8 @@ impl std::fmt::Debug for GenericByteArray { } } +impl super::private::Sealed for GenericByteArray {} + impl Array for GenericByteArray { fn as_any(&self) -> &dyn Any { self @@ -501,7 +536,7 @@ impl<'a, T: ByteArrayType> ArrayAccessor for &'a GenericByteArray { } unsafe fn value_unchecked(&self, index: usize) -> Self::Item { - GenericByteArray::value_unchecked(self, index) + unsafe { GenericByteArray::value_unchecked(self, index) } } } @@ -583,7 +618,7 @@ where #[cfg(test)] mod tests { - use crate::{BinaryArray, StringArray}; + use crate::{Array, BinaryArray, StringArray}; use arrow_buffer::{Buffer, NullBuffer, OffsetBuffer}; #[test] @@ -595,14 +630,23 @@ mod tests { let nulls = NullBuffer::new_null(3); let err = StringArray::try_new(offsets.clone(), data.clone(), Some(nulls.clone())).unwrap_err(); - assert_eq!(err.to_string(), "Invalid argument error: Incorrect length of null buffer for StringArray, expected 2 got 3"); + assert_eq!( + err.to_string(), + "Invalid argument error: Incorrect length of null buffer for StringArray, expected 2 got 3" + ); let err = BinaryArray::try_new(offsets.clone(), data.clone(), Some(nulls)).unwrap_err(); - assert_eq!(err.to_string(), "Invalid argument error: Incorrect length of null buffer for BinaryArray, expected 2 got 3"); + assert_eq!( + err.to_string(), + "Invalid argument error: Incorrect length of null buffer for BinaryArray, expected 2 got 3" + ); let non_utf8_data = Buffer::from_slice_ref(b"he\xFFloworld"); let err = StringArray::try_new(offsets.clone(), non_utf8_data.clone(), None).unwrap_err(); - assert_eq!(err.to_string(), "Invalid argument error: Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 2"); + assert_eq!( + err.to_string(), + "Invalid argument error: Encountered non UTF-8 data: invalid utf-8 sequence of 1 bytes from index 2" + ); BinaryArray::new(offsets, non_utf8_data, None); @@ -632,4 +676,42 @@ mod tests { BinaryArray::new(offsets, non_ascii_data, None); } + + #[test] + fn create_repeated() { + let arr = BinaryArray::new_repeated(b"hello", 3); + assert_eq!(arr.len(), 3); + assert_eq!(arr.value(0), b"hello"); + assert_eq!(arr.value(1), b"hello"); + assert_eq!(arr.value(2), b"hello"); + + let arr = StringArray::new_repeated("world", 2); + assert_eq!(arr.len(), 2); + assert_eq!(arr.value(0), "world"); + assert_eq!(arr.value(1), "world"); + } + + #[test] + #[should_panic(expected = "usize overflow")] + fn create_repeated_usize_overflow_1() { + let _arr = BinaryArray::new_repeated(b"hello", (usize::MAX / "hello".len()) + 1); + } + + #[test] + #[should_panic(expected = "usize overflow")] + fn create_repeated_usize_overflow_2() { + let _arr = BinaryArray::new_repeated(b"hello", usize::MAX); + } + + #[test] + #[should_panic(expected = "offset overflow")] + fn create_repeated_i32_offset_overflow_1() { + let _arr = BinaryArray::new_repeated(b"hello", usize::MAX / "hello".len()); + } + + #[test] + #[should_panic(expected = "offset overflow")] + fn create_repeated_i32_offset_overflow_2() { + let _arr = BinaryArray::new_repeated(b"hello", ((i32::MAX as usize) / "hello".len()) + 1); + } } diff --git a/arrow-array/src/array/byte_view_array.rs b/arrow-array/src/array/byte_view_array.rs index 713e275d186c..ca8ddfbe2ad5 100644 --- a/arrow-array/src/array/byte_view_array.rs +++ b/arrow-array/src/array/byte_view_array.rs @@ -22,11 +22,12 @@ use crate::types::bytes::ByteArrayNativeType; use crate::types::{BinaryViewType, ByteViewType, StringViewType}; use crate::{Array, ArrayAccessor, ArrayRef, GenericByteArray, OffsetSizeTrait, Scalar}; use arrow_buffer::{ArrowNativeType, Buffer, NullBuffer, ScalarBuffer}; -use arrow_data::{ArrayData, ArrayDataBuilder, ByteView}; +use arrow_data::{ArrayData, ArrayDataBuilder, ByteView, MAX_INLINE_VIEW_LEN}; use arrow_schema::{ArrowError, DataType}; use core::str; -use num::ToPrimitive; +use num_traits::ToPrimitive; use std::any::Any; +use std::cmp::Ordering; use std::fmt::Debug; use std::marker::PhantomData; use std::sync::Arc; @@ -77,8 +78,9 @@ use super::ByteArrayType; /// 0 31 63 95 127 /// ``` /// -/// * Strings with length <= 12 are stored directly in the view. See -/// [`Self::inline_value`] to access the inlined prefix from a short view. +/// * Strings with length <= 12 ([`MAX_INLINE_VIEW_LEN`]) are stored directly in +/// the view. See [`Self::inline_value`] to access the inlined prefix from a +/// short view. /// /// * Strings with length > 12: The first four bytes are stored inline in the /// view and the entire string is stored in one of the buffers. See [`ByteView`] @@ -128,6 +130,7 @@ use super::ByteArrayType; /// assert_eq!(value, "this string is also longer than 12 bytes"); /// ``` /// +/// [`MAX_INLINE_VIEW_LEN`]: arrow_data::MAX_INLINE_VIEW_LEN /// [`arrow_compute`]: https://docs.rs/arrow/latest/arrow/compute/index.html /// /// Unlike [`GenericByteArray`], there are no constraints on the offsets other @@ -162,7 +165,7 @@ use super::ByteArrayType; pub struct GenericByteViewArray { data_type: DataType, views: ScalarBuffer, - buffers: Vec, + buffers: Arc<[Buffer]>, phantom: PhantomData, nulls: Option, } @@ -185,7 +188,10 @@ impl GenericByteViewArray { /// # Panics /// /// Panics if [`GenericByteViewArray::try_new`] returns an error - pub fn new(views: ScalarBuffer, buffers: Vec, nulls: Option) -> Self { + pub fn new(views: ScalarBuffer, buffers: U, nulls: Option) -> Self + where + U: Into>, + { Self::try_new(views, buffers, nulls).unwrap() } @@ -195,11 +201,16 @@ impl GenericByteViewArray { /// /// * `views.len() != nulls.len()` /// * [ByteViewType::validate] fails - pub fn try_new( + pub fn try_new( views: ScalarBuffer, - buffers: Vec, + buffers: U, nulls: Option, - ) -> Result { + ) -> Result + where + U: Into>, + { + let buffers: Arc<[Buffer]> = buffers.into(); + T::validate(&views, &buffers)?; if let Some(n) = nulls.as_ref() { @@ -227,11 +238,14 @@ impl GenericByteViewArray { /// # Safety /// /// Safe if [`Self::try_new`] would not error - pub unsafe fn new_unchecked( + pub unsafe fn new_unchecked( views: ScalarBuffer, - buffers: Vec, + buffers: U, nulls: Option, - ) -> Self { + ) -> Self + where + U: Into>, + { if cfg!(feature = "force_validate") { return Self::new(views, buffers, nulls); } @@ -240,7 +254,7 @@ impl GenericByteViewArray { data_type: T::DATA_TYPE, phantom: Default::default(), views, - buffers, + buffers: buffers.into(), nulls, } } @@ -250,7 +264,7 @@ impl GenericByteViewArray { Self { data_type: T::DATA_TYPE, views: vec![0; len].into(), - buffers: vec![], + buffers: vec![].into(), nulls: Some(NullBuffer::new_null(len)), phantom: Default::default(), } @@ -276,7 +290,7 @@ impl GenericByteViewArray { } /// Deconstruct this array into its constituent parts - pub fn into_parts(self) -> (ScalarBuffer, Vec, Option) { + pub fn into_parts(self) -> (ScalarBuffer, Arc<[Buffer]>, Option) { (self.views, self.buffers, self.nulls) } @@ -293,6 +307,10 @@ impl GenericByteViewArray { } /// Returns the element at index `i` + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// (but still well-defined) if [`is_null`](Self::is_null) returns true for the index. + /// /// # Panics /// Panics if index `i` is out of bounds. pub fn value(&self, i: usize) -> &T::Native { @@ -309,33 +327,38 @@ impl GenericByteViewArray { /// Returns the element at index `i` without bounds checking /// + /// Note: This method does not check for nulls and the value is arbitrary + /// if [`is_null`](Self::is_null) returns true for the index. + /// /// # Safety /// /// Caller is responsible for ensuring that the index is within the bounds /// of the array pub unsafe fn value_unchecked(&self, idx: usize) -> &T::Native { - let v = self.views.get_unchecked(idx); + let v = unsafe { self.views.get_unchecked(idx) }; let len = *v as u32; - let b = if len <= 12 { - Self::inline_value(v, len as usize) + let b = if len <= MAX_INLINE_VIEW_LEN { + unsafe { Self::inline_value(v, len as usize) } } else { let view = ByteView::from(*v); - let data = self.buffers.get_unchecked(view.buffer_index as usize); + let data = unsafe { self.buffers.get_unchecked(view.buffer_index as usize) }; let offset = view.offset as usize; - data.get_unchecked(offset..offset + len as usize) + unsafe { data.get_unchecked(offset..offset + len as usize) } }; - T::Native::from_bytes_unchecked(b) + unsafe { T::Native::from_bytes_unchecked(b) } } /// Returns the first `len` bytes the inline value of the view. /// /// # Safety /// - The `view` must be a valid element from `Self::views()` that adheres to the view layout. - /// - The `len` must be the length of the inlined value. It should never be larger than 12. + /// - The `len` must be the length of the inlined value. It should never be larger than [`MAX_INLINE_VIEW_LEN`]. #[inline(always)] pub unsafe fn inline_value(view: &u128, len: usize) -> &[u8] { - debug_assert!(len <= 12); - std::slice::from_raw_parts((view as *const u128 as *const u8).wrapping_add(4), len) + debug_assert!(len <= MAX_INLINE_VIEW_LEN as usize); + unsafe { + std::slice::from_raw_parts((view as *const u128 as *const u8).wrapping_add(4), len) + } } /// Constructs a new iterator for iterating over the values of this array @@ -347,7 +370,7 @@ impl GenericByteViewArray { pub fn bytes_iter(&self) -> impl Iterator { self.views.iter().map(move |v| { let len = *v as u32; - if len <= 12 { + if len <= MAX_INLINE_VIEW_LEN { unsafe { Self::inline_value(v, len as usize) } } else { let view = ByteView::from(*v); @@ -371,7 +394,7 @@ impl GenericByteViewArray { return &[] as &[u8]; } - if prefix_len <= 4 || len <= 12 { + if prefix_len <= 4 || len as u32 <= MAX_INLINE_VIEW_LEN { unsafe { StringViewArray::inline_value(v, prefix_len) } } else { let view = ByteView::from(*v); @@ -401,7 +424,7 @@ impl GenericByteViewArray { return &[] as &[u8]; } - if len <= 12 { + if len as u32 <= MAX_INLINE_VIEW_LEN { unsafe { &StringViewArray::inline_value(v, len)[len - suffix_len..] } } else { let view = ByteView::from(*v); @@ -415,6 +438,26 @@ impl GenericByteViewArray { }) } + /// Return an iterator over the length of each array element, including null values. + /// + /// Null values length would equal to the underlying bytes length and NOT 0 + /// + /// Example of getting 0 for null values + /// ```rust + /// # use arrow_array::StringViewArray; + /// # use arrow_array::Array; + /// use arrow_data::ByteView; + /// + /// fn lengths_with_zero_for_nulls(view: &StringViewArray) -> impl Iterator { + /// view.lengths() + /// .enumerate() + /// .map(|(index, length)| if view.is_null(index) { 0 } else { length }) + /// } + /// ``` + pub fn lengths(&self) -> impl ExactSizeIterator + Clone { + self.views().iter().map(|v| *v as u32) + } + /// Returns a zero-copy slice of this array with the indicated offset and length. pub fn slice(&self, offset: usize, length: usize) -> Self { Self { @@ -470,13 +513,161 @@ impl GenericByteViewArray { /// Note: this function does not attempt to canonicalize / deduplicate values. For this /// feature see [`GenericByteViewBuilder::with_deduplicate_strings`]. pub fn gc(&self) -> Self { - let mut builder = GenericByteViewBuilder::::with_capacity(self.len()); + // 1) Read basic properties once + let len = self.len(); // number of elements + let nulls = self.nulls().cloned(); // reuse & clone existing null bitmap + + // 1.5) Fast path: if there are no buffers, just reuse original views and no data blocks + if self.data_buffers().is_empty() { + return unsafe { + GenericByteViewArray::new_unchecked( + self.views().clone(), + vec![], // empty data blocks + nulls, + ) + }; + } - for v in self.iter() { - builder.append_option(v); + // 2) Calculate total size of all non-inline data and detect if any exists + let total_large = self.total_buffer_bytes_used(); + + // 2.5) Fast path: if there is no non-inline data, avoid buffer allocation & processing + if total_large == 0 { + // Views are inline-only or all null; just reuse original views and no data blocks + return unsafe { + GenericByteViewArray::new_unchecked( + self.views().clone(), + vec![], // empty data blocks + nulls, + ) + }; } - builder.finish() + let (views_buf, data_blocks) = if total_large < i32::MAX as usize { + // fast path, the entire data fits in a single buffer + // 3) Allocate exactly capacity for all non-inline data + let mut data_buf = Vec::with_capacity(total_large); + + // 4) Iterate over views and process each inline/non-inline view + let views_buf: Vec = (0..len) + .map(|i| unsafe { self.copy_view_to_buffer(i, 0, &mut data_buf) }) + .collect(); + let data_block = Buffer::from_vec(data_buf); + let data_blocks = vec![data_block]; + (views_buf, data_blocks) + } else { + // slow path, need to split into multiple buffers + + struct GcCopyGroup { + total_buffer_bytes: usize, + total_len: usize, + } + + impl GcCopyGroup { + fn new(total_buffer_bytes: u32, total_len: usize) -> Self { + Self { + total_buffer_bytes: total_buffer_bytes as usize, + total_len, + } + } + } + + let mut groups = Vec::new(); + let mut current_length = 0; + let mut current_elements = 0; + + for view in self.views() { + let len = *view as u32; + if len > MAX_INLINE_VIEW_LEN { + if current_length + len > i32::MAX as u32 { + // Start a new group + groups.push(GcCopyGroup::new(current_length, current_elements)); + current_length = 0; + current_elements = 0; + } + current_length += len; + current_elements += 1; + } + } + if current_elements != 0 { + groups.push(GcCopyGroup::new(current_length, current_elements)); + } + debug_assert!(groups.len() <= i32::MAX as usize); + + // 3) Copy the buffers group by group + let mut views_buf = Vec::with_capacity(len); + let mut data_blocks = Vec::with_capacity(groups.len()); + + let mut current_view_idx = 0; + + for (group_idx, gc_copy_group) in groups.iter().enumerate() { + let mut data_buf = Vec::with_capacity(gc_copy_group.total_buffer_bytes); + + // Directly push views to avoid intermediate Vec allocation + let new_views = (current_view_idx..current_view_idx + gc_copy_group.total_len).map( + |view_idx| { + // safety: the view index came from iterating over valid range + unsafe { + self.copy_view_to_buffer(view_idx, group_idx as i32, &mut data_buf) + } + }, + ); + views_buf.extend(new_views); + + data_blocks.push(Buffer::from_vec(data_buf)); + current_view_idx += gc_copy_group.total_len; + } + (views_buf, data_blocks) + }; + + // 5) Wrap up views buffer + let views_scalar = ScalarBuffer::from(views_buf); + + // SAFETY: views_scalar, data_blocks, and nulls are correctly aligned and sized + unsafe { GenericByteViewArray::new_unchecked(views_scalar, data_blocks, nulls) } + } + + /// Copy the i‑th view into `data_buf` if it refers to an out‑of‑line buffer. + /// + /// # Safety + /// + /// - `i < self.len()`. + /// - Every element in `self.views()` must currently refer to a valid slice + /// inside one of `self.buffers`. + /// - `data_buf` must be ready to have additional bytes appended. + /// - After this call, the returned view will have its + /// `buffer_index` reset to `buffer_idx` and its `offset` updated so that it points + /// into the bytes just appended at the end of `data_buf`. + #[inline(always)] + unsafe fn copy_view_to_buffer( + &self, + i: usize, + buffer_idx: i32, + data_buf: &mut Vec, + ) -> u128 { + // SAFETY: `i < self.len()` ensures this is in‑bounds. + let raw_view = unsafe { *self.views().get_unchecked(i) }; + let mut bv = ByteView::from(raw_view); + + // Inline‑small views stay as‑is. + if bv.length <= MAX_INLINE_VIEW_LEN { + raw_view + } else { + // SAFETY: `bv.buffer_index` and `bv.offset..bv.offset+bv.length` + // must both lie within valid ranges for `self.buffers`. + let buffer = unsafe { self.buffers.get_unchecked(bv.buffer_index as usize) }; + let start = bv.offset as usize; + let end = start + bv.length as usize; + let slice = unsafe { buffer.get_unchecked(start..end) }; + + // Copy out‑of‑line data into our single “0” buffer. + let new_offset = data_buf.len() as u32; + data_buf.extend_from_slice(slice); + + bv.buffer_index = buffer_idx as u32; + bv.offset = new_offset; + bv.into() + } } /// Returns the total number of bytes used by all non inlined views in all @@ -495,9 +686,9 @@ impl GenericByteViewArray { self.views() .iter() .map(|v| { - let len = (*v as u32) as usize; - if len > 12 { - len + let len = *v as u32; + if len > MAX_INLINE_VIEW_LEN { + len as usize } else { 0 } @@ -511,11 +702,11 @@ impl GenericByteViewArray { /// It takes a bit of patience to understand why we don't just compare two &[u8] directly. /// /// ByteView types give us the following two advantages, and we need to be careful not to lose them: - /// (1) For string/byte smaller than 12 bytes, the entire data is inlined in the view. + /// (1) For string/byte smaller than [`MAX_INLINE_VIEW_LEN`] bytes, the entire data is inlined in the view. /// Meaning that reading one array element requires only one memory access /// (two memory access required for StringArray, one for offset buffer, the other for value buffer). /// - /// (2) For string/byte larger than 12 bytes, we can still be faster than (for certain operations) StringArray/ByteArray, + /// (2) For string/byte larger than [`MAX_INLINE_VIEW_LEN`] bytes, we can still be faster than (for certain operations) StringArray/ByteArray, /// thanks to the inlined 4 bytes. /// Consider equality check: /// If the first four bytes of the two strings are different, we can return false immediately (with just one memory access). @@ -525,8 +716,8 @@ impl GenericByteViewArray { /// e.g., if the inlined 4 bytes are different, we can directly return unequal without looking at the full string. /// /// # Order check flow - /// (1) if both string are smaller than 12 bytes, we can directly compare the data inlined to the view. - /// (2) if any of the string is larger than 12 bytes, we need to compare the full string. + /// (1) if both string are smaller than [`MAX_INLINE_VIEW_LEN`] bytes, we can directly compare the data inlined to the view. + /// (2) if any of the string is larger than [`MAX_INLINE_VIEW_LEN`] bytes, we need to compare the full string. /// (2.1) if the inlined 4 bytes are different, we can return the result immediately. /// (2.2) o.w., we need to compare the full string. /// @@ -537,25 +728,30 @@ impl GenericByteViewArray { left_idx: usize, right: &GenericByteViewArray, right_idx: usize, - ) -> std::cmp::Ordering { - let l_view = left.views().get_unchecked(left_idx); - let l_len = *l_view as u32; + ) -> Ordering { + let l_view = unsafe { left.views().get_unchecked(left_idx) }; + let l_byte_view = ByteView::from(*l_view); + + let r_view = unsafe { right.views().get_unchecked(right_idx) }; + let r_byte_view = ByteView::from(*r_view); - let r_view = right.views().get_unchecked(right_idx); - let r_len = *r_view as u32; + let l_len = l_byte_view.length; + let r_len = r_byte_view.length; if l_len <= 12 && r_len <= 12 { - let l_data = unsafe { GenericByteViewArray::::inline_value(l_view, l_len as usize) }; - let r_data = unsafe { GenericByteViewArray::::inline_value(r_view, r_len as usize) }; - return l_data.cmp(r_data); + return Self::inline_key_fast(*l_view).cmp(&Self::inline_key_fast(*r_view)); } // one of the string is larger than 12 bytes, // we then try to compare the inlined data first - let l_inlined_data = unsafe { GenericByteViewArray::::inline_value(l_view, 4) }; - let r_inlined_data = unsafe { GenericByteViewArray::::inline_value(r_view, 4) }; - if r_inlined_data != l_inlined_data { - return l_inlined_data.cmp(r_inlined_data); + + // Note: In theory, ByteView is only used for string which is larger than 12 bytes, + // but we can still use it to get the inlined prefix for shorter strings. + // The prefix is always the first 4 bytes of the view, for both short and long strings. + let l_inlined_be = l_byte_view.prefix.swap_bytes(); + let r_inlined_be = r_byte_view.prefix.swap_bytes(); + if l_inlined_be != r_inlined_be { + return l_inlined_be.cmp(&r_inlined_be); } // unfortunately, we need to compare the full data @@ -564,6 +760,119 @@ impl GenericByteViewArray { l_full_data.cmp(r_full_data) } + + /// Builds a 128-bit composite key for an inline value: + /// + /// - High 96 bits: the inline data in big-endian byte order (for correct lexicographical sorting). + /// - Low 32 bits: the length in big-endian byte order, acting as a tiebreaker so shorter strings + /// (or those with fewer meaningful bytes) always numerically sort before longer ones. + /// + /// This function extracts the length and the 12-byte inline string data from the raw + /// little-endian `u128` representation, converts them to big-endian ordering, and packs them + /// into a single `u128` value suitable for fast, branchless comparisons. + /// + /// # Why include length? + /// + /// A pure 96-bit content comparison can’t distinguish between two values whose inline bytes + /// compare equal—either because one is a true prefix of the other or because zero-padding + /// hides extra bytes. By tucking the 32-bit length into the lower bits, a single `u128` compare + /// handles both content and length in one go. + /// + /// Example: comparing "bar" (3 bytes) vs "bar\0" (4 bytes) + /// + /// | String | Bytes 0–4 (length LE) | Bytes 4–16 (data + padding) | + /// |------------|-----------------------|---------------------------------| + /// | `"bar"` | `03 00 00 00` | `62 61 72` + 9 × `00` | + /// | `"bar\0"`| `04 00 00 00` | `62 61 72 00` + 8 × `00` | + /// + /// Both inline parts become `62 61 72 00…00`, so they tie on content. The length field + /// then differentiates: + /// + /// ```text + /// key("bar") = 0x0000000000000000000062617200000003 + /// key("bar\0") = 0x0000000000000000000062617200000004 + /// ⇒ key("bar") < key("bar\0") + /// ``` + /// # Inlining and Endianness + /// + /// - We start by calling `.to_le_bytes()` on the `raw` `u128`, because Rust’s native in‑memory + /// representation is little‑endian on x86/ARM. + /// - We extract the low 32 bits numerically (`raw as u32`)—this step is endianness‑free. + /// - We copy the 12 bytes of inline data (original order) into `buf[0..12]`. + /// - We serialize `length` as big‑endian into `buf[12..16]`. + /// - Finally, `u128::from_be_bytes(buf)` treats `buf[0]` as the most significant byte + /// and `buf[15]` as the least significant, producing a `u128` whose integer value + /// directly encodes “inline data then length” in big‑endian form. + /// + /// This ensures that a simple `u128` comparison is equivalent to the desired + /// lexicographical comparison of the inline bytes followed by length. + #[inline(always)] + pub fn inline_key_fast(raw: u128) -> u128 { + // 1. Decompose `raw` into little‑endian bytes: + // - raw_bytes[0..4] = length in LE + // - raw_bytes[4..16] = inline string data + let raw_bytes = raw.to_le_bytes(); + + // 2. Numerically truncate to get the low 32‑bit length (endianness‑free). + let length = raw as u32; + + // 3. Build a 16‑byte buffer in big‑endian order: + // - buf[0..12] = inline string bytes (in original order) + // - buf[12..16] = length.to_be_bytes() (BE) + let mut buf = [0u8; 16]; + buf[0..12].copy_from_slice(&raw_bytes[4..16]); // inline data + + // Why convert length to big-endian for comparison? + // + // Rust (on most platforms) stores integers in little-endian format, + // meaning the least significant byte is at the lowest memory address. + // For example, an u32 value like 0x22345677 is stored in memory as: + // + // [0x77, 0x56, 0x34, 0x22] // little-endian layout + // ^ ^ ^ ^ + // LSB ↑↑↑ MSB + // + // This layout is efficient for arithmetic but *not* suitable for + // lexicographic (dictionary-style) comparison of byte arrays. + // + // To compare values by byte order—e.g., for sorted keys or binary trees— + // we must convert them to **big-endian**, where: + // + // - The most significant byte (MSB) comes first (index 0) + // - The least significant byte (LSB) comes last (index N-1) + // + // In big-endian, the same u32 = 0x22345677 would be represented as: + // + // [0x22, 0x34, 0x56, 0x77] + // + // This ordering aligns with natural string/byte sorting, so calling + // `.to_be_bytes()` allows us to construct + // keys where standard numeric comparison (e.g., `<`, `>`) behaves + // like lexicographic byte comparison. + buf[12..16].copy_from_slice(&length.to_be_bytes()); // length in BE + + // 4. Deserialize the buffer as a big‑endian u128: + // buf[0] is MSB, buf[15] is LSB. + // Details: + // Note on endianness and layout: + // + // Although `buf[0]` is stored at the lowest memory address, + // calling `u128::from_be_bytes(buf)` interprets it as the **most significant byte (MSB)**, + // and `buf[15]` as the **least significant byte (LSB)**. + // + // This is the core principle of **big-endian decoding**: + // - Byte at index 0 maps to bits 127..120 (highest) + // - Byte at index 1 maps to bits 119..112 + // - ... + // - Byte at index 15 maps to bits 7..0 (lowest) + // + // So even though memory layout goes from low to high (left to right), + // big-endian treats the **first byte** as highest in value. + // + // This guarantees that comparing two `u128` keys is equivalent to lexicographically + // comparing the original inline bytes, followed by length. + u128::from_be_bytes(buf) + } } impl Debug for GenericByteViewArray { @@ -576,6 +885,8 @@ impl Debug for GenericByteViewArray { } } +impl super::private::Sealed for GenericByteViewArray {} + impl Array for GenericByteViewArray { fn as_any(&self) -> &dyn Any { self @@ -607,8 +918,21 @@ impl Array for GenericByteViewArray { fn shrink_to_fit(&mut self) { self.views.shrink_to_fit(); - self.buffers.iter_mut().for_each(|b| b.shrink_to_fit()); - self.buffers.shrink_to_fit(); + + // The goal of `shrink_to_fit` is to minimize the space used by any of + // its allocations. The use of `Arc::get_mut` over `Arc::make_mut` is + // because if the reference count is greater than 1, `Arc::make_mut` + // will first clone its contents. So, any large allocations will first + // be cloned before being shrunk, leaving the pre-cloned allocations + // intact, before adding the extra (used) space of the new clones. + if let Some(buffers) = Arc::get_mut(&mut self.buffers) { + buffers.iter_mut().for_each(|b| b.shrink_to_fit()); + } + + // With the assumption that this is a best-effort function, no attempt + // is made to shrink `self.buffers`, which it can't because it's type + // does not expose a `shrink_to_fit` method. + if let Some(nulls) = &mut self.nulls { nulls.shrink_to_fit(); } @@ -649,7 +973,7 @@ impl<'a, T: ByteViewType + ?Sized> ArrayAccessor for &'a GenericByteViewArray } unsafe fn value_unchecked(&self, index: usize) -> Self::Item { - GenericByteViewArray::value_unchecked(self, index) + unsafe { GenericByteViewArray::value_unchecked(self, index) } } } @@ -663,15 +987,16 @@ impl<'a, T: ByteViewType + ?Sized> IntoIterator for &'a GenericByteViewArray } impl From for GenericByteViewArray { - fn from(value: ArrayData) -> Self { - let views = value.buffers()[0].clone(); - let views = ScalarBuffer::new(views, value.offset(), value.len()); - let buffers = value.buffers()[1..].to_vec(); + fn from(data: ArrayData) -> Self { + let (_data_type, len, nulls, offset, mut buffers, _child_data) = data.into_parts(); + let views = buffers.remove(0); // need to maintain order of remaining buffers + let buffers = Arc::from(buffers); + let views = ScalarBuffer::new(views, offset, len); Self { data_type: T::DATA_TYPE, views, buffers, - nulls: value.nulls().cloned(), + nulls, phantom: Default::default(), } } @@ -734,12 +1059,15 @@ where } impl From> for ArrayData { - fn from(mut array: GenericByteViewArray) -> Self { + fn from(array: GenericByteViewArray) -> Self { let len = array.len(); - array.buffers.insert(0, array.views.into_inner()); + + let mut buffers = array.buffers.to_vec(); + buffers.insert(0, array.views.into_inner()); + let builder = ArrayDataBuilder::new(T::DATA_TYPE) .len(len) - .buffers(array.buffers) + .buffers(buffers) .nulls(array.nulls); unsafe { builder.build_unchecked() } @@ -795,7 +1123,7 @@ impl BinaryViewArray { /// # Safety /// Caller is responsible for ensuring that items in array are utf8 data. pub unsafe fn to_string_view_unchecked(self) -> StringViewArray { - StringViewArray::new_unchecked(self.views, self.buffers, self.nulls) + unsafe { StringViewArray::new_unchecked(self.views, self.buffers, self.nulls) } } } @@ -872,9 +1200,16 @@ impl From>> for StringViewArray { #[cfg(test)] mod tests { use crate::builder::{BinaryViewBuilder, StringViewBuilder}; - use crate::{Array, BinaryViewArray, StringViewArray}; - use arrow_buffer::{Buffer, ScalarBuffer}; - use arrow_data::ByteView; + use crate::types::BinaryViewType; + use crate::{ + Array, BinaryViewArray, GenericBinaryArray, GenericByteViewArray, StringViewArray, + }; + use arrow_buffer::{Buffer, NullBuffer, ScalarBuffer}; + use arrow_data::{ByteView, MAX_INLINE_VIEW_LEN}; + use rand::prelude::StdRng; + use rand::{Rng, SeedableRng}; + + const BLOCK_SIZE: u32 = 8; #[test] fn try_new_string() { @@ -960,7 +1295,10 @@ mod tests { builder.finish() }; assert_eq!(array.value(0), "large payload over 12 bytes"); - assert_eq!(array.value(1), "another large payload over 12 bytes that double than the first one, so that we can trigger the in_progress in builder re-created"); + assert_eq!( + array.value(1), + "another large payload over 12 bytes that double than the first one, so that we can trigger the in_progress in builder re-created" + ); assert_eq!(2, array.buffers.len()); } @@ -1064,6 +1402,180 @@ mod tests { check_gc(&array.slice(3, 1)); } + /// 1) Empty array: no elements, expect gc to return empty with no data buffers + #[test] + fn test_gc_empty_array() { + let array = StringViewBuilder::new() + .with_fixed_block_size(BLOCK_SIZE) + .finish(); + let gced = array.gc(); + // length and null count remain zero + assert_eq!(gced.len(), 0); + assert_eq!(gced.null_count(), 0); + // no underlying data buffers should be allocated + assert!( + gced.data_buffers().is_empty(), + "Expected no data buffers for empty array" + ); + } + + /// 2) All inline values (<= INLINE_LEN): capacity-only data buffer, same values + #[test] + fn test_gc_all_inline() { + let mut builder = StringViewBuilder::new().with_fixed_block_size(BLOCK_SIZE); + // append many short strings, each exactly INLINE_LEN long + for _ in 0..100 { + let s = "A".repeat(MAX_INLINE_VIEW_LEN as usize); + builder.append_option(Some(&s)); + } + let array = builder.finish(); + let gced = array.gc(); + // Since all views fit inline, data buffer is empty + assert_eq!( + gced.data_buffers().len(), + 0, + "Should have no data buffers for inline values" + ); + assert_eq!(gced.len(), 100); + // verify element-wise equality + array.iter().zip(gced.iter()).for_each(|(orig, got)| { + assert_eq!(orig, got, "Inline value mismatch after gc"); + }); + } + + /// 3) All large values (> INLINE_LEN): each must be copied into the new data buffer + #[test] + fn test_gc_all_large() { + let mut builder = StringViewBuilder::new().with_fixed_block_size(BLOCK_SIZE); + let large_str = "X".repeat(MAX_INLINE_VIEW_LEN as usize + 5); + // append multiple large strings + for _ in 0..50 { + builder.append_option(Some(&large_str)); + } + let array = builder.finish(); + let gced = array.gc(); + // New data buffers should be populated (one or more blocks) + assert!( + !gced.data_buffers().is_empty(), + "Expected data buffers for large values" + ); + assert_eq!(gced.len(), 50); + // verify that every large string emerges unchanged + array.iter().zip(gced.iter()).for_each(|(orig, got)| { + assert_eq!(orig, got, "Large view mismatch after gc"); + }); + } + + /// 4) All null elements: ensure null bitmap handling path is correct + #[test] + fn test_gc_all_nulls() { + let mut builder = StringViewBuilder::new().with_fixed_block_size(BLOCK_SIZE); + for _ in 0..20 { + builder.append_null(); + } + let array = builder.finish(); + let gced = array.gc(); + // length and null count match + assert_eq!(gced.len(), 20); + assert_eq!(gced.null_count(), 20); + // data buffers remain empty for null-only array + assert!( + gced.data_buffers().is_empty(), + "No data should be stored for nulls" + ); + } + + /// 5) Random mix of inline, large, and null values with slicing tests + #[test] + fn test_gc_random_mixed_and_slices() { + let mut rng = StdRng::seed_from_u64(42); + let mut builder = StringViewBuilder::new().with_fixed_block_size(BLOCK_SIZE); + // Keep a Vec of original Option for later comparison + let mut original: Vec> = Vec::new(); + + for _ in 0..200 { + if rng.random_bool(0.1) { + // 10% nulls + builder.append_null(); + original.push(None); + } else { + // random length between 0 and twice the inline limit + let len = rng.random_range(0..(MAX_INLINE_VIEW_LEN * 2)); + let s: String = "A".repeat(len as usize); + builder.append_option(Some(&s)); + original.push(Some(s)); + } + } + + let array = builder.finish(); + // Test multiple slice ranges to ensure offset logic is correct + for (offset, slice_len) in &[(0, 50), (10, 100), (150, 30)] { + let sliced = array.slice(*offset, *slice_len); + let gced = sliced.gc(); + // Build expected slice of Option<&str> + let expected: Vec> = original[*offset..(*offset + *slice_len)] + .iter() + .map(|opt| opt.as_deref()) + .collect(); + + assert_eq!(gced.len(), *slice_len, "Slice length mismatch"); + // Compare element-wise + gced.iter().zip(expected.iter()).for_each(|(got, expect)| { + assert_eq!(got, *expect, "Value mismatch in mixed slice after gc"); + }); + } + } + + #[test] + #[cfg_attr(miri, ignore)] // Takes too long + fn test_gc_huge_array() { + // Construct multiple 128 MiB BinaryView entries so total > 4 GiB + let block_len: usize = 128 * 1024 * 1024; // 128 MiB per view + let num_views: usize = 36; + + // Create a single 128 MiB data block with a simple byte pattern + let buffer = Buffer::from_vec(vec![0xAB; block_len]); + let buffer2 = Buffer::from_vec(vec![0xFF; block_len]); + + // Append this block and then add many views pointing to it + let mut builder = BinaryViewBuilder::new(); + let block_id = builder.append_block(buffer); + for _ in 0..num_views / 2 { + builder + .try_append_view(block_id, 0, block_len as u32) + .expect("append view into 128MiB block"); + } + let block_id2 = builder.append_block(buffer2); + for _ in 0..num_views / 2 { + builder + .try_append_view(block_id2, 0, block_len as u32) + .expect("append view into 128MiB block"); + } + + let array = builder.finish(); + let total = array.total_buffer_bytes_used(); + assert!( + total > u32::MAX as usize, + "Expected total non-inline bytes to exceed 4 GiB, got {}", + total + ); + + // Run gc and verify correctness + let gced = array.gc(); + assert_eq!(gced.len(), num_views, "Length mismatch after gc"); + assert_eq!(gced.null_count(), 0, "Null count mismatch after gc"); + assert_ne!( + gced.data_buffers().len(), + 1, + "gc with huge buffer should not consolidate data into a single buffer" + ); + + // Element-wise equality check across the entire array + array.iter().zip(gced.iter()).for_each(|(orig, got)| { + assert_eq!(orig, got, "Value mismatch after gc on huge array"); + }); + } + #[test] fn test_eq() { let test_data = [ @@ -1088,4 +1600,218 @@ mod tests { assert_eq!(array2, array2.clone()); assert_eq!(array1, array2); } + + /// Integration tests for `inline_key_fast` covering: + /// + /// 1. Monotonic ordering across increasing lengths and lexical variations. + /// 2. Cross-check against `GenericBinaryArray` comparison to ensure semantic equivalence. + /// + /// This also includes a specific test for the “bar” vs. “bar\0” case, demonstrating why + /// the length field is required even when all inline bytes fit in 12 bytes. + /// + /// The test includes strings that verify correct byte order (prevent reversal bugs), + /// and length-based tie-breaking in the composite key. + /// + /// The test confirms that `inline_key_fast` produces keys which sort consistently + /// with the expected lexicographical order of the raw byte arrays. + #[test] + fn test_inline_key_fast_various_lengths_and_lexical() { + /// Helper to create a raw u128 value representing an inline ByteView: + /// - `length`: number of meaningful bytes (must be ≤ 12) + /// - `data`: the actual inline data bytes + /// + /// The first 4 bytes encode length in little-endian, + /// the following 12 bytes contain the inline string data (unpadded). + fn make_raw_inline(length: u32, data: &[u8]) -> u128 { + assert!(length as usize <= 12, "Inline length must be ≤ 12"); + assert!( + data.len() == length as usize, + "Data length must match `length`" + ); + + let mut raw_bytes = [0u8; 16]; + raw_bytes[0..4].copy_from_slice(&length.to_le_bytes()); // length stored little-endian + raw_bytes[4..(4 + data.len())].copy_from_slice(data); // inline data + u128::from_le_bytes(raw_bytes) + } + + // Test inputs: various lengths and lexical orders, + // plus special cases for byte order and length tie-breaking + let test_inputs: Vec<&[u8]> = vec![ + b"a", + b"aa", + b"aaa", + b"aab", + b"abcd", + b"abcde", + b"abcdef", + b"abcdefg", + b"abcdefgh", + b"abcdefghi", + b"abcdefghij", + b"abcdefghijk", + b"abcdefghijkl", + // Tests for byte-order reversal bug: + // Without the fix, "backend one" would compare as "eno dnekcab", + // causing incorrect sort order relative to "backend two". + b"backend one", + b"backend two", + // Tests length-tiebreaker logic: + // "bar" (3 bytes) and "bar\0" (4 bytes) have identical inline data, + // so only the length differentiates their ordering. + b"bar", + b"bar\0", + // Additional lexical and length tie-breaking cases with same prefix, in correct lex order: + b"than12Byt", + b"than12Bytes", + b"than12Bytes\0", + b"than12Bytesx", + b"than12Bytex", + b"than12Bytez", + // Additional lexical tests + b"xyy", + b"xyz", + b"xza", + ]; + + // Create a GenericBinaryArray for cross-comparison of lex order + let array: GenericBinaryArray = + GenericBinaryArray::from(test_inputs.iter().map(|s| Some(*s)).collect::>()); + + for i in 0..array.len() - 1 { + let v1 = array.value(i); + let v2 = array.value(i + 1); + + // Assert the array's natural lexical ordering is correct + assert!(v1 < v2, "Array compare failed: {v1:?} !< {v2:?}"); + + // Assert the keys produced by inline_key_fast reflect the same ordering + let key1 = GenericByteViewArray::::inline_key_fast(make_raw_inline( + v1.len() as u32, + v1, + )); + let key2 = GenericByteViewArray::::inline_key_fast(make_raw_inline( + v2.len() as u32, + v2, + )); + + assert!( + key1 < key2, + "Key compare failed: key({v1:?})=0x{key1:032x} !< key({v2:?})=0x{key2:032x}", + ); + } + } + + #[test] + fn empty_array_should_return_empty_lengths_iterator() { + let empty = GenericByteViewArray::::from(Vec::<&[u8]>::new()); + + let mut lengths_iter = empty.lengths(); + assert_eq!(lengths_iter.len(), 0); + assert_eq!(lengths_iter.next(), None); + } + + #[test] + fn array_lengths_should_return_correct_length_for_both_inlined_and_non_inlined() { + let cases = GenericByteViewArray::::from(vec![ + // Not inlined as longer than 12 bytes + b"Supercalifragilisticexpialidocious" as &[u8], + // Inlined as shorter than 12 bytes + b"Hello", + // Empty value + b"", + // Exactly 12 bytes + b"abcdefghijkl", + ]); + + let mut lengths_iter = cases.lengths(); + + assert_eq!(lengths_iter.len(), cases.len()); + + let cases_iter = cases.iter(); + + for case in cases_iter { + let case_value = case.unwrap(); + let length = lengths_iter.next().expect("Should have a length"); + + assert_eq!(case_value.len(), length as usize); + } + + assert_eq!(lengths_iter.next(), None, "Should not have more lengths"); + } + + #[test] + fn array_lengths_should_return_the_underlying_length_for_null_values() { + let cases = GenericByteViewArray::::from(vec![ + // Not inlined as longer than 12 bytes + b"Supercalifragilisticexpialidocious" as &[u8], + // Inlined as shorter than 12 bytes + b"Hello", + // Empty value + b"", + // Exactly 12 bytes + b"abcdefghijkl", + ]); + + let (views, buffer, _) = cases.clone().into_parts(); + + // Keeping the values but just adding nulls on top + let cases_with_all_nulls = GenericByteViewArray::::new( + views, + buffer, + Some(NullBuffer::new_null(cases.len())), + ); + + let lengths_iter = cases.lengths(); + let mut all_nulls_lengths_iter = cases_with_all_nulls.lengths(); + + assert_eq!(lengths_iter.len(), all_nulls_lengths_iter.len()); + + for expected_length in lengths_iter { + let actual_length = all_nulls_lengths_iter.next().expect("Should have a length"); + + assert_eq!(expected_length, actual_length); + } + + assert_eq!( + all_nulls_lengths_iter.next(), + None, + "Should not have more lengths" + ); + } + + #[test] + fn array_lengths_on_sliced_should_only_return_lengths_for_sliced_data() { + let array = GenericByteViewArray::::from(vec![ + b"aaaaaaaaaaaaaaaaaaaaaaaaaaa" as &[u8], + b"Hello", + b"something great", + b"is", + b"coming soon!", + b"when you find what it is", + b"let me know", + b"cause", + b"I", + b"have no idea", + b"what it", + b"is", + ]); + + let sliced_array = array.slice(2, array.len() - 3); + + let mut lengths_iter = sliced_array.lengths(); + + assert_eq!(lengths_iter.len(), sliced_array.len()); + + let values_iter = sliced_array.iter(); + + for value in values_iter { + let value = value.unwrap(); + let length = lengths_iter.next().expect("Should have a length"); + + assert_eq!(value.len(), length as usize); + } + + assert_eq!(lengths_iter.next(), None, "Should not have more lengths"); + } } diff --git a/arrow-array/src/array/dictionary_array.rs b/arrow-array/src/array/dictionary_array.rs index acbdcb8b60fa..be7703b13c5c 100644 --- a/arrow-array/src/array/dictionary_array.rs +++ b/arrow-array/src/array/dictionary_array.rs @@ -20,8 +20,8 @@ use crate::cast::AsArray; use crate::iterator::ArrayIter; use crate::types::*; use crate::{ - make_array, Array, ArrayAccessor, ArrayRef, ArrowNativeTypeOp, PrimitiveArray, Scalar, - StringArray, + Array, ArrayAccessor, ArrayRef, ArrowNativeTypeOp, PrimitiveArray, Scalar, StringArray, + make_array, }; use arrow_buffer::bit_util::set_bit; use arrow_buffer::buffer::NullBuffer; @@ -697,6 +697,8 @@ impl<'a, T: ArrowDictionaryKeyType> FromIterator<&'a str> for DictionaryArray } } +impl super::private::Sealed for DictionaryArray {} + impl Array for DictionaryArray { fn as_any(&self) -> &dyn Any { self @@ -856,6 +858,8 @@ impl<'a, K: ArrowDictionaryKeyType, V> TypedDictionaryArray<'a, K, V> { } } +impl super::private::Sealed for TypedDictionaryArray<'_, K, V> {} + impl Array for TypedDictionaryArray<'_, K, V> { fn as_any(&self) -> &dyn Any { self.dictionary @@ -947,13 +951,13 @@ where } unsafe fn value_unchecked(&self, index: usize) -> Self::Item { - let val = self.dictionary.keys.value_unchecked(index); + let val = unsafe { self.dictionary.keys.value_unchecked(index) }; let value_idx = val.as_usize(); // As dictionary keys are only verified for non-null indexes // we must check the value is within bounds match value_idx < self.values.len() { - true => self.values.value_unchecked(value_idx), + true => unsafe { self.values.value_unchecked(value_idx) }, false => Default::default(), } } @@ -1051,7 +1055,7 @@ impl AnyDictionaryArray for DictionaryArray { mod tests { use super::*; use crate::cast::as_dictionary_array; - use crate::{Int16Array, Int32Array, Int8Array, RunArray}; + use crate::{Int8Array, Int16Array, Int32Array, RunArray}; use arrow_buffer::{Buffer, ToByteSlice}; #[test] diff --git a/arrow-array/src/array/fixed_size_binary_array.rs b/arrow-array/src/array/fixed_size_binary_array.rs index 576b8012491b..b94e168cfe7c 100644 --- a/arrow-array/src/array/fixed_size_binary_array.rs +++ b/arrow-array/src/array/fixed_size_binary_array.rs @@ -19,7 +19,7 @@ use crate::array::print_long_array; use crate::iterator::FixedSizeBinaryIter; use crate::{Array, ArrayAccessor, ArrayRef, FixedSizeListArray, Scalar}; use arrow_buffer::buffer::NullBuffer; -use arrow_buffer::{bit_util, ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer}; +use arrow_buffer::{ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, bit_util}; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType}; use std::any::Any; @@ -76,10 +76,14 @@ impl FixedSizeBinaryArray { /// Create a new [`FixedSizeBinaryArray`] from the provided parts, returning an error on failure /// + /// Creating an arrow with `size == 0` will try to get the length from the null buffer. If + /// no null buffer is provided, the resulting array will have length zero. + /// /// # Errors /// /// * `size < 0` /// * `values.len() / size != nulls.len()` + /// * `size == 0 && values.len() != 0` pub fn try_new( size: i32, values: Buffer, @@ -87,10 +91,21 @@ impl FixedSizeBinaryArray { ) -> Result { let data_type = DataType::FixedSizeBinary(size); let s = size.to_usize().ok_or_else(|| { - ArrowError::InvalidArgumentError(format!("Size cannot be negative, got {}", size)) + ArrowError::InvalidArgumentError(format!("Size cannot be negative, got {size}")) })?; - let len = values.len() / s; + let len = if s == 0 { + if !values.is_empty() { + return Err(ArrowError::InvalidArgumentError( + "Buffer cannot have non-zero length if the item size is zero".to_owned(), + )); + } + + // If the item size is zero, try to determine the length from the null buffer + nulls.as_ref().map(|n| n.len()).unwrap_or(0) + } else { + values.len() / s + }; if let Some(n) = nulls.as_ref() { if n.len() != len { return Err(ArrowError::InvalidArgumentError(format!( @@ -119,10 +134,11 @@ impl FixedSizeBinaryArray { /// * `size < 0` /// * `size * len` would overflow `usize` pub fn new_null(size: i32, len: usize) -> Self { - let capacity = size.to_usize().unwrap().checked_mul(len).unwrap(); + const BITS_IN_A_BYTE: usize = 8; + let capacity_in_bytes = size.to_usize().unwrap().checked_mul(len).unwrap(); Self { data_type: DataType::FixedSizeBinary(size), - value_data: MutableBuffer::new(capacity).into(), + value_data: MutableBuffer::new_null(capacity_in_bytes * BITS_IN_A_BYTE).into(), nulls: Some(NullBuffer::new_null(len)), value_length: size, len, @@ -135,6 +151,10 @@ impl FixedSizeBinaryArray { } /// Returns the element at index `i` as a byte slice. + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// (but still well-defined) if [`is_null`](Self::is_null) returns true for the index. + /// /// # Panics /// Panics if index `i` is out of bounds. pub fn value(&self, i: usize) -> &[u8] { @@ -155,15 +175,23 @@ impl FixedSizeBinaryArray { } /// Returns the element at index `i` as a byte slice. + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// if [`is_null`](Self::is_null) returns true for the index. + /// /// # Safety - /// Caller is responsible for ensuring that the index is within the bounds of the array + /// + /// Caller is responsible for ensuring that the index is within the bounds + /// of the array pub unsafe fn value_unchecked(&self, i: usize) -> &[u8] { let offset = i + self.offset(); let pos = self.value_offset_at(offset); - std::slice::from_raw_parts( - self.value_data.as_ptr().offset(pos as isize), - (self.value_offset_at(offset + 1) - pos) as usize, - ) + unsafe { + std::slice::from_raw_parts( + self.value_data.as_ptr().offset(pos as isize), + (self.value_offset_at(offset + 1) - pos) as usize, + ) + } } /// Returns the offset for the element at index `i`. @@ -574,6 +602,8 @@ impl std::fmt::Debug for FixedSizeBinaryArray { } } +impl super::private::Sealed for FixedSizeBinaryArray {} + impl Array for FixedSizeBinaryArray { fn as_any(&self) -> &dyn Any { self @@ -644,7 +674,7 @@ impl<'a> ArrayAccessor for &'a FixedSizeBinaryArray { } unsafe fn value_unchecked(&self, index: usize) -> Self::Item { - FixedSizeBinaryArray::value_unchecked(self, index) + unsafe { FixedSizeBinaryArray::value_unchecked(self, index) } } } @@ -659,11 +689,10 @@ impl<'a> IntoIterator for &'a FixedSizeBinaryArray { #[cfg(test)] mod tests { + use super::*; use crate::RecordBatch; use arrow_schema::{Field, Schema}; - use super::*; - #[test] fn test_fixed_size_binary_array() { let values: [u8; 15] = *b"hellotherearrow"; @@ -971,6 +1000,10 @@ mod tests { let nulls = NullBuffer::new_null(5); FixedSizeBinaryArray::new(2, buffer.clone(), Some(nulls)); + let null_array = FixedSizeBinaryArray::new_null(4, 3); + assert_eq!(null_array.len(), 3); + assert_eq!(null_array.values().len(), 12); + let a = FixedSizeBinaryArray::new(3, buffer.clone(), None); assert_eq!(a.len(), 3); @@ -985,7 +1018,24 @@ mod tests { ); let nulls = NullBuffer::new_null(3); - let err = FixedSizeBinaryArray::try_new(2, buffer, Some(nulls)).unwrap_err(); - assert_eq!(err.to_string(), "Invalid argument error: Incorrect length of null buffer for FixedSizeBinaryArray, expected 5 got 3"); + let err = FixedSizeBinaryArray::try_new(2, buffer.clone(), Some(nulls)).unwrap_err(); + assert_eq!( + err.to_string(), + "Invalid argument error: Incorrect length of null buffer for FixedSizeBinaryArray, expected 5 got 3" + ); + + let zero_sized = FixedSizeBinaryArray::new(0, Buffer::default(), None); + assert_eq!(zero_sized.len(), 0); + + let nulls = NullBuffer::new_null(3); + let zero_sized_with_nulls = FixedSizeBinaryArray::new(0, Buffer::default(), Some(nulls)); + assert_eq!(zero_sized_with_nulls.len(), 3); + + let zero_sized_with_non_empty_buffer_err = + FixedSizeBinaryArray::try_new(0, buffer, None).unwrap_err(); + assert_eq!( + zero_sized_with_non_empty_buffer_err.to_string(), + "Invalid argument error: Buffer cannot have non-zero length if the item size is zero" + ); } } diff --git a/arrow-array/src/array/fixed_size_list_array.rs b/arrow-array/src/array/fixed_size_list_array.rs index 44be442c9f85..f53b042f873b 100644 --- a/arrow-array/src/array/fixed_size_list_array.rs +++ b/arrow-array/src/array/fixed_size_list_array.rs @@ -18,9 +18,9 @@ use crate::array::print_long_array; use crate::builder::{FixedSizeListBuilder, PrimitiveBuilder}; use crate::iterator::FixedSizeListIter; -use crate::{make_array, Array, ArrayAccessor, ArrayRef, ArrowPrimitiveType}; -use arrow_buffer::buffer::NullBuffer; +use crate::{Array, ArrayAccessor, ArrayRef, ArrowPrimitiveType, make_array}; use arrow_buffer::ArrowNativeType; +use arrow_buffer::buffer::NullBuffer; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType, FieldRef}; use std::any::Any; @@ -114,7 +114,7 @@ use std::sync::Arc; /// ``` /// /// [`StringArray`]: crate::array::StringArray -/// [fixed size arrays](https://arrow.apache.org/docs/format/Columnar.html#fixed-size-list-layout) +/// [fixed length lists]: https://arrow.apache.org/docs/format/Columnar.html#fixed-size-list-layout #[derive(Clone)] pub struct FixedSizeListArray { data_type: DataType, // Must be DataType::FixedSizeList(value_length) @@ -125,7 +125,15 @@ pub struct FixedSizeListArray { } impl FixedSizeListArray { - /// Create a new [`FixedSizeListArray`] with `size` element size, panicking on failure + /// Create a new [`FixedSizeListArray`] with `size` element size, panicking on failure. + /// + /// Note that if `size == 0` and `nulls` is `None` (a degenerate, non-nullable + /// `FixedSizeListArray`), this function will set the length of the array to 0. + /// + /// If you would like to have a degenerate, non-nullable `FixedSizeListArray` with arbitrary + /// length, use the [`try_new_with_length()`] constructor. + /// + /// [`try_new_with_length()`]: Self::try_new_with_length /// /// # Panics /// @@ -134,12 +142,20 @@ impl FixedSizeListArray { Self::try_new(field, size, values, nulls).unwrap() } - /// Create a new [`FixedSizeListArray`] from the provided parts, returning an error on failure + /// Create a new [`FixedSizeListArray`] from the provided parts, returning an error on failure. + /// + /// Note that if `size == 0` and `nulls` is `None` (a degenerate, non-nullable + /// `FixedSizeListArray`), this function will set the length of the array to 0. + /// + /// If you would like to have a degenerate, non-nullable `FixedSizeListArray` with arbitrary + /// length, use the [`try_new_with_length()`] constructor. + /// + /// [`try_new_with_length()`]: Self::try_new_with_length /// /// # Errors /// /// * `size < 0` - /// * `values.len() / size != nulls.len()` + /// * `values.len() != nulls.len() * size` if `nulls` is `Some` /// * `values.data_type() != field.data_type()` /// * `!field.is_nullable() && !nulls.expand(size).contains(values.logical_nulls())` pub fn try_new( @@ -149,25 +165,91 @@ impl FixedSizeListArray { nulls: Option, ) -> Result { let s = size.to_usize().ok_or_else(|| { - ArrowError::InvalidArgumentError(format!("Size cannot be negative, got {}", size)) + ArrowError::InvalidArgumentError(format!("Size cannot be negative, got {size}")) })?; - let len = match s { - 0 => nulls.as_ref().map(|x| x.len()).unwrap_or_default(), - _ => { - let len = values.len() / s.max(1); - if let Some(n) = nulls.as_ref() { - if n.len() != len { - return Err(ArrowError::InvalidArgumentError(format!( - "Incorrect length of null buffer for FixedSizeListArray, expected {} got {}", - len, - n.len(), - ))); - } + if s == 0 { + // Note that for degenerate (`size == 0`) and non-nullable `FixedSizeList`s, we will set + // the length to 0 (`_or_default`). + let len = nulls.as_ref().map(|x| x.len()).unwrap_or_default(); + + Self::try_new_with_length(field, size, values, nulls, len) + } else { + if values.len() % s != 0 { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect length of values buffer for FixedSizeListArray, \ + expected a multiple of {s} got {}", + values.len(), + ))); + } + + let len = values.len() / s; + + // Check that the null buffer length is correct (if it exists). + if let Some(null_buffer) = &nulls { + if s * null_buffer.len() != values.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect length of values buffer for FixedSizeListArray, \ + expected {} got {}", + s * null_buffer.len(), + values.len(), + ))); } - len } - }; + + Self::try_new_with_length(field, size, values, nulls, len) + } + } + + /// Create a new [`FixedSizeListArray`] from the provided parts, returning an error on failure. + /// + /// This method exists to allow the construction of arbitrary length degenerate (`size == 0`) + /// and non-nullable `FixedSizeListArray`s. If you want a nullable `FixedSizeListArray`, then + /// you can use [`try_new()`] instead. + /// + /// [`try_new()`]: Self::try_new + /// + /// # Errors + /// + /// * `size < 0` + /// * `nulls.len() != len` if `nulls` is `Some` + /// * `values.len() != len * size` + /// * `values.data_type() != field.data_type()` + /// * `!field.is_nullable() && !nulls.expand(size).contains(values.logical_nulls())` + pub fn try_new_with_length( + field: FieldRef, + size: i32, + values: ArrayRef, + nulls: Option, + len: usize, + ) -> Result { + let s = size.to_usize().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!("Size cannot be negative, got {size}")) + })?; + + if let Some(null_buffer) = &nulls { + if null_buffer.len() != len { + return Err(ArrowError::InvalidArgumentError(format!( + "Invalid null buffer for FixedSizeListArray, expected {len} found {}", + null_buffer.len() + ))); + } + } + + if s == 0 && !values.is_empty() { + return Err(ArrowError::InvalidArgumentError(format!( + "An degenerate FixedSizeListArray should have no underlying values, found {} values", + values.len() + ))); + } + + if values.len() != len * s { + return Err(ArrowError::InvalidArgumentError(format!( + "Incorrect length of values buffer for FixedSizeListArray, expected {} got {}", + len * s, + values.len(), + ))); + } if field.data_type() != values.data_type() { return Err(ArrowError::InvalidArgumentError(format!( @@ -243,6 +325,12 @@ impl FixedSizeListArray { } /// Returns ith value of this list array. + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// (but still well-defined) if [`is_null`](Self::is_null) returns true for the index. + /// + /// # Panics + /// Panics if index `i` is out of bounds pub fn value(&self, i: usize) -> ArrayRef { self.values .slice(self.value_offset_at(i), self.value_length() as usize) @@ -343,8 +431,10 @@ impl From for FixedSizeListArray { fn from(data: ArrayData) -> Self { let value_length = match data.data_type() { DataType::FixedSizeList(_, len) => *len, - _ => { - panic!("FixedSizeListArray data should contain a FixedSizeList data type") + data_type => { + panic!( + "FixedSizeListArray data should contain a FixedSizeList data type, got {data_type}" + ) } }; @@ -372,6 +462,8 @@ impl From for ArrayData { } } +impl super::private::Sealed for FixedSizeListArray {} + impl Array for FixedSizeListArray { fn as_any(&self) -> &dyn Any { self @@ -474,12 +566,12 @@ impl ArrayAccessor for &FixedSizeListArray { #[cfg(test)] mod tests { - use arrow_buffer::{bit_util, BooleanBuffer, Buffer}; + use arrow_buffer::{BooleanBuffer, Buffer, bit_util}; use arrow_schema::Field; use crate::cast::AsArray; use crate::types::Int32Type; - use crate::{new_empty_array, Int32Array}; + use crate::{Int32Array, new_empty_array}; use super::*; @@ -665,8 +757,23 @@ mod tests { let list = FixedSizeListArray::new(field.clone(), 2, values.clone(), Some(nulls)); assert_eq!(list.len(), 3); - let list = FixedSizeListArray::new(field.clone(), 4, values.clone(), None); - assert_eq!(list.len(), 1); + let list = FixedSizeListArray::new(field.clone(), 3, values.clone(), None); + assert_eq!(list.len(), 2); + + let err = FixedSizeListArray::try_new(field.clone(), 4, values.clone(), None).unwrap_err(); + assert_eq!( + err.to_string(), + "Invalid argument error: Incorrect length of values buffer for FixedSizeListArray, \ + expected a multiple of 4 got 6", + ); + + let err = + FixedSizeListArray::try_new_with_length(field.clone(), 4, values.clone(), None, 1) + .unwrap_err(); + assert_eq!( + err.to_string(), + "Invalid argument error: Incorrect length of values buffer for FixedSizeListArray, expected 4 got 6" + ); let err = FixedSizeListArray::try_new(field.clone(), -1, values.clone(), None).unwrap_err(); assert_eq!( @@ -674,16 +781,19 @@ mod tests { "Invalid argument error: Size cannot be negative, got -1" ); - let list = FixedSizeListArray::new(field.clone(), 0, values.clone(), None); - assert_eq!(list.len(), 0); - let nulls = NullBuffer::new_null(2); let err = FixedSizeListArray::try_new(field, 2, values.clone(), Some(nulls)).unwrap_err(); - assert_eq!(err.to_string(), "Invalid argument error: Incorrect length of null buffer for FixedSizeListArray, expected 3 got 2"); + assert_eq!( + err.to_string(), + "Invalid argument error: Incorrect length of values buffer for FixedSizeListArray, expected 4 got 6" + ); let field = Arc::new(Field::new_list_field(DataType::Int32, false)); let err = FixedSizeListArray::try_new(field.clone(), 2, values.clone(), None).unwrap_err(); - assert_eq!(err.to_string(), "Invalid argument error: Found unmasked nulls for non-nullable FixedSizeListArray field \"item\""); + assert_eq!( + err.to_string(), + "Invalid argument error: Found unmasked nulls for non-nullable FixedSizeListArray field \"item\"" + ); // Valid as nulls in child masked by parent let nulls = NullBuffer::new(BooleanBuffer::new(Buffer::from([0b0000101]), 0, 3)); @@ -691,15 +801,49 @@ mod tests { let field = Arc::new(Field::new_list_field(DataType::Int64, true)); let err = FixedSizeListArray::try_new(field, 2, values, None).unwrap_err(); - assert_eq!(err.to_string(), "Invalid argument error: FixedSizeListArray expected data type Int64 got Int32 for \"item\""); + assert_eq!( + err.to_string(), + "Invalid argument error: FixedSizeListArray expected data type Int64 got Int32 for \"item\"" + ); } #[test] - fn empty_fixed_size_list() { + fn degenerate_fixed_size_list() { let field = Arc::new(Field::new_list_field(DataType::Int32, true)); let nulls = NullBuffer::new_null(2); let values = new_empty_array(&DataType::Int32); - let list = FixedSizeListArray::new(field.clone(), 0, values, Some(nulls)); + let list = FixedSizeListArray::new(field.clone(), 0, values.clone(), Some(nulls.clone())); assert_eq!(list.len(), 2); + + // Test invalid null buffer length. + let err = FixedSizeListArray::try_new_with_length( + field.clone(), + 0, + values.clone(), + Some(nulls), + 5, + ) + .unwrap_err(); + assert_eq!( + err.to_string(), + "Invalid argument error: Invalid null buffer for FixedSizeListArray, expected 5 found 2" + ); + + // Test non-empty values for degenerate list. + let non_empty_values = Arc::new(Int32Array::from(vec![1, 2, 3])); + let err = + FixedSizeListArray::try_new_with_length(field.clone(), 0, non_empty_values, None, 3) + .unwrap_err(); + assert_eq!( + err.to_string(), + "Invalid argument error: An degenerate FixedSizeListArray should have no underlying values, found 3 values" + ); + } + + #[test] + fn test_fixed_size_list_new_null_len() { + let field = Arc::new(Field::new_list_field(DataType::Int32, true)); + let array = FixedSizeListArray::new_null(field, 2, 5); + assert_eq!(array.len(), 5); } } diff --git a/arrow-array/src/array/list_array.rs b/arrow-array/src/array/list_array.rs index 79627776569b..225be14ae365 100644 --- a/arrow-array/src/array/list_array.rs +++ b/arrow-array/src/array/list_array.rs @@ -18,13 +18,13 @@ use crate::array::{get_offsets, make_array, print_long_array}; use crate::builder::{GenericListBuilder, PrimitiveBuilder}; use crate::{ - iterator::GenericListArrayIter, new_empty_array, Array, ArrayAccessor, ArrayRef, - ArrowPrimitiveType, FixedSizeListArray, + Array, ArrayAccessor, ArrayRef, ArrowPrimitiveType, FixedSizeListArray, + iterator::GenericListArrayIter, new_empty_array, }; use arrow_buffer::{ArrowNativeType, NullBuffer, OffsetBuffer}; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType, FieldRef}; -use num::Integer; +use num_integer::Integer; use std::any::Any; use std::sync::Arc; @@ -37,7 +37,9 @@ use std::sync::Arc; /// [`LargeBinaryArray`]: crate::array::LargeBinaryArray /// [`StringArray`]: crate::array::StringArray /// [`LargeStringArray`]: crate::array::LargeStringArray -pub trait OffsetSizeTrait: ArrowNativeType + std::ops::AddAssign + Integer { +pub trait OffsetSizeTrait: + ArrowNativeType + std::ops::AddAssign + Integer + num_traits::CheckedAdd +{ /// True for 64 bit offset size and false for 32 bit offset size const IS_LARGE: bool; /// Prefix for the offset size @@ -108,21 +110,21 @@ impl OffsetSizeTrait for i64 { /// ┌─────────────┐ ┌───────┐ │ ┌───┐ ┌───┐ ┌───┐ ┌───┐ /// │ [A,B,C] │ │ (0,3) │ │ 1 │ │ 0 │ │ │ 1 │ │ A │ │ 0 │ /// ├─────────────┤ ├───────┤ │ ├───┤ ├───┤ ├───┤ ├───┤ -/// │ [] │ │ (3,3) │ │ 1 │ │ 3 │ │ │ 1 │ │ B │ │ 1 │ +/// │ [] (empty) │ │ (3,3) │ │ 1 │ │ 3 │ │ │ 1 │ │ B │ │ 1 │ /// ├─────────────┤ ├───────┤ │ ├───┤ ├───┤ ├───┤ ├───┤ -/// │ NULL │ │ (3,4) │ │ 0 │ │ 3 │ │ │ 1 │ │ C │ │ 2 │ +/// │ NULL │ │ (3,3) │ │ 0 │ │ 3 │ │ │ 1 │ │ C │ │ 2 │ /// ├─────────────┤ ├───────┤ │ ├───┤ ├───┤ ├───┤ ├───┤ -/// │ [D] │ │ (4,5) │ │ 1 │ │ 4 │ │ │ ? │ │ ? │ │ 3 │ +/// │ [D] │ │ (3,4) │ │ 1 │ │ 3 │ │ │ 1 │ │ D │ │ 3 │ /// ├─────────────┤ ├───────┤ │ ├───┤ ├───┤ ├───┤ ├───┤ -/// │ [NULL, F] │ │ (5,7) │ │ 1 │ │ 5 │ │ │ 1 │ │ D │ │ 4 │ +/// │ [NULL, F] │ │ (4,6) │ │ 1 │ │ 4 │ │ │ 0 │ │ ? │ │ 4 │ /// └─────────────┘ └───────┘ │ └───┘ ├───┤ ├───┤ ├───┤ -/// │ 7 │ │ │ 0 │ │ ? │ │ 5 │ -/// │ Validity └───┘ ├───┤ ├───┤ -/// Logical Logical (nulls) Offsets │ │ 1 │ │ F │ │ 6 │ -/// Values Offsets │ └───┘ └───┘ -/// │ Values │ │ -/// (offsets[i], │ ListArray (Array) -/// offsets[i+1]) └ ─ ─ ─ ─ ─ ─ ┘ │ +/// │ 6 │ │ │ 1 │ │ F │ │ 5 │ +/// │ Validity └───┘ └───┘ └───┘ +/// Logical Logical (nulls) Offsets │ Values │ │ +/// Values Offsets │ (Array) +/// └ ─ ─ ─ ─ ─ ─ ┘ │ +/// (offsets[i], │ ListArray +/// offsets[i+1]) │ /// └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ /// ``` /// @@ -145,19 +147,19 @@ impl OffsetSizeTrait for i64 { /// ┌─────────────┐ ┌───────┐ │ ┌───┐ ┌───┐ ╠═══╣ ╠═══╣ /// │ [] (empty) │ │ (3,3) │ │ 1 │ │ 3 │ │ ║ 1 ║ ║ B ║ │ 1 │ /// ├─────────────┤ ├───────┤ │ ├───┤ ├───┤ ╠═══╣ ╠═══╣ -/// │ NULL │ │ (3,4) │ │ 0 │ │ 3 │ │ ║ 1 ║ ║ C ║ │ 2 │ -/// ├─────────────┤ ├───────┤ │ ├───┤ ├───┤ ╠───╣ ╠───╣ -/// │ [D] │ │ (4,5) │ │ 1 │ │ 4 │ │ │ 0 │ │ ? │ │ 3 │ -/// └─────────────┘ └───────┘ │ └───┘ ├───┤ ├───┤ ├───┤ -/// │ 5 │ │ │ 1 │ │ D │ │ 4 │ -/// │ └───┘ ├───┤ ├───┤ -/// │ │ 0 │ │ ? │ │ 5 │ -/// │ Validity ╠═══╣ ╠═══╣ -/// Logical Logical (nulls) Offsets │ ║ 1 ║ ║ F ║ │ 6 │ -/// Values Offsets │ ╚═══╝ ╚═══╝ -/// │ Values │ │ -/// (offsets[i], │ ListArray (Array) -/// offsets[i+1]) └ ─ ─ ─ ─ ─ ─ ┘ │ +/// │ NULL │ │ (3,3) │ │ 0 │ │ 3 │ │ ║ 1 ║ ║ C ║ │ 2 │ +/// ├─────────────┤ ├───────┤ │ ├───┤ ├───┤ ╚═══╝ ╚═══╝ +/// │ [D] │ │ (3,4) │ │ 1 │ │ 3 │ │ │ 1 │ │ D │ │ 3 │ +/// └─────────────┘ └───────┘ │ └───┘ ├───┤ ╔═══╗ ╔═══╗ +/// │ 4 │ │ ║ 0 ║ ║ ? ║ │ 4 │ +/// │ └───┘ ╠═══╣ ╠═══╣ +/// │ ║ 1 ║ ║ F ║ │ 5 │ +/// │ Validity ╚═══╝ ╚═══╝ +/// Logical Logical (nulls) Offsets │ Values │ │ +/// Values Offsets │ (Array) +/// └ ─ ─ ─ ─ ─ ─ ┘ │ +/// (offsets[i], │ ListArray +/// offsets[i+1]) │ /// └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ /// ``` /// @@ -327,15 +329,25 @@ impl GenericListArray { } /// Returns ith value of this list array. + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// if [`is_null`](Self::is_null) returns true for the index. + /// /// # Safety /// Caller must ensure that the index is within the array bounds pub unsafe fn value_unchecked(&self, i: usize) -> ArrayRef { - let end = self.value_offsets().get_unchecked(i + 1).as_usize(); - let start = self.value_offsets().get_unchecked(i).as_usize(); + let end = unsafe { self.value_offsets().get_unchecked(i + 1).as_usize() }; + let start = unsafe { self.value_offsets().get_unchecked(i).as_usize() }; self.values.slice(start, end - start) } /// Returns ith value of this list array. + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// (but still well-defined) if [`is_null`](Self::is_null) returns true for the index. + /// + /// # Panics + /// Panics if index `i` is out of bounds pub fn value(&self, i: usize) -> ArrayRef { let end = self.value_offsets()[i + 1].as_usize(); let start = self.value_offsets()[i].as_usize(); @@ -454,7 +466,7 @@ impl From for GenericListArray< _ => unreachable!(), }; - let offsets = OffsetBuffer::from_lengths(std::iter::repeat(size).take(value.len())); + let offsets = OffsetBuffer::from_repeated_length(size, value.len()); Self { data_type: Self::DATA_TYPE_CONSTRUCTOR(field.clone()), @@ -513,6 +525,8 @@ impl GenericListArray { } } +impl super::private::Sealed for GenericListArray {} + impl Array for GenericListArray { fn as_any(&self) -> &dyn Any { self @@ -623,7 +637,7 @@ mod tests { use crate::cast::AsArray; use crate::types::Int32Type; use crate::{Int32Array, Int64Array}; - use arrow_buffer::{bit_util, Buffer, ScalarBuffer}; + use arrow_buffer::{Buffer, ScalarBuffer, bit_util}; use arrow_schema::Field; fn create_from_buffers() -> ListArray { @@ -1272,4 +1286,11 @@ mod tests { let field = Arc::new(Field::new("element", values.data_type().clone(), false)); ListArray::new(field.clone(), offsets, Arc::new(values), None); } + + #[test] + fn test_list_new_null_len() { + let field = Arc::new(Field::new_list_field(DataType::Int32, true)); + let array = ListArray::new_null(field, 5); + assert_eq!(array.len(), 5); + } } diff --git a/arrow-array/src/array/list_view_array.rs b/arrow-array/src/array/list_view_array.rs index 6118607bcbbf..52c88d581d20 100644 --- a/arrow-array/src/array/list_view_array.rs +++ b/arrow-array/src/array/list_view_array.rs @@ -23,8 +23,12 @@ use std::ops::Add; use std::sync::Arc; use crate::array::{make_array, print_long_array}; +use crate::builder::{GenericListViewBuilder, PrimitiveBuilder}; use crate::iterator::GenericListViewArrayIter; -use crate::{new_empty_array, Array, ArrayAccessor, ArrayRef, FixedSizeListArray, OffsetSizeTrait}; +use crate::{ + Array, ArrayAccessor, ArrayRef, ArrowPrimitiveType, FixedSizeListArray, GenericListArray, + OffsetSizeTrait, new_empty_array, +}; /// A [`GenericListViewArray`] of variable size lists, storing offsets as `i32`. pub type ListViewArray = GenericListViewArray; @@ -89,9 +93,9 @@ pub type LargeListViewArray = GenericListViewArray; /// │ │ 1 │ │ D │ │ 5 │ /// Logical Logical │ Validity Offsets Sizes └───┘ └───┘ /// Values Offset (nulls) │ Values │ │ -/// & Size │ (Array) +/// & Size │ (Array) /// └ ─ ─ ─ ─ ─ ─ ┘ │ -/// (offsets[i], │ ListViewArray +/// (offsets[i], │ ListViewArray /// sizes[i]) │ /// └ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ /// ``` @@ -154,7 +158,8 @@ impl GenericListViewArray { if len != sizes.len() { return Err(ArrowError::InvalidArgumentError(format!( "Length of offsets buffer and sizes buffer must be equal for {}ListViewArray, got {len} and {}", - OffsetSize::PREFIX, sizes.len() + OffsetSize::PREFIX, + sizes.len() ))); } @@ -224,8 +229,8 @@ impl GenericListViewArray { Self { data_type: Self::DATA_TYPE_CONSTRUCTOR(field), nulls: Some(NullBuffer::new_null(len)), - value_offsets: ScalarBuffer::from(vec![]), - value_sizes: ScalarBuffer::from(vec![]), + value_offsets: ScalarBuffer::from(vec![OffsetSize::usize_as(0); len]), + value_sizes: ScalarBuffer::from(vec![OffsetSize::usize_as(0); len]), values, } } @@ -283,15 +288,23 @@ impl GenericListViewArray { } /// Returns ith value of this list view array. + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// if [`is_null`](Self::is_null) returns true for the index. + /// /// # Safety /// Caller must ensure that the index is within the array bounds pub unsafe fn value_unchecked(&self, i: usize) -> ArrayRef { - let offset = self.value_offsets().get_unchecked(i).as_usize(); - let length = self.value_sizes().get_unchecked(i).as_usize(); + let offset = unsafe { self.value_offsets().get_unchecked(i).as_usize() }; + let length = unsafe { self.value_sizes().get_unchecked(i).as_usize() }; self.values.slice(offset, length) } /// Returns ith value of this list view array. + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// (but still well-defined) if [`is_null`](Self::is_null) returns true for the index. + /// /// # Panics /// Panics if the index is out of bounds pub fn value(&self, i: usize) -> ArrayRef { @@ -348,6 +361,46 @@ impl GenericListViewArray { value_sizes: self.value_sizes.slice(offset, length), } } + + /// Creates a [`GenericListViewArray`] from an iterator of primitive values + /// # Example + /// ``` + /// # use arrow_array::ListViewArray; + /// # use arrow_array::types::Int32Type; + /// + /// let data = vec![ + /// Some(vec![Some(0), Some(1), Some(2)]), + /// None, + /// Some(vec![Some(3), None, Some(5)]), + /// Some(vec![Some(6), Some(7)]), + /// ]; + /// let list_array = ListViewArray::from_iter_primitive::(data); + /// println!("{:?}", list_array); + /// ``` + pub fn from_iter_primitive(iter: I) -> Self + where + T: ArrowPrimitiveType, + P: IntoIterator::Native>>, + I: IntoIterator>, + { + let iter = iter.into_iter(); + let size_hint = iter.size_hint().0; + let mut builder = + GenericListViewBuilder::with_capacity(PrimitiveBuilder::::new(), size_hint); + + for i in iter { + match i { + Some(p) => { + for t in p { + builder.values().append_option(t); + } + builder.append(true); + } + None => builder.append(false), + } + } + builder.finish() + } } impl ArrayAccessor for &GenericListViewArray { @@ -358,10 +411,12 @@ impl ArrayAccessor for &GenericListViewArray Self::Item { - GenericListViewArray::value_unchecked(self, index) + unsafe { GenericListViewArray::value_unchecked(self, index) } } } +impl super::private::Sealed for GenericListViewArray {} + impl Array for GenericListViewArray { fn as_any(&self) -> &dyn Any { self @@ -445,6 +500,29 @@ impl std::fmt::Debug for GenericListViewArray From> + for GenericListViewArray +{ + fn from(value: GenericListArray) -> Self { + let (field, offsets, values, nulls) = value.into_parts(); + let len = offsets.len() - 1; + let mut sizes = Vec::with_capacity(len); + let mut view_offsets = Vec::with_capacity(len); + for (i, offset) in offsets.iter().enumerate().take(len) { + view_offsets.push(*offset); + sizes.push(offsets[i + 1] - offsets[i]); + } + + Self::new( + field, + ScalarBuffer::from(view_offsets), + ScalarBuffer::from(sizes), + values, + nulls, + ) + } +} + impl From> for ArrayData { fn from(array: GenericListViewArray) -> Self { let len = array.len(); @@ -475,7 +553,7 @@ impl From for GenericListViewAr _ => unreachable!(), }; let mut acc = 0_usize; - let iter = std::iter::repeat(size).take(value.len()); + let iter = std::iter::repeat_n(size, value.len()); let mut sizes = Vec::with_capacity(iter.size_hint().0); let mut offsets = Vec::with_capacity(iter.size_hint().0); @@ -550,7 +628,7 @@ impl GenericListViewArray { #[cfg(test)] mod tests { - use arrow_buffer::{bit_util, BooleanBuffer, Buffer, ScalarBuffer}; + use arrow_buffer::{BooleanBuffer, Buffer, NullBufferBuilder, ScalarBuffer, bit_util}; use arrow_schema::Field; use crate::builder::{FixedSizeListBuilder, Int32Builder}; @@ -1111,4 +1189,36 @@ mod tests { .collect(); assert_eq!(values, vec![Some(vec![]), Some(vec![]), Some(vec![])]); } + + #[test] + fn test_list_view_new_null_len() { + let field = Arc::new(Field::new_list_field(DataType::Int32, true)); + let array = ListViewArray::new_null(field, 5); + assert_eq!(array.len(), 5); + } + + #[test] + fn test_from_iter_primitive() { + let data = vec![ + Some(vec![Some(0), Some(1), Some(2)]), + None, + Some(vec![Some(3), Some(4), Some(5)]), + Some(vec![Some(6), Some(7)]), + ]; + let list_array = ListViewArray::from_iter_primitive::(data); + + // [[0, 1, 2], NULL, [3, 4, 5], [6, 7]] + let values = Int32Array::from(vec![0, 1, 2, 3, 4, 5, 6, 7]); + let offsets = ScalarBuffer::from(vec![0, 3, 3, 6]); + let sizes = ScalarBuffer::from(vec![3, 0, 3, 2]); + let field = Arc::new(Field::new_list_field(DataType::Int32, true)); + + let mut nulls = NullBufferBuilder::new(4); + nulls.append(true); + nulls.append(false); + nulls.append_n_non_nulls(2); + let another = ListViewArray::new(field, offsets, sizes, Arc::new(values), nulls.finish()); + + assert_eq!(list_array, another) + } } diff --git a/arrow-array/src/array/map_array.rs b/arrow-array/src/array/map_array.rs index 18a7c491aa16..86608d586f34 100644 --- a/arrow-array/src/array/map_array.rs +++ b/arrow-array/src/array/map_array.rs @@ -17,7 +17,7 @@ use crate::array::{get_offsets, print_long_array}; use crate::iterator::MapArrayIter; -use crate::{make_array, Array, ArrayAccessor, ArrayRef, ListArray, StringArray, StructArray}; +use crate::{Array, ArrayAccessor, ArrayRef, ListArray, StringArray, StructArray, make_array}; use arrow_buffer::{ArrowNativeType, Buffer, NullBuffer, OffsetBuffer, ToByteSlice}; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType, Field, FieldRef}; @@ -173,6 +173,15 @@ impl MapArray { &self.entries } + /// Returns a reference to the fields of the [`StructArray`] that backs this map. + pub fn entries_fields(&self) -> (&Field, &Field) { + let fields = self.entries.fields().iter().collect::>(); + let fields = TryInto::<[&FieldRef; 2]>::try_into(fields) + .expect("Every map has a key and value field"); + + (fields[0].as_ref(), fields[1].as_ref()) + } + /// Returns the data type of the map's keys. pub fn key_type(&self) -> &DataType { self.keys().data_type() @@ -185,11 +194,14 @@ impl MapArray { /// Returns ith value of this map array. /// + /// Note: This method does not check for nulls and the value is arbitrary + /// if [`is_null`](Self::is_null) returns true for the index. + /// /// # Safety /// Caller must ensure that the index is within the array bounds pub unsafe fn value_unchecked(&self, i: usize) -> StructArray { - let end = *self.value_offsets().get_unchecked(i + 1); - let start = *self.value_offsets().get_unchecked(i); + let end = *unsafe { self.value_offsets().get_unchecked(i + 1) }; + let start = *unsafe { self.value_offsets().get_unchecked(i) }; self.entries .slice(start.to_usize().unwrap(), (end - start).to_usize().unwrap()) } @@ -197,6 +209,12 @@ impl MapArray { /// Returns ith value of this map array. /// /// This is a [`StructArray`] containing two fields + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// (but still well-defined) if [`is_null`](Self::is_null) returns true for the index. + /// + /// # Panics + /// Panics if index `i` is out of bounds pub fn value(&self, i: usize) -> StructArray { let end = self.value_offsets()[i + 1] as usize; let start = self.value_offsets()[i] as usize; @@ -343,6 +361,8 @@ impl MapArray { } } +impl super::private::Sealed for MapArray {} + impl Array for MapArray { fn as_any(&self) -> &dyn Any { self diff --git a/arrow-array/src/array/mod.rs b/arrow-array/src/array/mod.rs index 29d284e3c5c4..aae382ace7b4 100644 --- a/arrow-array/src/array/mod.rs +++ b/arrow-array/src/array/mod.rs @@ -78,8 +78,18 @@ pub use list_view_array::*; use crate::iterator::ArrayIter; +mod private { + /// Private marker trait to ensure [`super::Array`] can not be implemented outside this crate + pub trait Sealed {} + + impl Sealed for &T {} +} + /// An array in the [arrow columnar format](https://arrow.apache.org/docs/format/Columnar.html) -pub trait Array: std::fmt::Debug + Send + Sync { +/// +/// This trait is sealed as it is not intended for custom array types, rather only +/// those defined in this crate. +pub trait Array: std::fmt::Debug + Send + Sync + private::Sealed { /// Returns the array as [`Any`] so that it can be /// downcasted to a specific implementation. /// @@ -341,6 +351,8 @@ pub trait Array: std::fmt::Debug + Send + Sync { /// A reference-counted reference to a generic `Array` pub type ArrayRef = Arc; +impl private::Sealed for ArrayRef {} + /// Ergonomics: Allow use of an ArrayRef as an `&dyn Array` impl Array for ArrayRef { fn as_any(&self) -> &dyn Any { @@ -620,10 +632,11 @@ impl<'a> StringArrayType<'a> for &'a StringViewArray { } } -/// A trait for Arrow String Arrays, currently three types are supported: +/// A trait for Arrow Binary Arrays, currently four types are supported: /// - `BinaryArray` /// - `LargeBinaryArray` /// - `BinaryViewArray` +/// - `FixedSizeBinaryArray` /// /// This trait helps to abstract over the different types of binary arrays /// so that we don't need to duplicate the implementation for each type. @@ -642,6 +655,11 @@ impl<'a> BinaryArrayType<'a> for &'a BinaryViewArray { BinaryViewArray::iter(self) } } +impl<'a> BinaryArrayType<'a> for &'a FixedSizeBinaryArray { + fn iter(&self) -> ArrayIter { + FixedSizeBinaryArray::iter(self) + } +} impl PartialEq for dyn Array + '_ { fn eq(&self, other: &Self) -> bool { @@ -739,8 +757,36 @@ impl PartialEq for RunArray { } } -/// Constructs an array using the input `data`. -/// Returns a reference-counted `Array` instance. +/// Constructs an [`ArrayRef`] from an [`ArrayData`]. +/// +/// # Notes: +/// +/// It is more efficient to directly construct the concrete array type rather +/// than using this function as creating an `ArrayData` requires at least one +/// additional allocation (the Vec of buffers). +/// +/// # Example: +/// ``` +/// # use std::sync::Arc; +/// # use arrow_data::ArrayData; +/// # use arrow_array::{make_array, ArrayRef, Int32Array}; +/// # use arrow_buffer::{Buffer, ScalarBuffer}; +/// # use arrow_schema::DataType; +/// // Create an Int32Array with values [1, 2, 3] +/// let values_buffer = Buffer::from_slice_ref(&[1, 2, 3]); +/// // ArrayData can be constructed using ArrayDataBuilder +/// let builder = ArrayData::builder(DataType::Int32) +/// .len(3) +/// .add_buffer(values_buffer.clone()); +/// let array_data = builder.build().unwrap(); +/// // Create the ArrayRef from the ArrayData +/// let array = make_array(array_data); +/// +/// // It is equivalent to directly constructing the Int32Array +/// let scalar_buffer = ScalarBuffer::from(values_buffer); +/// let int32_array: ArrayRef = Arc::new(Int32Array::new(scalar_buffer, None)); +/// assert_eq!(&array, &int32_array); +/// ``` pub fn make_array(data: ArrayData) -> ArrayRef { match data.data_type() { DataType::Boolean => Arc::new(BooleanArray::from(data)) as ArrayRef, @@ -815,7 +861,7 @@ pub fn make_array(data: ArrayData) -> ArrayRef { DataType::Map(_, _) => Arc::new(MapArray::from(data)) as ArrayRef, DataType::Union(_, _) => Arc::new(UnionArray::from(data)) as ArrayRef, DataType::FixedSizeList(_, _) => Arc::new(FixedSizeListArray::from(data)) as ArrayRef, - DataType::Dictionary(ref key_type, _) => match key_type.as_ref() { + DataType::Dictionary(key_type, _) => match key_type.as_ref() { DataType::Int8 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, DataType::Int16 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, DataType::Int32 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, @@ -824,18 +870,20 @@ pub fn make_array(data: ArrayData) -> ArrayRef { DataType::UInt16 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, DataType::UInt32 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, DataType::UInt64 => Arc::new(DictionaryArray::::from(data)) as ArrayRef, - dt => panic!("Unexpected dictionary key type {dt:?}"), + dt => unimplemented!("Unexpected dictionary key type {dt}"), }, - DataType::RunEndEncoded(ref run_ends_type, _) => match run_ends_type.data_type() { + DataType::RunEndEncoded(run_ends_type, _) => match run_ends_type.data_type() { DataType::Int16 => Arc::new(RunArray::::from(data)) as ArrayRef, DataType::Int32 => Arc::new(RunArray::::from(data)) as ArrayRef, DataType::Int64 => Arc::new(RunArray::::from(data)) as ArrayRef, - dt => panic!("Unexpected data type for run_ends array {dt:?}"), + dt => unimplemented!("Unexpected data type for run_ends array {dt}"), }, DataType::Null => Arc::new(NullArray::from(data)) as ArrayRef, + DataType::Decimal32(_, _) => Arc::new(Decimal32Array::from(data)) as ArrayRef, + DataType::Decimal64(_, _) => Arc::new(Decimal64Array::from(data)) as ArrayRef, DataType::Decimal128(_, _) => Arc::new(Decimal128Array::from(data)) as ArrayRef, DataType::Decimal256(_, _) => Arc::new(Decimal256Array::from(data)) as ArrayRef, - dt => panic!("Unexpected data type {dt:?}"), + dt => unimplemented!("Unexpected data type {dt}"), } } @@ -1065,13 +1113,14 @@ mod tests { fn test_null_union() { for mode in [UnionMode::Sparse, UnionMode::Dense] { let data_type = DataType::Union( - UnionFields::new( + UnionFields::try_new( vec![2, 1], vec![ Field::new("foo", DataType::Int32, true), Field::new("bar", DataType::Int64, true), ], - ), + ) + .unwrap(), mode, ); let array = new_null_array(&data_type, 4); diff --git a/arrow-array/src/array/null_array.rs b/arrow-array/src/array/null_array.rs index 2dd9570a0e94..b682466b6738 100644 --- a/arrow-array/src/array/null_array.rs +++ b/arrow-array/src/array/null_array.rs @@ -76,6 +76,8 @@ impl NullArray { } } +impl super::private::Sealed for NullArray {} + impl Array for NullArray { fn as_any(&self) -> &dyn Any { self @@ -170,7 +172,7 @@ impl std::fmt::Debug for NullArray { #[cfg(test)] mod tests { use super::*; - use crate::{make_array, Int64Array, StructArray}; + use crate::{Int64Array, StructArray, make_array}; use arrow_data::transform::MutableArrayData; use arrow_schema::Field; diff --git a/arrow-array/src/array/primitive_array.rs b/arrow-array/src/array/primitive_array.rs index 073ad9774459..87de5f61605f 100644 --- a/arrow-array/src/array/primitive_array.rs +++ b/arrow-array/src/array/primitive_array.rs @@ -25,7 +25,7 @@ use crate::timezone::Tz; use crate::trusted_len::trusted_len_unzip; use crate::types::*; use crate::{Array, ArrayAccessor, ArrayRef, Scalar}; -use arrow_buffer::{i256, ArrowNativeType, Buffer, NullBuffer, ScalarBuffer}; +use arrow_buffer::{ArrowNativeType, Buffer, NullBuffer, ScalarBuffer, i256}; use arrow_data::bit_iterator::try_for_each_valid_idx; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType}; @@ -410,6 +410,44 @@ pub type DurationMicrosecondArray = PrimitiveArray; /// A [`PrimitiveArray`] of elapsed durations in nanoseconds pub type DurationNanosecondArray = PrimitiveArray; +/// A [`PrimitiveArray`] of 32-bit fixed point decimals +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::Decimal32Array; +/// // Create from Vec> +/// let arr = Decimal32Array::from(vec![Some(1), None, Some(2)]); +/// // Create from Vec +/// let arr = Decimal32Array::from(vec![1, 2, 3]); +/// // Create iter/collect +/// let arr: Decimal32Array = std::iter::repeat(42).take(10).collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type Decimal32Array = PrimitiveArray; + +/// A [`PrimitiveArray`] of 64-bit fixed point decimals +/// +/// # Examples +/// +/// Construction +/// +/// ``` +/// # use arrow_array::Decimal64Array; +/// // Create from Vec> +/// let arr = Decimal64Array::from(vec![Some(1), None, Some(2)]); +/// // Create from Vec +/// let arr = Decimal64Array::from(vec![1, 2, 3]); +/// // Create iter/collect +/// let arr: Decimal64Array = std::iter::repeat(42).take(10).collect(); +/// ``` +/// +/// See [`PrimitiveArray`] for more information and examples +pub type Decimal64Array = PrimitiveArray; + /// A [`PrimitiveArray`] of 128-bit fixed point decimals /// /// # Examples @@ -455,6 +493,9 @@ pub use crate::types::ArrowPrimitiveType; /// /// # Example: From a Vec /// +/// *Note*: Converting a `Vec` to a `PrimitiveArray` does not copy the data. +/// The new `PrimitiveArray` uses the same underlying allocation from the `Vec`. +/// /// ``` /// # use arrow_array::{Array, PrimitiveArray, types::Int32Type}; /// let arr: PrimitiveArray = vec![1, 2, 3, 4].into(); @@ -463,6 +504,33 @@ pub use crate::types::ArrowPrimitiveType; /// assert_eq!(arr.values(), &[1, 2, 3, 4]) /// ``` /// +/// # Example: To a `Vec` +/// +/// *Note*: In some cases, converting `PrimitiveArray` to a `Vec` is zero-copy +/// and does not copy the data (see [`Buffer::into_vec`] for conditions). In +/// such cases, the `Vec` will use the same underlying memory allocation from +/// the `PrimitiveArray`. +/// +/// The Rust compiler generates highly optimized code for operations on +/// Vec, so using a Vec can often be faster than using a PrimitiveArray directly. +/// +/// ``` +/// # use arrow_array::{Array, PrimitiveArray, types::Int32Type}; +/// let arr = PrimitiveArray::::from(vec![1, 2, 3, 4]); +/// let starting_ptr = arr.values().as_ptr(); +/// // split into its parts +/// let (datatype, buffer, nulls) = arr.into_parts(); +/// // Convert the buffer to a Vec (zero copy) +/// // (note this requires that there are no other references) +/// let mut vec: Vec = buffer.into(); +/// vec[2] = 300; +/// // put the parts back together +/// let arr = PrimitiveArray::::try_new(vec.into(), nulls).unwrap(); +/// assert_eq!(arr.values(), &[1, 2, 300, 4]); +/// // The same allocation was used +/// assert_eq!(starting_ptr, arr.values().as_ptr()); +/// ``` +/// /// # Example: From an optional Vec /// /// ``` @@ -672,6 +740,8 @@ impl PrimitiveArray { DataType::Timestamp(t1, _) => { matches!(data_type, DataType::Timestamp(t2, _) if &t1 == t2) } + DataType::Decimal32(_, _) => matches!(data_type, DataType::Decimal32(_, _)), + DataType::Decimal64(_, _) => matches!(data_type, DataType::Decimal64(_, _)), DataType::Decimal128(_, _) => matches!(data_type, DataType::Decimal128(_, _)), DataType::Decimal256(_, _) => matches!(data_type, DataType::Decimal256(_, _)), _ => T::DATA_TYPE.eq(data_type), @@ -680,15 +750,22 @@ impl PrimitiveArray { /// Returns the primitive value at index `i`. /// + /// Note: This method does not check for nulls and the value is arbitrary + /// if [`is_null`](Self::is_null) returns true for the index. + /// /// # Safety /// /// caller must ensure that the passed in offset is less than the array len() #[inline] pub unsafe fn value_unchecked(&self, i: usize) -> T::Native { - *self.values.get_unchecked(i) + unsafe { *self.values.get_unchecked(i) } } /// Returns the primitive value at index `i`. + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// if [`is_null`](Self::is_null) returns true for the index. + /// /// # Panics /// Panics if index `i` is out of bounds #[inline] @@ -749,7 +826,7 @@ impl PrimitiveArray { &'a self, indexes: impl Iterator> + 'a, ) -> impl Iterator> + 'a { - indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index))) + indexes.map(|opt_index| opt_index.map(|index| unsafe { self.value_unchecked(index) })) } /// Returns a zero-copy slice of this array with the indicated offset and length. @@ -782,11 +859,7 @@ impl PrimitiveArray { where K: ArrowPrimitiveType, { - let d = self.to_data().into_builder().data_type(K::DATA_TYPE); - - // SAFETY: - // Native type is the same - PrimitiveArray::from(unsafe { d.build_unchecked() }) + PrimitiveArray::new(self.values.clone(), self.nulls.clone()) } /// Applies a unary infallible function to a primitive array, producing a @@ -1113,6 +1186,8 @@ impl From> for ArrayData { } } +impl super::private::Sealed for PrimitiveArray {} + impl Array for PrimitiveArray { fn as_any(&self) -> &dyn Any { self @@ -1183,7 +1258,7 @@ impl ArrayAccessor for &PrimitiveArray { #[inline] unsafe fn value_unchecked(&self, index: usize) -> Self::Item { - PrimitiveArray::value_unchecked(self, index) + unsafe { PrimitiveArray::value_unchecked(self, index) } } } @@ -1195,6 +1270,8 @@ where /// /// If a data type cannot be converted to `NaiveDateTime`, a `None` is returned. /// A valid value is expected, thus the user should first check for validity. + /// + /// See notes on [`PrimitiveArray::value`] regarding nulls and panics pub fn value_as_datetime(&self, i: usize) -> Option { as_datetime::(i64::from(self.value(i))) } @@ -1203,6 +1280,8 @@ where /// /// functionally it is same as `value_as_datetime`, however it adds /// the passed tz to the to-be-returned NaiveDateTime + /// + /// See notes on [`PrimitiveArray::value`] regarding nulls and panics pub fn value_as_datetime_with_tz(&self, i: usize, tz: Tz) -> Option> { as_datetime_with_timezone::(i64::from(self.value(i)), tz) } @@ -1210,6 +1289,8 @@ where /// Returns value as a chrono `NaiveDate` by using `Self::datetime()` /// /// If a data type cannot be converted to `NaiveDate`, a `None` is returned + /// + /// See notes on [`PrimitiveArray::value`] regarding nulls and panics pub fn value_as_date(&self, i: usize) -> Option { self.value_as_datetime(i).map(|datetime| datetime.date()) } @@ -1217,6 +1298,8 @@ where /// Returns a value as a chrono `NaiveTime` /// /// `Date32` and `Date64` return UTC midnight as they do not have time resolution + /// + /// See notes on [`PrimitiveArray::value`] regarding nulls and panics pub fn value_as_time(&self, i: usize) -> Option { as_time::(i64::from(self.value(i))) } @@ -1224,6 +1307,8 @@ where /// Returns a value as a chrono `Duration` /// /// If a data type cannot be converted to `Duration`, a `None` is returned + /// + /// See notes on [`PrimitiveArray::value`] regarding nulls and panics pub fn value_as_duration(&self, i: usize) -> Option { as_duration::(i64::from(self.value(i))) } @@ -1233,7 +1318,7 @@ impl std::fmt::Debug for PrimitiveArray { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { let data_type = self.data_type(); - write!(f, "PrimitiveArray<{data_type:?}>\n[\n")?; + write!(f, "PrimitiveArray<{data_type}>\n[\n")?; print_long_array(self, f, |array, index, f| match data_type { DataType::Date32 | DataType::Date64 => { let v = self.value(index).to_i64().unwrap(); @@ -1242,7 +1327,7 @@ impl std::fmt::Debug for PrimitiveArray { None => { write!( f, - "Cast error: Failed to convert {v} to temporal for {data_type:?}" + "Cast error: Failed to convert {v} to temporal for {data_type}" ) } } @@ -1254,7 +1339,7 @@ impl std::fmt::Debug for PrimitiveArray { None => { write!( f, - "Cast error: Failed to convert {v} to temporal for {data_type:?}" + "Cast error: Failed to convert {v} to temporal for {data_type}" ) } } @@ -1343,6 +1428,8 @@ def_from_for_primitive!(UInt64Type, u64); def_from_for_primitive!(Float16Type, f16); def_from_for_primitive!(Float32Type, f32); def_from_for_primitive!(Float64Type, f64); +def_from_for_primitive!(Decimal32Type, i32); +def_from_for_primitive!(Decimal64Type, i64); def_from_for_primitive!(Decimal128Type, i128); def_from_for_primitive!(Decimal256Type, i256); @@ -1412,10 +1499,11 @@ impl PrimitiveArray { let (_, upper) = iterator.size_hint(); let len = upper.expect("trusted_len_unzip requires an upper limit"); - let (null, buffer) = trusted_len_unzip(iterator); + let (null, buffer) = unsafe { trusted_len_unzip(iterator) }; - let data = - ArrayData::new_unchecked(T::DATA_TYPE, len, None, Some(null), 0, vec![buffer], vec![]); + let data = unsafe { + ArrayData::new_unchecked(T::DATA_TYPE, len, None, Some(null), 0, vec![buffer], vec![]) + }; PrimitiveArray::from(data) } } @@ -1455,6 +1543,8 @@ def_numeric_from_vec!(UInt64Type); def_numeric_from_vec!(Float16Type); def_numeric_from_vec!(Float32Type); def_numeric_from_vec!(Float64Type); +def_numeric_from_vec!(Decimal32Type); +def_numeric_from_vec!(Decimal64Type); def_numeric_from_vec!(Decimal128Type); def_numeric_from_vec!(Decimal256Type); @@ -1539,10 +1629,16 @@ impl PrimitiveArray { /// Validates values in this array can be properly interpreted /// with the specified precision. pub fn validate_decimal_precision(&self, precision: u8) -> Result<(), ArrowError> { + if precision < self.scale() as u8 { + return Err(ArrowError::InvalidArgumentError(format!( + "Decimal precision {precision} is less than scale {}", + self.scale() + ))); + } (0..self.len()).try_for_each(|idx| { if self.is_valid(idx) { let decimal = unsafe { self.value_unchecked(idx) }; - T::validate_decimal_precision(decimal, precision) + T::validate_decimal_precision(decimal, precision, self.scale()) } else { Ok(()) } @@ -1563,6 +1659,26 @@ impl PrimitiveArray { /// Returns the decimal precision of this array pub fn precision(&self) -> u8 { match T::BYTE_LENGTH { + 4 => { + if let DataType::Decimal32(p, _) = self.data_type() { + *p + } else { + unreachable!( + "Decimal32Array datatype is not DataType::Decimal32 but {}", + self.data_type() + ) + } + } + 8 => { + if let DataType::Decimal64(p, _) = self.data_type() { + *p + } else { + unreachable!( + "Decimal64Array datatype is not DataType::Decimal64 but {}", + self.data_type() + ) + } + } 16 => { if let DataType::Decimal128(p, _) = self.data_type() { *p @@ -1590,6 +1706,26 @@ impl PrimitiveArray { /// Returns the decimal scale of this array pub fn scale(&self) -> i8 { match T::BYTE_LENGTH { + 4 => { + if let DataType::Decimal32(_, s) = self.data_type() { + *s + } else { + unreachable!( + "Decimal32Array datatype is not DataType::Decimal32 but {}", + self.data_type() + ) + } + } + 8 => { + if let DataType::Decimal64(_, s) = self.data_type() { + *s + } else { + unreachable!( + "Decimal64Array datatype is not DataType::Decimal64 but {}", + self.data_type() + ) + } + } 16 => { if let DataType::Decimal128(_, s) = self.data_type() { *s @@ -1618,9 +1754,11 @@ impl PrimitiveArray { #[cfg(test)] mod tests { use super::*; - use crate::builder::{Decimal128Builder, Decimal256Builder}; - use crate::cast::downcast_array; use crate::BooleanArray; + use crate::builder::{ + Decimal32Builder, Decimal64Builder, Decimal128Builder, Decimal256Builder, + }; + use crate::cast::downcast_array; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use arrow_schema::TimeUnit; @@ -1990,7 +2128,7 @@ mod tests { let arr: PrimitiveArray = TimestampMillisecondArray::from(vec![1546214400000, 1546214400000, -1546214400000]); assert_eq!( - "PrimitiveArray\n[\n 2018-12-31T00:00:00,\n 2018-12-31T00:00:00,\n 1921-01-02T00:00:00,\n]", + "PrimitiveArray\n[\n 2018-12-31T00:00:00,\n 2018-12-31T00:00:00,\n 1921-01-02T00:00:00,\n]", format!("{arr:?}") ); } @@ -2001,7 +2139,7 @@ mod tests { TimestampMillisecondArray::from(vec![1546214400000, 1546214400000, -1546214400000]) .with_timezone_utc(); assert_eq!( - "PrimitiveArray\n[\n 2018-12-31T00:00:00+00:00,\n 2018-12-31T00:00:00+00:00,\n 1921-01-02T00:00:00+00:00,\n]", + "PrimitiveArray\n[\n 2018-12-31T00:00:00+00:00,\n 2018-12-31T00:00:00+00:00,\n 1921-01-02T00:00:00+00:00,\n]", format!("{arr:?}") ); } @@ -2013,8 +2151,8 @@ mod tests { TimestampMillisecondArray::from(vec![1546214400000, 1546214400000, -1546214400000]) .with_timezone("Asia/Taipei".to_string()); assert_eq!( - "PrimitiveArray\n[\n 2018-12-31T08:00:00+08:00,\n 2018-12-31T08:00:00+08:00,\n 1921-01-02T08:00:00+08:00,\n]", - format!("{:?}", arr) + "PrimitiveArray\n[\n 2018-12-31T08:00:00+08:00,\n 2018-12-31T08:00:00+08:00,\n 1921-01-02T08:00:00+08:00,\n]", + format!("{arr:?}") ); } @@ -2028,7 +2166,7 @@ mod tests { println!("{arr:?}"); assert_eq!( - "PrimitiveArray\n[\n 2018-12-31T00:00:00 (Unknown Time Zone 'Asia/Taipei'),\n 2018-12-31T00:00:00 (Unknown Time Zone 'Asia/Taipei'),\n 1921-01-02T00:00:00 (Unknown Time Zone 'Asia/Taipei'),\n]", + "PrimitiveArray\n[\n 2018-12-31T00:00:00 (Unknown Time Zone 'Asia/Taipei'),\n 2018-12-31T00:00:00 (Unknown Time Zone 'Asia/Taipei'),\n 1921-01-02T00:00:00 (Unknown Time Zone 'Asia/Taipei'),\n]", format!("{arr:?}") ); } @@ -2039,7 +2177,7 @@ mod tests { TimestampMillisecondArray::from(vec![1546214400000, 1546214400000, -1546214400000]) .with_timezone("+08:00".to_string()); assert_eq!( - "PrimitiveArray\n[\n 2018-12-31T08:00:00+08:00,\n 2018-12-31T08:00:00+08:00,\n 1921-01-02T08:00:00+08:00,\n]", + "PrimitiveArray\n[\n 2018-12-31T08:00:00+08:00,\n 2018-12-31T08:00:00+08:00,\n 1921-01-02T08:00:00+08:00,\n]", format!("{arr:?}") ); } @@ -2050,7 +2188,7 @@ mod tests { TimestampMillisecondArray::from(vec![1546214400000, 1546214400000, -1546214400000]) .with_timezone("xxx".to_string()); assert_eq!( - "PrimitiveArray\n[\n 2018-12-31T00:00:00 (Unknown Time Zone 'xxx'),\n 2018-12-31T00:00:00 (Unknown Time Zone 'xxx'),\n 1921-01-02T00:00:00 (Unknown Time Zone 'xxx'),\n]", + "PrimitiveArray\n[\n 2018-12-31T00:00:00 (Unknown Time Zone 'xxx'),\n 2018-12-31T00:00:00 (Unknown Time Zone 'xxx'),\n 1921-01-02T00:00:00 (Unknown Time Zone 'xxx'),\n]", format!("{arr:?}") ); } @@ -2066,8 +2204,8 @@ mod tests { ]) .with_timezone("America/Denver".to_string()); assert_eq!( - "PrimitiveArray\n[\n 2022-03-13T01:59:59-07:00,\n 2022-03-13T03:00:00-06:00,\n 2022-11-06T00:59:59-06:00,\n 2022-11-06T01:00:00-06:00,\n]", - format!("{:?}", arr) + "PrimitiveArray\n[\n 2022-03-13T01:59:59-07:00,\n 2022-03-13T03:00:00-06:00,\n 2022-11-06T00:59:59-06:00,\n 2022-11-06T01:00:00-06:00,\n]", + format!("{arr:?}") ); } @@ -2084,7 +2222,7 @@ mod tests { fn test_time32second_fmt_debug() { let arr: PrimitiveArray = vec![7201, 60054].into(); assert_eq!( - "PrimitiveArray\n[\n 02:00:01,\n 16:40:54,\n]", + "PrimitiveArray\n[\n 02:00:01,\n 16:40:54,\n]", format!("{arr:?}") ); } @@ -2094,8 +2232,8 @@ mod tests { // chrono::NaiveDatetime::from_timestamp_opt returns None while input is invalid let arr: PrimitiveArray = vec![-7201, -60054].into(); assert_eq!( - "PrimitiveArray\n[\n Cast error: Failed to convert -7201 to temporal for Time32(Second),\n Cast error: Failed to convert -60054 to temporal for Time32(Second),\n]", - // "PrimitiveArray\n[\n null,\n null,\n]", + "PrimitiveArray\n[\n Cast error: Failed to convert -7201 to temporal for Time32(s),\n Cast error: Failed to convert -60054 to temporal for Time32(s),\n]", + // "PrimitiveArray\n[\n null,\n null,\n]", format!("{arr:?}") ) } @@ -2105,7 +2243,7 @@ mod tests { // replicate the issue from https://github.com/apache/arrow-datafusion/issues/3832 let arr: PrimitiveArray = vec![9065525203050843594].into(); assert_eq!( - "PrimitiveArray\n[\n null,\n]", + "PrimitiveArray\n[\n null,\n]", format!("{arr:?}") ) } @@ -2228,6 +2366,42 @@ mod tests { let _ = PrimitiveArray::::from(foo.into_data()); } + #[test] + fn test_decimal32() { + let values: Vec<_> = vec![0, 1, -1, i32::MIN, i32::MAX]; + let array: PrimitiveArray = + PrimitiveArray::from_iter(values.iter().copied()); + assert_eq!(array.values(), &values); + + let array: PrimitiveArray = + PrimitiveArray::from_iter_values(values.iter().copied()); + assert_eq!(array.values(), &values); + + let array = PrimitiveArray::::from(values.clone()); + assert_eq!(array.values(), &values); + + let array = PrimitiveArray::::from(array.to_data()); + assert_eq!(array.values(), &values); + } + + #[test] + fn test_decimal64() { + let values: Vec<_> = vec![0, 1, -1, i64::MIN, i64::MAX]; + let array: PrimitiveArray = + PrimitiveArray::from_iter(values.iter().copied()); + assert_eq!(array.values(), &values); + + let array: PrimitiveArray = + PrimitiveArray::from_iter_values(values.iter().copied()); + assert_eq!(array.values(), &values); + + let array = PrimitiveArray::::from(values.clone()); + assert_eq!(array.values(), &values); + + let array = PrimitiveArray::::from(array.to_data()); + assert_eq!(array.values(), &values); + } + #[test] fn test_decimal128() { let values: Vec<_> = vec![0, 1, -1, i128::MIN, i128::MAX]; @@ -2297,7 +2471,7 @@ mod tests { let result = arr.validate_decimal_precision(5); let error = result.unwrap_err(); assert_eq!( - "Invalid argument error: 123456 is too large to store in a Decimal128 of precision 5. Max is 99999", + "Invalid argument error: 123.456 is too large to store in a Decimal128 of precision 5. Max is 99.999", error.to_string() ); @@ -2316,7 +2490,7 @@ mod tests { let result = arr.validate_decimal_precision(2); let error = result.unwrap_err(); assert_eq!( - "Invalid argument error: 100 is too large to store in a Decimal128 of precision 2. Max is 99", + "Invalid argument error: 10.0 is too large to store in a Decimal128 of precision 2. Max is 9.9", error.to_string() ); } @@ -2402,7 +2576,7 @@ mod tests { #[test] #[should_panic( - expected = "-123223423432432 is too small to store in a Decimal128 of precision 5. Min is -99999" + expected = "-1232234234324.32 is too small to store in a Decimal128 of precision 5. Min is -999.99" )] fn test_decimal_array_with_precision_and_scale_out_of_range() { let arr = Decimal128Array::from_iter_values([12345, 456, 7890, -123223423432432]) @@ -2499,6 +2673,74 @@ mod tests { assert!(!array.is_null(2)); } + #[test] + fn test_decimal64_iter() { + let mut builder = Decimal64Builder::with_capacity(30); + let decimal1 = 12345; + builder.append_value(decimal1); + + builder.append_null(); + + let decimal2 = 56789; + builder.append_value(decimal2); + + let array: Decimal64Array = builder.finish().with_precision_and_scale(18, 4).unwrap(); + + let collected: Vec<_> = array.iter().collect(); + assert_eq!(vec![Some(decimal1), None, Some(decimal2)], collected); + } + + #[test] + fn test_from_iter_decimal64array() { + let value1 = 12345; + let value2 = 56789; + + let mut array: Decimal64Array = + vec![Some(value1), None, Some(value2)].into_iter().collect(); + array = array.with_precision_and_scale(18, 4).unwrap(); + assert_eq!(array.len(), 3); + assert_eq!(array.data_type(), &DataType::Decimal64(18, 4)); + assert_eq!(value1, array.value(0)); + assert!(!array.is_null(0)); + assert!(array.is_null(1)); + assert_eq!(value2, array.value(2)); + assert!(!array.is_null(2)); + } + + #[test] + fn test_decimal32_iter() { + let mut builder = Decimal32Builder::with_capacity(30); + let decimal1 = 12345; + builder.append_value(decimal1); + + builder.append_null(); + + let decimal2 = 56789; + builder.append_value(decimal2); + + let array: Decimal32Array = builder.finish().with_precision_and_scale(9, 2).unwrap(); + + let collected: Vec<_> = array.iter().collect(); + assert_eq!(vec![Some(decimal1), None, Some(decimal2)], collected); + } + + #[test] + fn test_from_iter_decimal32array() { + let value1 = 12345; + let value2 = 56789; + + let mut array: Decimal32Array = + vec![Some(value1), None, Some(value2)].into_iter().collect(); + array = array.with_precision_and_scale(9, 2).unwrap(); + assert_eq!(array.len(), 3); + assert_eq!(array.data_type(), &DataType::Decimal32(9, 2)); + assert_eq!(value1, array.value(0)); + assert!(!array.is_null(0)); + assert!(array.is_null(1)); + assert_eq!(value2, array.value(2)); + assert!(!array.is_null(2)); + } + #[test] fn test_unary_opt() { let array = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7]); @@ -2641,10 +2883,11 @@ mod tests { None, ] .into(); - let debug_str = format!("{:?}", array); - assert_eq!("PrimitiveArray\n[\n Cast error: Failed to convert -1 to temporal for Time32(Second),\n 00:00:00,\n 23:59:59,\n Cast error: Failed to convert 86400 to temporal for Time32(Second),\n Cast error: Failed to convert 86401 to temporal for Time32(Second),\n null,\n]", - debug_str - ); + let debug_str = format!("{array:?}"); + assert_eq!( + "PrimitiveArray\n[\n Cast error: Failed to convert -1 to temporal for Time32(s),\n 00:00:00,\n 23:59:59,\n Cast error: Failed to convert 86400 to temporal for Time32(s),\n Cast error: Failed to convert 86401 to temporal for Time32(s),\n null,\n]", + debug_str + ); } #[test] @@ -2658,8 +2901,9 @@ mod tests { None, ] .into(); - let debug_str = format!("{:?}", array); - assert_eq!("PrimitiveArray\n[\n Cast error: Failed to convert -1 to temporal for Time32(Millisecond),\n 00:00:00,\n 23:59:59,\n Cast error: Failed to convert 86400000 to temporal for Time32(Millisecond),\n Cast error: Failed to convert 86401000 to temporal for Time32(Millisecond),\n null,\n]", + let debug_str = format!("{array:?}"); + assert_eq!( + "PrimitiveArray\n[\n Cast error: Failed to convert -1 to temporal for Time32(ms),\n 00:00:00,\n 23:59:59,\n Cast error: Failed to convert 86400000 to temporal for Time32(ms),\n Cast error: Failed to convert 86401000 to temporal for Time32(ms),\n null,\n]", debug_str ); } @@ -2675,9 +2919,9 @@ mod tests { None, ] .into(); - let debug_str = format!("{:?}", array); + let debug_str = format!("{array:?}"); assert_eq!( - "PrimitiveArray\n[\n Cast error: Failed to convert -1 to temporal for Time64(Nanosecond),\n 00:00:00,\n 23:59:59,\n Cast error: Failed to convert 86400000000000 to temporal for Time64(Nanosecond),\n Cast error: Failed to convert 86401000000000 to temporal for Time64(Nanosecond),\n null,\n]", + "PrimitiveArray\n[\n Cast error: Failed to convert -1 to temporal for Time64(ns),\n 00:00:00,\n 23:59:59,\n Cast error: Failed to convert 86400000000000 to temporal for Time64(ns),\n Cast error: Failed to convert 86401000000000 to temporal for Time64(ns),\n null,\n]", debug_str ); } @@ -2693,8 +2937,11 @@ mod tests { None, ] .into(); - let debug_str = format!("{:?}", array); - assert_eq!("PrimitiveArray\n[\n Cast error: Failed to convert -1 to temporal for Time64(Microsecond),\n 00:00:00,\n 23:59:59,\n Cast error: Failed to convert 86400000000 to temporal for Time64(Microsecond),\n Cast error: Failed to convert 86401000000 to temporal for Time64(Microsecond),\n null,\n]", debug_str); + let debug_str = format!("{array:?}"); + assert_eq!( + "PrimitiveArray\n[\n Cast error: Failed to convert -1 to temporal for Time64(µs),\n 00:00:00,\n 23:59:59,\n Cast error: Failed to convert 86400000000 to temporal for Time64(µs),\n Cast error: Failed to convert 86401000000 to temporal for Time64(µs),\n null,\n]", + debug_str + ); } #[test] diff --git a/arrow-array/src/array/run_array.rs b/arrow-array/src/array/run_array.rs index 05cfa2d17135..9ca1af943d27 100644 --- a/arrow-array/src/array/run_array.rs +++ b/arrow-array/src/array/run_array.rs @@ -23,23 +23,22 @@ use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType, Field}; use crate::{ + Array, ArrayAccessor, ArrayRef, PrimitiveArray, builder::StringRunBuilder, make_array, run_iterator::RunArrayIter, types::{Int16Type, Int32Type, Int64Type, RunEndIndexType}, - Array, ArrayAccessor, ArrayRef, PrimitiveArray, }; -/// An array of [run-end encoded values](https://arrow.apache.org/docs/format/Columnar.html#run-end-encoded-layout) -/// -/// This encoding is variation on [run-length encoding (RLE)](https://en.wikipedia.org/wiki/Run-length_encoding) -/// and is good for representing data containing same values repeated consecutively. +/// An array of [run-end encoded values]. /// -/// [`RunArray`] contains `run_ends` array and `values` array of same length. -/// The `run_ends` array stores the indexes at which the run ends. The `values` array -/// stores the value of each run. Below example illustrates how a logical array is represented in -/// [`RunArray`] +/// This encoding is variation on [run-length encoding (RLE)] and is good for representing +/// data containing the same values repeated consecutively. /// +/// A [`RunArray`] consists of a `run_ends` buffer and a `values` array of equivalent +/// lengths. The `run_ends` buffer stores the indexes at which the run ends. The +/// `values` array stores the corresponding value of each run. The below example +/// illustrates how a logical array is represented by a [`RunArray`]: /// /// ```text /// ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┐ @@ -60,6 +59,9 @@ use crate::{ /// Logical array /// Contents /// ``` +/// +/// [run-end encoded values]: https://arrow.apache.org/docs/format/Columnar.html#run-end-encoded-layout +/// [run-length encoding (RLE)]: https://en.wikipedia.org/wiki/Run-length_encoding pub struct RunArray { data_type: DataType, run_ends: RunEndBuffer, @@ -77,8 +79,8 @@ impl Clone for RunArray { } impl RunArray { - /// Calculates the logical length of the array encoded - /// by the given run_ends array. + /// Calculates the logical length of the array encoded by treating the `run_ends` + /// array as if it were a [`RunEndBuffer`]. pub fn logical_len(run_ends: &PrimitiveArray) -> usize { let len = run_ends.len(); if len == 0 { @@ -87,9 +89,13 @@ impl RunArray { run_ends.value(len - 1).as_usize() } - /// Attempts to create RunArray using given run_ends (index where a run ends) - /// and the values (value of the run). Returns an error if the given data is not compatible - /// with RunEndEncoded specification. + /// Attempts to create a [`RunArray`] using the given `run_ends` and `values`. + /// + /// # Errors + /// + /// - If `run_ends` and `values` have different lengths + /// - If `run_ends` has any null values + /// - If `run_ends` doesn't consist of strictly increasing positive integers pub fn try_new(run_ends: &PrimitiveArray, values: &dyn Array) -> Result { let run_ends_type = run_ends.data_type().clone(); let values_type = values.data_type().clone(); @@ -117,25 +123,39 @@ impl RunArray { Ok(array_data.into()) } - /// Returns a reference to [`RunEndBuffer`] + /// Returns a reference to the [`RunEndBuffer`]. pub fn run_ends(&self) -> &RunEndBuffer { &self.run_ends } - /// Returns a reference to values array + /// Returns a reference to the values array. /// - /// Note: any slicing of this [`RunArray`] array is not applied to the returned array - /// and must be handled separately + /// Any slicing of this [`RunArray`] array is **not** applied to the returned + /// values here and must be handled separately. pub fn values(&self) -> &ArrayRef { &self.values } + /// Similar to [`values`] but accounts for logical slicing, returning only the values + /// that are part of the logical slice of this array. + /// + /// [`values`]: Self::values + pub fn values_slice(&self) -> ArrayRef { + let start = self.get_start_physical_index(); + let end = self.get_end_physical_index(); + self.values.slice(start, end - start + 1) + } + /// Returns the physical index at which the array slice starts. + /// + /// See [`RunEndBuffer::get_start_physical_index`]. pub fn get_start_physical_index(&self) -> usize { self.run_ends.get_start_physical_index() } /// Returns the physical index at which the array slice ends. + /// + /// See [`RunEndBuffer::get_end_physical_index`]. pub fn get_end_physical_index(&self) -> usize { self.run_ends.get_end_physical_index() } @@ -152,7 +172,6 @@ impl RunArray { /// assert_eq!(typed.value(1), "b"); /// assert!(typed.values().is_null(2)); /// ``` - /// pub fn downcast(&self) -> Option> { let values = self.values.as_any().downcast_ref()?; Some(TypedRunArray { @@ -161,89 +180,37 @@ impl RunArray { }) } - /// Returns index to the physical array for the given index to the logical array. - /// This function adjusts the input logical index based on `ArrayData::offset` - /// Performs a binary search on the run_ends array for the input index. + /// Calls [`RunEndBuffer::get_physical_index`]. /// /// The result is arbitrary if `logical_index >= self.len()` pub fn get_physical_index(&self, logical_index: usize) -> usize { self.run_ends.get_physical_index(logical_index) } - /// Returns the physical indices of the input logical indices. Returns error if any of the logical - /// index cannot be converted to physical index. The logical indices are sorted and iterated along - /// with run_ends array to find matching physical index. The approach used here was chosen over - /// finding physical index for each logical index using binary search using the function - /// `get_physical_index`. Running benchmarks on both approaches showed that the approach used here - /// scaled well for larger inputs. - /// See for more details. + /// Returns the physical indices corresponding to the provided logical indices. + /// + /// See [`RunEndBuffer::get_physical_indices`] for more details. #[inline] pub fn get_physical_indices(&self, logical_indices: &[I]) -> Result, ArrowError> where I: ArrowNativeType, { - let len = self.run_ends().len(); - let offset = self.run_ends().offset(); - - let indices_len = logical_indices.len(); - - if indices_len == 0 { - return Ok(vec![]); - } - - // `ordered_indices` store index into `logical_indices` and can be used - // to iterate `logical_indices` in sorted order. - let mut ordered_indices: Vec = (0..indices_len).collect(); - - // Instead of sorting `logical_indices` directly, sort the `ordered_indices` - // whose values are index of `logical_indices` - ordered_indices.sort_unstable_by(|lhs, rhs| { - logical_indices[*lhs] - .partial_cmp(&logical_indices[*rhs]) - .unwrap() - }); - - // Return early if all the logical indices cannot be converted to physical indices. - let largest_logical_index = logical_indices[*ordered_indices.last().unwrap()].as_usize(); - if largest_logical_index >= len { - return Err(ArrowError::InvalidArgumentError(format!( - "Cannot convert all logical indices to physical indices. The logical index cannot be converted is {largest_logical_index}.", - ))); - } - - // Skip some physical indices based on offset. - let skip_value = self.get_start_physical_index(); - - let mut physical_indices = vec![0; indices_len]; - - let mut ordered_index = 0_usize; - for (physical_index, run_end) in self.run_ends.values().iter().enumerate().skip(skip_value) - { - // Get the run end index (relative to offset) of current physical index - let run_end_value = run_end.as_usize() - offset; - - // All the `logical_indices` that are less than current run end index - // belongs to current physical index. - while ordered_index < indices_len - && logical_indices[ordered_indices[ordered_index]].as_usize() < run_end_value - { - physical_indices[ordered_indices[ordered_index]] = physical_index; - ordered_index += 1; - } - } - - // If there are input values >= run_ends.last_value then we'll not be able to convert - // all logical indices to physical indices. - if ordered_index < logical_indices.len() { - let logical_index = logical_indices[ordered_indices[ordered_index]].as_usize(); - return Err(ArrowError::InvalidArgumentError(format!( - "Cannot convert all logical indices to physical indices. The logical index cannot be converted is {logical_index}.", - ))); - } - Ok(physical_indices) + self.run_ends() + .get_physical_indices(logical_indices) + .map_err(|index| { + ArrowError::InvalidArgumentError(format!( + "Logical index {} is out of bounds for RunArray of length {}", + index.as_usize(), + self.len() + )) + }) } /// Returns a zero-copy slice of this array with the indicated offset and length. + /// + /// # Panics + /// + /// - Specified slice (`offset` + `length`) exceeds existing length pub fn slice(&self, offset: usize, length: usize) -> Self { Self { data_type: self.data_type.clone(), @@ -259,7 +226,9 @@ impl From for RunArray { match data.data_type() { DataType::RunEndEncoded(_, _) => {} _ => { - panic!("Invalid data type for RunArray. The data type should be DataType::RunEndEncoded"); + panic!( + "Invalid data type for RunArray. The data type should be DataType::RunEndEncoded" + ); } } @@ -301,6 +270,8 @@ impl From> for ArrayData { } } +impl super::private::Sealed for RunArray {} + impl Array for RunArray { fn as_any(&self) -> &dyn Any { self @@ -560,6 +531,8 @@ impl<'a, R: RunEndIndexType, V> TypedRunArray<'a, R, V> { } } +impl super::private::Sealed for TypedRunArray<'_, R, V> {} + impl Array for TypedRunArray<'_, R, V> { fn as_any(&self) -> &dyn Any { self.run_array @@ -641,7 +614,7 @@ where unsafe fn value_unchecked(&self, logical_index: usize) -> Self::Item { let physical_index = self.run_array.get_physical_index(logical_index); - self.values().value_unchecked(physical_index) + unsafe { self.values().value_unchecked(physical_index) } } } @@ -662,9 +635,9 @@ where #[cfg(test)] mod tests { + use rand::Rng; use rand::rng; use rand::seq::SliceRandom; - use rand::Rng; use super::*; use crate::builder::PrimitiveRunBuilder; @@ -1169,4 +1142,35 @@ mod tests { assert_eq!(array_i16_1, array_i16_2); } + + #[test] + fn test_run_array_values_slice() { + // 0, 0, 1, 1, 1, 2...2 (15 2s) + let run_ends: PrimitiveArray = vec![2, 5, 20].into(); + let values: PrimitiveArray = vec![0, 1, 2].into(); + let array = RunArray::::try_new(&run_ends, &values).unwrap(); + + let slice = array.slice(1, 4); // 0 | 1, 1, 1 | + // logical indices: 1, 2, 3, 4 + // physical indices: 0, 1, 1, 1 + // values at 0 is 0 + // values at 1 is 1 + // values slice should be [0, 1] + assert_eq!(slice.get_start_physical_index(), 0); + assert_eq!(slice.get_end_physical_index(), 1); + + let values_slice = slice.values_slice(); + let values_slice = values_slice.as_primitive::(); + assert_eq!(values_slice.values(), &[0, 1]); + + let slice2 = array.slice(2, 3); // 1, 1, 1 + // logical indices: 2, 3, 4 + // physical indices: 1, 1, 1 + assert_eq!(slice2.get_start_physical_index(), 1); + assert_eq!(slice2.get_end_physical_index(), 1); + + let values_slice2 = slice2.values_slice(); + let values_slice2 = values_slice2.as_primitive::(); + assert_eq!(values_slice2.values(), &[1]); + } } diff --git a/arrow-array/src/array/string_array.rs b/arrow-array/src/array/string_array.rs index ed70e5744fff..80f3153eceed 100644 --- a/arrow-array/src/array/string_array.rs +++ b/arrow-array/src/array/string_array.rs @@ -48,7 +48,7 @@ impl GenericStringArray { &'a self, indexes: impl Iterator> + 'a, ) -> impl Iterator> { - indexes.map(|opt_index| opt_index.map(|index| self.value_unchecked(index))) + indexes.map(|opt_index| opt_index.map(|index| unsafe { self.value_unchecked(index) })) } /// Fallibly creates a [`GenericStringArray`] from a [`GenericBinaryArray`] returning @@ -156,9 +156,9 @@ pub type LargeStringArray = GenericStringArray; #[cfg(test)] mod tests { use super::*; + use crate::Array; use crate::builder::{ListBuilder, PrimitiveBuilder, StringBuilder}; use crate::types::UInt8Type; - use crate::Array; use arrow_buffer::Buffer; use arrow_data::ArrayData; use arrow_schema::{DataType, Field}; diff --git a/arrow-array/src/array/struct_array.rs b/arrow-array/src/array/struct_array.rs index fbc34ef0c85b..a738a733218a 100644 --- a/arrow-array/src/array/struct_array.rs +++ b/arrow-array/src/array/struct_array.rs @@ -16,7 +16,7 @@ // under the License. use crate::array::print_long_array; -use crate::{make_array, new_null_array, Array, ArrayRef, RecordBatch}; +use crate::{Array, ArrayRef, RecordBatch, make_array, new_null_array}; use arrow_buffer::{BooleanBuffer, Buffer, NullBuffer}; use arrow_data::{ArrayData, ArrayDataBuilder}; use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields}; @@ -347,25 +347,26 @@ impl StructArray { impl From for StructArray { fn from(data: ArrayData) -> Self { - let parent_offset = data.offset(); - let parent_len = data.len(); + let (data_type, len, nulls, offset, _buffers, child_data) = data.into_parts(); - let fields = data - .child_data() - .iter() + let parent_offset = offset; + let parent_len = len; + + let fields = child_data + .into_iter() .map(|cd| { if parent_offset != 0 || parent_len != cd.len() { make_array(cd.slice(parent_offset, parent_len)) } else { - make_array(cd.clone()) + make_array(cd) } }) .collect(); Self { - len: data.len(), - data_type: data.data_type().clone(), - nulls: data.nulls().cloned(), + len, + data_type, + nulls, fields, } } @@ -401,6 +402,8 @@ impl TryFrom> for StructArray { } } +impl super::private::Sealed for StructArray {} + impl Array for StructArray { fn as_any(&self) -> &dyn Any { self @@ -922,7 +925,10 @@ mod tests { (0..30).map(|i| i % 2 == 0).collect::>(), ))), ); - assert_eq!(format!("{arr:?}"), "StructArray\n-- validity:\n[\n valid,\n null,\n valid,\n null,\n valid,\n null,\n valid,\n null,\n valid,\n null,\n ...10 elements...,\n valid,\n null,\n valid,\n null,\n valid,\n null,\n valid,\n null,\n valid,\n null,\n]\n[\n-- child 0: \"c\" (Int32)\nPrimitiveArray\n[\n 0,\n 1,\n 2,\n 3,\n 4,\n 5,\n 6,\n 7,\n 8,\n 9,\n ...10 elements...,\n 20,\n 21,\n 22,\n 23,\n 24,\n 25,\n 26,\n 27,\n 28,\n 29,\n]\n]") + assert_eq!( + format!("{arr:?}"), + "StructArray\n-- validity:\n[\n valid,\n null,\n valid,\n null,\n valid,\n null,\n valid,\n null,\n valid,\n null,\n ...10 elements...,\n valid,\n null,\n valid,\n null,\n valid,\n null,\n valid,\n null,\n valid,\n null,\n]\n[\n-- child 0: \"c\" (Int32)\nPrimitiveArray\n[\n 0,\n 1,\n 2,\n 3,\n 4,\n 5,\n 6,\n 7,\n 8,\n 9,\n ...10 elements...,\n 20,\n 21,\n 22,\n 23,\n 24,\n 25,\n 26,\n 27,\n 28,\n 29,\n]\n]" + ) } #[test] diff --git a/arrow-array/src/array/union_array.rs b/arrow-array/src/array/union_array.rs index 2afe9af47327..e08542bc8638 100644 --- a/arrow-array/src/array/union_array.rs +++ b/arrow-array/src/array/union_array.rs @@ -16,7 +16,7 @@ // under the License. #![allow(clippy::enum_clike_unportable_variant)] -use crate::{make_array, Array, ArrayRef}; +use crate::{Array, ArrayRef, make_array}; use arrow_buffer::bit_chunk_iterator::{BitChunkIterator, BitChunks}; use arrow_buffer::buffer::NullBuffer; use arrow_buffer::{BooleanBuffer, MutableBuffer, ScalarBuffer}; @@ -137,11 +137,11 @@ impl UnionArray { /// /// # Safety /// - /// The `type_ids` values should be positive and must match one of the type ids of the fields provided in `fields`. + /// The `type_ids` values should be non-negative and must match one of the type ids of the fields provided in `fields`. /// These values are used to index into the `children` arrays. /// /// The `offsets` is provided in the case of a dense union, sparse unions should use `None`. - /// If provided the `offsets` values should be positive and must be less than the length of the + /// If provided the `offsets` values should be non-negative and must be less than the length of the /// corresponding array. /// /// In both cases above we use signed integer types to maintain compatibility with other @@ -165,8 +165,8 @@ impl UnionArray { .len(len); let data = match offsets { - Some(offsets) => builder.add_buffer(offsets.into_inner()).build_unchecked(), - None => builder.build_unchecked(), + Some(offsets) => unsafe { builder.add_buffer(offsets.into_inner()).build_unchecked() }, + None => unsafe { builder.build_unchecked() }, }; Self::from(data) } @@ -219,7 +219,7 @@ impl UnionArray { _ => { return Err(ArrowError::InvalidArgumentError( "Type Ids values must match one of the field type ids".to_owned(), - )) + )); } } } @@ -230,7 +230,7 @@ impl UnionArray { if iter.any(|(type_id, &offset)| offset < 0 || offset >= array_lens[*type_id as usize]) { return Err(ArrowError::InvalidArgumentError( - "Offsets must be positive and within the length of the Array".to_owned(), + "Offsets must be non-negative and within the length of the Array".to_owned(), )); } } @@ -287,6 +287,10 @@ impl UnionArray { } /// Returns the array's value at index `i`. + /// + /// Note: This method does not check for nulls and the value is arbitrary + /// (but still well-defined) if [`is_null`](Self::is_null) returns true for the index. + /// /// # Panics /// Panics if index `i` is out of bounds pub fn value(&self, i: usize) -> ArrayRef { @@ -307,8 +311,16 @@ impl UnionArray { } } + /// Returns the [`UnionFields`] for the union. + pub fn fields(&self) -> &UnionFields { + match self.data_type() { + DataType::Union(fields, _) => fields, + _ => unreachable!("Union array's data type is not a union!"), + } + } + /// Returns whether the `UnionArray` is dense (or sparse if `false`). - fn is_dense(&self) -> bool { + pub fn is_dense(&self) -> bool { match self.data_type() { DataType::Union(_, mode) => mode == &UnionMode::Dense, _ => unreachable!("Union array's data type is not a union!"), @@ -726,6 +738,8 @@ impl From for ArrayData { } } +impl super::private::Sealed for UnionArray {} + impl Array for UnionArray { fn as_any(&self) -> &dyn Any { self @@ -781,13 +795,18 @@ impl Array for UnionArray { }; if fields.len() <= 1 { - return self - .fields - .iter() - .flatten() - .map(Array::logical_nulls) - .next() - .flatten(); + return self.fields.iter().find_map(|field_opt| { + field_opt + .as_ref() + .and_then(|field| field.logical_nulls()) + .map(|logical_nulls| { + if self.is_dense() { + self.gather_nulls(vec![(0, logical_nulls)]).into() + } else { + logical_nulls + } + }) + }); } let logical_nulls = self.fields_logical_nulls(); @@ -940,7 +959,7 @@ impl std::fmt::Debug for UnionArray { if let Some(offsets) = &self.offsets { writeln!(f, "-- offsets buffer:")?; - writeln!(f, "{:?}", offsets)?; + writeln!(f, "{offsets:?}")?; } let fields = match self.data_type() { @@ -1074,6 +1093,30 @@ mod tests { } } + #[test] + fn slice_union_array_single_field() { + // Dense Union + // [1, null, 3, null, 4] + let union_array = { + let mut builder = UnionBuilder::new_dense(); + builder.append::("a", 1).unwrap(); + builder.append_null::("a").unwrap(); + builder.append::("a", 3).unwrap(); + builder.append_null::("a").unwrap(); + builder.append::("a", 4).unwrap(); + builder.build().unwrap() + }; + + // [null, 3, null] + let union_slice = union_array.slice(1, 3); + let logical_nulls = union_slice.logical_nulls().unwrap(); + + assert_eq!(logical_nulls.len(), 3); + assert!(logical_nulls.is_null(0)); + assert!(logical_nulls.is_valid(1)); + assert!(logical_nulls.is_null(2)); + } + #[test] #[cfg_attr(miri, ignore)] fn test_dense_i32_large() { @@ -1641,14 +1684,15 @@ mod tests { #[test] fn test_custom_type_ids() { let data_type = DataType::Union( - UnionFields::new( + UnionFields::try_new( vec![8, 4, 9], vec![ Field::new("strings", DataType::Utf8, false), Field::new("integers", DataType::Int32, false), Field::new("floats", DataType::Float64, false), ], - ), + ) + .unwrap(), UnionMode::Dense, ); @@ -1755,14 +1799,15 @@ mod tests { fn into_parts_custom_type_ids() { let set_field_type_ids: [i8; 3] = [8, 4, 9]; let data_type = DataType::Union( - UnionFields::new( + UnionFields::try_new( set_field_type_ids, [ Field::new("strings", DataType::Utf8, false), Field::new("integers", DataType::Int32, false), Field::new("floats", DataType::Float64, false), ], - ), + ) + .unwrap(), UnionMode::Dense, ); let string_array = StringArray::from(vec!["foo", "bar", "baz"]); @@ -1795,13 +1840,14 @@ mod tests { #[test] fn test_invalid() { - let fields = UnionFields::new( + let fields = UnionFields::try_new( [3, 2], [ Field::new("a", DataType::Utf8, false), Field::new("b", DataType::Utf8, false), ], - ); + ) + .unwrap(); let children = vec![ Arc::new(StringArray::from_iter_values(["a", "b"])) as _, Arc::new(StringArray::from_iter_values(["c", "d"])) as _, @@ -1844,7 +1890,7 @@ mod tests { assert_eq!( err.to_string(), - "Invalid argument error: Offsets must be positive and within the length of the Array" + "Invalid argument error: Offsets must be non-negative and within the length of the Array" ); let offsets = Some(vec![0, 1].into()); @@ -1871,13 +1917,14 @@ mod tests { assert_eq!(array.logical_nulls(), None); - let fields = UnionFields::new( + let fields = UnionFields::try_new( [1, 3], [ Field::new("a", DataType::Int8, false), // non nullable Field::new("b", DataType::Int8, false), // non nullable ], - ); + ) + .unwrap(); let array = UnionArray::try_new( fields, vec![1].into(), @@ -1891,13 +1938,14 @@ mod tests { assert_eq!(array.logical_nulls(), None); - let nullable_fields = UnionFields::new( + let nullable_fields = UnionFields::try_new( [1, 3], [ Field::new("a", DataType::Int8, true), // nullable but without nulls Field::new("b", DataType::Int8, true), // nullable but without nulls ], - ); + ) + .unwrap(); let array = UnionArray::try_new( nullable_fields.clone(), vec![1, 1].into(), diff --git a/arrow-array/src/builder/boolean_builder.rs b/arrow-array/src/builder/boolean_builder.rs index a0bd5745d21d..275aa8c9e56a 100644 --- a/arrow-array/src/builder/boolean_builder.rs +++ b/arrow-array/src/builder/boolean_builder.rs @@ -234,9 +234,12 @@ impl ArrayBuilder for BooleanBuilder { impl Extend> for BooleanBuilder { #[inline] fn extend>>(&mut self, iter: T) { - for v in iter { - self.append_option(v) - } + let buffered = iter.into_iter().collect::>(); + let array = unsafe { + // SAFETY: std::vec::IntoIter implements TrustedLen + BooleanArray::from_trusted_len_iter(buffered.into_iter()) + }; + self.append_array(&array) } } diff --git a/arrow-array/src/builder/buffer_builder.rs b/arrow-array/src/builder/buffer_builder.rs index c0cabb1f7353..d183aae86551 100644 --- a/arrow-array/src/builder/buffer_builder.rs +++ b/arrow-array/src/builder/buffer_builder.rs @@ -45,6 +45,10 @@ pub type Float32BufferBuilder = BufferBuilder; /// Buffer builder for 64-bit floating point type. pub type Float64BufferBuilder = BufferBuilder; +/// Buffer builder for 32-bit decimal type. +pub type Decimal32BufferBuilder = BufferBuilder<::Native>; +/// Buffer builder for 64-bit decimal type. +pub type Decimal64BufferBuilder = BufferBuilder<::Native>; /// Buffer builder for 128-bit decimal type. pub type Decimal128BufferBuilder = BufferBuilder<::Native>; /// Buffer builder for 256-bit decimal type. @@ -106,8 +110,8 @@ pub type DurationNanosecondBufferBuilder = #[cfg(test)] mod tests { - use crate::builder::{ArrayBuilder, Int32BufferBuilder, Int8Builder, UInt8BufferBuilder}; use crate::Array; + use crate::builder::{ArrayBuilder, Int8Builder, Int32BufferBuilder, UInt8BufferBuilder}; #[test] fn test_builder_i32_empty() { diff --git a/arrow-array/src/builder/fixed_size_binary_builder.rs b/arrow-array/src/builder/fixed_size_binary_builder.rs index b5f268917c92..f6b4c33d9454 100644 --- a/arrow-array/src/builder/fixed_size_binary_builder.rs +++ b/arrow-array/src/builder/fixed_size_binary_builder.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::builder::{ArrayBuilder, UInt8BufferBuilder}; +use crate::array::Array; +use crate::builder::ArrayBuilder; use crate::{ArrayRef, FixedSizeBinaryArray}; use arrow_buffer::Buffer; use arrow_buffer::NullBufferBuilder; @@ -42,7 +43,7 @@ use std::sync::Arc; /// ``` #[derive(Debug)] pub struct FixedSizeBinaryBuilder { - values_builder: UInt8BufferBuilder, + values_builder: Vec, null_buffer_builder: NullBufferBuilder, value_length: i32, } @@ -61,7 +62,7 @@ impl FixedSizeBinaryBuilder { "value length ({byte_width}) of the array must >= 0" ); Self { - values_builder: UInt8BufferBuilder::new(capacity * byte_width as usize), + values_builder: Vec::with_capacity(capacity * byte_width as usize), null_buffer_builder: NullBufferBuilder::new(capacity), value_length: byte_width, } @@ -79,7 +80,7 @@ impl FixedSizeBinaryBuilder { .to_string(), )) } else { - self.values_builder.append_slice(value.as_ref()); + self.values_builder.extend_from_slice(value.as_ref()); self.null_buffer_builder.append_non_null(); Ok(()) } @@ -89,7 +90,7 @@ impl FixedSizeBinaryBuilder { #[inline] pub fn append_null(&mut self) { self.values_builder - .append_slice(&vec![0u8; self.value_length as usize][..]); + .extend(std::iter::repeat_n(0u8, self.value_length as usize)); self.null_buffer_builder.append_null(); } @@ -97,10 +98,27 @@ impl FixedSizeBinaryBuilder { #[inline] pub fn append_nulls(&mut self, n: usize) { self.values_builder - .append_slice(&vec![0u8; self.value_length as usize * n][..]); + .extend(std::iter::repeat_n(0u8, self.value_length as usize * n)); self.null_buffer_builder.append_n_nulls(n); } + /// Appends all elements in array into the builder. + pub fn append_array(&mut self, array: &FixedSizeBinaryArray) -> Result<(), ArrowError> { + if self.value_length != array.value_length() { + return Err(ArrowError::InvalidArgumentError( + "Cannot append FixedSizeBinaryArray with different value length".to_string(), + )); + } + let buffer = array.value_data(); + self.values_builder.extend_from_slice(buffer); + if let Some(validity) = array.nulls() { + self.null_buffer_builder.append_buffer(validity); + } else { + self.null_buffer_builder.append_n_non_nulls(array.len()); + } + Ok(()) + } + /// Returns the current values buffer as a slice pub fn values_slice(&self) -> &[u8] { self.values_builder.as_slice() @@ -110,7 +128,7 @@ impl FixedSizeBinaryBuilder { pub fn finish(&mut self) -> FixedSizeBinaryArray { let array_length = self.len(); let array_data_builder = ArrayData::builder(DataType::FixedSizeBinary(self.value_length)) - .add_buffer(self.values_builder.finish()) + .add_buffer(std::mem::take(&mut self.values_builder).into()) .nulls(self.null_buffer_builder.finish()) .len(array_length); let array_data = unsafe { array_data_builder.build_unchecked() }; @@ -270,4 +288,45 @@ mod tests { fn test_fixed_size_binary_builder_invalid_value_length() { let _ = FixedSizeBinaryBuilder::with_capacity(15, -1); } + + #[test] + fn test_fixed_size_binary_builder_append_array() { + let mut other_builder = FixedSizeBinaryBuilder::with_capacity(3, 5); + other_builder.append_value(b"hello").unwrap(); + other_builder.append_null(); + other_builder.append_value(b"arrow").unwrap(); + let other_array = other_builder.finish(); + + let mut builder = FixedSizeBinaryBuilder::with_capacity(6, 5); + builder.append_array(&other_array).unwrap(); + // Append again to test if breaks when appending multiple times + builder.append_array(&other_array).unwrap(); + let array = builder.finish(); + + assert_eq!(array.value_length(), other_array.value_length()); + assert_eq!(&DataType::FixedSizeBinary(5), array.data_type()); + assert_eq!(6, array.len()); + assert_eq!(2, array.null_count()); + for i in 0..6 { + assert_eq!(i * 5, array.value_offset(i as usize)); + } + + assert_eq!(b"hello", array.value(0)); + assert!(array.is_null(1)); + assert_eq!(b"arrow", array.value(2)); + + assert_eq!(b"hello", array.value(3)); + assert!(array.is_null(4)); + assert_eq!(b"arrow", array.value(5)); + } + + #[test] + #[should_panic(expected = "Cannot append FixedSizeBinaryArray with different value length")] + fn test_fixed_size_binary_builder_append_array_invalid_value_length() { + let mut other_builder = FixedSizeBinaryBuilder::with_capacity(3, 4); + other_builder.append_value(b"test").unwrap(); + let other_array = other_builder.finish(); + let mut builder = FixedSizeBinaryBuilder::with_capacity(3, 5); + builder.append_array(&other_array).unwrap(); + } } diff --git a/arrow-array/src/builder/fixed_size_binary_dictionary_builder.rs b/arrow-array/src/builder/fixed_size_binary_dictionary_builder.rs index f3460353b164..fa3066b7e11e 100644 --- a/arrow-array/src/builder/fixed_size_binary_dictionary_builder.rs +++ b/arrow-array/src/builder/fixed_size_binary_dictionary_builder.rs @@ -17,11 +17,12 @@ use crate::builder::{ArrayBuilder, FixedSizeBinaryBuilder, PrimitiveBuilder}; use crate::types::ArrowDictionaryKeyType; -use crate::{Array, ArrayRef, DictionaryArray}; +use crate::{Array, ArrayRef, DictionaryArray, PrimitiveArray}; use arrow_buffer::ArrowNativeType; use arrow_schema::DataType::FixedSizeBinary; use arrow_schema::{ArrowError, DataType}; use hashbrown::HashTable; +use num_traits::NumCast; use std::any::Any; use std::sync::Arc; @@ -100,6 +101,71 @@ where byte_width, } } + + /// Creates a new `FixedSizeBinaryDictionaryBuilder` from the existing builder with the same + /// keys and values, but with a new data type for the keys. + /// + /// # Example + /// ``` + /// # use arrow_array::builder::FixedSizeBinaryDictionaryBuilder; + /// # use arrow_array::types::{UInt8Type, UInt16Type, UInt64Type}; + /// # use arrow_array::UInt16Array; + /// # use arrow_schema::ArrowError; + /// + /// let mut u8_keyed_builder = FixedSizeBinaryDictionaryBuilder::::new(2); + /// // appending too many values causes the dictionary to overflow + /// for i in 0..=255 { + /// u8_keyed_builder.append_value(vec![0, i]); + /// } + /// let result = u8_keyed_builder.append(vec![1, 0]); + /// assert!(matches!(result, Err(ArrowError::DictionaryKeyOverflowError{}))); + /// + /// // we need to upgrade to a larger key type + /// let mut u16_keyed_builder = FixedSizeBinaryDictionaryBuilder::::try_new_from_builder(u8_keyed_builder).unwrap(); + /// let dictionary_array = u16_keyed_builder.finish(); + /// let keys = dictionary_array.keys(); + /// + /// assert_eq!(keys, &UInt16Array::from_iter(0..256)); + /// ``` + pub fn try_new_from_builder( + mut source: FixedSizeBinaryDictionaryBuilder, + ) -> Result + where + K::Native: NumCast, + K2: ArrowDictionaryKeyType, + K2::Native: NumCast, + { + let state = source.state; + let dedup = source.dedup; + let values_builder = source.values_builder; + let byte_width = source.byte_width; + + let source_keys = source.keys_builder.finish(); + let new_keys: PrimitiveArray = source_keys.try_unary(|value| { + num_traits::cast::cast::(value).ok_or_else(|| { + ArrowError::CastError(format!( + "Can't cast dictionary keys from source type {:?} to type {:?}", + K2::DATA_TYPE, + K::DATA_TYPE + )) + }) + })?; + + // drop source key here because currently source_keys and new_keys are holding reference to + // the same underlying null_buffer. Below we want to call new_keys.into_builder() it must + // be the only reference holder. + drop(source_keys); + + Ok(Self { + state, + dedup, + keys_builder: new_keys + .into_builder() + .expect("underlying buffer has no references"), + values_builder, + byte_width, + }) + } } impl ArrayBuilder for FixedSizeBinaryDictionaryBuilder @@ -186,6 +252,28 @@ where } } + /// Append a value multiple times to the array. + /// This is the same as [`Self::append`] but allows to append the same value multiple times without doing multiple lookups. + /// + /// Returns an error if the new index would overflow the key type. + pub fn append_n( + &mut self, + value: impl AsRef<[u8]>, + count: usize, + ) -> Result { + if self.byte_width != value.as_ref().len() as i32 { + Err(ArrowError::InvalidArgumentError(format!( + "Invalid input length passed to FixedSizeBinaryBuilder. Expected {} got {}", + self.byte_width, + value.as_ref().len() + ))) + } else { + let key = self.get_or_insert_key(value)?; + self.keys_builder.append_value_n(key, count); + Ok(key) + } + } + /// Appends a null slot into the builder #[inline] pub fn append_null(&mut self) { @@ -245,6 +333,41 @@ where DictionaryArray::from(unsafe { builder.build_unchecked() }) } + + /// Builds the `DictionaryArray` without resetting the values builder or + /// the internal de-duplication map. + /// + /// The advantage of doing this is that the values will represent the entire + /// set of what has been built so-far by this builder and ensures + /// consistency in the assignment of keys to values across multiple calls + /// to `finish_preserve_values`. This enables ipc writers to efficiently + /// emit delta dictionaries. + /// + /// The downside to this is that building the record requires creating a + /// copy of the values, which can become slowly more expensive if the + /// dictionary grows. + /// + /// Additionally, if record batches from multiple different dictionary + /// builders for the same column are fed into a single ipc writer, beware + /// that entire dictionaries are likely to be re-sent frequently even when + /// the majority of the values are not used by the current record batch. + pub fn finish_preserve_values(&mut self) -> DictionaryArray { + let values = self.values_builder.finish_cloned(); + let keys = self.keys_builder.finish(); + + let data_type = DataType::Dictionary( + Box::new(K::DATA_TYPE), + Box::new(FixedSizeBinary(self.byte_width)), + ); + + let builder = keys + .into_data() + .into_builder() + .data_type(data_type) + .child_data(vec![values.into_data()]); + + DictionaryArray::from(unsafe { builder.build_unchecked() }) + } } fn get_bytes(values: &FixedSizeBinaryBuilder, byte_width: i32, idx: usize) -> &[u8] { @@ -258,8 +381,8 @@ fn get_bytes(values: &FixedSizeBinaryBuilder, byte_width: i32, idx: usize) -> &[ mod tests { use super::*; - use crate::types::Int8Type; - use crate::{FixedSizeBinaryArray, Int8Array}; + use crate::types::{Int8Type, Int16Type, Int32Type, UInt8Type, UInt16Type}; + use crate::{ArrowPrimitiveType, FixedSizeBinaryArray, Int8Array}; #[test] fn test_fixed_size_dictionary_builder() { @@ -300,13 +423,57 @@ mod tests { assert_eq!(ava.value(1), values[1].as_bytes()); } + #[test] + fn test_fixed_size_dictionary_builder_append_n() { + let values = ["abc", "def"]; + let mut b = FixedSizeBinaryDictionaryBuilder::::new(3); + assert_eq!(b.append_n(values[0], 2).unwrap(), 0); + assert_eq!(b.append_n(values[1], 3).unwrap(), 1); + assert_eq!(b.append_n(values[0], 2).unwrap(), 0); + let array = b.finish(); + + assert_eq!( + array.keys(), + &Int8Array::from(vec![ + Some(0), + Some(0), + Some(1), + Some(1), + Some(1), + Some(0), + Some(0), + ]), + ); + + // Values are polymorphic and so require a downcast. + let ava = array + .values() + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!(ava.value(0), values[0].as_bytes()); + assert_eq!(ava.value(1), values[1].as_bytes()); + } + #[test] fn test_fixed_size_dictionary_builder_wrong_size() { let mut b = FixedSizeBinaryDictionaryBuilder::::new(3); let err = b.append(b"too long").unwrap_err().to_string(); - assert_eq!(err, "Invalid argument error: Invalid input length passed to FixedSizeBinaryBuilder. Expected 3 got 8"); + assert_eq!( + err, + "Invalid argument error: Invalid input length passed to FixedSizeBinaryBuilder. Expected 3 got 8" + ); let err = b.append("").unwrap_err().to_string(); - assert_eq!(err, "Invalid argument error: Invalid input length passed to FixedSizeBinaryBuilder. Expected 3 got 0"); + assert_eq!( + err, + "Invalid argument error: Invalid input length passed to FixedSizeBinaryBuilder. Expected 3 got 0" + ); + let err = b.append_n("a", 3).unwrap_err().to_string(); + assert_eq!( + err, + "Invalid argument error: Invalid input length passed to FixedSizeBinaryBuilder. Expected 3 got 1" + ); } #[test] @@ -368,4 +535,136 @@ mod tests { assert_eq!(ava2.value(1), values[1].as_bytes()); assert_eq!(ava2.value(2), values[2].as_bytes()); } + + fn _test_try_new_from_builder_generic_for_key_types(values: Vec<[u8; 3]>) + where + K1: ArrowDictionaryKeyType, + K1::Native: NumCast, + K2: ArrowDictionaryKeyType, + K2::Native: NumCast + From, + { + let mut source = FixedSizeBinaryDictionaryBuilder::::new(3); + source.append_value(values[0]); + source.append_null(); + source.append_value(values[1]); + source.append_value(values[2]); + + let mut result = + FixedSizeBinaryDictionaryBuilder::::try_new_from_builder(source).unwrap(); + let array = result.finish(); + + let mut expected_keys_builder = PrimitiveBuilder::::new(); + expected_keys_builder + .append_value(<::Native as From>::from(0u8)); + expected_keys_builder.append_null(); + expected_keys_builder + .append_value(<::Native as From>::from(1u8)); + expected_keys_builder + .append_value(<::Native as From>::from(2u8)); + let expected_keys = expected_keys_builder.finish(); + assert_eq!(array.keys(), &expected_keys); + + let av = array.values(); + let ava = av.as_any().downcast_ref::().unwrap(); + assert_eq!(ava.value(0), values[0]); + assert_eq!(ava.value(1), values[1]); + assert_eq!(ava.value(2), values[2]); + } + + #[test] + fn test_try_new_from_builder() { + let values = vec![[1, 2, 3], [5, 6, 7], [6, 7, 8]]; + // test cast to bigger size unsigned + _test_try_new_from_builder_generic_for_key_types::(values.clone()); + // test cast going to smaller size unsigned + _test_try_new_from_builder_generic_for_key_types::(values.clone()); + // test cast going to bigger size signed + _test_try_new_from_builder_generic_for_key_types::(values.clone()); + // test cast going to smaller size signed + _test_try_new_from_builder_generic_for_key_types::(values.clone()); + // test going from signed to signed for different size changes + _test_try_new_from_builder_generic_for_key_types::(values.clone()); + _test_try_new_from_builder_generic_for_key_types::(values.clone()); + _test_try_new_from_builder_generic_for_key_types::(values.clone()); + _test_try_new_from_builder_generic_for_key_types::(values.clone()); + } + + #[test] + fn test_try_new_from_builder_cast_fails() { + let mut source_builder = FixedSizeBinaryDictionaryBuilder::::new(2); + for i in 0u16..257u16 { + source_builder.append_value(vec![(i >> 8) as u8, i as u8]); + } + + // there should be too many values that we can't downcast to the underlying type + // we have keys that wouldn't fit into UInt8Type + let result = + FixedSizeBinaryDictionaryBuilder::::try_new_from_builder(source_builder); + assert!(result.is_err()); + if let Err(e) = result { + assert!(matches!(e, ArrowError::CastError(_))); + assert_eq!( + e.to_string(), + "Cast error: Can't cast dictionary keys from source type UInt16 to type UInt8" + ); + } + } + + #[test] + fn test_finish_preserve_values() { + // Create the first dictionary + let mut builder = FixedSizeBinaryDictionaryBuilder::::new(3); + builder.append_value("aaa"); + builder.append_value("bbb"); + builder.append_value("ccc"); + let dict = builder.finish_preserve_values(); + assert_eq!(dict.keys().values(), &[0, 1, 2]); + let values = dict + .downcast_dict::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!( + values, + vec![ + Some("aaa".as_bytes()), + Some("bbb".as_bytes()), + Some("ccc".as_bytes()) + ] + ); + + // Create a new dictionary + builder.append_value("ddd"); + builder.append_value("eee"); + let dict2 = builder.finish_preserve_values(); + + // Make sure the keys are assigned after the old ones and we have the + // right values + assert_eq!(dict2.keys().values(), &[3, 4]); + let values = dict2 + .downcast_dict::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!(values, [Some("ddd".as_bytes()), Some("eee".as_bytes())]); + + // Check that we have all of the expected values + let all_values = dict2 + .values() + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!( + all_values, + [ + Some("aaa".as_bytes()), + Some("bbb".as_bytes()), + Some("ccc".as_bytes()), + Some("ddd".as_bytes()), + Some("eee".as_bytes()) + ] + ); + } } diff --git a/arrow-array/src/builder/fixed_size_list_builder.rs b/arrow-array/src/builder/fixed_size_list_builder.rs index 5c142b277d14..6eb48fc0527c 100644 --- a/arrow-array/src/builder/fixed_size_list_builder.rs +++ b/arrow-array/src/builder/fixed_size_list_builder.rs @@ -172,7 +172,8 @@ where let nulls = self.null_buffer_builder.finish(); assert_eq!( - values.len(), len * self.list_len as usize, + values.len(), + len * self.list_len as usize, "Length of the child array ({}) must be the multiple of the value length ({}) and the array length ({}).", values.len(), self.list_len, @@ -194,7 +195,8 @@ where let nulls = self.null_buffer_builder.finish_cloned(); assert_eq!( - values.len(), len * self.list_len as usize, + values.len(), + len * self.list_len as usize, "Length of the child array ({}) must be the multiple of the value length ({}) and the array length ({}).", values.len(), self.list_len, @@ -220,9 +222,9 @@ mod tests { use super::*; use arrow_schema::DataType; - use crate::builder::Int32Builder; use crate::Array; use crate::Int32Array; + use crate::builder::Int32Builder; fn make_list_builder( include_null_element: bool, diff --git a/arrow-array/src/builder/generic_byte_run_builder.rs b/arrow-array/src/builder/generic_byte_run_builder.rs index 0bf5658b297e..18544f7e75c9 100644 --- a/arrow-array/src/builder/generic_byte_run_builder.rs +++ b/arrow-array/src/builder/generic_byte_run_builder.rs @@ -19,8 +19,8 @@ use crate::types::bytes::ByteArrayNativeType; use std::{any::Any, sync::Arc}; use crate::{ - types::{BinaryType, ByteArrayType, LargeBinaryType, LargeUtf8Type, RunEndIndexType, Utf8Type}, ArrayRef, ArrowPrimitiveType, RunArray, + types::{BinaryType, ByteArrayType, LargeBinaryType, LargeUtf8Type, RunEndIndexType, Utf8Type}, }; use super::{ArrayBuilder, GenericByteBuilder, PrimitiveBuilder}; @@ -375,11 +375,11 @@ pub type LargeBinaryRunBuilder = GenericByteRunBuilder; mod tests { use super::*; + use crate::GenericByteArray; + use crate::Int16RunArray; use crate::array::Array; use crate::cast::AsArray; use crate::types::{Int16Type, Int32Type}; - use crate::GenericByteArray; - use crate::Int16RunArray; fn test_bytes_run_builder(values: Vec<&T::Native>) where diff --git a/arrow-array/src/builder/generic_bytes_builder.rs b/arrow-array/src/builder/generic_bytes_builder.rs index 91ac2a483ef4..7ed4bc5826c0 100644 --- a/arrow-array/src/builder/generic_bytes_builder.rs +++ b/arrow-array/src/builder/generic_bytes_builder.rs @@ -15,12 +15,12 @@ // specific language governing permissions and limitations // under the License. -use crate::builder::{ArrayBuilder, BufferBuilder, UInt8BufferBuilder}; +use crate::builder::ArrayBuilder; use crate::types::{ByteArrayType, GenericBinaryType, GenericStringType}; use crate::{Array, ArrayRef, GenericByteArray, OffsetSizeTrait}; -use arrow_buffer::NullBufferBuilder; -use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer}; +use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer, NullBufferBuilder, ScalarBuffer}; use arrow_data::ArrayDataBuilder; +use arrow_schema::ArrowError; use std::any::Any; use std::sync::Arc; @@ -29,8 +29,8 @@ use std::sync::Arc; /// For building strings, see docs on [`GenericStringBuilder`]. /// For building binary, see docs on [`GenericBinaryBuilder`]. pub struct GenericByteBuilder { - value_builder: UInt8BufferBuilder, - offsets_builder: BufferBuilder, + value_builder: Vec, + offsets_builder: Vec, null_buffer_builder: NullBufferBuilder, } @@ -47,10 +47,10 @@ impl GenericByteBuilder { /// - `data_capacity` is the total number of bytes of data to pre-allocate /// (for all items, not per item). pub fn with_capacity(item_capacity: usize, data_capacity: usize) -> Self { - let mut offsets_builder = BufferBuilder::::new(item_capacity + 1); - offsets_builder.append(T::Offset::from_usize(0).unwrap()); + let mut offsets_builder = Vec::with_capacity(item_capacity + 1); + offsets_builder.push(T::Offset::from_usize(0).unwrap()); Self { - value_builder: UInt8BufferBuilder::new(data_capacity), + value_builder: Vec::with_capacity(data_capacity), offsets_builder, null_buffer_builder: NullBufferBuilder::new(item_capacity), } @@ -67,8 +67,9 @@ impl GenericByteBuilder { value_buffer: MutableBuffer, null_buffer: Option, ) -> Self { - let offsets_builder = BufferBuilder::::new_from_buffer(offsets_buffer); - let value_builder = BufferBuilder::::new_from_buffer(value_buffer); + let offsets_builder: Vec = + ScalarBuffer::::from(offsets_buffer).into(); + let value_builder: Vec = ScalarBuffer::::from(value_buffer).into(); let null_buffer_builder = null_buffer .map(|buffer| NullBufferBuilder::new_from_buffer(buffer, offsets_builder.len() - 1)) @@ -103,9 +104,10 @@ impl GenericByteBuilder { /// [`BinaryArray`]: crate::BinaryArray #[inline] pub fn append_value(&mut self, value: impl AsRef) { - self.value_builder.append_slice(value.as_ref().as_ref()); + self.value_builder + .extend_from_slice(value.as_ref().as_ref()); self.null_buffer_builder.append(true); - self.offsets_builder.append(self.next_offset()); + self.offsets_builder.push(self.next_offset()); } /// Append an `Option` value into the builder. @@ -126,7 +128,7 @@ impl GenericByteBuilder { #[inline] pub fn append_null(&mut self) { self.null_buffer_builder.append(false); - self.offsets_builder.append(self.next_offset()); + self.offsets_builder.push(self.next_offset()); } /// Appends `n` `null`s into the builder. @@ -134,15 +136,17 @@ impl GenericByteBuilder { pub fn append_nulls(&mut self, n: usize) { self.null_buffer_builder.append_n_nulls(n); let next_offset = self.next_offset(); - self.offsets_builder.append_n(n, next_offset); + self.offsets_builder + .extend(std::iter::repeat_n(next_offset, n)); } /// Appends array values and null to this builder as is /// (this means that underlying null values are copied as is). #[inline] - pub fn append_array(&mut self, array: &GenericByteArray) { + pub fn append_array(&mut self, array: &GenericByteArray) -> Result<(), ArrowError> { + use num_traits::CheckedAdd; if array.len() == 0 { - return; + return Ok(()); } let offsets = array.offsets(); @@ -150,25 +154,23 @@ impl GenericByteBuilder { // If the offsets are contiguous, we can append them directly avoiding the need to align // for example, when the first appended array is not sliced (starts at offset 0) if self.next_offset() == offsets[0] { - self.offsets_builder.append_slice(&offsets[1..]); + self.offsets_builder.extend_from_slice(&offsets[1..]); } else { // Shifting all the offsets let shift: T::Offset = self.next_offset() - offsets[0]; - // Creating intermediate offsets instead of pushing each offset is faster - // (even if we make MutableBuffer to avoid updating length on each push - // and reserve the necessary capacity, it's still slower) - let mut intermediate = Vec::with_capacity(offsets.len() - 1); - - for &offset in &offsets[1..] { - intermediate.push(offset + shift) + if shift.checked_add(&offsets[offsets.len() - 1]).is_none() { + return Err(ArrowError::OffsetOverflowError( + shift.as_usize() + offsets[offsets.len() - 1].as_usize(), + )); } - self.offsets_builder.append_slice(&intermediate); + self.offsets_builder + .extend(offsets[1..].iter().map(|&offset| offset + shift)); } // Append underlying values, starting from the first offset and ending at the last offset - self.value_builder.append_slice( + self.value_builder.extend_from_slice( &array.values().as_slice()[offsets[0].as_usize()..offsets[array.len()].as_usize()], ); @@ -177,6 +179,7 @@ impl GenericByteBuilder { } else { self.null_buffer_builder.append_n_non_nulls(array.len()); } + Ok(()) } /// Builds the [`GenericByteArray`] and reset this builder. @@ -184,11 +187,11 @@ impl GenericByteBuilder { let array_type = T::DATA_TYPE; let array_builder = ArrayDataBuilder::new(array_type) .len(self.len()) - .add_buffer(self.offsets_builder.finish()) - .add_buffer(self.value_builder.finish()) + .add_buffer(std::mem::take(&mut self.offsets_builder).into()) + .add_buffer(std::mem::take(&mut self.value_builder).into()) .nulls(self.null_buffer_builder.finish()); - self.offsets_builder.append(self.next_offset()); + self.offsets_builder.push(self.next_offset()); let array_data = unsafe { array_builder.build_unchecked() }; GenericByteArray::from(array_data) } @@ -340,11 +343,99 @@ pub type GenericStringBuilder = GenericByteBuilder>; impl std::fmt::Write for GenericStringBuilder { fn write_str(&mut self, s: &str) -> std::fmt::Result { - self.value_builder.append_slice(s.as_bytes()); + self.value_builder.extend_from_slice(s.as_bytes()); Ok(()) } } +/// A byte size value representing the number of bytes to allocate per string in [`GenericStringBuilder`] +/// +/// To create a [`GenericStringBuilder`] using `.with_capacity` we are required to provide: \ +/// - `item_capacity` - the row count \ +/// - `data_capacity` - total string byte count \ +/// +/// We will use the `AVERAGE_STRING_LENGTH` * row_count for `data_capacity`. \ +/// +/// These capacities are preallocation hints used to improve performance, +/// but consequences of passing a hint too large or too small should be negligible. +const AVERAGE_STRING_LENGTH: usize = 16; +/// Trait for string-like array builders +/// +/// This trait provides unified interface for builders that append string-like data +/// such as [`GenericStringBuilder`] and [`crate::builder::StringViewBuilder`] +pub trait StringLikeArrayBuilder: ArrayBuilder { + /// Returns a human-readable type name for the builder. + fn type_name() -> &'static str; + + /// Creates a new builder with the given row capacity. + fn with_capacity(capacity: usize) -> Self; + + /// Appends a non-null string value to the builder. + fn append_value(&mut self, value: &str); + + /// Appends a null value to the builder. + fn append_null(&mut self); +} + +impl StringLikeArrayBuilder for GenericStringBuilder { + fn type_name() -> &'static str { + std::any::type_name::() + } + fn with_capacity(capacity: usize) -> Self { + Self::with_capacity(capacity, capacity * AVERAGE_STRING_LENGTH) + } + fn append_value(&mut self, value: &str) { + Self::append_value(self, value); + } + fn append_null(&mut self) { + Self::append_null(self); + } +} + +/// A byte size value representing the number of bytes to allocate per binary in [`GenericBinaryBuilder`] +/// +/// To create a [`GenericBinaryBuilder`] using `.with_capacity` we are required to provide: \ +/// - `item_capacity` - the row count \ +/// - `data_capacity` - total binary byte count \ +/// +/// We will use the `AVERAGE_BINARY_LENGTH` * row_count for `data_capacity`. \ +/// +/// These capacities are preallocation hints used to improve performance, +/// but consequences of passing a hint too large or too small should be negligible. +const AVERAGE_BINARY_LENGTH: usize = 128; +/// Trait for binary-like array builders +/// +/// This trait provides unified interface for builders that append binary-like data +/// such as [`GenericBinaryBuilder`] and [`crate::builder::BinaryViewBuilder`] +pub trait BinaryLikeArrayBuilder: ArrayBuilder { + /// Returns a human-readable type name for the builder. + fn type_name() -> &'static str; + + /// Creates a new builder with the given row capacity. + fn with_capacity(capacity: usize) -> Self; + + /// Appends a non-null string value to the builder. + fn append_value(&mut self, value: &[u8]); + + /// Appends a null value to the builder. + fn append_null(&mut self); +} + +impl BinaryLikeArrayBuilder for GenericBinaryBuilder { + fn type_name() -> &'static str { + std::any::type_name::() + } + fn with_capacity(capacity: usize) -> Self { + Self::with_capacity(capacity, capacity * AVERAGE_BINARY_LENGTH) + } + fn append_value(&mut self, value: &[u8]) { + Self::append_value(self, value); + } + fn append_null(&mut self) { + Self::append_null(self); + } +} + /// Array builder for [`GenericBinaryArray`][crate::GenericBinaryArray] /// /// Values can be appended using [`GenericByteBuilder::append_value`], and nulls with @@ -394,7 +485,7 @@ pub type GenericBinaryBuilder = GenericByteBuilder>; impl std::io::Write for GenericBinaryBuilder { fn write(&mut self, bs: &[u8]) -> std::io::Result { - self.value_builder.append_slice(bs); + self.value_builder.extend_from_slice(bs); Ok(bs.len()) } @@ -406,8 +497,8 @@ impl std::io::Write for GenericBinaryBuilder { #[cfg(test)] mod tests { use super::*; - use crate::array::Array; use crate::GenericStringArray; + use crate::array::Array; use arrow_buffer::NullBuffer; use std::fmt::Write as _; use std::io::Write as _; @@ -671,9 +762,9 @@ mod tests { let arr3 = GenericStringArray::::from(input[7..].to_vec()); let mut builder = GenericStringBuilder::::new(); - builder.append_array(&arr1); - builder.append_array(&arr2); - builder.append_array(&arr3); + builder.append_array(&arr1).unwrap(); + builder.append_array(&arr2).unwrap(); + builder.append_array(&arr3).unwrap(); let actual = builder.finish(); let expected = GenericStringArray::::from(input); @@ -701,9 +792,9 @@ mod tests { let arr3 = GenericStringArray::::from(input[7..].to_vec()); let mut builder = GenericStringBuilder::::new(); - builder.append_array(&arr1); - builder.append_array(&arr2); - builder.append_array(&arr3); + builder.append_array(&arr1).unwrap(); + builder.append_array(&arr2).unwrap(); + builder.append_array(&arr3).unwrap(); let actual = builder.finish(); let expected = GenericStringArray::::from(input); @@ -715,7 +806,7 @@ mod tests { fn test_append_empty_array() { let arr = GenericStringArray::::from(Vec::<&str>::new()); let mut builder = GenericStringBuilder::::new(); - builder.append_array(&arr); + builder.append_array(&arr).unwrap(); let result = builder.finish(); assert_eq!(result.len(), 0); } @@ -742,7 +833,7 @@ mod tests { assert_ne!(sliced.offsets().last(), full_array.offsets().last()); let mut builder = GenericStringBuilder::::new(); - builder.append_array(&sliced); + builder.append_array(&sliced).unwrap(); let actual = builder.finish(); let expected = GenericStringArray::::from(vec![None, Some("how"), None, None]); @@ -778,8 +869,8 @@ mod tests { }; let mut builder = GenericStringBuilder::::new(); - builder.append_array(&input_1_array_with_nulls); - builder.append_array(&input_2_array_with_nulls); + builder.append_array(&input_1_array_with_nulls).unwrap(); + builder.append_array(&input_2_array_with_nulls).unwrap(); let actual = builder.finish(); let expected = GenericStringArray::::from(vec![ @@ -825,12 +916,27 @@ mod tests { let slice3 = full_array.slice(7, full_array.len() - 7); let mut builder = GenericStringBuilder::::new(); - builder.append_array(&slice1); - builder.append_array(&slice2); - builder.append_array(&slice3); + builder.append_array(&slice1).unwrap(); + builder.append_array(&slice2).unwrap(); + builder.append_array(&slice3).unwrap(); let actual = builder.finish(); assert_eq!(actual, full_array); } + + #[test] + fn test_append_array_offset_overflow_precise() { + let mut builder = GenericStringBuilder::::new(); + + let initial_string = "x".repeat(i32::MAX as usize - 100); + builder.append_value(&initial_string); + + let overflow_string = "y".repeat(200); + let overflow_array = GenericStringArray::::from(vec![overflow_string.as_str()]); + + let result = builder.append_array(&overflow_array); + + assert!(matches!(result, Err(ArrowError::OffsetOverflowError(_)))); + } } diff --git a/arrow-array/src/builder/generic_bytes_dictionary_builder.rs b/arrow-array/src/builder/generic_bytes_dictionary_builder.rs index 3713a411232f..35c7bfced1fd 100644 --- a/arrow-array/src/builder/generic_bytes_dictionary_builder.rs +++ b/arrow-array/src/builder/generic_bytes_dictionary_builder.rs @@ -23,7 +23,7 @@ use crate::{ use arrow_buffer::ArrowNativeType; use arrow_schema::{ArrowError, DataType}; use hashbrown::HashTable; -use num::NumCast; +use num_traits::NumCast; use std::any::Any; use std::sync::Arc; @@ -197,7 +197,7 @@ where let source_keys = source.keys_builder.finish(); let new_keys: PrimitiveArray = source_keys.try_unary(|value| { - num::cast::cast::(value).ok_or_else(|| { + num_traits::cast::cast::(value).ok_or_else(|| { ArrowError::CastError(format!( "Can't cast dictionary keys from source type {:?} to type {:?}", K2::DATA_TYPE, @@ -463,6 +463,38 @@ where DictionaryArray::from(unsafe { builder.build_unchecked() }) } + /// Builds the `DictionaryArray` without resetting the values builder or + /// the internal de-duplication map. + /// + /// The advantage of doing this is that the values will represent the entire + /// set of what has been built so-far by this builder and ensures + /// consistency in the assignment of keys to values across multiple calls + /// to `finish_preserve_values`. This enables ipc writers to efficiently + /// emit delta dictionaries. + /// + /// The downside to this is that building the record requires creating a + /// copy of the values, which can become slowly more expensive if the + /// dictionary grows. + /// + /// Additionally, if record batches from multiple different dictionary + /// builders for the same column are fed into a single ipc writer, beware + /// that entire dictionaries are likely to be re-sent frequently even when + /// the majority of the values are not used by the current record batch. + pub fn finish_preserve_values(&mut self) -> DictionaryArray { + let values = self.values_builder.finish_cloned(); + let keys = self.keys_builder.finish(); + + let data_type = DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(T::DATA_TYPE)); + + let builder = keys + .into_data() + .into_builder() + .data_type(data_type) + .child_data(vec![values.into_data()]); + + DictionaryArray::from(unsafe { builder.build_unchecked() }) + } + /// Returns the current null buffer as a slice pub fn validity_slice(&self) -> Option<&[u8]> { self.keys_builder.validity_slice() @@ -571,7 +603,7 @@ mod tests { use crate::array::Int8Array; use crate::cast::AsArray; - use crate::types::{Int16Type, Int32Type, Int8Type, UInt16Type, UInt8Type, Utf8Type}; + use crate::types::{Int8Type, Int16Type, Int32Type, UInt8Type, UInt16Type, Utf8Type}; use crate::{ArrowPrimitiveType, BinaryArray, StringArray}; fn test_bytes_dictionary_builder(values: Vec<&T::Native>) @@ -757,7 +789,7 @@ mod tests { fn test_try_new_from_builder_cast_fails() { let mut source_builder = StringDictionaryBuilder::::new(); for i in 0..257 { - source_builder.append_value(format!("val{}", i)); + source_builder.append_value(format!("val{i}")); } // there should be too many values that we can't downcast to the underlying type @@ -1006,4 +1038,51 @@ mod tests { assert_eq!(values, [None, None]); } + + #[test] + fn test_finish_preserve_values() { + // Create the first dictionary + let mut builder = GenericByteDictionaryBuilder::::new(); + builder.append("a").unwrap(); + builder.append("b").unwrap(); + builder.append("c").unwrap(); + let dict = builder.finish_preserve_values(); + assert_eq!(dict.keys().values(), &[0, 1, 2]); + assert_eq!(dict.values().len(), 3); + let values = dict + .downcast_dict::>() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!(values, [Some("a"), Some("b"), Some("c")]); + + // Create a new dictionary + builder.append("d").unwrap(); + builder.append("e").unwrap(); + let dict2 = builder.finish_preserve_values(); + + // Make sure the keys are assigned after the old ones and we have the + // right values + assert_eq!(dict2.keys().values(), &[3, 4]); + let values = dict2 + .downcast_dict::>() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!(values, [Some("d"), Some("e")]); + + // Check that we have all of the expected values + assert_eq!(dict2.values().len(), 5); + let all_values = dict2 + .values() + .as_any() + .downcast_ref::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!( + all_values, + [Some("a"), Some("b"), Some("c"), Some("d"), Some("e"),] + ); + } } diff --git a/arrow-array/src/builder/generic_bytes_view_builder.rs b/arrow-array/src/builder/generic_bytes_view_builder.rs index ae7355433f81..2d60187344cf 100644 --- a/arrow-array/src/builder/generic_bytes_view_builder.rs +++ b/arrow-array/src/builder/generic_bytes_view_builder.rs @@ -20,12 +20,12 @@ use std::marker::PhantomData; use std::sync::Arc; use arrow_buffer::{Buffer, NullBufferBuilder, ScalarBuffer}; -use arrow_data::ByteView; +use arrow_data::{ByteView, MAX_INLINE_VIEW_LEN}; use arrow_schema::ArrowError; -use hashbrown::hash_table::Entry; use hashbrown::HashTable; +use hashbrown::hash_table::Entry; -use crate::builder::ArrayBuilder; +use crate::builder::{ArrayBuilder, BinaryLikeArrayBuilder, StringLikeArrayBuilder}; use crate::types::bytes::ByteArrayNativeType; use crate::types::{BinaryViewType, ByteViewType, StringViewType}; use crate::{Array, ArrayRef, GenericByteViewArray}; @@ -68,8 +68,8 @@ impl BlockSizeGrowthStrategy { /// /// To avoid bump allocating, this builder allocates data in fixed size blocks, configurable /// using [`GenericByteViewBuilder::with_fixed_block_size`]. [`GenericByteViewBuilder::append_value`] -/// writes values larger than 12 bytes to the current in-progress block, with values smaller -/// than 12 bytes inlined into the views. If a value is appended that will not fit in the +/// writes values larger than [`MAX_INLINE_VIEW_LEN`] bytes to the current in-progress block, with values smaller +/// than [`MAX_INLINE_VIEW_LEN`] bytes inlined into the views. If a value is appended that will not fit in the /// in-progress block, it will be closed, and a new block of sufficient size allocated /// /// # Append Views @@ -87,6 +87,7 @@ pub struct GenericByteViewBuilder { /// Some if deduplicating strings /// map ` -> ` string_tracker: Option<(HashTable, ahash::RandomState)>, + max_deduplication_len: Option, phantom: PhantomData, } @@ -107,21 +108,39 @@ impl GenericByteViewBuilder { current_size: STARTING_BLOCK_SIZE, }, string_tracker: None, + max_deduplication_len: None, phantom: Default::default(), } } + /// Configure max deduplication length when deduplicating strings while building the array. + /// Default is None. + /// + /// When [`Self::with_deduplicate_strings`] is enabled, the builder attempts to deduplicate + /// any strings longer than 12 bytes. However, since it takes time proportional to the length + /// of the string to deduplicate, setting this option limits the CPU overhead for this option. + pub fn with_max_deduplication_len(self, max_deduplication_len: u32) -> Self { + debug_assert!( + max_deduplication_len > 0, + "max_deduplication_len must be greater than 0" + ); + Self { + max_deduplication_len: Some(max_deduplication_len), + ..self + } + } + /// Set a fixed buffer size for variable length strings /// /// The block size is the size of the buffer used to store values greater - /// than 12 bytes. The builder allocates new buffers when the current + /// than [`MAX_INLINE_VIEW_LEN`] bytes. The builder allocates new buffers when the current /// buffer is full. /// /// By default the builder balances buffer size and buffer count by /// growing buffer size exponentially from 8KB up to 2MB. The /// first buffer allocated is 8KB, then 16KB, then 32KB, etc up to 2MB. /// - /// If this method is used, any new buffers allocated are + /// If this method is used, any new buffers allocated are /// exactly this size. This can be useful for advanced users /// that want to control the memory usage and buffer count. /// @@ -134,13 +153,6 @@ impl GenericByteViewBuilder { } } - /// Override the size of buffers to allocate for holding string data - /// Use `with_fixed_block_size` instead. - #[deprecated(since = "53.0.0", note = "Use `with_fixed_block_size` instead")] - pub fn with_block_size(self, block_size: u32) -> Self { - self.with_fixed_block_size(block_size) - } - /// Deduplicate strings while building the array /// /// This will potentially decrease the memory usage if the array have repeated strings @@ -195,10 +207,10 @@ impl GenericByteViewBuilder { /// (2) The range `offset..offset+length` must be within the bounds of the block /// (3) The data in the block must be valid of type `T` pub unsafe fn append_view_unchecked(&mut self, block: u32, offset: u32, len: u32) { - let b = self.completed.get_unchecked(block as usize); + let b = unsafe { self.completed.get_unchecked(block as usize) }; let start = offset as usize; let end = start.saturating_add(len as usize); - let b = b.get_unchecked(start..end); + let b = unsafe { b.get_unchecked(start..end) }; let view = make_view(b, block, offset); self.views_buffer.push(view); @@ -221,7 +233,7 @@ impl GenericByteViewBuilder { } else { self.views_buffer.extend(array.views().iter().map(|v| { let mut byte_view = ByteView::from(*v); - if byte_view.length > 12 { + if byte_view.length > MAX_INLINE_VIEW_LEN { // Small views (<=12 bytes) are inlined, so only need to update large views byte_view.buffer_index += starting_buffer; }; @@ -289,7 +301,7 @@ impl GenericByteViewBuilder { pub fn get_value(&self, index: usize) -> &[u8] { let view = self.views_buffer.as_slice().get(index).unwrap(); let len = *view as u32; - if len <= 12 { + if len <= MAX_INLINE_VIEW_LEN { // # Safety // The view is valid from the builder unsafe { GenericByteViewArray::::inline_value(view, len as usize) } @@ -313,48 +325,70 @@ impl GenericByteViewBuilder { /// - String length exceeds `u32::MAX` #[inline] pub fn append_value(&mut self, value: impl AsRef) { + self.try_append_value(value).unwrap() + } + + /// Appends a value into the builder + /// + /// # Errors + /// + /// Returns an error if: + /// - String buffer count exceeds `u32::MAX` + /// - String length exceeds `u32::MAX` + #[inline] + pub fn try_append_value(&mut self, value: impl AsRef) -> Result<(), ArrowError> { let v: &[u8] = value.as_ref().as_ref(); - let length: u32 = v.len().try_into().unwrap(); - if length <= 12 { + let length: u32 = v.len().try_into().map_err(|_| { + ArrowError::InvalidArgumentError(format!("String length {} exceeds u32::MAX", v.len())) + })?; + + if length <= MAX_INLINE_VIEW_LEN { let mut view_buffer = [0; 16]; view_buffer[0..4].copy_from_slice(&length.to_le_bytes()); view_buffer[4..4 + v.len()].copy_from_slice(v); self.views_buffer.push(u128::from_le_bytes(view_buffer)); self.null_buffer_builder.append_non_null(); - return; + return Ok(()); } // Deduplication if: // (1) deduplication is enabled. - // (2) len > 12 - if let Some((mut ht, hasher)) = self.string_tracker.take() { - let hash_val = hasher.hash_one(v); - let hasher_fn = |v: &_| hasher.hash_one(v); - - let entry = ht.entry( - hash_val, - |idx| { - let stored_value = self.get_value(*idx); - v == stored_value - }, - hasher_fn, - ); - match entry { - Entry::Occupied(occupied) => { - // If the string already exists, we will directly use the view - let idx = occupied.get(); - self.views_buffer.push(self.views_buffer[*idx]); - self.null_buffer_builder.append_non_null(); - self.string_tracker = Some((ht, hasher)); - return; - } - Entry::Vacant(vacant) => { - // o.w. we insert the (string hash -> view index) - // the idx is current length of views_builder, as we are inserting a new view - vacant.insert(self.views_buffer.len()); + // (2) len > `MAX_INLINE_VIEW_LEN` and len <= `max_deduplication_len` + let can_deduplicate = self.string_tracker.is_some() + && self + .max_deduplication_len + .map(|max_length| length <= max_length) + .unwrap_or(true); + if can_deduplicate { + if let Some((mut ht, hasher)) = self.string_tracker.take() { + let hash_val = hasher.hash_one(v); + let hasher_fn = |v: &_| hasher.hash_one(v); + + let entry = ht.entry( + hash_val, + |idx| { + let stored_value = self.get_value(*idx); + v == stored_value + }, + hasher_fn, + ); + match entry { + Entry::Occupied(occupied) => { + // If the string already exists, we will directly use the view + let idx = occupied.get(); + self.views_buffer.push(self.views_buffer[*idx]); + self.null_buffer_builder.append_non_null(); + self.string_tracker = Some((ht, hasher)); + return Ok(()); + } + Entry::Vacant(vacant) => { + // o.w. we insert the (string hash -> view index) + // the idx is current length of views_builder, as we are inserting a new view + vacant.insert(self.views_buffer.len()); + } } + self.string_tracker = Some((ht, hasher)); } - self.string_tracker = Some((ht, hasher)); } let required_cap = self.in_progress.len() + v.len(); @@ -363,17 +397,28 @@ impl GenericByteViewBuilder { let to_reserve = v.len().max(self.block_size.next_size() as usize); self.in_progress.reserve(to_reserve); }; + let offset = self.in_progress.len() as u32; self.in_progress.extend_from_slice(v); + let buffer_index: u32 = self.completed.len().try_into().map_err(|_| { + ArrowError::InvalidArgumentError(format!( + "Buffer count {} exceeds u32::MAX", + self.completed.len() + )) + })?; + let view = ByteView { length, + // This won't panic as we checked the length of prefix earlier. prefix: u32::from_le_bytes(v[0..4].try_into().unwrap()), - buffer_index: self.completed.len() as u32, + buffer_index, offset, }; self.views_buffer.push(view.into()); self.null_buffer_builder.append_non_null(); + + Ok(()) } /// Append an `Option` value into the builder @@ -385,6 +430,53 @@ impl GenericByteViewBuilder { }; } + /// Append the same value `n` times into the builder + /// + /// This is more efficient than calling [`Self::try_append_value`] `n` times, + /// especially when deduplication is enabled, as it only hashes the value once. + /// + /// # Errors + /// + /// Returns an error if + /// - String buffer count exceeds `u32::MAX` + /// - String length exceeds `u32::MAX` + /// + /// # Example + /// ``` + /// # use arrow_array::builder::StringViewBuilder; + /// # use arrow_array::Array; + /// let mut builder = StringViewBuilder::new().with_deduplicate_strings(); + /// + /// // Append "hello" 1000 times efficiently + /// builder.try_append_value_n("hello", 1000)?; + /// + /// let array = builder.finish(); + /// assert_eq!(array.len(), 1000); + /// + /// // All values are "hello" + /// for value in array.iter() { + /// assert_eq!(value, Some("hello")); + /// } + /// # Ok::<(), arrow_schema::ArrowError>(()) + /// ``` + #[inline] + pub fn try_append_value_n( + &mut self, + value: impl AsRef, + n: usize, + ) -> Result<(), ArrowError> { + if n == 0 { + return Ok(()); + } + // Process value once (handles deduplication, buffer management, view creation) + self.try_append_value(value)?; + // Reuse the view (n-1) times + let view = *self.views_buffer.last().unwrap(); + self.views_buffer.extend(std::iter::repeat_n(view, n - 1)); + self.null_buffer_builder.append_n_non_nulls(n - 1); + Ok(()) + } + /// Append a null value into the builder #[inline] pub fn append_null(&mut self) { @@ -397,7 +489,7 @@ impl GenericByteViewBuilder { self.flush_in_progress(); let completed = std::mem::take(&mut self.completed); let nulls = self.null_buffer_builder.finish(); - if let Some((ref mut ht, _)) = self.string_tracker.as_mut() { + if let Some((ht, _)) = self.string_tracker.as_mut() { ht.clear(); } let views = std::mem::take(&mut self.views_buffer); @@ -514,6 +606,21 @@ impl> Extend> /// ``` pub type StringViewBuilder = GenericByteViewBuilder; +impl StringLikeArrayBuilder for StringViewBuilder { + fn type_name() -> &'static str { + std::any::type_name::() + } + fn with_capacity(capacity: usize) -> Self { + Self::with_capacity(capacity) + } + fn append_value(&mut self, value: &str) { + Self::append_value(self, value); + } + fn append_null(&mut self) { + Self::append_null(self); + } +} + /// Array builder for [`BinaryViewArray`][crate::BinaryViewArray] /// /// Values can be appended using [`GenericByteViewBuilder::append_value`], and nulls with @@ -536,6 +643,21 @@ pub type StringViewBuilder = GenericByteViewBuilder; /// pub type BinaryViewBuilder = GenericByteViewBuilder; +impl BinaryLikeArrayBuilder for BinaryViewBuilder { + fn type_name() -> &'static str { + std::any::type_name::() + } + fn with_capacity(capacity: usize) -> Self { + Self::with_capacity(capacity) + } + fn append_value(&mut self, value: &[u8]) { + Self::append_value(self, value); + } + fn append_null(&mut self) { + Self::append_null(self); + } +} + /// Creates a view from a fixed length input (the compiler can generate /// specialized code for this) fn make_inlined_view(data: &[u8]) -> u128 { @@ -587,8 +709,52 @@ pub fn make_view(data: &[u8], block_id: u32, offset: u32) -> u128 { mod tests { use core::str; + use arrow_buffer::ArrowNativeType; + use super::*; - use crate::Array; + + #[test] + fn test_string_max_deduplication_len() { + let value_1 = "short"; + let value_2 = "not so similar string but long"; + let value_3 = "1234567890123"; + + let max_deduplication_len = MAX_INLINE_VIEW_LEN * 2; + + let mut builder = StringViewBuilder::new() + .with_deduplicate_strings() + .with_max_deduplication_len(max_deduplication_len); + + assert!(value_1.len() < MAX_INLINE_VIEW_LEN.as_usize()); + assert!(value_2.len() > max_deduplication_len.as_usize()); + assert!( + value_3.len() > MAX_INLINE_VIEW_LEN.as_usize() + && value_3.len() < max_deduplication_len.as_usize() + ); + + // append value1 (short), expect it is inlined and not deduplicated + builder.append_value(value_1); // view 0 + builder.append_value(value_1); // view 1 + // append value2, expect second copy is not deduplicated as it exceeds max_deduplication_len + builder.append_value(value_2); // view 2 + builder.append_value(value_2); // view 3 + // append value3, expect second copy is deduplicated + builder.append_value(value_3); // view 4 + builder.append_value(value_3); // view 5 + + let array = builder.finish(); + + // verify + let v2 = ByteView::from(array.views()[2]); + let v3 = ByteView::from(array.views()[3]); + assert_eq!(v2.buffer_index, v3.buffer_index); // stored in same buffer + assert_ne!(v2.offset, v3.offset); // different offsets --> not deduplicated + + let v4 = ByteView::from(array.views()[4]); + let v5 = ByteView::from(array.views()[5]); + assert_eq!(v4.buffer_index, v5.buffer_index); // stored in same buffer + assert_eq!(v4.offset, v5.offset); // same offsets --> deduplicated + } #[test] fn test_string_view_deduplicate() { @@ -695,7 +861,10 @@ mod tests { ); let err = v.try_append_view(0, u32::MAX, 1).unwrap_err(); - assert_eq!(err.to_string(), "Invalid argument error: Range 4294967295..4294967296 out of bounds for block of length 17"); + assert_eq!( + err.to_string(), + "Invalid argument error: Range 4294967295..4294967296 out of bounds for block of length 17" + ); let err = v.try_append_view(0, 1, u32::MAX).unwrap_err(); assert_eq!( @@ -746,10 +915,12 @@ mod tests { assert_eq!(fixed_builder.completed.len(), 2_usize.pow(i + 1) - 1); // Every buffer is fixed size - assert!(fixed_builder - .completed - .iter() - .all(|b| b.len() == STARTING_BLOCK_SIZE as usize)); + assert!( + fixed_builder + .completed + .iter() + .all(|b| b.len() == STARTING_BLOCK_SIZE as usize) + ); } // Add one more value, and the buffer stop growing. @@ -760,4 +931,76 @@ mod tests { MAX_BLOCK_SIZE as usize ); } + + #[test] + fn test_append_value_n() { + // Test with inline strings (<=12 bytes) + let mut builder = StringViewBuilder::new(); + + builder.try_append_value_n("hello", 100).unwrap(); + builder.append_value("world"); + builder.try_append_value_n("foo", 50).unwrap(); + + let array = builder.finish(); + assert_eq!(array.len(), 151); + assert_eq!(array.null_count(), 0); + + // Verify the values + for i in 0..100 { + assert_eq!(array.value(i), "hello"); + } + assert_eq!(array.value(100), "world"); + for i in 101..151 { + assert_eq!(array.value(i), "foo"); + } + + // All inline strings should have no data buffers + assert_eq!(array.data_buffers().len(), 0); + } + + #[test] + fn test_append_value_n_with_deduplication() { + let long_string = "This is a very long string that exceeds the inline length"; + + // Test with deduplication enabled + let mut builder = StringViewBuilder::new().with_deduplicate_strings(); + + // First append the string once to add it to the hash map + builder.append_value(long_string); + + // Then append_n the same string - should deduplicate and reuse the existing value + builder.try_append_value_n(long_string, 999).unwrap(); + + let array = builder.finish(); + assert_eq!(array.len(), 1000); + assert_eq!(array.null_count(), 0); + + // Verify all values are the same + for i in 0..1000 { + assert_eq!(array.value(i), long_string); + } + + // With deduplication, should only have 1 data buffer containing the string once + assert_eq!(array.data_buffers().len(), 1); + + // All views should be identical + let first_view = array.views()[0]; + for view in array.views().iter() { + assert_eq!(*view, first_view); + } + } + + #[test] + fn test_append_value_n_zero() { + let mut builder = StringViewBuilder::new(); + + builder.append_value("first"); + builder.try_append_value_n("should not appear", 0).unwrap(); + builder.append_value("second"); + + let array = builder.finish(); + assert_eq!(array.len(), 2); + assert_eq!(array.value(0), "first"); + assert_eq!(array.value(1), "second"); + } } diff --git a/arrow-array/src/builder/generic_list_builder.rs b/arrow-array/src/builder/generic_list_builder.rs index 463b498c55ba..cabf7a514050 100644 --- a/arrow-array/src/builder/generic_list_builder.rs +++ b/arrow-array/src/builder/generic_list_builder.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::builder::{ArrayBuilder, BufferBuilder}; +use crate::builder::ArrayBuilder; use crate::{Array, ArrayRef, GenericListArray, OffsetSizeTrait}; use arrow_buffer::NullBufferBuilder; use arrow_buffer::{Buffer, OffsetBuffer}; @@ -86,7 +86,7 @@ use std::sync::Arc; /// [`LargeListArray`]: crate::array::LargeListArray #[derive(Debug)] pub struct GenericListBuilder { - offsets_builder: BufferBuilder, + offsets_builder: Vec, null_buffer_builder: NullBufferBuilder, values_builder: T, field: Option, @@ -108,8 +108,8 @@ impl GenericListBuilder Self { - let mut offsets_builder = BufferBuilder::::new(capacity + 1); - offsets_builder.append(OffsetSize::zero()); + let mut offsets_builder = Vec::with_capacity(capacity + 1); + offsets_builder.push(OffsetSize::zero()); Self { offsets_builder, null_buffer_builder: NullBufferBuilder::new(capacity), @@ -192,7 +192,7 @@ where /// Panics if the length of [`Self::values`] exceeds `OffsetSize::MAX` #[inline] pub fn append(&mut self, is_valid: bool) { - self.offsets_builder.append(self.next_offset()); + self.offsets_builder.push(self.next_offset()); self.null_buffer_builder.append(is_valid); } @@ -266,7 +266,7 @@ where /// See [`Self::append_value`] for an example use. #[inline] pub fn append_null(&mut self) { - self.offsets_builder.append(self.next_offset()); + self.offsets_builder.push(self.next_offset()); self.null_buffer_builder.append_null(); } @@ -274,7 +274,8 @@ where #[inline] pub fn append_nulls(&mut self, n: usize) { let next_offset = self.next_offset(); - self.offsets_builder.append_n(n, next_offset); + self.offsets_builder + .extend(std::iter::repeat_n(next_offset, n)); self.null_buffer_builder.append_n_nulls(n); } @@ -298,10 +299,10 @@ where let values = self.values_builder.finish(); let nulls = self.null_buffer_builder.finish(); - let offsets = self.offsets_builder.finish(); + let offsets = Buffer::from_vec(std::mem::take(&mut self.offsets_builder)); // Safety: Safe by construction let offsets = unsafe { OffsetBuffer::new_unchecked(offsets.into()) }; - self.offsets_builder.append(OffsetSize::zero()); + self.offsets_builder.push(OffsetSize::zero()); let field = match &self.field { Some(f) => f.clone(), @@ -362,10 +363,10 @@ where #[cfg(test)] mod tests { use super::*; - use crate::builder::{make_builder, Int32Builder, ListBuilder}; + use crate::Int32Array; + use crate::builder::{Int32Builder, ListBuilder, make_builder}; use crate::cast::AsArray; use crate::types::Int32Type; - use crate::Int32Array; use arrow_schema::DataType; fn _test_generic_list_array_builder() { diff --git a/arrow-array/src/builder/generic_list_view_builder.rs b/arrow-array/src/builder/generic_list_view_builder.rs index 5aaf9efefe24..c13c21cb988b 100644 --- a/arrow-array/src/builder/generic_list_view_builder.rs +++ b/arrow-array/src/builder/generic_list_view_builder.rs @@ -17,7 +17,7 @@ use crate::builder::ArrayBuilder; use crate::{ArrayRef, GenericListViewArray, OffsetSizeTrait}; -use arrow_buffer::{Buffer, BufferBuilder, NullBufferBuilder, ScalarBuffer}; +use arrow_buffer::{Buffer, NullBufferBuilder, ScalarBuffer}; use arrow_schema::{Field, FieldRef}; use std::any::Any; use std::sync::Arc; @@ -25,8 +25,8 @@ use std::sync::Arc; /// Builder for [`GenericListViewArray`] #[derive(Debug)] pub struct GenericListViewBuilder { - offsets_builder: BufferBuilder, - sizes_builder: BufferBuilder, + offsets_builder: Vec, + sizes_builder: Vec, null_buffer_builder: NullBufferBuilder, values_builder: T, field: Option, @@ -83,8 +83,8 @@ impl GenericListViewBuilder Self { - let offsets_builder = BufferBuilder::::new(capacity); - let sizes_builder = BufferBuilder::::new(capacity); + let offsets_builder = Vec::with_capacity(capacity); + let sizes_builder = Vec::with_capacity(capacity); Self { offsets_builder, null_buffer_builder: NullBufferBuilder::new(capacity), @@ -132,8 +132,8 @@ where /// Panics if the length of [`Self::values`] exceeds `OffsetSize::MAX` #[inline] pub fn append(&mut self, is_valid: bool) { - self.offsets_builder.append(self.current_offset); - self.sizes_builder.append( + self.offsets_builder.push(self.current_offset); + self.sizes_builder.push( OffsetSize::from_usize( self.values_builder.len() - self.current_offset.to_usize().unwrap(), ) @@ -158,9 +158,8 @@ where /// See [`Self::append_value`] for an example use. #[inline] pub fn append_null(&mut self) { - self.offsets_builder.append(self.current_offset); - self.sizes_builder - .append(OffsetSize::from_usize(0).unwrap()); + self.offsets_builder.push(self.current_offset); + self.sizes_builder.push(OffsetSize::from_usize(0).unwrap()); self.null_buffer_builder.append_null(); } @@ -183,12 +182,12 @@ where pub fn finish(&mut self) -> GenericListViewArray { let values = self.values_builder.finish(); let nulls = self.null_buffer_builder.finish(); - let offsets = self.offsets_builder.finish(); + let offsets = Buffer::from_vec(std::mem::take(&mut self.offsets_builder)); self.current_offset = OffsetSize::zero(); // Safety: Safe by construction let offsets = ScalarBuffer::from(offsets); - let sizes = self.sizes_builder.finish(); + let sizes = Buffer::from_vec(std::mem::take(&mut self.sizes_builder)); let sizes = ScalarBuffer::from(sizes); let field = match &self.field { Some(f) => f.clone(), @@ -246,7 +245,7 @@ where #[cfg(test)] mod tests { use super::*; - use crate::builder::{make_builder, Int32Builder, ListViewBuilder}; + use crate::builder::{Int32Builder, ListViewBuilder, make_builder}; use crate::cast::AsArray; use crate::types::Int32Type; use crate::{Array, Int32Array}; diff --git a/arrow-array/src/builder/map_builder.rs b/arrow-array/src/builder/map_builder.rs index 012a454e76c9..b70d4b73880b 100644 --- a/arrow-array/src/builder/map_builder.rs +++ b/arrow-array/src/builder/map_builder.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::builder::{ArrayBuilder, BufferBuilder}; +use crate::builder::ArrayBuilder; use crate::{Array, ArrayRef, MapArray, StructArray}; use arrow_buffer::Buffer; use arrow_buffer::{NullBuffer, NullBufferBuilder}; @@ -56,7 +56,7 @@ use std::sync::Arc; /// ``` #[derive(Debug)] pub struct MapBuilder { - offsets_builder: BufferBuilder, + offsets_builder: Vec, null_buffer_builder: NullBufferBuilder, field_names: MapFieldNames, key_builder: K, @@ -100,8 +100,8 @@ impl MapBuilder { value_builder: V, capacity: usize, ) -> Self { - let mut offsets_builder = BufferBuilder::::new(capacity + 1); - offsets_builder.append(0); + let mut offsets_builder = Vec::with_capacity(capacity + 1); + offsets_builder.push(0); Self { offsets_builder, null_buffer_builder: NullBufferBuilder::new(capacity), @@ -166,7 +166,7 @@ impl MapBuilder { self.value_builder.len() ))); } - self.offsets_builder.append(self.key_builder.len() as i32); + self.offsets_builder.push(self.key_builder.len() as i32); self.null_buffer_builder.append(is_valid); Ok(()) } @@ -177,8 +177,8 @@ impl MapBuilder { // Build the keys let keys_arr = self.key_builder.finish(); let values_arr = self.value_builder.finish(); - let offset_buffer = self.offsets_builder.finish(); - self.offsets_builder.append(0); + let offset_buffer = Buffer::from_vec(std::mem::take(&mut self.offsets_builder)); + self.offsets_builder.push(0); let null_bit_buffer = self.null_buffer_builder.finish(); self.finish_helper(keys_arr, values_arr, offset_buffer, null_bit_buffer, len) @@ -284,7 +284,7 @@ impl ArrayBuilder for MapBuilder { #[cfg(test)] mod tests { use super::*; - use crate::builder::{make_builder, Int32Builder, StringBuilder}; + use crate::builder::{Int32Builder, StringBuilder, make_builder}; use crate::{Int32Array, StringArray}; use std::collections::HashMap; diff --git a/arrow-array/src/builder/mod.rs b/arrow-array/src/builder/mod.rs index 680563c6cfc3..02c6df453b6c 100644 --- a/arrow-array/src/builder/mod.rs +++ b/arrow-array/src/builder/mod.rs @@ -273,8 +273,8 @@ mod union_builder; pub use union_builder::*; -use crate::types::{Int16Type, Int32Type, Int64Type, Int8Type}; use crate::ArrayRef; +use crate::types::{Int8Type, Int16Type, Int32Type, Int64Type}; use arrow_schema::{DataType, IntervalUnit, TimeUnit}; use std::any::Any; @@ -447,9 +447,16 @@ pub fn make_builder(datatype: &DataType, capacity: usize) -> Box Box::new(Float64Builder::with_capacity(capacity)), DataType::Binary => Box::new(BinaryBuilder::with_capacity(capacity, 1024)), DataType::LargeBinary => Box::new(LargeBinaryBuilder::with_capacity(capacity, 1024)), + DataType::BinaryView => Box::new(BinaryViewBuilder::with_capacity(capacity)), DataType::FixedSizeBinary(len) => { Box::new(FixedSizeBinaryBuilder::with_capacity(capacity, *len)) } + DataType::Decimal32(p, s) => Box::new( + Decimal32Builder::with_capacity(capacity).with_data_type(DataType::Decimal32(*p, *s)), + ), + DataType::Decimal64(p, s) => Box::new( + Decimal64Builder::with_capacity(capacity).with_data_type(DataType::Decimal64(*p, *s)), + ), DataType::Decimal128(p, s) => Box::new( Decimal128Builder::with_capacity(capacity).with_data_type(DataType::Decimal128(*p, *s)), ), @@ -458,6 +465,7 @@ pub fn make_builder(datatype: &DataType, capacity: usize) -> Box Box::new(StringBuilder::with_capacity(capacity, 1024)), DataType::LargeUtf8 => Box::new(LargeStringBuilder::with_capacity(capacity, 1024)), + DataType::Utf8View => Box::new(StringViewBuilder::with_capacity(capacity)), DataType::Date32 => Box::new(Date32Builder::with_capacity(capacity)), DataType::Date64 => Box::new(Date64Builder::with_capacity(capacity)), DataType::Time32(TimeUnit::Second) => { @@ -559,7 +567,7 @@ pub fn make_builder(datatype: &DataType, capacity: usize) -> Box panic!("The field of Map data type {t:?} should have a child Struct field"), + t => panic!("The field of Map data type {t} should have a child Struct field"), }, DataType::Struct(fields) => Box::new(StructBuilder::from_fields(fields.clone(), capacity)), t @ DataType::Dictionary(key_type, value_type) => { @@ -586,7 +594,7 @@ pub fn make_builder(datatype: &DataType, capacity: usize) -> Box panic!("Dictionary value type {t:?} is not currently supported"), + t => unimplemented!("Dictionary value type {t} is not currently supported"), } }; } @@ -596,10 +604,12 @@ pub fn make_builder(datatype: &DataType, capacity: usize) -> Box dict_builder!(Int32Type), DataType::Int64 => dict_builder!(Int64Type), _ => { - panic!("Data type {t:?} with key type {key_type:?} is not currently supported") + unimplemented!( + "Data type {t} with key type {key_type} is not currently supported" + ) } } } - t => panic!("Data type {t:?} is not currently supported"), + t => unimplemented!("Data type {t} is not currently supported"), } } diff --git a/arrow-array/src/builder/null_builder.rs b/arrow-array/src/builder/null_builder.rs index 59086dffa907..489822065b56 100644 --- a/arrow-array/src/builder/null_builder.rs +++ b/arrow-array/src/builder/null_builder.rs @@ -59,18 +59,6 @@ impl NullBuilder { Self { len: 0 } } - /// Creates a new null builder with space for `capacity` elements without re-allocating - #[deprecated = "there is no actual notion of capacity in the NullBuilder, so emulating it makes little sense"] - pub fn with_capacity(_capacity: usize) -> Self { - Self::new() - } - - /// Returns the capacity of this builder measured in slots of type `T` - #[deprecated = "there is no actual notion of capacity in the NullBuilder, so emulating it makes little sense"] - pub fn capacity(&self) -> usize { - self.len - } - /// Appends a null slot into the builder #[inline] pub fn append_null(&mut self) { diff --git a/arrow-array/src/builder/primitive_builder.rs b/arrow-array/src/builder/primitive_builder.rs index 41c65fe34e35..049cef241c83 100644 --- a/arrow-array/src/builder/primitive_builder.rs +++ b/arrow-array/src/builder/primitive_builder.rs @@ -15,11 +15,10 @@ // specific language governing permissions and limitations // under the License. -use crate::builder::{ArrayBuilder, BufferBuilder}; +use crate::builder::ArrayBuilder; use crate::types::*; use crate::{Array, ArrayRef, PrimitiveArray}; -use arrow_buffer::NullBufferBuilder; -use arrow_buffer::{Buffer, MutableBuffer}; +use arrow_buffer::{Buffer, MutableBuffer, NullBufferBuilder, ScalarBuffer}; use arrow_data::ArrayData; use arrow_schema::{ArrowError, DataType}; use std::any::Any; @@ -87,6 +86,10 @@ pub type DurationMicrosecondBuilder = PrimitiveBuilder; /// An elapsed time in nanoseconds array builder. pub type DurationNanosecondBuilder = PrimitiveBuilder; +/// A decimal 32 array builder +pub type Decimal32Builder = PrimitiveBuilder; +/// A decimal 64 array builder +pub type Decimal64Builder = PrimitiveBuilder; /// A decimal 128 array builder pub type Decimal128Builder = PrimitiveBuilder; /// A decimal 256 array builder @@ -95,7 +98,7 @@ pub type Decimal256Builder = PrimitiveBuilder; /// Builder for [`PrimitiveArray`] #[derive(Debug)] pub struct PrimitiveBuilder { - values_builder: BufferBuilder, + values_builder: Vec, null_buffer_builder: NullBufferBuilder, data_type: DataType, } @@ -147,7 +150,7 @@ impl PrimitiveBuilder { /// Creates a new primitive array builder with capacity no of items pub fn with_capacity(capacity: usize) -> Self { Self { - values_builder: BufferBuilder::::new(capacity), + values_builder: Vec::with_capacity(capacity), null_buffer_builder: NullBufferBuilder::new(capacity), data_type: T::DATA_TYPE, } @@ -158,7 +161,7 @@ impl PrimitiveBuilder { values_buffer: MutableBuffer, null_buffer: Option, ) -> Self { - let values_builder = BufferBuilder::::new_from_buffer(values_buffer); + let values_builder: Vec = ScalarBuffer::::from(values_buffer).into(); let null_buffer_builder = null_buffer .map(|buffer| NullBufferBuilder::new_from_buffer(buffer, values_builder.len())) @@ -175,7 +178,8 @@ impl PrimitiveBuilder { /// data type of the generated array. /// /// This method allows overriding the data type, to allow specifying timezones - /// for [`DataType::Timestamp`] or precision and scale for [`DataType::Decimal128`] and [`DataType::Decimal256`] + /// for [`DataType::Timestamp`] or precision and scale for [`DataType::Decimal32`], + /// [`DataType::Decimal64`], [`DataType::Decimal128`] and [`DataType::Decimal256`] /// /// # Panics /// @@ -199,28 +203,29 @@ impl PrimitiveBuilder { #[inline] pub fn append_value(&mut self, v: T::Native) { self.null_buffer_builder.append_non_null(); - self.values_builder.append(v); + self.values_builder.push(v); } /// Appends a value of type `T` into the builder `n` times #[inline] pub fn append_value_n(&mut self, v: T::Native, n: usize) { self.null_buffer_builder.append_n_non_nulls(n); - self.values_builder.append_n(n, v); + self.values_builder.extend(std::iter::repeat_n(v, n)); } /// Appends a null slot into the builder #[inline] pub fn append_null(&mut self) { self.null_buffer_builder.append_null(); - self.values_builder.advance(1); + self.values_builder.push(T::Native::default()); } /// Appends `n` no. of null's into the builder #[inline] pub fn append_nulls(&mut self, n: usize) { self.null_buffer_builder.append_n_nulls(n); - self.values_builder.advance(n); + self.values_builder + .extend(std::iter::repeat_n(T::Native::default(), n)); } /// Appends an `Option` into the builder @@ -236,7 +241,7 @@ impl PrimitiveBuilder { #[inline] pub fn append_slice(&mut self, v: &[T::Native]) { self.null_buffer_builder.append_n_non_nulls(v.len()); - self.values_builder.append_slice(v); + self.values_builder.extend_from_slice(v); } /// Appends values from a slice of type `T` and a validity boolean slice @@ -252,7 +257,7 @@ impl PrimitiveBuilder { "Value and validity lengths must be equal" ); self.null_buffer_builder.append_slice(is_valid); - self.values_builder.append_slice(values); + self.values_builder.extend_from_slice(values); } /// Appends array values and null to this builder as is @@ -269,7 +274,7 @@ impl PrimitiveBuilder { "array data type mismatch" ); - self.values_builder.append_slice(array.values()); + self.values_builder.extend_from_slice(array.values()); if let Some(null_buffer) = array.nulls() { self.null_buffer_builder.append_buffer(null_buffer); } else { @@ -291,7 +296,7 @@ impl PrimitiveBuilder { .expect("append_trusted_len_iter requires an upper bound"); self.null_buffer_builder.append_n_non_nulls(len); - self.values_builder.append_trusted_len_iter(iter); + self.values_builder.extend(iter); } /// Builds the [`PrimitiveArray`] and reset this builder. @@ -300,7 +305,7 @@ impl PrimitiveBuilder { let nulls = self.null_buffer_builder.finish(); let builder = ArrayData::builder(self.data_type.clone()) .len(len) - .add_buffer(self.values_builder.finish()) + .add_buffer(std::mem::take(&mut self.values_builder).into()) .nulls(nulls); let array_data = unsafe { builder.build_unchecked() }; @@ -328,7 +333,7 @@ impl PrimitiveBuilder { /// Returns the current values buffer as a mutable slice pub fn values_slice_mut(&mut self) -> &mut [T::Native] { - self.values_builder.as_slice_mut() + self.values_builder.as_mut_slice() } /// Returns the current null buffer as a slice @@ -344,7 +349,7 @@ impl PrimitiveBuilder { /// Returns the current values buffer and null buffer as a slice pub fn slices_mut(&mut self) -> (&mut [T::Native], Option<&mut [u8]>) { ( - self.values_builder.as_slice_mut(), + self.values_builder.as_mut_slice(), self.null_buffer_builder.as_slice_mut(), ) } diff --git a/arrow-array/src/builder/primitive_dictionary_builder.rs b/arrow-array/src/builder/primitive_dictionary_builder.rs index f4a6662462e0..d9544aec3b9d 100644 --- a/arrow-array/src/builder/primitive_dictionary_builder.rs +++ b/arrow-array/src/builder/primitive_dictionary_builder.rs @@ -22,6 +22,7 @@ use crate::{ }; use arrow_buffer::{ArrowNativeType, ToByteSlice}; use arrow_schema::{ArrowError, DataType}; +use num_traits::NumCast; use std::any::Any; use std::collections::HashMap; use std::sync::Arc; @@ -169,6 +170,68 @@ where map: HashMap::with_capacity(values_capacity), } } + + /// Creates a new `PrimitiveDictionaryBuilder` from the existing builder with the same + /// keys and values, but with a new data type for the keys. + /// + /// # Example + /// ``` + /// # + /// # use arrow_array::builder::PrimitiveDictionaryBuilder; + /// # use arrow_array::types::{UInt8Type, UInt16Type, UInt64Type}; + /// # use arrow_array::UInt16Array; + /// # use arrow_schema::ArrowError; + /// + /// let mut u8_keyed_builder = PrimitiveDictionaryBuilder::::new(); + /// + /// // appending too many values causes the dictionary to overflow + /// for i in 0..256 { + /// u8_keyed_builder.append_value(i); + /// } + /// let result = u8_keyed_builder.append(256); + /// assert!(matches!(result, Err(ArrowError::DictionaryKeyOverflowError{}))); + /// + /// // we need to upgrade to a larger key type + /// let mut u16_keyed_builder = PrimitiveDictionaryBuilder::::try_new_from_builder(u8_keyed_builder).unwrap(); + /// let dictionary_array = u16_keyed_builder.finish(); + /// let keys = dictionary_array.keys(); + /// + /// assert_eq!(keys, &UInt16Array::from_iter(0..256)); + pub fn try_new_from_builder( + mut source: PrimitiveDictionaryBuilder, + ) -> Result + where + K::Native: NumCast, + K2: ArrowDictionaryKeyType, + K2::Native: NumCast, + { + let map = source.map; + let values_builder = source.values_builder; + + let source_keys = source.keys_builder.finish(); + let new_keys: PrimitiveArray = source_keys.try_unary(|value| { + num_traits::cast::cast::(value).ok_or_else(|| { + ArrowError::CastError(format!( + "Can't cast dictionary keys from source type {:?} to type {:?}", + K2::DATA_TYPE, + K::DATA_TYPE + )) + }) + })?; + + // drop source key here because currently source_keys and new_keys are holding reference to + // the same underlying null_buffer. Below we want to call new_keys.into_builder() it must + // be the only reference holder. + drop(source_keys); + + Ok(Self { + map, + keys_builder: new_keys + .into_builder() + .expect("underlying buffer has no references"), + values_builder, + }) + } } impl ArrayBuilder for PrimitiveDictionaryBuilder @@ -397,6 +460,38 @@ where DictionaryArray::from(unsafe { builder.build_unchecked() }) } + /// Builds the `DictionaryArray` without resetting the values builder or + /// the internal de-duplication map. + /// + /// The advantage of doing this is that the values will represent the entire + /// set of what has been built so-far by this builder and ensures + /// consistency in the assignment of keys to values across multiple calls + /// to `finish_preserve_values`. This enables ipc writers to efficiently + /// emit delta dictionaries. + /// + /// The downside to this is that building the record requires creating a + /// copy of the values, which can become slowly more expensive if the + /// dictionary grows. + /// + /// Additionally, if record batches from multiple different dictionary + /// builders for the same column are fed into a single ipc writer, beware + /// that entire dictionaries are likely to be re-sent frequently even when + /// the majority of the values are not used by the current record batch. + pub fn finish_preserve_values(&mut self) -> DictionaryArray { + let values = self.values_builder.finish_cloned(); + let keys = self.keys_builder.finish(); + + let data_type = DataType::Dictionary(Box::new(K::DATA_TYPE), Box::new(V::DATA_TYPE)); + + let builder = keys + .into_data() + .into_builder() + .data_type(data_type) + .child_data(vec![values.into_data()]); + + DictionaryArray::from(unsafe { builder.build_unchecked() }) + } + /// Returns the current dictionary values buffer as a slice pub fn values_slice(&self) -> &[V::Native] { self.values_builder.values_slice() @@ -428,10 +523,14 @@ impl Extend> mod tests { use super::*; - use crate::array::{Int32Array, UInt32Array, UInt8Array}; + use crate::array::{Int32Array, UInt8Array, UInt32Array}; use crate::builder::Decimal128Builder; use crate::cast::AsArray; - use crate::types::{Decimal128Type, Int32Type, UInt32Type, UInt8Type}; + use crate::types::{ + Date32Type, Decimal128Type, DurationNanosecondType, Float32Type, Float64Type, Int8Type, + Int16Type, Int32Type, Int64Type, TimestampNanosecondType, UInt8Type, UInt16Type, + UInt32Type, UInt64Type, + }; #[test] fn test_primitive_dictionary_builder() { @@ -649,4 +748,146 @@ mod tests { builder.values_builder.capacity() ) } + + fn _test_try_new_from_builder_generic_for_key_types(values: Vec) + where + K1: ArrowDictionaryKeyType, + K1::Native: NumCast, + K2: ArrowDictionaryKeyType, + K2::Native: NumCast + From, + V: ArrowPrimitiveType, + { + let mut source = PrimitiveDictionaryBuilder::::new(); + source.append(values[0]).unwrap(); + source.append_null(); + source.append(values[1]).unwrap(); + source.append(values[2]).unwrap(); + + let mut result = PrimitiveDictionaryBuilder::::try_new_from_builder(source).unwrap(); + let array = result.finish(); + + let mut expected_keys_builder = PrimitiveBuilder::::new(); + expected_keys_builder + .append_value(<::Native as From>::from(0u8)); + expected_keys_builder.append_null(); + expected_keys_builder + .append_value(<::Native as From>::from(1u8)); + expected_keys_builder + .append_value(<::Native as From>::from(2u8)); + let expected_keys = expected_keys_builder.finish(); + assert_eq!(array.keys(), &expected_keys); + + let av = array.values(); + let ava = av.as_any().downcast_ref::>().unwrap(); + assert_eq!(ava.value(0), values[0]); + assert_eq!(ava.value(1), values[1]); + assert_eq!(ava.value(2), values[2]); + } + + fn _test_try_new_from_builder_generic_for_value(values: Vec) + where + T: ArrowPrimitiveType, + { + // test cast to bigger size unsigned + _test_try_new_from_builder_generic_for_key_types::( + values.clone(), + ); + // test cast going to smaller size unsigned + _test_try_new_from_builder_generic_for_key_types::( + values.clone(), + ); + // test cast going to bigger size signed + _test_try_new_from_builder_generic_for_key_types::(values.clone()); + // test cast going to smaller size signed + _test_try_new_from_builder_generic_for_key_types::(values.clone()); + // test going from signed to signed for different size changes + _test_try_new_from_builder_generic_for_key_types::(values.clone()); + _test_try_new_from_builder_generic_for_key_types::(values.clone()); + _test_try_new_from_builder_generic_for_key_types::(values.clone()); + _test_try_new_from_builder_generic_for_key_types::(values.clone()); + } + + #[test] + fn test_try_new_from_builder() { + // test unsigned types + _test_try_new_from_builder_generic_for_value::(vec![1, 2, 3]); + _test_try_new_from_builder_generic_for_value::(vec![1, 2, 3]); + _test_try_new_from_builder_generic_for_value::(vec![1, 2, 3]); + _test_try_new_from_builder_generic_for_value::(vec![1, 2, 3]); + // test signed types + _test_try_new_from_builder_generic_for_value::(vec![-1, 0, 1]); + _test_try_new_from_builder_generic_for_value::(vec![-1, 0, 1]); + _test_try_new_from_builder_generic_for_value::(vec![-1, 0, 1]); + _test_try_new_from_builder_generic_for_value::(vec![-1, 0, 1]); + // test some date types + _test_try_new_from_builder_generic_for_value::(vec![5, 6, 7]); + _test_try_new_from_builder_generic_for_value::(vec![1, 2, 3]); + _test_try_new_from_builder_generic_for_value::(vec![1, 2, 3]); + // test some floating point types + _test_try_new_from_builder_generic_for_value::(vec![0.1, 0.2, 0.3]); + _test_try_new_from_builder_generic_for_value::(vec![-0.1, 0.2, 0.3]); + } + + #[test] + fn test_try_new_from_builder_cast_fails() { + let mut source_builder = PrimitiveDictionaryBuilder::::new(); + for i in 0..257 { + source_builder.append_value(i); + } + + // there should be too many values that we can't downcast to the underlying type + // we have keys that wouldn't fit into UInt8Type + let result = PrimitiveDictionaryBuilder::::try_new_from_builder( + source_builder, + ); + assert!(result.is_err()); + if let Err(e) = result { + assert!(matches!(e, ArrowError::CastError(_))); + assert_eq!( + e.to_string(), + "Cast error: Can't cast dictionary keys from source type UInt16 to type UInt8" + ); + } + } + + #[test] + fn test_finish_preserve_values() { + // Create the first dictionary + let mut builder = PrimitiveDictionaryBuilder::::new(); + builder.append(10).unwrap(); + builder.append(20).unwrap(); + let array = builder.finish_preserve_values(); + assert_eq!(array.keys(), &UInt8Array::from(vec![Some(0), Some(1)])); + let values: &[u32] = array + .values() + .as_any() + .downcast_ref::() + .unwrap() + .values(); + assert_eq!(values, &[10, 20]); + + // Create a new dictionary + builder.append(30).unwrap(); + builder.append(40).unwrap(); + let array2 = builder.finish_preserve_values(); + + // Make sure the keys are assigned after the old ones + // and that we have the right values + assert_eq!(array2.keys(), &UInt8Array::from(vec![Some(2), Some(3)])); + let values = array2 + .downcast_dict::() + .unwrap() + .into_iter() + .collect::>(); + assert_eq!(values, vec![Some(30), Some(40)]); + + // Check that we have all of the expected values + let all_values: &[u32] = array2 + .values() + .as_any() + .downcast_ref::() + .unwrap() + .values(); + assert_eq!(all_values, &[10, 20, 30, 40]); + } } diff --git a/arrow-array/src/builder/primitive_run_builder.rs b/arrow-array/src/builder/primitive_run_builder.rs index 1db9c91e081d..52bdaa6f40e4 100644 --- a/arrow-array/src/builder/primitive_run_builder.rs +++ b/arrow-array/src/builder/primitive_run_builder.rs @@ -17,7 +17,7 @@ use std::{any::Any, sync::Arc}; -use crate::{types::RunEndIndexType, ArrayRef, ArrowPrimitiveType, RunArray}; +use crate::{ArrayRef, ArrowPrimitiveType, RunArray, types::RunEndIndexType}; use super::{ArrayBuilder, PrimitiveBuilder}; diff --git a/arrow-array/src/builder/struct_builder.rs b/arrow-array/src/builder/struct_builder.rs index 3afee5863f52..4fb312739cb5 100644 --- a/arrow-array/src/builder/struct_builder.rs +++ b/arrow-array/src/builder/struct_builder.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::builder::*; use crate::StructArray; +use crate::builder::*; use arrow_buffer::NullBufferBuilder; use arrow_schema::{Fields, SchemaBuilder}; use std::sync::Arc; @@ -62,7 +62,7 @@ use std::sync::Arc; /// /// // We can't obtain the ListBuilder with the expected generic types, because under the hood /// // the StructBuilder was returned as a Box and passed as such to the ListBuilder constructor -/// +/// /// // This panics in runtime, even though we know that the builder is a ListBuilder. /// // let sb = col_struct_builder /// // .field_builder::>(0) @@ -201,6 +201,11 @@ impl StructBuilder { self.field_builders.len() } + /// Returns the fields for the struct this builder is building. + pub fn fields(&self) -> &Fields { + &self.fields + } + /// Appends an element (either null or non-null) to the struct. The actual elements /// should be appended for each child sub-array in a consistent way. #[inline] @@ -267,7 +272,7 @@ impl StructBuilder { let schema = builder.finish(); panic!("{}", format!( - "StructBuilder ({:?}) and field_builder with index {} ({:?}) are of unequal lengths: ({} != {}).", + "StructBuilder ({}) and field_builder with index {} ({}) are of unequal lengths: ({} != {}).", schema, idx, self.fields[idx].data_type(), @@ -440,11 +445,13 @@ mod tests { match builder { Some(builder) => { assert_eq!(builder.value_length(), LIST_LENGTH); - assert!(builder - .values() - .as_any_mut() - .downcast_mut::() - .is_some()); + assert!( + builder + .values() + .as_any_mut() + .downcast_mut::() + .is_some() + ); } None => panic!("expected FixedSizeListBuilder, got a different builder type"), } @@ -648,7 +655,7 @@ mod tests { #[test] #[should_panic( - expected = "StructBuilder (Schema { fields: [Field { name: \"f1\", data_type: Int32, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: \"f2\", data_type: Boolean, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }], metadata: {} }) and field_builder with index 1 (Boolean) are of unequal lengths: (2 != 1)." + expected = "StructBuilder (Field { \"f1\": Int32 }, Field { \"f2\": Boolean }) and field_builder with index 1 (Boolean) are of unequal lengths: (2 != 1)." )] fn test_struct_array_builder_unequal_field_builders_lengths() { let mut int_builder = Int32Builder::with_capacity(10); @@ -690,7 +697,7 @@ mod tests { #[test] #[should_panic( - expected = "Incorrect datatype for StructArray field \\\"timestamp\\\", expected Timestamp(Nanosecond, Some(\\\"UTC\\\")) got Timestamp(Nanosecond, None)" + expected = "Incorrect datatype for StructArray field \\\"timestamp\\\", expected Timestamp(ns, \\\"UTC\\\") got Timestamp(ns)" )] fn test_struct_array_mismatch_builder() { let fields = vec![Field::new( diff --git a/arrow-array/src/builder/union_builder.rs b/arrow-array/src/builder/union_builder.rs index e6184f4ac6d2..3b8934f2ebf4 100644 --- a/arrow-array/src/builder/union_builder.rs +++ b/arrow-array/src/builder/union_builder.rs @@ -15,11 +15,11 @@ // specific language governing permissions and limitations // under the License. -use crate::builder::buffer_builder::{Int32BufferBuilder, Int8BufferBuilder}; -use crate::builder::BufferBuilder; -use crate::{make_array, ArrowPrimitiveType, UnionArray}; +use crate::builder::buffer_builder::{Int8BufferBuilder, Int32BufferBuilder}; +use crate::builder::{ArrayBuilder, BufferBuilder}; +use crate::{ArrayRef, ArrowPrimitiveType, UnionArray, make_array}; use arrow_buffer::NullBufferBuilder; -use arrow_buffer::{ArrowNativeType, Buffer}; +use arrow_buffer::{ArrowNativeType, Buffer, ScalarBuffer}; use arrow_data::ArrayDataBuilder; use arrow_schema::{ArrowError, DataType, Field}; use std::any::Any; @@ -42,12 +42,14 @@ struct FieldData { } /// A type-erased [`BufferBuilder`] used by [`FieldData`] -trait FieldDataValues: std::fmt::Debug { +trait FieldDataValues: std::fmt::Debug + Send + Sync { fn as_mut_any(&mut self) -> &mut dyn Any; fn append_null(&mut self); fn finish(&mut self) -> Buffer; + + fn finish_cloned(&self) -> Buffer; } impl FieldDataValues for BufferBuilder { @@ -62,6 +64,10 @@ impl FieldDataValues for BufferBuilder { fn finish(&mut self) -> Buffer { self.finish() } + + fn finish_cloned(&self) -> Buffer { + Buffer::from_slice_ref(self.as_slice()) + } } impl FieldData { @@ -138,7 +144,7 @@ impl FieldData { /// assert_eq!(union.value_offset(1), 1); /// assert_eq!(union.value_offset(2), 2); /// ``` -#[derive(Debug)] +#[derive(Debug, Default)] pub struct UnionBuilder { /// The current number of slots in the array len: usize, @@ -310,4 +316,172 @@ impl UnionBuilder { children, ) } + + /// Builds this builder creating a new `UnionArray` without consuming the builder. + /// + /// This is used for the `finish_cloned` implementation in `ArrayBuilder`. + fn build_cloned(&self) -> Result { + let mut children = Vec::with_capacity(self.fields.len()); + let union_fields: Vec<_> = self + .fields + .iter() + .map(|(name, field_data)| { + let FieldData { + type_id, + data_type, + values_buffer, + slots, + null_buffer_builder, + } = field_data; + + let array_ref = make_array(unsafe { + ArrayDataBuilder::new(data_type.clone()) + .add_buffer(values_buffer.finish_cloned()) + .len(*slots) + .nulls(null_buffer_builder.finish_cloned()) + .build_unchecked() + }); + children.push(array_ref); + ( + *type_id, + Arc::new(Field::new(name.clone(), data_type.clone(), false)), + ) + }) + .collect(); + UnionArray::try_new( + union_fields.into_iter().collect(), + ScalarBuffer::from(self.type_id_builder.as_slice().to_vec()), + self.value_offset_builder + .as_ref() + .map(|builder| ScalarBuffer::from(builder.as_slice().to_vec())), + children, + ) + } +} + +impl ArrayBuilder for UnionBuilder { + /// Returns the number of array slots in the builder + fn len(&self) -> usize { + self.len + } + + /// Builds the array + fn finish(&mut self) -> ArrayRef { + // Even simpler - just move the builder using mem::take and replace with default + let builder = std::mem::take(self); + + // Since UnionBuilder controls all invariants, this should never fail + Arc::new(builder.build().unwrap()) + } + + /// Builds the array without resetting the underlying builder + fn finish_cloned(&self) -> ArrayRef { + // We construct the UnionArray carefully to ensure try_new cannot fail. + // Since UnionBuilder controls all the invariants, this should never panic. + Arc::new(self.build_cloned().unwrap_or_else(|err| { + panic!("UnionBuilder::build_cloned failed unexpectedly: {}", err) + })) + } + + /// Returns the builder as a non-mutable `Any` reference + fn as_any(&self) -> &dyn Any { + self + } + + /// Returns the builder as a mutable `Any` reference + fn as_any_mut(&mut self) -> &mut dyn Any { + self + } + + /// Returns the boxed builder as a box of `Any` + fn into_box_any(self: Box) -> Box { + self + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::array::Array; + use crate::cast::AsArray; + use crate::types::{Float64Type, Int32Type}; + + #[test] + fn test_union_builder_array_builder_trait() { + // Test that UnionBuilder implements ArrayBuilder trait + let mut builder = UnionBuilder::new_dense(); + + // Add some data + builder.append::("a", 1).unwrap(); + builder.append::("b", 3.0).unwrap(); + builder.append::("a", 4).unwrap(); + + assert_eq!(builder.len(), 3); + + // Test finish_cloned (non-destructive) + let array1 = builder.finish_cloned(); + assert_eq!(array1.len(), 3); + + // Verify values in cloned array + let union1 = array1.as_any().downcast_ref::().unwrap(); + assert_eq!(union1.type_ids(), &[0, 1, 0]); + assert_eq!(union1.offsets().unwrap().as_ref(), &[0, 0, 1]); + let int_array1 = union1.child(0).as_primitive::(); + let float_array1 = union1.child(1).as_primitive::(); + assert_eq!(int_array1.value(0), 1); + assert_eq!(int_array1.value(1), 4); + assert_eq!(float_array1.value(0), 3.0); + + // Builder should still be usable after finish_cloned + builder.append::("b", 5.0).unwrap(); + assert_eq!(builder.len(), 4); + + // Test finish (destructive) + let array2 = builder.finish(); + assert_eq!(array2.len(), 4); + + // Verify values in final array + let union2 = array2.as_any().downcast_ref::().unwrap(); + assert_eq!(union2.type_ids(), &[0, 1, 0, 1]); + assert_eq!(union2.offsets().unwrap().as_ref(), &[0, 0, 1, 1]); + let int_array2 = union2.child(0).as_primitive::(); + let float_array2 = union2.child(1).as_primitive::(); + assert_eq!(int_array2.value(0), 1); + assert_eq!(int_array2.value(1), 4); + assert_eq!(float_array2.value(0), 3.0); + assert_eq!(float_array2.value(1), 5.0); + } + + #[test] + fn test_union_builder_type_erased() { + // Test type-erased usage with Box + let mut builders: Vec> = vec![Box::new(UnionBuilder::new_sparse())]; + + // Downcast and use + let union_builder = builders[0] + .as_any_mut() + .downcast_mut::() + .unwrap(); + union_builder.append::("x", 10).unwrap(); + union_builder.append::("y", 20.0).unwrap(); + + assert_eq!(builders[0].len(), 2); + + let result = builders + .into_iter() + .map(|mut b| b.finish()) + .collect::>(); + assert_eq!(result[0].len(), 2); + + // Verify sparse union values + let union = result[0].as_any().downcast_ref::().unwrap(); + assert_eq!(union.type_ids(), &[0, 1]); + assert!(union.offsets().is_none()); // Sparse union has no offsets + let int_array = union.child(0).as_primitive::(); + let float_array = union.child(1).as_primitive::(); + assert_eq!(int_array.value(0), 10); + assert!(int_array.is_null(1)); // Null in sparse layout + assert!(float_array.is_null(0)); // Null in sparse layout + assert_eq!(float_array.value(1), 20.0); + } } diff --git a/arrow-array/src/cast.rs b/arrow-array/src/cast.rs index c9b92efe6c0e..de590ff87c77 100644 --- a/arrow-array/src/cast.rs +++ b/arrow-array/src/cast.rs @@ -365,6 +365,12 @@ macro_rules! downcast_primitive { $crate::repeat_pat!($crate::cast::__private::DataType::Float64, $($data_type),+) => { $m!($crate::types::Float64Type $(, $args)*) } + $crate::repeat_pat!($crate::cast::__private::DataType::Decimal32(_, _), $($data_type),+) => { + $m!($crate::types::Decimal32Type $(, $args)*) + } + $crate::repeat_pat!($crate::cast::__private::DataType::Decimal64(_, _), $($data_type),+) => { + $m!($crate::types::Decimal64Type $(, $args)*) + } $crate::repeat_pat!($crate::cast::__private::DataType::Decimal128(_, _), $($data_type),+) => { $m!($crate::types::Decimal128Type $(, $args)*) } @@ -1126,6 +1132,18 @@ mod tests { assert!(!as_string_array(&array).is_empty()) } + #[test] + fn test_decimal32array() { + let a = Decimal32Array::from_iter_values([1, 2, 4, 5]); + assert!(!as_primitive_array::(&a).is_empty()); + } + + #[test] + fn test_decimal64array() { + let a = Decimal64Array::from_iter_values([1, 2, 4, 5]); + assert!(!as_primitive_array::(&a).is_empty()); + } + #[test] fn test_decimal128array() { let a = Decimal128Array::from_iter_values([1, 2, 4, 5]); diff --git a/arrow-array/src/ffi.rs b/arrow-array/src/ffi.rs index ac28289e652b..f50dd3420baa 100644 --- a/arrow-array/src/ffi.rs +++ b/arrow-array/src/ffi.rs @@ -103,9 +103,9 @@ To export an array, create an `ArrowArray` using [ArrowArray::try_new]. use std::{mem::size_of, ptr::NonNull, sync::Arc}; -use arrow_buffer::{bit_util, Buffer, MutableBuffer}; +use arrow_buffer::{Buffer, MutableBuffer, bit_util}; pub use arrow_data::ffi::FFI_ArrowArray; -use arrow_data::{layout, ArrayData}; +use arrow_data::{ArrayData, layout}; pub use arrow_schema::ffi::FFI_ArrowSchema; use arrow_schema::{ArrowError, DataType, UnionMode}; @@ -134,23 +134,23 @@ pub unsafe fn export_array_into_raw( let array = FFI_ArrowArray::new(&data); let schema = FFI_ArrowSchema::try_from(data.data_type())?; - std::ptr::write_unaligned(out_array, array); - std::ptr::write_unaligned(out_schema, schema); + unsafe { std::ptr::write_unaligned(out_array, array) }; + unsafe { std::ptr::write_unaligned(out_schema, schema) }; Ok(()) } -// returns the number of bits that buffer `i` (in the C data interface) is expected to have. -// This is set by the Arrow specification +/// returns the number of bits that buffer `i` (in the C data interface) is expected to have. +/// This is set by the Arrow specification fn bit_width(data_type: &DataType, i: usize) -> Result { if let Some(primitive) = data_type.primitive_width() { return match i { 0 => Err(ArrowError::CDataInterface(format!( - "The datatype \"{data_type:?}\" doesn't expect buffer at index 0. Please verify that the C data interface is correctly implemented." + "The datatype \"{data_type}\" doesn't expect buffer at index 0. Please verify that the C data interface is correctly implemented." ))), 1 => Ok(primitive * 8), i => Err(ArrowError::CDataInterface(format!( - "The datatype \"{data_type:?}\" expects 2 buffers, but requested {i}. Please verify that the C data interface is correctly implemented." + "The datatype \"{data_type}\" expects 2 buffers, but requested {i}. Please verify that the C data interface is correctly implemented." ))), }; } @@ -159,75 +159,84 @@ fn bit_width(data_type: &DataType, i: usize) -> Result { (DataType::Boolean, 1) => 1, (DataType::Boolean, _) => { return Err(ArrowError::CDataInterface(format!( - "The datatype \"{data_type:?}\" expects 2 buffers, but requested {i}. Please verify that the C data interface is correctly implemented." - ))) + "The datatype \"{data_type}\" expects 2 buffers, but requested {i}. Please verify that the C data interface is correctly implemented." + ))); } (DataType::FixedSizeBinary(num_bytes), 1) => *num_bytes as usize * u8::BITS as usize, (DataType::FixedSizeList(f, num_elems), 1) => { let child_bit_width = bit_width(f.data_type(), 1)?; child_bit_width * (*num_elems as usize) - }, + } (DataType::FixedSizeBinary(_), _) | (DataType::FixedSizeList(_, _), _) => { return Err(ArrowError::CDataInterface(format!( - "The datatype \"{data_type:?}\" expects 2 buffers, but requested {i}. Please verify that the C data interface is correctly implemented." - ))) - }, + "The datatype \"{data_type}\" expects 2 buffers, but requested {i}. Please verify that the C data interface is correctly implemented." + ))); + } // Variable-size list and map have one i32 buffer. // Variable-sized binaries: have two buffers. // "small": first buffer is i32, second is in bytes - (DataType::Utf8, 1) | (DataType::Binary, 1) | (DataType::List(_), 1) | (DataType::Map(_, _), 1) => i32::BITS as _, + (DataType::Utf8, 1) + | (DataType::Binary, 1) + | (DataType::List(_), 1) + | (DataType::Map(_, _), 1) => i32::BITS as _, (DataType::Utf8, 2) | (DataType::Binary, 2) => u8::BITS as _, + // List views have two i32 buffers, offsets and sizes + (DataType::ListView(_), 1) | (DataType::ListView(_), 2) => i32::BITS as _, + // Large list views have two i64 buffers, offsets and sizes + (DataType::LargeListView(_), 1) | (DataType::LargeListView(_), 2) => i64::BITS as _, (DataType::List(_), _) | (DataType::Map(_, _), _) => { return Err(ArrowError::CDataInterface(format!( - "The datatype \"{data_type:?}\" expects 2 buffers, but requested {i}. Please verify that the C data interface is correctly implemented." - ))) + "The datatype \"{data_type}\" expects 2 buffers, but requested {i}. Please verify that the C data interface is correctly implemented." + ))); } (DataType::Utf8, _) | (DataType::Binary, _) => { return Err(ArrowError::CDataInterface(format!( - "The datatype \"{data_type:?}\" expects 3 buffers, but requested {i}. Please verify that the C data interface is correctly implemented." - ))) + "The datatype \"{data_type}\" expects 3 buffers, but requested {i}. Please verify that the C data interface is correctly implemented." + ))); } // Variable-sized binaries: have two buffers. // LargeUtf8: first buffer is i64, second is in bytes - (DataType::LargeUtf8, 1) | (DataType::LargeBinary, 1) | (DataType::LargeList(_), 1) => i64::BITS as _, - (DataType::LargeUtf8, 2) | (DataType::LargeBinary, 2) | (DataType::LargeList(_), 2)=> u8::BITS as _, - (DataType::LargeUtf8, _) | (DataType::LargeBinary, _) | (DataType::LargeList(_), _)=> { + (DataType::LargeUtf8, 1) | (DataType::LargeBinary, 1) | (DataType::LargeList(_), 1) => { + i64::BITS as _ + } + (DataType::LargeUtf8, 2) | (DataType::LargeBinary, 2) | (DataType::LargeList(_), 2) => { + u8::BITS as _ + } + (DataType::LargeUtf8, _) | (DataType::LargeBinary, _) | (DataType::LargeList(_), _) => { return Err(ArrowError::CDataInterface(format!( - "The datatype \"{data_type:?}\" expects 3 buffers, but requested {i}. Please verify that the C data interface is correctly implemented." - ))) + "The datatype \"{data_type}\" expects 3 buffers, but requested {i}. Please verify that the C data interface is correctly implemented." + ))); } // Variable-sized views: have 3 or more buffers. // Buffer 1 are the u128 views // Buffers 2...N-1 are u8 byte buffers - (DataType::Utf8View, 1) | (DataType::BinaryView,1) => u128::BITS as _, - (DataType::Utf8View, _) | (DataType::BinaryView, _) => { - u8::BITS as _ - } + (DataType::Utf8View, 1) | (DataType::BinaryView, 1) => u128::BITS as _, + (DataType::Utf8View, _) | (DataType::BinaryView, _) => u8::BITS as _, // type ids. UnionArray doesn't have null bitmap so buffer index begins with 0. (DataType::Union(_, _), 0) => i8::BITS as _, // Only DenseUnion has 2nd buffer (DataType::Union(_, UnionMode::Dense), 1) => i32::BITS as _, (DataType::Union(_, UnionMode::Sparse), _) => { return Err(ArrowError::CDataInterface(format!( - "The datatype \"{data_type:?}\" expects 1 buffer, but requested {i}. Please verify that the C data interface is correctly implemented." - ))) + "The datatype \"{data_type}\" expects 1 buffer, but requested {i}. Please verify that the C data interface is correctly implemented." + ))); } (DataType::Union(_, UnionMode::Dense), _) => { return Err(ArrowError::CDataInterface(format!( - "The datatype \"{data_type:?}\" expects 2 buffer, but requested {i}. Please verify that the C data interface is correctly implemented." - ))) + "The datatype \"{data_type}\" expects 2 buffer, but requested {i}. Please verify that the C data interface is correctly implemented." + ))); } (_, 0) => { // We don't call this `bit_width` to compute buffer length for null buffer. If any types that don't have null buffer like // UnionArray, they should be handled above. return Err(ArrowError::CDataInterface(format!( - "The datatype \"{data_type:?}\" doesn't expect buffer at index 0. Please verify that the C data interface is correctly implemented." - ))) + "The datatype \"{data_type}\" doesn't expect buffer at index 0. Please verify that the C data interface is correctly implemented." + ))); } _ => { return Err(ArrowError::CDataInterface(format!( - "The datatype \"{data_type:?}\" is still not supported in Rust implementation" - ))) + "The datatype \"{data_type}\" is still not supported in Rust implementation" + ))); } }) } @@ -249,7 +258,7 @@ unsafe fn create_buffer( return None; } NonNull::new(array.buffer(index) as _) - .map(|ptr| Buffer::from_custom_allocation(ptr, len, owner)) + .map(|ptr| unsafe { Buffer::from_custom_allocation(ptr, len, owner) }) } /// Export to the C Data Interface @@ -346,6 +355,8 @@ impl ImportedArrowArray<'_> { DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) + | DataType::ListView(field) + | DataType::LargeListView(field) | DataType::Map(field, _) => Ok([self.consume_child(0, field.data_type())?].to_vec()), DataType::Struct(fields) => { assert!(fields.len() == self.array.num_children()); @@ -408,7 +419,17 @@ impl ImportedArrowArray<'_> { .map(|index| { let len = self.buffer_len(index, variadic_buffer_lens, &self.data_type)?; match unsafe { create_buffer(self.owner.clone(), self.array, index, len) } { - Some(buf) => Ok(buf), + Some(buf) => { + // External libraries may use a dangling pointer for a buffer with length 0. + // We respect the array length specified in the C Data Interface. Actually, + // if the length is incorrect, we cannot create a correct buffer even if + // the pointer is valid. + if buf.is_empty() { + Ok(MutableBuffer::new(0).into()) + } else { + Ok(buf) + } + } None if len == 0 => { // Null data buffer, which Rust doesn't allow. So create // an empty buffer. @@ -456,6 +477,14 @@ impl ImportedArrowArray<'_> { debug_assert_eq!(bits % 8, 0); (length + 1) * (bits / 8) } + (DataType::ListView(_), 1) + | (DataType::ListView(_), 2) + | (DataType::LargeListView(_), 1) + | (DataType::LargeListView(_), 2) => { + let bits = bit_width(data_type, i)?; + debug_assert_eq!(bits % 8, 0); + length * (bits / 8) + } (DataType::Utf8, 2) | (DataType::Binary, 2) => { if self.array.is_empty() { return Ok(0); @@ -515,7 +544,7 @@ impl ImportedArrowArray<'_> { unsafe { create_buffer(self.owner.clone(), self.array, 0, buffer_len) } } - fn dictionary(&self) -> Result> { + fn dictionary(&self) -> Result>> { match (self.array.dictionary(), &self.data_type) { (Some(array), DataType::Dictionary(_, value_type)) => Ok(Some(ImportedArrowArray { array, @@ -538,12 +567,12 @@ mod tests_to_then_from_ffi { use std::collections::HashMap; use std::mem::ManuallyDrop; - use arrow_buffer::NullBuffer; + use arrow_buffer::{ArrowNativeType, NullBuffer}; use arrow_schema::Field; use crate::builder::UnionBuilder; use crate::cast::AsArray; - use crate::types::{Float64Type, Int32Type, Int8Type}; + use crate::types::{Float64Type, Int8Type, Int32Type}; use crate::*; use super::*; @@ -768,6 +797,71 @@ mod tests_to_then_from_ffi { test_generic_list::() } + fn test_generic_list_view() -> Result<()> { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int16) + .len(8) + .add_buffer(Buffer::from_slice_ref([0_i16, 1, 2, 3, 4, 5, 6, 7])) + .build() + .unwrap(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1, 2], [3, 4, 5], [6, 7]] + let value_offsets = [0_usize, 3, 6] + .iter() + .map(|i| Offset::from_usize(*i).unwrap()) + .collect::(); + + let sizes_buffer = [3_usize, 3, 2] + .iter() + .map(|i| Offset::from_usize(*i).unwrap()) + .collect::(); + + // Construct a list array from the above two + let list_view_dt = GenericListViewArray::::DATA_TYPE_CONSTRUCTOR(Arc::new( + Field::new_list_field(DataType::Int16, false), + )); + + let list_data = ArrayData::builder(list_view_dt) + .len(3) + .add_buffer(value_offsets) + .add_buffer(sizes_buffer) + .add_child_data(value_data) + .build() + .unwrap(); + + let original = GenericListViewArray::::from(list_data.clone()); + + // export it + let (array, schema) = to_ffi(&original.to_data())?; + + // (simulate consumer) import it + let data = unsafe { from_ffi(array, &schema) }?; + let array = make_array(data); + + // downcast + let array = array + .as_any() + .downcast_ref::>() + .unwrap(); + + assert_eq!(&array.value(0), &original.value(0)); + assert_eq!(&array.value(1), &original.value(1)); + assert_eq!(&array.value(2), &original.value(2)); + + Ok(()) + } + + #[test] + fn test_list_view() -> Result<()> { + test_generic_list_view::() + } + + #[test] + fn test_large_list_view() -> Result<()> { + test_generic_list_view::() + } + fn test_generic_binary() -> Result<()> { // create an array natively let array: Vec> = vec![Some(b"a"), None, Some(b"aaa")]; @@ -1296,23 +1390,32 @@ mod tests_to_then_from_ffi { #[cfg(test)] mod tests_from_ffi { + #[cfg(not(feature = "force_validate"))] + use std::ptr::NonNull; use std::sync::Arc; + use arrow_buffer::NullBuffer; + #[cfg(not(feature = "force_validate"))] + use arrow_buffer::{ScalarBuffer, bit_util, buffer::Buffer}; + #[cfg(feature = "force_validate")] use arrow_buffer::{bit_util, buffer::Buffer}; - use arrow_data::transform::MutableArrayData; + use arrow_data::ArrayData; + use arrow_data::transform::MutableArrayData; use arrow_schema::{DataType, Field}; use super::Result; + use crate::builder::GenericByteViewBuilder; use crate::types::{BinaryViewType, ByteViewType, Int32Type, StringViewType}; use crate::{ + ArrayRef, GenericByteViewArray, ListArray, array::{ Array, BooleanArray, DictionaryArray, FixedSizeBinaryArray, FixedSizeListArray, Int32Array, Int64Array, StringArray, StructArray, UInt32Array, UInt64Array, }, - ffi::{from_ffi, FFI_ArrowArray, FFI_ArrowSchema}, - make_array, ArrayRef, GenericByteViewArray, ListArray, + ffi::{FFI_ArrowArray, FFI_ArrowSchema, from_ffi}, + make_array, }; fn test_round_trip(expected: &ArrayData) -> Result<()> { @@ -1506,6 +1609,65 @@ mod tests_from_ffi { test_round_trip(&data) } + #[test] + fn test_list_view() -> Result<()> { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int16) + .len(8) + .add_buffer(Buffer::from_slice_ref([0_i16, 1, 2, 3, 4, 5, 6, 7])) + .build() + .unwrap(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1, 2], [3, 4, 5], [6, 7]] + let value_offsets = Buffer::from(vec![0_i32, 3, 6]); + let sizes_buffer = Buffer::from(vec![3_i32, 3, 2]); + + // Construct a list array from the above two + let list_view_dt = + DataType::ListView(Arc::new(Field::new_list_field(DataType::Int16, false))); + + let list_view_data = ArrayData::builder(list_view_dt) + .len(3) + .add_buffer(value_offsets) + .add_buffer(sizes_buffer) + .add_child_data(value_data) + .build() + .unwrap(); + + test_round_trip(&list_view_data) + } + + #[test] + fn test_list_view_with_nulls() -> Result<()> { + // Construct a value array + let value_data = ArrayData::builder(DataType::Int16) + .len(8) + .add_buffer(Buffer::from_slice_ref([0_i16, 1, 2, 3, 4, 5, 6, 7])) + .build() + .unwrap(); + + // Construct a buffer for value offsets, for the nested array: + // [[0, 1, 2], [3, 4, 5], [6, 7], null] + let value_offsets = Buffer::from(vec![0_i32, 3, 6, 8]); + let sizes_buffer = Buffer::from(vec![3_i32, 3, 2, 0]); + + // Construct a list array from the above two + let list_view_dt = + DataType::ListView(Arc::new(Field::new_list_field(DataType::Int16, true))); + + let list_view_data = ArrayData::builder(list_view_dt) + .len(4) + .add_buffer(value_offsets) + .add_buffer(sizes_buffer) + .add_child_data(value_data) + .nulls(Some(NullBuffer::from(vec![true, true, true, false]))) + .build() + .unwrap(); + + test_round_trip(&list_view_data) + } + #[test] #[cfg(not(feature = "force_validate"))] fn test_empty_string_with_non_zero_offset() -> Result<()> { @@ -1576,7 +1738,7 @@ mod tests_from_ffi { let mut strings = vec![]; for i in 0..1000 { - strings.push(format!("string: {}", i)); + strings.push(format!("string: {i}")); } let string_array = StringArray::from(strings); @@ -1660,6 +1822,25 @@ mod tests_from_ffi { } } + #[test] + #[cfg(not(feature = "force_validate"))] + fn test_utf8_view_ffi_from_dangling_pointer() { + let empty = GenericByteViewBuilder::::new().finish(); + let buffers = empty.data_buffers().to_vec(); + let nulls = empty.nulls().cloned(); + + // Create a dangling pointer to a view buffer with zero length. + let alloc = Arc::new(1); + let buffer = unsafe { Buffer::from_custom_allocation(NonNull::::dangling(), 0, alloc) }; + let views = unsafe { ScalarBuffer::new_unchecked(buffer) }; + + let str_view: GenericByteViewArray = + unsafe { GenericByteViewArray::new_unchecked(views, buffers, nulls) }; + let imported = roundtrip_byte_view_array(str_view); + assert_eq!(imported.len(), 0); + assert_eq!(&imported, &empty); + } + #[test] fn test_round_trip_byte_view() { fn test_case() diff --git a/arrow-array/src/ffi_stream.rs b/arrow-array/src/ffi_stream.rs index 3d4e89e80b89..c46943682914 100644 --- a/arrow-array/src/ffi_stream.rs +++ b/arrow-array/src/ffi_stream.rs @@ -64,7 +64,7 @@ use std::{ }; use arrow_data::ffi::FFI_ArrowArray; -use arrow_schema::{ffi::FFI_ArrowSchema, ArrowError, Schema, SchemaRef}; +use arrow_schema::{ArrowError, Schema, SchemaRef, ffi::FFI_ArrowSchema}; use crate::array::Array; use crate::array::StructArray; @@ -105,13 +105,13 @@ unsafe extern "C" fn release_stream(stream: *mut FFI_ArrowArrayStream) { if stream.is_null() { return; } - let stream = &mut *stream; + let stream = unsafe { &mut *stream }; stream.get_schema = None; stream.get_next = None; stream.get_last_error = None; - let private_data = Box::from_raw(stream.private_data as *mut StreamPrivateData); + let private_data = unsafe { Box::from_raw(stream.private_data as *mut StreamPrivateData) }; drop(private_data); stream.release = None; @@ -188,7 +188,7 @@ impl FFI_ArrowArrayStream { /// [move]: https://arrow.apache.org/docs/format/CDataInterface.html#moving-an-array /// [valid]: https://doc.rust-lang.org/std/ptr/index.html#safety pub unsafe fn from_raw(raw_stream: *mut FFI_ArrowArrayStream) -> Self { - std::ptr::replace(raw_stream, Self::empty()) + unsafe { std::ptr::replace(raw_stream, Self::empty()) } } /// Creates a new empty [FFI_ArrowArrayStream]. Used to import from the C Stream Interface. @@ -330,7 +330,7 @@ impl ArrowArrayStreamReader { /// /// See [`FFI_ArrowArrayStream::from_raw`] pub unsafe fn from_raw(raw_stream: *mut FFI_ArrowArrayStream) -> Result { - Self::try_new(FFI_ArrowArrayStream::from_raw(raw_stream)) + Self::try_new(unsafe { FFI_ArrowArrayStream::from_raw(raw_stream) }) } /// Get the last error from `ArrowArrayStreamReader` @@ -364,7 +364,9 @@ impl Iterator for ArrowArrayStreamReader { let result = unsafe { from_ffi_and_data_type(array, DataType::Struct(self.schema().fields().clone())) }; - Some(result.map(|data| RecordBatch::from(StructArray::from(data)))) + Some(result.and_then(|data| { + RecordBatch::try_new(self.schema.clone(), StructArray::from(data).into_parts().1) + })) } else { let last_error = self.get_stream_last_error(); let err = ArrowError::CDataInterface(last_error.unwrap()); @@ -382,6 +384,7 @@ impl RecordBatchReader for ArrowArrayStreamReader { #[cfg(test)] mod tests { use super::*; + use std::collections::HashMap; use arrow_schema::Field; @@ -417,11 +420,18 @@ mod tests { } fn _test_round_trip_export(arrays: Vec>) -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", arrays[0].data_type().clone(), true), - Field::new("b", arrays[1].data_type().clone(), true), - Field::new("c", arrays[2].data_type().clone(), true), - ])); + let metadata = HashMap::from([("foo".to_owned(), "bar".to_owned())]); + let schema = Arc::new(Schema::new_with_metadata( + vec![ + Field::new("a", arrays[0].data_type().clone(), true) + .with_metadata(metadata.clone()), + Field::new("b", arrays[1].data_type().clone(), true) + .with_metadata(metadata.clone()), + Field::new("c", arrays[2].data_type().clone(), true) + .with_metadata(metadata.clone()), + ], + metadata, + )); let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap(); let iter = Box::new(vec![batch.clone(), batch.clone()].into_iter().map(Ok)) as _; @@ -452,7 +462,11 @@ mod tests { let array = unsafe { from_ffi(ffi_array, &ffi_schema) }.unwrap(); - let record_batch = RecordBatch::from(StructArray::from(array)); + let record_batch = RecordBatch::try_new( + SchemaRef::from(exported_schema.clone()), + StructArray::from(array).into_parts().1, + ) + .unwrap(); produced_batches.push(record_batch); } @@ -462,11 +476,18 @@ mod tests { } fn _test_round_trip_import(arrays: Vec>) -> Result<()> { - let schema = Arc::new(Schema::new(vec![ - Field::new("a", arrays[0].data_type().clone(), true), - Field::new("b", arrays[1].data_type().clone(), true), - Field::new("c", arrays[2].data_type().clone(), true), - ])); + let metadata = HashMap::from([("foo".to_owned(), "bar".to_owned())]); + let schema = Arc::new(Schema::new_with_metadata( + vec![ + Field::new("a", arrays[0].data_type().clone(), true) + .with_metadata(metadata.clone()), + Field::new("b", arrays[1].data_type().clone(), true) + .with_metadata(metadata.clone()), + Field::new("c", arrays[2].data_type().clone(), true) + .with_metadata(metadata.clone()), + ], + metadata, + )); let batch = RecordBatch::try_new(schema.clone(), arrays).unwrap(); let iter = Box::new(vec![batch.clone(), batch.clone()].into_iter().map(Ok)) as _; diff --git a/arrow-array/src/iterator.rs b/arrow-array/src/iterator.rs index 6708da3d5dd6..c281231a2e79 100644 --- a/arrow-array/src/iterator.rs +++ b/arrow-array/src/iterator.rs @@ -44,7 +44,7 @@ use arrow_buffer::NullBuffer; /// [`PrimitiveArray`]: crate::PrimitiveArray /// [`compute::unary`]: https://docs.rs/arrow/latest/arrow/compute/fn.unary.html /// [`compute::try_unary`]: https://docs.rs/arrow/latest/arrow/compute/fn.try_unary.html -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ArrayIter { array: T, logical_nulls: Option, @@ -56,7 +56,7 @@ impl ArrayIter { /// create a new iterator pub fn new(array: T) -> Self { let len = array.len(); - let logical_nulls = array.logical_nulls(); + let logical_nulls = array.logical_nulls().filter(|x| x.null_count() > 0); ArrayIter { array, logical_nulls, @@ -98,10 +98,42 @@ impl Iterator for ArrayIter { fn size_hint(&self) -> (usize, Option) { ( - self.array.len() - self.current, - Some(self.array.len() - self.current), + self.current_end - self.current, + Some(self.current_end - self.current), ) } + + #[inline] + fn nth(&mut self, n: usize) -> Option { + // Check if we can advance to the desired offset + match self.current.checked_add(n) { + // Yes, and still within bounds + Some(new_current) if new_current < self.current_end => { + self.current = new_current; + } + + // Either overflow or would exceed current_end + _ => { + self.current = self.current_end; + return None; + } + } + + self.next() + } + + #[inline] + fn last(mut self) -> Option { + self.next_back() + } + + #[inline] + fn count(self) -> usize + where + Self: Sized, + { + self.len() + } } impl DoubleEndedIterator for ArrayIter { @@ -122,6 +154,25 @@ impl DoubleEndedIterator for ArrayIter { }) } } + + #[inline] + fn nth_back(&mut self, n: usize) -> Option { + // Check if we advance to the one before the desired offset + match self.current_end.checked_sub(n) { + // Yes, and still within bounds + Some(new_offset) if self.current < new_offset => { + self.current_end = new_offset; + } + + // Either underflow or would exceed current + _ => { + self.current = self.current_end; + return None; + } + } + + self.next_back() + } } /// all arrays have known size. @@ -147,9 +198,12 @@ pub type MapArrayIter<'a> = ArrayIter<&'a MapArray>; pub type GenericListViewArrayIter<'a, O> = ArrayIter<&'a GenericListViewArray>; #[cfg(test)] mod tests { - use std::sync::Arc; - use crate::array::{ArrayRef, BinaryArray, BooleanArray, Int32Array, StringArray}; + use crate::iterator::ArrayIter; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + use std::fmt::Debug; + use std::sync::Arc; #[test] fn test_primitive_array_iter_round_trip() { @@ -264,4 +318,875 @@ mod tests { // check if ExactSizeIterator is implemented let _ = array.iter().rposition(|opt_b| opt_b == Some(true)); } + + trait SharedBetweenArrayIterAndSliceIter: + ExactSizeIterator> + DoubleEndedIterator> + Clone + { + } + impl> + DoubleEndedIterator>> + SharedBetweenArrayIterAndSliceIter for T + { + } + + fn get_int32_iterator_cases() -> impl Iterator>)> { + let mut rng = StdRng::seed_from_u64(42); + + let no_nulls_and_no_duplicates = (0..10).map(Some).collect::>>(); + let no_nulls_random_values = (0..10) + .map(|_| rng.random::()) + .map(Some) + .collect::>>(); + + let all_nulls = (0..10).map(|_| None).collect::>>(); + let only_start_nulls = (0..10) + .map(|item| if item < 4 { None } else { Some(item) }) + .collect::>>(); + let only_end_nulls = (0..10) + .map(|item| if item > 8 { None } else { Some(item) }) + .collect::>>(); + let only_middle_nulls = (0..10) + .map(|item| { + if (4..=8).contains(&item) && rng.random_bool(0.9) { + None + } else { + Some(item) + } + }) + .collect::>>(); + let random_values_with_random_nulls = (0..10) + .map(|_| { + if rng.random_bool(0.3) { + None + } else { + Some(rng.random::()) + } + }) + .collect::>>(); + + let no_nulls_and_some_duplicates = (0..10) + .map(|item| item % 3) + .map(Some) + .collect::>>(); + let no_nulls_and_all_same_value = + (0..10).map(|_| 9).map(Some).collect::>>(); + let no_nulls_and_continues_duplicates = [0, 0, 0, 1, 1, 2, 2, 2, 2, 3] + .map(Some) + .into_iter() + .collect::>>(); + + let single_null_and_no_duplicates = (0..10) + .map(|item| if item == 4 { None } else { Some(item) }) + .collect::>>(); + let multiple_nulls_and_no_duplicates = (0..10) + .map(|item| if item % 3 == 2 { None } else { Some(item) }) + .collect::>>(); + let continues_nulls_and_no_duplicates = [ + Some(0), + Some(1), + None, + None, + Some(2), + Some(3), + None, + Some(4), + Some(5), + None, + ] + .into_iter() + .collect::>>(); + + [ + no_nulls_and_no_duplicates, + no_nulls_random_values, + no_nulls_and_some_duplicates, + no_nulls_and_all_same_value, + no_nulls_and_continues_duplicates, + all_nulls, + only_start_nulls, + only_end_nulls, + only_middle_nulls, + random_values_with_random_nulls, + single_null_and_no_duplicates, + multiple_nulls_and_no_duplicates, + continues_nulls_and_no_duplicates, + ] + .map(|case| (Int32Array::from(case.clone()), case)) + .into_iter() + } + + trait SetupIter { + fn description(&self) -> String; + fn setup(&self, iter: &mut I); + } + + struct NoSetup; + impl SetupIter for NoSetup { + fn description(&self) -> String { + "no setup".to_string() + } + fn setup(&self, _iter: &mut I) { + // none + } + } + + fn setup_and_assert_cases_on_single_operation( + o: &impl ConsumingArrayIteratorOp, + setup_iterator: impl SetupIter, + ) { + for (array, source) in get_int32_iterator_cases() { + let mut actual = ArrayIter::new(&array); + let mut expected = source.iter().copied(); + + setup_iterator.setup(&mut actual); + setup_iterator.setup(&mut expected); + + let current_iterator_values: Vec> = expected.clone().collect(); + + assert_eq!( + o.get_value(actual), + o.get_value(expected), + "Failed on op {} for {} (left actual, right expected) ({current_iterator_values:?})", + o.name(), + setup_iterator.description(), + ); + } + } + + /// Trait representing an operation on a [`ArrayIter`] + /// that can be compared against a slice iterator + /// + /// this is for consuming operations (e.g. `count`, `last`, etc) + trait ConsumingArrayIteratorOp { + /// What the operation returns (e.g. Option for last, usize for count, etc) + type Output: PartialEq + Debug; + + /// The name of the operation, used for error messages + fn name(&self) -> String; + + /// Get the value of the operation for the provided iterator + /// This will be either a [`ArrayIter`] or a slice iterator to make sure they produce the same result + /// + /// Example implementation: + /// 1. for `last` it will be the last value + /// 2. for `count` it will be the returned length + fn get_value(&self, iter: T) -> Self::Output; + } + + /// Trait representing an operation on a [`ArrayIter`] + /// that can be compared against a slice iterator. + /// + /// This is for mutating operations (e.g. `position`, `any`, `find`, etc) + trait MutatingArrayIteratorOp { + /// What the operation returns (e.g. Option for last, usize for count, etc) + type Output: PartialEq + Debug; + + /// The name of the operation, used for error messages + fn name(&self) -> String; + + /// Get the value of the operation for the provided iterator + /// This will be either a [`ArrayIter`] or a slice iterator to make sure they produce the same result + /// + /// Example implementation: + /// 1. for `for_each` it will be the iterator element that the function was called with + /// 2. for `fold` it will be the accumulator and the iterator element from each call, as well as the final result + fn get_value(&self, iter: &mut T) -> Self::Output; + } + + /// Helper function that will assert that the provided operation + /// produces the same result for both [`ArrayIter`] and slice iterator + /// under various consumption patterns (e.g. some calls to next/next_back/consume_all/etc) + fn assert_array_iterator_cases(o: O) { + setup_and_assert_cases_on_single_operation(&o, NoSetup); + + struct Next; + impl SetupIter for Next { + fn description(&self) -> String { + "new iter after consuming 1 element from the start".to_string() + } + fn setup(&self, iter: &mut I) { + iter.next(); + } + } + setup_and_assert_cases_on_single_operation(&o, Next); + + struct NextBack; + impl SetupIter for NextBack { + fn description(&self) -> String { + "new iter after consuming 1 element from the end".to_string() + } + + fn setup(&self, iter: &mut I) { + iter.next_back(); + } + } + + setup_and_assert_cases_on_single_operation(&o, NextBack); + + struct NextAndBack; + impl SetupIter for NextAndBack { + fn description(&self) -> String { + "new iter after consuming 1 element from start and end".to_string() + } + + fn setup(&self, iter: &mut I) { + iter.next(); + iter.next_back(); + } + } + + setup_and_assert_cases_on_single_operation(&o, NextAndBack); + + struct NextUntilLast; + impl SetupIter for NextUntilLast { + fn description(&self) -> String { + "new iter after consuming all from the start but 1".to_string() + } + fn setup(&self, iter: &mut I) { + let len = iter.len(); + if len > 1 { + iter.nth(len - 2); + } + } + } + setup_and_assert_cases_on_single_operation(&o, NextUntilLast); + + struct NextBackUntilFirst; + impl SetupIter for NextBackUntilFirst { + fn description(&self) -> String { + "new iter after consuming all from the end but 1".to_string() + } + + fn setup(&self, iter: &mut I) { + let len = iter.len(); + if len > 1 { + iter.nth_back(len - 2); + } + } + } + setup_and_assert_cases_on_single_operation(&o, NextBackUntilFirst); + + struct NextFinish; + impl SetupIter for NextFinish { + fn description(&self) -> String { + "new iter after consuming all from the start".to_string() + } + fn setup(&self, iter: &mut I) { + iter.nth(iter.len()); + } + } + setup_and_assert_cases_on_single_operation(&o, NextFinish); + + struct NextBackFinish; + impl SetupIter for NextBackFinish { + fn description(&self) -> String { + "new iter after consuming all from the end".to_string() + } + fn setup(&self, iter: &mut I) { + iter.nth_back(iter.len()); + } + } + setup_and_assert_cases_on_single_operation(&o, NextBackFinish); + + struct NextUntilLastNone; + impl SetupIter for NextUntilLastNone { + fn description(&self) -> String { + "new iter that have no nulls left".to_string() + } + fn setup(&self, iter: &mut I) { + let last_null_position = iter.clone().rposition(|item| item.is_none()); + + // move the iterator to the location where there are no nulls anymore + if let Some(last_null_position) = last_null_position { + iter.nth(last_null_position); + } + } + } + setup_and_assert_cases_on_single_operation(&o, NextUntilLastNone); + + struct NextUntilLastSome; + impl SetupIter for NextUntilLastSome { + fn description(&self) -> String { + "iter that only have nulls left".to_string() + } + fn setup(&self, iter: &mut I) { + let last_some_position = iter.clone().rposition(|item| item.is_some()); + + // move the iterator to the location where there are only nulls + if let Some(last_some_position) = last_some_position { + iter.nth(last_some_position); + } + } + } + setup_and_assert_cases_on_single_operation(&o, NextUntilLastSome); + } + + /// Helper function that will assert that the provided operation + /// produces the same result for both [`ArrayIter`] and slice iterator + /// under various consumption patterns (e.g. some calls to next/next_back/consume_all/etc) + /// + /// this is different from [`assert_array_iterator_cases`] as this also check that the state after the call is correct + /// to make sure we don't leave the iterator in incorrect state + fn assert_array_iterator_cases_mutate(o: O) { + struct Adapter { + o: O, + } + + #[derive(Debug, PartialEq)] + struct AdapterOutput { + value: Value, + /// collect on the iterator after running the operation + leftover: Vec>, + } + + impl ConsumingArrayIteratorOp for Adapter { + type Output = AdapterOutput; + + fn name(&self) -> String { + self.o.name() + } + + fn get_value( + &self, + mut iter: T, + ) -> Self::Output { + let value = self.o.get_value(&mut iter); + + // Get the rest of the iterator to make sure we leave the iterator in a valid state + let leftover: Vec<_> = iter.collect(); + + AdapterOutput { value, leftover } + } + } + + assert_array_iterator_cases(Adapter { o }) + } + + #[derive(Debug, PartialEq)] + struct CallTrackingAndResult { + result: Result, + calls: Vec, + } + type CallTrackingWithInputType = CallTrackingAndResult>; + type CallTrackingOnly = CallTrackingWithInputType<()>; + + #[test] + fn assert_position() { + struct PositionOp { + reverse: bool, + number_of_false: usize, + } + + impl MutatingArrayIteratorOp for PositionOp { + type Output = CallTrackingWithInputType>; + fn name(&self) -> String { + if self.reverse { + format!("rposition with {} false returned", self.number_of_false) + } else { + format!("position with {} false returned", self.number_of_false) + } + } + + fn get_value( + &self, + iter: &mut T, + ) -> Self::Output { + let mut items = vec![]; + + let mut count = 0; + + let cb = |item| { + items.push(item); + + if count < self.number_of_false { + count += 1; + false + } else { + true + } + }; + + let position_result = if self.reverse { + iter.rposition(cb) + } else { + iter.position(cb) + }; + + CallTrackingAndResult { + result: position_result, + calls: items, + } + } + } + + for reverse in [false, true] { + for number_of_false in [0, 1, 2, usize::MAX] { + assert_array_iterator_cases_mutate(PositionOp { + reverse, + number_of_false, + }); + } + } + } + + #[test] + fn assert_nth() { + for (array, source) in get_int32_iterator_cases() { + let actual = ArrayIter::new(&array); + let expected = source.iter().copied(); + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + #[allow(clippy::iter_nth_zero)] + let actual_val = actual.nth(0); + #[allow(clippy::iter_nth_zero)] + let expected_val = expected.nth(0); + assert_eq!(actual_val, expected_val, "Failed on nth(0)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth(1); + let expected_val = expected.nth(1); + assert_eq!(actual_val, expected_val, "Failed on nth(1)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth(2); + let expected_val = expected.nth(2); + assert_eq!(actual_val, expected_val, "Failed on nth(2)"); + } + } + } + } + + #[test] + fn assert_nth_back() { + for (array, source) in get_int32_iterator_cases() { + let actual = ArrayIter::new(&array); + let expected = source.iter().copied(); + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + #[allow(clippy::iter_nth_zero)] + let actual_val = actual.nth_back(0); + #[allow(clippy::iter_nth_zero)] + let expected_val = expected.nth_back(0); + assert_eq!(actual_val, expected_val, "Failed on nth_back(0)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth_back(1); + let expected_val = expected.nth_back(1); + assert_eq!(actual_val, expected_val, "Failed on nth_back(1)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth_back(2); + let expected_val = expected.nth_back(2); + assert_eq!(actual_val, expected_val, "Failed on nth_back(2)"); + } + } + } + } + + #[test] + fn assert_last() { + for (array, source) in get_int32_iterator_cases() { + let mut actual_forward = ArrayIter::new(&array); + let mut expected_forward = source.iter().copied(); + + for _ in 0..source.len() + 1 { + { + let actual_forward_clone = actual_forward.clone(); + let expected_forward_clone = expected_forward.clone(); + + assert_eq!(actual_forward_clone.last(), expected_forward_clone.last()); + } + + actual_forward.next(); + expected_forward.next(); + } + + let mut actual_backward = ArrayIter::new(&array); + let mut expected_backward = source.iter().copied(); + for _ in 0..source.len() + 1 { + { + assert_eq!( + actual_backward.clone().last(), + expected_backward.clone().last() + ); + } + + actual_backward.next_back(); + expected_backward.next_back(); + } + } + } + + #[test] + fn assert_for_each() { + struct ForEachOp; + + impl ConsumingArrayIteratorOp for ForEachOp { + type Output = CallTrackingOnly; + + fn name(&self) -> String { + "for_each".to_string() + } + + fn get_value(&self, iter: T) -> Self::Output { + let mut items = Vec::with_capacity(iter.len()); + + iter.for_each(|item| { + items.push(item); + }); + + CallTrackingAndResult { + calls: items, + result: (), + } + } + } + + assert_array_iterator_cases(ForEachOp) + } + + #[test] + fn assert_fold() { + struct FoldOp { + reverse: bool, + } + + #[derive(Debug, PartialEq)] + struct CallArgs { + acc: Option, + item: Option, + } + + impl ConsumingArrayIteratorOp for FoldOp { + type Output = CallTrackingAndResult, CallArgs>; + + fn name(&self) -> String { + if self.reverse { + "rfold".to_string() + } else { + "fold".to_string() + } + } + + fn get_value(&self, iter: T) -> Self::Output { + let mut items = Vec::with_capacity(iter.len()); + + let cb = |acc, item| { + items.push(CallArgs { item, acc }); + + item.map(|val| val + 100) + }; + + let result = if self.reverse { + iter.rfold(Some(1), cb) + } else { + #[allow(clippy::manual_try_fold)] + iter.fold(Some(1), cb) + }; + + CallTrackingAndResult { + calls: items, + result, + } + } + } + + assert_array_iterator_cases(FoldOp { reverse: false }); + assert_array_iterator_cases(FoldOp { reverse: true }); + } + + #[test] + fn assert_count() { + struct CountOp; + + impl ConsumingArrayIteratorOp for CountOp { + type Output = usize; + + fn name(&self) -> String { + "count".to_string() + } + + fn get_value(&self, iter: T) -> Self::Output { + iter.count() + } + } + + assert_array_iterator_cases(CountOp) + } + + #[test] + fn assert_any() { + struct AnyOp { + false_count: usize, + } + + impl MutatingArrayIteratorOp for AnyOp { + type Output = CallTrackingWithInputType; + + fn name(&self) -> String { + format!("any with {} false returned", self.false_count) + } + + fn get_value( + &self, + iter: &mut T, + ) -> Self::Output { + let mut items = Vec::with_capacity(iter.len()); + + let mut count = 0; + let res = iter.any(|item| { + items.push(item); + + if count < self.false_count { + count += 1; + false + } else { + true + } + }); + + CallTrackingWithInputType { + calls: items, + result: res, + } + } + } + + for false_count in [0, 1, 2, usize::MAX] { + assert_array_iterator_cases_mutate(AnyOp { false_count }); + } + } + + #[test] + fn assert_all() { + struct AllOp { + true_count: usize, + } + + impl MutatingArrayIteratorOp for AllOp { + type Output = CallTrackingWithInputType; + + fn name(&self) -> String { + format!("all with {} false returned", self.true_count) + } + + fn get_value( + &self, + iter: &mut T, + ) -> Self::Output { + let mut items = Vec::with_capacity(iter.len()); + + let mut count = 0; + let res = iter.all(|item| { + items.push(item); + + if count < self.true_count { + count += 1; + true + } else { + false + } + }); + + CallTrackingWithInputType { + calls: items, + result: res, + } + } + } + + for true_count in [0, 1, 2, usize::MAX] { + assert_array_iterator_cases_mutate(AllOp { true_count }); + } + } + + #[test] + fn assert_find() { + struct FindOp { + reverse: bool, + false_count: usize, + } + + impl MutatingArrayIteratorOp for FindOp { + type Output = CallTrackingWithInputType>>; + + fn name(&self) -> String { + if self.reverse { + format!("rfind with {} false returned", self.false_count) + } else { + format!("find with {} false returned", self.false_count) + } + } + + fn get_value( + &self, + iter: &mut T, + ) -> Self::Output { + let mut items = vec![]; + + let mut count = 0; + + let cb = |item: &Option| { + items.push(*item); + + if count < self.false_count { + count += 1; + false + } else { + true + } + }; + + let position_result = if self.reverse { + iter.rfind(cb) + } else { + iter.find(cb) + }; + + CallTrackingWithInputType { + calls: items, + result: position_result, + } + } + } + + for reverse in [false, true] { + for false_count in [0, 1, 2, usize::MAX] { + assert_array_iterator_cases_mutate(FindOp { + reverse, + false_count, + }); + } + } + } + + #[test] + fn assert_find_map() { + struct FindMapOp { + number_of_nones: usize, + } + + impl MutatingArrayIteratorOp for FindMapOp { + type Output = CallTrackingWithInputType>; + + fn name(&self) -> String { + format!("find_map with {} None returned", self.number_of_nones) + } + + fn get_value( + &self, + iter: &mut T, + ) -> Self::Output { + let mut items = vec![]; + + let mut count = 0; + + let result = iter.find_map(|item| { + items.push(item); + + if count < self.number_of_nones { + count += 1; + None + } else { + Some("found it") + } + }); + + CallTrackingAndResult { + result, + calls: items, + } + } + } + + for number_of_nones in [0, 1, 2, usize::MAX] { + assert_array_iterator_cases_mutate(FindMapOp { number_of_nones }); + } + } + + #[test] + fn assert_partition() { + struct PartitionOp) -> bool> { + description: &'static str, + predicate: F, + } + + #[derive(Debug, PartialEq)] + struct PartitionResult { + left: Vec>, + right: Vec>, + } + + impl) -> bool> ConsumingArrayIteratorOp for PartitionOp { + type Output = CallTrackingWithInputType; + + fn name(&self) -> String { + format!("partition by {}", self.description) + } + + fn get_value(&self, iter: T) -> Self::Output { + let mut items = vec![]; + + let mut index = 0; + + let (left, right) = iter.partition(|item| { + items.push(*item); + + let res = (self.predicate)(index, item); + + index += 1; + res + }); + + CallTrackingAndResult { + result: PartitionResult { left, right }, + calls: items, + } + } + } + + assert_array_iterator_cases(PartitionOp { + description: "None on one side and Some(*) on the other", + predicate: |_, item| item.is_none(), + }); + + assert_array_iterator_cases(PartitionOp { + description: "all true", + predicate: |_, _| true, + }); + + assert_array_iterator_cases(PartitionOp { + description: "all false", + predicate: |_, _| false, + }); + + let random_values = (0..100).map(|_| rand::random_bool(0.5)).collect::>(); + assert_array_iterator_cases(PartitionOp { + description: "random", + predicate: |index, _| random_values[index % random_values.len()], + }); + } } diff --git a/arrow-array/src/lib.rs b/arrow-array/src/lib.rs index 91696540d219..86c1c6550cdb 100644 --- a/arrow-array/src/lib.rs +++ b/arrow-array/src/lib.rs @@ -225,7 +225,7 @@ html_logo_url = "https://arrow.apache.org/img/arrow-logo_chevrons_black-txt_white-bg.svg", html_favicon_url = "https://arrow.apache.org/img/arrow-logo_chevrons_black-txt_transparent-bg.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] #![deny(rustdoc::broken_intra_doc_links)] #![warn(missing_docs)] diff --git a/arrow-array/src/record_batch.rs b/arrow-array/src/record_batch.rs index 73464358657c..cfec969165a9 100644 --- a/arrow-array/src/record_batch.rs +++ b/arrow-array/src/record_batch.rs @@ -19,7 +19,7 @@ //! [schema](arrow_schema::Schema). use crate::cast::AsArray; -use crate::{new_empty_array, Array, ArrayRef, StructArray}; +use crate::{Array, ArrayRef, StructArray, new_empty_array}; use arrow_schema::{ArrowError, DataType, Field, FieldRef, Schema, SchemaBuilder, SchemaRef}; use std::ops::Index; use std::sync::Arc; @@ -65,7 +65,7 @@ pub trait RecordBatchWriter { /// Support for limited data types is available. The macro will return a compile error if an unsupported data type is used. /// Presently supported data types are: /// - `Boolean`, `Null` -/// - `Decimal128`, `Decimal256` +/// - `Decimal32`, `Decimal64`, `Decimal128`, `Decimal256` /// - `Float16`, `Float32`, `Float64` /// - `Int8`, `Int16`, `Int32`, `Int64` /// - `UInt8`, `UInt16`, `UInt32`, `UInt64` @@ -107,6 +107,8 @@ macro_rules! create_array { (@from DurationMillisecond) => { $crate::DurationMillisecondArray }; (@from DurationMicrosecond) => { $crate::DurationMicrosecondArray }; (@from DurationNanosecond) => { $crate::DurationNanosecondArray }; + (@from Decimal32) => { $crate::Decimal32Array }; + (@from Decimal64) => { $crate::Decimal64Array }; (@from Decimal128) => { $crate::Decimal128Array }; (@from Decimal256) => { $crate::Decimal256Array }; (@from TimestampSecond) => { $crate::TimestampSecondArray }; @@ -358,7 +360,8 @@ impl RecordBatch { if let Some((i, (col_type, field_type))) = not_match { return Err(ArrowError::InvalidArgumentError(format!( - "column types must match schema types, expected {field_type:?} but found {col_type:?} at column index {i}"))); + "column types must match schema types, expected {field_type} but found {col_type} at column index {i}" + ))); } Ok(RecordBatch { @@ -420,7 +423,7 @@ impl RecordBatch { /// // Insert a key-value pair into the metadata /// batch.schema_metadata_mut().insert("key".into(), "value".into()); /// assert_eq!(batch.schema().metadata().get("key"), Some(&String::from("value"))); - /// ``` + /// ``` pub fn schema_metadata_mut(&mut self) -> &mut std::collections::HashMap { let schema = Arc::make_mut(&mut self.schema); &mut schema.metadata @@ -442,14 +445,16 @@ impl RecordBatch { }) .collect::, _>>()?; - RecordBatch::try_new_with_options( - SchemaRef::new(projected_schema), - batch_fields, - &RecordBatchOptions { - match_field_names: true, - row_count: Some(self.row_count), - }, - ) + unsafe { + // Since we're starting from a valid RecordBatch and project + // creates a strict subset of the original, there's no need to + // redo the validation checks in `try_new_with_options`. + Ok(RecordBatch::new_unchecked( + SchemaRef::new(projected_schema), + batch_fields, + self.row_count, + )) + } } /// Normalize a semi-structured [`RecordBatch`] into a flat table. @@ -930,7 +935,7 @@ where mod tests { use super::*; use crate::{ - BooleanArray, Int32Array, Int64Array, Int8Array, ListArray, StringArray, StringViewArray, + BooleanArray, Int8Array, Int32Array, Int64Array, ListArray, StringArray, StringViewArray, }; use arrow_buffer::{Buffer, ToByteSlice}; use arrow_data::{ArrayData, ArrayDataBuilder}; @@ -1098,7 +1103,10 @@ mod tests { let a = Int64Array::from(vec![1, 2, 3, 4, 5]); let err = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)]).unwrap_err(); - assert_eq!(err.to_string(), "Invalid argument error: column types must match schema types, expected Int32 but found Int64 at column index 0"); + assert_eq!( + err.to_string(), + "Invalid argument error: column types must match schema types, expected Int32 but found Int64 at column index 0" + ); } #[test] @@ -1572,9 +1580,10 @@ mod tests { let schema = Arc::new(Schema::empty()); let err = RecordBatch::try_new(schema.clone(), vec![]).unwrap_err(); - assert!(err - .to_string() - .contains("must either specify a row count or at least one column")); + assert!( + err.to_string() + .contains("must either specify a row count or at least one column") + ); let options = RecordBatchOptions::new().with_row_count(Some(10)); @@ -1598,7 +1607,10 @@ mod tests { schema, vec![Arc::new(Int32Array::from(vec![Some(1), None]))], ); - assert_eq!("Invalid argument error: Column 'a' is declared as non-nullable but contains null values", format!("{}", maybe_batch.err().unwrap())); + assert_eq!( + "Invalid argument error: Column 'a' is declared as non-nullable but contains null values", + format!("{}", maybe_batch.err().unwrap()) + ); } #[test] fn test_record_batch_options() { diff --git a/arrow-array/src/run_iterator.rs b/arrow-array/src/run_iterator.rs index 4fb0eef32eca..f7277a93ff62 100644 --- a/arrow-array/src/run_iterator.rs +++ b/arrow-array/src/run_iterator.rs @@ -17,7 +17,7 @@ //! Idiomatic iterator for [`RunArray`](crate::RunArray) -use crate::{array::ArrayAccessor, types::RunEndIndexType, Array, TypedRunArray}; +use crate::{Array, TypedRunArray, array::ArrayAccessor, types::RunEndIndexType}; use arrow_buffer::ArrowNativeType; /// The [`RunArrayIter`] provides an idiomatic way to iterate over the run array. @@ -172,13 +172,13 @@ where #[cfg(test)] mod tests { - use rand::{rng, seq::SliceRandom, Rng}; + use rand::{Rng, rng, seq::SliceRandom}; use crate::{ + Array, Int64RunArray, PrimitiveArray, RunArray, array::{Int32Array, StringArray}, builder::PrimitiveRunBuilder, types::{Int16Type, Int32Type}, - Array, Int64RunArray, PrimitiveArray, RunArray, }; fn build_input_array(size: usize) -> Vec> { diff --git a/arrow-array/src/temporal_conversions.rs b/arrow-array/src/temporal_conversions.rs index 7a4c67602932..a5ec50da1fc6 100644 --- a/arrow-array/src/temporal_conversions.rs +++ b/arrow-array/src/temporal_conversions.rs @@ -17,8 +17,8 @@ //! Conversion methods for dates and times. -use crate::timezone::Tz; use crate::ArrowPrimitiveType; +use crate::timezone::Tz; use arrow_schema::{DataType, TimeUnit}; use chrono::{DateTime, Duration, NaiveDate, NaiveDateTime, NaiveTime, TimeZone, Timelike, Utc}; @@ -322,9 +322,9 @@ pub fn as_duration(v: i64) -> Option { #[cfg(test)] mod tests { use crate::temporal_conversions::{ - date64_to_datetime, split_second, timestamp_ms_to_datetime, timestamp_ns_to_datetime, - timestamp_s_to_date, timestamp_s_to_datetime, timestamp_s_to_time, - timestamp_us_to_datetime, NANOSECONDS, + NANOSECONDS, date64_to_datetime, split_second, timestamp_ms_to_datetime, + timestamp_ns_to_datetime, timestamp_s_to_date, timestamp_s_to_datetime, + timestamp_s_to_time, timestamp_us_to_datetime, }; use chrono::DateTime; diff --git a/arrow-array/src/timezone.rs b/arrow-array/src/timezone.rs index b4df77deb4f5..bcf582152146 100644 --- a/arrow-array/src/timezone.rs +++ b/arrow-array/src/timezone.rs @@ -53,6 +53,7 @@ mod private { use super::*; use chrono::offset::TimeZone; use chrono::{LocalResult, NaiveDate, NaiveDateTime, Offset}; + use std::fmt::Display; use std::str::FromStr; /// An [`Offset`] for [`Tz`] @@ -97,6 +98,15 @@ mod private { } } + impl Display for Tz { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self.0 { + TzInner::Timezone(tz) => tz.fmt(f), + TzInner::Offset(offset) => offset.fmt(f), + } + } + } + macro_rules! tz { ($s:ident, $tz:ident, $b:block) => { match $s.0 { @@ -228,6 +238,15 @@ mod private { sydney_offset_with_dst ); } + + #[test] + fn test_timezone_display() { + let test_cases = ["UTC", "America/Los_Angeles", "-08:00", "+05:30"]; + for &case in &test_cases { + let tz: Tz = case.parse().unwrap(); + assert_eq!(tz.to_string(), case); + } + } } } diff --git a/arrow-array/src/trusted_len.rs b/arrow-array/src/trusted_len.rs index 781cad38f7e9..b2e1948ccc76 100644 --- a/arrow-array/src/trusted_len.rs +++ b/arrow-array/src/trusted_len.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow_buffer::{bit_util, ArrowNativeType, Buffer, MutableBuffer}; +use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer, bit_util}; /// Creates two [`Buffer`]s from an iterator of `Option`. /// The first buffer corresponds to a bitmap buffer, the second one @@ -41,19 +41,19 @@ where for (i, item) in iterator.enumerate() { let item = item.borrow(); if let Some(item) = item { - std::ptr::write(dst, *item); - bit_util::set_bit_raw(dst_null, i); + unsafe { std::ptr::write(dst, *item) }; + unsafe { bit_util::set_bit_raw(dst_null, i) }; } else { - std::ptr::write(dst, T::default()); + unsafe { std::ptr::write(dst, T::default()) }; } - dst = dst.add(1); + dst = unsafe { dst.add(1) }; } assert_eq!( - dst.offset_from(buffer.as_ptr() as *mut T) as usize, + unsafe { dst.offset_from(buffer.as_ptr() as *mut T) as usize }, upper, "Trusted iterator length was not accurately reported" ); - buffer.set_len(len); + unsafe { buffer.set_len(len) }; (null.into(), buffer.into()) } diff --git a/arrow-array/src/types.rs b/arrow-array/src/types.rs index 3d8cfcdb112b..fcd2d6958f35 100644 --- a/arrow-array/src/types.rs +++ b/arrow-array/src/types.rs @@ -23,15 +23,18 @@ use crate::delta::{ use crate::temporal_conversions::as_datetime_with_timezone; use crate::timezone::Tz; use crate::{ArrowNativeTypeOp, OffsetSizeTrait}; -use arrow_buffer::{i256, Buffer, OffsetBuffer}; +use arrow_buffer::{Buffer, OffsetBuffer, i256}; use arrow_data::decimal::{ - is_validate_decimal256_precision, is_validate_decimal_precision, validate_decimal256_precision, - validate_decimal_precision, + format_decimal_str, is_validate_decimal_precision, is_validate_decimal32_precision, + is_validate_decimal64_precision, is_validate_decimal256_precision, validate_decimal_precision, + validate_decimal32_precision, validate_decimal64_precision, validate_decimal256_precision, }; use arrow_data::{validate_binary_view, validate_string_view}; use arrow_schema::{ - ArrowError, DataType, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, - DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DECIMAL_DEFAULT_SCALE, + ArrowError, DECIMAL_DEFAULT_SCALE, DECIMAL32_DEFAULT_SCALE, DECIMAL32_MAX_PRECISION, + DECIMAL32_MAX_SCALE, DECIMAL64_DEFAULT_SCALE, DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE, + DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, + DataType, IntervalUnit, TimeUnit, }; use chrono::{Duration, NaiveDate, NaiveDateTime}; use half::f16; @@ -68,12 +71,6 @@ pub trait ArrowPrimitiveType: primitive::PrimitiveTypeSealed + 'static { /// the corresponding Arrow data type of this primitive type. const DATA_TYPE: DataType; - /// Returns the byte width of this primitive type. - #[deprecated(since = "52.0.0", note = "Use ArrowNativeType::get_byte_width")] - fn get_byte_width() -> usize { - std::mem::size_of::() - } - /// Returns a default value of this primitive type. /// /// This is useful for aggregate array ops like `sum()`, `mean()`. @@ -1031,9 +1028,25 @@ impl Date64Type { /// # Arguments /// /// * `i` - The Date64Type to convert + #[deprecated(since = "56.0.0", note = "Use to_naive_date_opt instead.")] pub fn to_naive_date(i: ::Native) -> NaiveDate { + Self::to_naive_date_opt(i) + .unwrap_or_else(|| panic!("Date64Type::to_naive_date overflowed for date: {i}",)) + } + + /// Converts an arrow Date64Type into a chrono::NaiveDateTime if it fits in the range that chrono::NaiveDateTime can represent. + /// Returns `None` if the calculation would overflow or underflow. + /// + /// This function is able to handle dates ranging between 1677-09-21 (-9,223,372,800,000) and 2262-04-11 (9,223,286,400,000). + /// + /// # Arguments + /// + /// * `i` - The Date64Type to convert + /// + /// Returns `Some(NaiveDateTime)` if it fits, `None` otherwise. + pub fn to_naive_date_opt(i: ::Native) -> Option { let epoch = NaiveDate::from_ymd_opt(1970, 1, 1).unwrap(); - epoch.add(Duration::try_milliseconds(i).unwrap()) + Duration::try_milliseconds(i).and_then(|d| epoch.checked_add_signed(d)) } /// Converts a chrono::NaiveDate into an arrow Date64Type @@ -1052,14 +1065,35 @@ impl Date64Type { /// /// * `date` - The date on which to perform the operation /// * `delta` - The interval to add + #[deprecated( + since = "56.0.0", + note = "Use `add_year_months_opt` instead, which returns an Option to handle overflow." + )] pub fn add_year_months( date: ::Native, delta: ::Native, ) -> ::Native { - let prior = Date64Type::to_naive_date(date); + Self::add_year_months_opt(date, delta).unwrap_or_else(|| { + panic!("Date64Type::add_year_months overflowed for date: {date}, delta: {delta}",) + }) + } + + /// Adds the given IntervalYearMonthType to an arrow Date64Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to add + /// + /// Returns `Some(Date64Type)` if it fits, `None` otherwise. + pub fn add_year_months_opt( + date: ::Native, + delta: ::Native, + ) -> Option<::Native> { + let prior = Date64Type::to_naive_date_opt(date)?; let months = IntervalYearMonthType::to_months(delta); let posterior = shift_months(prior, months); - Date64Type::from_naive_date(posterior) + Some(Date64Type::from_naive_date(posterior)) } /// Adds the given IntervalDayTimeType to an arrow Date64Type @@ -1068,15 +1102,36 @@ impl Date64Type { /// /// * `date` - The date on which to perform the operation /// * `delta` - The interval to add + #[deprecated( + since = "56.0.0", + note = "Use `add_day_time_opt` instead, which returns an Option to handle overflow." + )] pub fn add_day_time( date: ::Native, delta: ::Native, ) -> ::Native { + Self::add_day_time_opt(date, delta).unwrap_or_else(|| { + panic!("Date64Type::add_day_time overflowed for date: {date}, delta: {delta:?}",) + }) + } + + /// Adds the given IntervalDayTimeType to an arrow Date64Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to add + /// + /// Returns `Some(Date64Type)` if it fits, `None` otherwise. + pub fn add_day_time_opt( + date: ::Native, + delta: ::Native, + ) -> Option<::Native> { let (days, ms) = IntervalDayTimeType::to_parts(delta); - let res = Date64Type::to_naive_date(date); - let res = res.add(Duration::try_days(days as i64).unwrap()); - let res = res.add(Duration::try_milliseconds(ms as i64).unwrap()); - Date64Type::from_naive_date(res) + let res = Date64Type::to_naive_date_opt(date)?; + let res = res.checked_add_signed(Duration::try_days(days as i64)?)?; + let res = res.checked_add_signed(Duration::try_milliseconds(ms as i64)?)?; + Some(Date64Type::from_naive_date(res)) } /// Adds the given IntervalMonthDayNanoType to an arrow Date64Type @@ -1085,16 +1140,37 @@ impl Date64Type { /// /// * `date` - The date on which to perform the operation /// * `delta` - The interval to add + #[deprecated( + since = "56.0.0", + note = "Use `add_month_day_nano_opt` instead, which returns an Option to handle overflow." + )] pub fn add_month_day_nano( date: ::Native, delta: ::Native, ) -> ::Native { + Self::add_month_day_nano_opt(date, delta).unwrap_or_else(|| { + panic!("Date64Type::add_month_day_nano overflowed for date: {date}, delta: {delta:?}",) + }) + } + + /// Adds the given IntervalMonthDayNanoType to an arrow Date64Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to add + /// + /// Returns `Some(Date64Type)` if it fits, `None` otherwise. + pub fn add_month_day_nano_opt( + date: ::Native, + delta: ::Native, + ) -> Option<::Native> { let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(delta); - let res = Date64Type::to_naive_date(date); + let res = Date64Type::to_naive_date_opt(date)?; let res = shift_months(res, months); - let res = res.add(Duration::try_days(days as i64).unwrap()); - let res = res.add(Duration::nanoseconds(nanos)); - Date64Type::from_naive_date(res) + let res = res.checked_add_signed(Duration::try_days(days as i64)?)?; + let res = res.checked_add_signed(Duration::nanoseconds(nanos))?; + Some(Date64Type::from_naive_date(res)) } /// Subtract the given IntervalYearMonthType to an arrow Date64Type @@ -1103,14 +1179,35 @@ impl Date64Type { /// /// * `date` - The date on which to perform the operation /// * `delta` - The interval to subtract + #[deprecated( + since = "56.0.0", + note = "Use `subtract_year_months_opt` instead, which returns an Option to handle overflow." + )] pub fn subtract_year_months( date: ::Native, delta: ::Native, ) -> ::Native { - let prior = Date64Type::to_naive_date(date); + Self::subtract_year_months_opt(date, delta).unwrap_or_else(|| { + panic!("Date64Type::subtract_year_months overflowed for date: {date}, delta: {delta}",) + }) + } + + /// Subtract the given IntervalYearMonthType to an arrow Date64Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to subtract + /// + /// Returns `Some(Date64Type)` if it fits, `None` otherwise. + pub fn subtract_year_months_opt( + date: ::Native, + delta: ::Native, + ) -> Option<::Native> { + let prior = Date64Type::to_naive_date_opt(date)?; let months = IntervalYearMonthType::to_months(-delta); let posterior = shift_months(prior, months); - Date64Type::from_naive_date(posterior) + Some(Date64Type::from_naive_date(posterior)) } /// Subtract the given IntervalDayTimeType to an arrow Date64Type @@ -1119,15 +1216,36 @@ impl Date64Type { /// /// * `date` - The date on which to perform the operation /// * `delta` - The interval to subtract + #[deprecated( + since = "56.0.0", + note = "Use `subtract_day_time_opt` instead, which returns an Option to handle overflow." + )] pub fn subtract_day_time( date: ::Native, delta: ::Native, ) -> ::Native { + Self::subtract_day_time_opt(date, delta).unwrap_or_else(|| { + panic!("Date64Type::subtract_day_time overflowed for date: {date}, delta: {delta:?}",) + }) + } + + /// Subtract the given IntervalDayTimeType to an arrow Date64Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to subtract + /// + /// Returns `Some(Date64Type)` if it fits, `None` otherwise. + pub fn subtract_day_time_opt( + date: ::Native, + delta: ::Native, + ) -> Option<::Native> { let (days, ms) = IntervalDayTimeType::to_parts(delta); - let res = Date64Type::to_naive_date(date); - let res = res.sub(Duration::try_days(days as i64).unwrap()); - let res = res.sub(Duration::try_milliseconds(ms as i64).unwrap()); - Date64Type::from_naive_date(res) + let res = Date64Type::to_naive_date_opt(date)?; + let res = res.checked_sub_signed(Duration::try_days(days as i64)?)?; + let res = res.checked_sub_signed(Duration::try_milliseconds(ms as i64)?)?; + Some(Date64Type::from_naive_date(res)) } /// Subtract the given IntervalMonthDayNanoType to an arrow Date64Type @@ -1136,16 +1254,39 @@ impl Date64Type { /// /// * `date` - The date on which to perform the operation /// * `delta` - The interval to subtract + #[deprecated( + since = "56.0.0", + note = "Use `subtract_month_day_nano_opt` instead, which returns an Option to handle overflow." + )] pub fn subtract_month_day_nano( date: ::Native, delta: ::Native, ) -> ::Native { + Self::subtract_month_day_nano_opt(date, delta).unwrap_or_else(|| { + panic!( + "Date64Type::subtract_month_day_nano overflowed for date: {date}, delta: {delta:?}", + ) + }) + } + + /// Subtract the given IntervalMonthDayNanoType to an arrow Date64Type + /// + /// # Arguments + /// + /// * `date` - The date on which to perform the operation + /// * `delta` - The interval to subtract + /// + /// Returns `Some(Date64Type)` if it fits, `None` otherwise. + pub fn subtract_month_day_nano_opt( + date: ::Native, + delta: ::Native, + ) -> Option<::Native> { let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(delta); - let res = Date64Type::to_naive_date(date); + let res = Date64Type::to_naive_date_opt(date)?; let res = shift_months(res, -months); - let res = res.sub(Duration::try_days(days as i64).unwrap()); - let res = res.sub(Duration::nanoseconds(nanos)); - Date64Type::from_naive_date(res) + let res = res.checked_sub_signed(Duration::try_days(days as i64)?)?; + let res = res.checked_sub_signed(Duration::nanoseconds(nanos))?; + Some(Date64Type::from_naive_date(res)) } } @@ -1156,6 +1297,8 @@ mod decimal { use super::*; pub trait DecimalTypeSealed {} + impl DecimalTypeSealed for Decimal32Type {} + impl DecimalTypeSealed for Decimal64Type {} impl DecimalTypeSealed for Decimal128Type {} impl DecimalTypeSealed for Decimal256Type {} } @@ -1163,10 +1306,12 @@ mod decimal { /// A trait over the decimal types, used by [`PrimitiveArray`] to provide a generic /// implementation across the various decimal types /// -/// Implemented by [`Decimal128Type`] and [`Decimal256Type`] for [`Decimal128Array`] -/// and [`Decimal256Array`] respectively +/// Implemented by [`Decimal32Type`], [`Decimal64Type`], [`Decimal128Type`] and [`Decimal256Type`] +/// for [`Decimal32Array`], [`Decimal64Array`], [`Decimal128Array`] and [`Decimal256Array`] respectively /// /// [`PrimitiveArray`]: crate::array::PrimitiveArray +/// [`Decimal32Array`]: crate::array::Decimal32Array +/// [`Decimal64Array`]: crate::array::Decimal64Array /// [`Decimal128Array`]: crate::array::Decimal128Array /// [`Decimal256Array`]: crate::array::Decimal256Array pub trait DecimalType: @@ -1178,19 +1323,25 @@ pub trait DecimalType: const MAX_PRECISION: u8; /// Maximum no of digits after the decimal point (note the scale can be negative) const MAX_SCALE: i8; + /// The maximum value for each precision in `0..=MAX_PRECISION`: [0, 9, 99, ...] + const MAX_FOR_EACH_PRECISION: &'static [Self::Native]; /// fn to create its [`DataType`] const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType; /// Default values for [`DataType`] const DEFAULT_TYPE: DataType; - /// "Decimal128" or "Decimal256", for use in error messages + /// "Decimal32", "Decimal64", "Decimal128" or "Decimal256", for use in error messages const PREFIX: &'static str; /// Formats the decimal value with the provided precision and scale fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String; /// Validates that `value` contains no more than `precision` decimal digits - fn validate_decimal_precision(value: Self::Native, precision: u8) -> Result<(), ArrowError>; + fn validate_decimal_precision( + value: Self::Native, + precision: u8, + scale: i8, + ) -> Result<(), ArrowError>; /// Determines whether `value` contains no more than `precision` decimal digits fn is_valid_decimal_precision(value: Self::Native, precision: u8) -> bool; @@ -1236,6 +1387,78 @@ pub fn validate_decimal_precision_and_scale( Ok(()) } +/// The decimal type for a Decimal32Array +#[derive(Debug)] +pub struct Decimal32Type {} + +impl DecimalType for Decimal32Type { + const BYTE_LENGTH: usize = 4; + const MAX_PRECISION: u8 = DECIMAL32_MAX_PRECISION; + const MAX_SCALE: i8 = DECIMAL32_MAX_SCALE; + const MAX_FOR_EACH_PRECISION: &'static [i32] = + &arrow_data::decimal::MAX_DECIMAL32_FOR_EACH_PRECISION; + const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType = DataType::Decimal32; + const DEFAULT_TYPE: DataType = + DataType::Decimal32(DECIMAL32_MAX_PRECISION, DECIMAL32_DEFAULT_SCALE); + const PREFIX: &'static str = "Decimal32"; + + fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String { + format_decimal_str(&value.to_string(), precision as usize, scale) + } + + fn validate_decimal_precision(num: i32, precision: u8, scale: i8) -> Result<(), ArrowError> { + validate_decimal32_precision(num, precision, scale) + } + + fn is_valid_decimal_precision(value: Self::Native, precision: u8) -> bool { + is_validate_decimal32_precision(value, precision) + } +} + +impl ArrowPrimitiveType for Decimal32Type { + type Native = i32; + + const DATA_TYPE: DataType = ::DEFAULT_TYPE; +} + +impl primitive::PrimitiveTypeSealed for Decimal32Type {} + +/// The decimal type for a Decimal64Array +#[derive(Debug)] +pub struct Decimal64Type {} + +impl DecimalType for Decimal64Type { + const BYTE_LENGTH: usize = 8; + const MAX_PRECISION: u8 = DECIMAL64_MAX_PRECISION; + const MAX_SCALE: i8 = DECIMAL64_MAX_SCALE; + const MAX_FOR_EACH_PRECISION: &'static [i64] = + &arrow_data::decimal::MAX_DECIMAL64_FOR_EACH_PRECISION; + const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType = DataType::Decimal64; + const DEFAULT_TYPE: DataType = + DataType::Decimal64(DECIMAL64_MAX_PRECISION, DECIMAL64_DEFAULT_SCALE); + const PREFIX: &'static str = "Decimal64"; + + fn format_decimal(value: Self::Native, precision: u8, scale: i8) -> String { + format_decimal_str(&value.to_string(), precision as usize, scale) + } + + fn validate_decimal_precision(num: i64, precision: u8, scale: i8) -> Result<(), ArrowError> { + validate_decimal64_precision(num, precision, scale) + } + + fn is_valid_decimal_precision(value: Self::Native, precision: u8) -> bool { + is_validate_decimal64_precision(value, precision) + } +} + +impl ArrowPrimitiveType for Decimal64Type { + type Native = i64; + + const DATA_TYPE: DataType = ::DEFAULT_TYPE; +} + +impl primitive::PrimitiveTypeSealed for Decimal64Type {} + /// The decimal type for a Decimal128Array #[derive(Debug)] pub struct Decimal128Type {} @@ -1244,6 +1467,8 @@ impl DecimalType for Decimal128Type { const BYTE_LENGTH: usize = 16; const MAX_PRECISION: u8 = DECIMAL128_MAX_PRECISION; const MAX_SCALE: i8 = DECIMAL128_MAX_SCALE; + const MAX_FOR_EACH_PRECISION: &'static [i128] = + &arrow_data::decimal::MAX_DECIMAL128_FOR_EACH_PRECISION; const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType = DataType::Decimal128; const DEFAULT_TYPE: DataType = DataType::Decimal128(DECIMAL128_MAX_PRECISION, DECIMAL_DEFAULT_SCALE); @@ -1253,8 +1478,8 @@ impl DecimalType for Decimal128Type { format_decimal_str(&value.to_string(), precision as usize, scale) } - fn validate_decimal_precision(num: i128, precision: u8) -> Result<(), ArrowError> { - validate_decimal_precision(num, precision) + fn validate_decimal_precision(num: i128, precision: u8, scale: i8) -> Result<(), ArrowError> { + validate_decimal_precision(num, precision, scale) } fn is_valid_decimal_precision(value: Self::Native, precision: u8) -> bool { @@ -1278,6 +1503,8 @@ impl DecimalType for Decimal256Type { const BYTE_LENGTH: usize = 32; const MAX_PRECISION: u8 = DECIMAL256_MAX_PRECISION; const MAX_SCALE: i8 = DECIMAL256_MAX_SCALE; + const MAX_FOR_EACH_PRECISION: &'static [i256] = + &arrow_data::decimal::MAX_DECIMAL256_FOR_EACH_PRECISION; const TYPE_CONSTRUCTOR: fn(u8, i8) -> DataType = DataType::Decimal256; const DEFAULT_TYPE: DataType = DataType::Decimal256(DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE); @@ -1287,8 +1514,8 @@ impl DecimalType for Decimal256Type { format_decimal_str(&value.to_string(), precision as usize, scale) } - fn validate_decimal_precision(num: i256, precision: u8) -> Result<(), ArrowError> { - validate_decimal256_precision(num, precision) + fn validate_decimal_precision(num: i256, precision: u8, scale: i8) -> Result<(), ArrowError> { + validate_decimal256_precision(num, precision, scale) } fn is_valid_decimal_precision(value: Self::Native, precision: u8) -> bool { @@ -1304,29 +1531,6 @@ impl ArrowPrimitiveType for Decimal256Type { impl primitive::PrimitiveTypeSealed for Decimal256Type {} -fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String { - let (sign, rest) = match value_str.strip_prefix('-') { - Some(stripped) => ("-", stripped), - None => ("", value_str), - }; - let bound = precision.min(rest.len()) + sign.len(); - let value_str = &value_str[0..bound]; - - if scale == 0 { - value_str.to_string() - } else if scale < 0 { - let padding = value_str.len() + scale.unsigned_abs() as usize; - format!("{value_str:0 scale as usize { - // Decimal separator is in the middle of the string - let (whole, decimal) = value_str.split_at(value_str.len() - scale as usize); - format!("{whole}.{decimal}") - } else { - // String has to be padded - format!("{}0.{:0>width$}", sign, rest, width = scale as usize) - } -} - /// Crate private types for Byte Arrays /// /// Not intended to be used outside this crate @@ -1366,7 +1570,7 @@ pub(crate) mod bytes { #[inline] unsafe fn from_bytes_unchecked(b: &[u8]) -> &Self { - std::str::from_utf8_unchecked(b) + unsafe { std::str::from_utf8_unchecked(b) } } } } @@ -1541,7 +1745,7 @@ impl ByteViewType for BinaryViewType { #[cfg(test)] mod tests { use super::*; - use arrow_data::{layout, BufferSpec}; + use arrow_data::{BufferSpec, layout}; #[test] fn month_day_nano_should_roundtrip() { @@ -1607,6 +1811,8 @@ mod tests { test_layout::(); test_layout::(); test_layout::(); + test_layout::(); + test_layout::(); test_layout::(); test_layout::(); test_layout::(); diff --git a/arrow-avro/Cargo.toml b/arrow-avro/Cargo.toml index 24297f4a7e5f..48cea8467eb7 100644 --- a/arrow-avro/Cargo.toml +++ b/arrow-avro/Cargo.toml @@ -36,27 +36,62 @@ bench = false all-features = true [features] -default = ["deflate", "snappy", "zstd"] +default = ["deflate", "snappy", "zstd", "bzip2", "xz"] deflate = ["flate2"] snappy = ["snap", "crc"] +canonical_extension_types = ["arrow-schema/canonical_extension_types"] +md5 = ["dep:md5"] +sha256 = ["dep:sha2"] +small_decimals = [] +avro_custom_types = ["dep:arrow-select"] [dependencies] arrow-schema = { workspace = true } arrow-buffer = { workspace = true } arrow-array = { workspace = true } +arrow-select = { workspace = true, optional = true } serde_json = { version = "1.0", default-features = false, features = ["std"] } serde = { version = "1.0.188", features = ["derive"] } -flate2 = { version = "1.0", default-features = false, features = ["rust_backend"], optional = true } +flate2 = { version = "1.0", default-features = false, features = [ + "rust_backend", +], optional = true } snap = { version = "1.0", default-features = false, optional = true } zstd = { version = "0.13", default-features = false, optional = true } +bzip2 = { version = "0.6.0", optional = true } +xz = { package = "liblzma", version = "0.4", default-features = false, optional = true } crc = { version = "3.0", optional = true } +strum_macros = "0.27" +uuid = "1.17" +indexmap = "2.10" +rand = "0.9" +md5 = { version = "0.8", optional = true } +sha2 = { version = "0.10", optional = true } [dev-dependencies] -rand = { version = "0.9", default-features = false, features = ["std", "std_rng", "thread_rng"] } -criterion = { version = "0.5", default-features = false } +arrow-data = { workspace = true } +rand = { version = "0.9.1", default-features = false, features = [ + "std", + "std_rng", + "thread_rng", +] } +criterion = { workspace = true, default-features = false } tempfile = "3.3" arrow = { workspace = true } +futures = "0.3.31" +bytes = "1.10.1" +async-stream = "0.3.6" +apache-avro = "0.21.0" +num-bigint = "0.4" +once_cell = "1.21.3" [[bench]] name = "avro_reader" harness = false + +[[bench]] +name = "decoder" +harness = false + +[[bench]] +name = "avro_writer" +harness = false diff --git a/arrow-avro/README.md b/arrow-avro/README.md new file mode 100644 index 000000000000..85fd76094755 --- /dev/null +++ b/arrow-avro/README.md @@ -0,0 +1,182 @@ + + +# `arrow-avro` + +[![crates.io](https://img.shields.io/crates/v/arrow-avro.svg)](https://crates.io/crates/arrow-avro) +[![docs.rs](https://img.shields.io/docsrs/arrow-avro.svg)](https://docs.rs/arrow-avro/latest/arrow_avro/) + +Transfer data between the [Apache Arrow] memory format and [Apache Avro]. + +This crate provides: + +- a **reader** that decodes Avro + - **Object Container Files (OCF)**, + - **Avro Single‑Object Encoding (SOE)**, and + - **Confluent Schema Registry wire format** + into Arrow `RecordBatch`es; and +- a **writer** that encodes Arrow `RecordBatch`es into Avro (**OCF** or **SOE**). + +> The latest API docs for `main` (unreleased) are published on the Arrow website: **arrow_avro**. + +[Apache Arrow]: https://arrow.apache.org/ +[Apache Avro]: https://avro.apache.org/ + +--- + +## Install + +```toml +[dependencies] +arrow-avro = "57.0.0" +```` + +Disable defaults and pick only what you need (see **Feature Flags**): + +```toml +[dependencies] +arrow-avro = { version = "57.0.0", default-features = false, features = ["deflate", "snappy"] } +``` + +--- + +## Quick start + +### Read an Avro OCF file into Arrow + +```rust +use std::fs::File; +use std::io::BufReader; + +use arrow_avro::reader::ReaderBuilder; +use arrow_array::RecordBatch; + +fn main() -> anyhow::Result<()> { + let file = BufReader::new(File::open("data/example.avro")?); + let mut reader = ReaderBuilder::new().build(file)?; + while let Some(batch) = reader.next() { + let batch: RecordBatch = batch?; + println!("rows: {}", batch.num_rows()); + } + Ok(()) +} +``` + +### Write Arrow to Avro OCF (in‑memory) + +```rust +use std::sync::Arc; + +use arrow_avro::writer::AvroWriter; +use arrow_array::{ArrayRef, Int32Array, RecordBatch}; +use arrow_schema::{DataType, Field, Schema}; + +fn main() -> anyhow::Result<()> { + let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef], + )?; + + let sink: Vec = Vec::new(); + let mut w = AvroWriter::new(sink, schema)?; + w.write(&batch)?; + w.finish()?; + assert!(!w.into_inner().is_empty()); + Ok(()) +} +``` + +See the crate docs for runnable SOE and Confluent round‑trip examples. + +--- + +## Feature Flags (what they do and when to use them) + +### Compression codecs (OCF block compression) + +`arrow-avro` supports the Avro‑standard OCF codecs. The **defaults** include all five: `deflate`, `snappy`, `zstd`, `bzip2`, and `xz`. + +| Feature | Default | What it enables | When to use | +|-----------|--------:|---------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------| +| `deflate` | ✅ | DEFLATE compression via `flate2` (pure‑Rust backend) | Most compatible; widely supported; good compression, slower than Snappy. | +| `snappy` | ✅ | Snappy block compression via `snap` with CRC‑32 as required by Avro | Fastest decode/encode; common in streaming/data‑lake pipelines. (Avro requires a 4‑byte big‑endian CRC of the **uncompressed** block.) | +| `zstd` | ✅ | Zstandard block compression via `zstd` | Great compression/speed trade‑off on modern systems. May pull in a native library. | +| `bzip2` | ✅ | BZip2 block compression | For compatibility with older datasets that used BZip2. Slower; larger deps. | +| `xz` | ✅ | XZ/LZMA block compression | Highest compression for archival data; slowest; larger deps. | + +> Avro defines these codecs for OCF: `null` (no compression), `deflate`, `snappy`, `bzip2`, `xz`, and `zstandard` (recent spec versions). + +**Notes** + +* Only **OCF** uses these codecs (they compress per‑block). They do **not** apply to raw Avro frames used by Confluent wire format or SOE. The crate’s `compression` module is specifically for **OCF blocks**. +* `deflate` uses `flate2` with the `rust_backend` (no system zlib required). + +### Schema fingerprints & custom logical type helpers + +| Feature | Default | What it enables | When to use | +|-----------------------------|--------:|----------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------| +| `md5` | ⬜ | `md5` dep for optional **MD5** schema fingerprints | If you want to compute MD5 fingerprints of writer schemas (i.e. for custom prefixing/validation). | +| `sha256` | ⬜ | `sha2` dep for optional **SHA‑256** schema fingerprints | If you prefer longer fingerprints; affects max prefix length (i.e. when framing). | +| `small_decimals` | ⬜ | Extra handling for **small decimal** logical types (`Decimal32` and `Decimal64`) | If your Avro `decimal` values are small and you want more compact Arrow representations. | +| `avro_custom_types` | ⬜ | Annotates Avro values using Arrow specific custom logical types | Enable when you need arrow-avro to reinterpret certain Avro fields as Arrow types that Avro doesn’t natively model. | +| `canonical_extension_types` | ⬜ | Re‑exports Arrow’s canonical extension types support from `arrow-schema` | Enable if your workflow uses Arrow [canonical extension types] and you want `arrow-avro` to respect them. | + +[canonical extension types]: https://arrow.apache.org/docs/format/CanonicalExtensions.html + +**Lower‑level/internal toggles (rarely used directly)** + +* `flate2`, `snap`, `crc`, `zstd`, `bzip2`, `xz` are optional **dependencies** wired to the user‑facing features above. You normally enable `deflate`/`snappy`/`zstd`/`bzip2`/`xz`, not these directly. + +### Feature snippets + +* Minimal, fast build (common pipelines): + + ```toml + arrow-avro = { version = "56", default-features = false, features = ["deflate", "snappy"] } + ``` +* Include Zstandard too (modern data lakes): + + ```toml + arrow-avro = { version = "56", default-features = false, features = ["deflate", "snappy", "zstd"] } + ``` +* Fingerprint helpers: + + ```toml + arrow-avro = { version = "56", features = ["md5", "sha256"] } + ``` + +--- + +## What formats are supported? + +* **OCF (Object Container Files)**: self‑describing Avro files with header, optional compression, sync markers; reader and writer supported. +* **Confluent Schema Registry wire format**: 1‑byte magic `0x00` + 4‑byte BE schema ID + Avro body; supports decode + encode helpers. +* **Avro Single‑Object Encoding (SOE)**: 2‑byte magic `0xC3 0x01` + 8‑byte LE CRC‑64‑AVRO fingerprint + Avro body; supports decode + encode helpers. + +--- + +## Examples + +* Read/write OCF in memory and from files (see crate docs “OCF round‑trip”). +* Confluent wire‑format and SOE quickstarts are provided as runnable snippets in docs. + +There are additional examples under `arrow-avro/examples/` in the repository. + +--- diff --git a/arrow-avro/benches/avro_reader.rs b/arrow-avro/benches/avro_reader.rs index b525a0c788cd..2f2a3a10dbf3 100644 --- a/arrow-avro/benches/avro_reader.rs +++ b/arrow-avro/benches/avro_reader.rs @@ -20,7 +20,7 @@ //! This benchmark suite compares the performance characteristics of StringArray vs //! StringViewArray across three key dimensions: //! 1. Array creation performance -//! 2. String value access operations +//! 2. String value access operations //! 3. Avro file reading with each array type use std::fs::File; @@ -31,14 +31,13 @@ use std::time::Duration; use arrow::array::RecordBatch; use arrow::datatypes::{DataType, Field, Schema}; use arrow_array::{ArrayRef, Int32Array, StringArray, StringViewArray}; -use arrow_avro::ReadOptions; use arrow_schema::ArrowError; use criterion::*; use tempfile::NamedTempFile; fn create_test_data(count: usize, str_length: usize) -> Vec { (0..count) - .map(|i| format!("str_{}", i) + &"a".repeat(str_length)) + .map(|i| format!("str_{i}") + &"a".repeat(str_length)) .collect() } @@ -79,7 +78,7 @@ fn create_avro_test_file(row_count: usize, str_length: usize) -> Result Result { let file = File::open(file_path)?; let mut reader = BufReader::new(file); @@ -101,7 +100,7 @@ fn read_avro_test_file( reader.read_exact(&mut buf)?; let s = String::from_utf8(buf) - .map_err(|e| ArrowError::ParseError(format!("Invalid UTF-8: {}", e)))?; + .map_err(|e| ArrowError::ParseError(format!("Invalid UTF-8: {e}")))?; strings.push(s); @@ -110,7 +109,7 @@ fn read_avro_test_file( ints.push(i32::from_le_bytes(int_bytes)); } - let string_array: ArrayRef = if options.use_utf8view() { + let string_array: ArrayRef = if use_utf8view { Arc::new(StringViewArray::from_iter( strings.iter().map(|s| Some(s.as_str())), )) @@ -123,7 +122,7 @@ fn read_avro_test_file( let int_array: ArrayRef = Arc::new(Int32Array::from(ints)); let schema = Arc::new(Schema::new(vec![ - if options.use_utf8view() { + if use_utf8view { Field::new("string_field", DataType::Utf8View, false) } else { Field::new("string_field", DataType::Utf8, false) @@ -143,7 +142,7 @@ fn bench_array_creation(c: &mut Criterion) { let data = create_test_data(10000, str_length); let row_count = 1000; - group.bench_function(format!("string_array_{}_chars", str_length), |b| { + group.bench_function(format!("string_array_{str_length}_chars"), |b| { b.iter(|| { let string_array = StringArray::from_iter(data[0..row_count].iter().map(|s| Some(s.as_str()))); @@ -163,11 +162,11 @@ fn bench_array_creation(c: &mut Criterion) { ) .unwrap(); - criterion::black_box(batch) + std::hint::black_box(batch) }) }); - group.bench_function(format!("string_view_{}_chars", str_length), |b| { + group.bench_function(format!("string_view_{str_length}_chars"), |b| { b.iter(|| { let string_array = StringViewArray::from_iter(data[0..row_count].iter().map(|s| Some(s.as_str()))); @@ -187,7 +186,7 @@ fn bench_array_creation(c: &mut Criterion) { ) .unwrap(); - criterion::black_box(batch) + std::hint::black_box(batch) }) }); } @@ -208,23 +207,23 @@ fn bench_string_operations(c: &mut Criterion) { let string_view_array = StringViewArray::from_iter(data[0..rows].iter().map(|s| Some(s.as_str()))); - group.bench_function(format!("string_array_value_{}_chars", str_length), |b| { + group.bench_function(format!("string_array_value_{str_length}_chars"), |b| { b.iter(|| { let mut sum_len = 0; for i in 0..rows { sum_len += string_array.value(i).len(); } - criterion::black_box(sum_len) + std::hint::black_box(sum_len) }) }); - group.bench_function(format!("string_view_value_{}_chars", str_length), |b| { + group.bench_function(format!("string_view_value_{str_length}_chars"), |b| { b.iter(|| { let mut sum_len = 0; for i in 0..rows { sum_len += string_view_array.value(i).len(); } - criterion::black_box(sum_len) + std::hint::black_box(sum_len) }) }); } @@ -242,19 +241,17 @@ fn bench_avro_reader(c: &mut Criterion) { let temp_file = create_avro_test_file(row_count, str_length).unwrap(); let file_path = temp_file.path(); - group.bench_function(format!("string_array_{}_chars", str_length), |b| { + group.bench_function(format!("string_array_{str_length}_chars"), |b| { b.iter(|| { - let options = ReadOptions::default(); - let batch = read_avro_test_file(file_path, &options).unwrap(); - criterion::black_box(batch) + let batch = read_avro_test_file(file_path, false).unwrap(); + std::hint::black_box(batch) }) }); - group.bench_function(format!("string_view_{}_chars", str_length), |b| { + group.bench_function(format!("string_view_{str_length}_chars"), |b| { b.iter(|| { - let options = ReadOptions::default().with_utf8view(true); - let batch = read_avro_test_file(file_path, &options).unwrap(); - criterion::black_box(batch) + let batch = read_avro_test_file(file_path, true).unwrap(); + std::hint::black_box(batch) }) }); } diff --git a/arrow-avro/benches/avro_writer.rs b/arrow-avro/benches/avro_writer.rs new file mode 100644 index 000000000000..58b014c5a3fe --- /dev/null +++ b/arrow-avro/benches/avro_writer.rs @@ -0,0 +1,849 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmarks for `arrow-avro` Writer (Avro Object Container File) + +extern crate arrow_avro; +extern crate criterion; +extern crate once_cell; + +use arrow_array::{ + ArrayRef, BinaryArray, BooleanArray, Decimal128Array, Decimal256Array, FixedSizeBinaryArray, + Float32Array, Float64Array, ListArray, PrimitiveArray, RecordBatch, StringArray, StructArray, + builder::{ListBuilder, StringBuilder}, + types::{Int32Type, Int64Type, IntervalMonthDayNanoType, TimestampMicrosecondType}, +}; +#[cfg(feature = "small_decimals")] +use arrow_array::{Decimal32Array, Decimal64Array}; +use arrow_avro::writer::AvroWriter; +use arrow_buffer::{Buffer, i256}; +use arrow_schema::{DataType, Field, IntervalUnit, Schema, TimeUnit, UnionFields, UnionMode}; +use criterion::{BatchSize, BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; +use once_cell::sync::Lazy; +use rand::{ + Rng, SeedableRng, + distr::uniform::{SampleRange, SampleUniform}, + rngs::StdRng, +}; +use std::collections::HashMap; +use std::io::Cursor; +use std::sync::Arc; +use std::time::Duration; +use tempfile::tempfile; + +const SIZES: [usize; 4] = [4_096, 8_192, 100_000, 1_000_000]; +const BASE_SEED: u64 = 0x5EED_1234_ABCD_EF01; +const MIX_CONST_1: u64 = 0x9E37_79B1_85EB_CA87; +const MIX_CONST_2: u64 = 0xC2B2_AE3D_27D4_EB4F; + +#[inline] +fn rng_for(tag: u64, n: usize) -> StdRng { + let seed = BASE_SEED ^ tag.wrapping_mul(MIX_CONST_1) ^ (n as u64).wrapping_mul(MIX_CONST_2); + StdRng::seed_from_u64(seed) +} + +#[inline] +fn sample_in(rng: &mut StdRng, range: Rg) -> T +where + T: SampleUniform, + Rg: SampleRange, +{ + rng.random_range(range) +} + +#[inline] +fn make_bool_array_with_tag(n: usize, tag: u64) -> BooleanArray { + let mut rng = rng_for(tag, n); + // Can't use SampleUniform for bool; use the RNG's boolean helper + let values = (0..n).map(|_| rng.random_bool(0.5)); + // This repo exposes `from_iter`, not `from_iter_values` for BooleanArray + BooleanArray::from_iter(values.map(Some)) +} + +#[inline] +fn make_i32_array_with_tag(n: usize, tag: u64) -> PrimitiveArray { + let mut rng = rng_for(tag, n); + let values = (0..n).map(|_| rng.random::()); + PrimitiveArray::::from_iter_values(values) +} + +#[inline] +fn make_i64_array_with_tag(n: usize, tag: u64) -> PrimitiveArray { + let mut rng = rng_for(tag, n); + let values = (0..n).map(|_| rng.random::()); + PrimitiveArray::::from_iter_values(values) +} + +#[inline] +fn rand_ascii_string(rng: &mut StdRng, min_len: usize, max_len: usize) -> String { + let len = rng.random_range(min_len..=max_len); + (0..len) + .map(|_| rng.random_range(b'a'..=b'z') as char) + .collect() +} + +#[inline] +fn make_utf8_array_with_tag(n: usize, tag: u64) -> StringArray { + let mut rng = rng_for(tag, n); + let data: Vec = (0..n).map(|_| rand_ascii_string(&mut rng, 3, 16)).collect(); + StringArray::from_iter_values(data) +} + +#[inline] +fn make_f32_array_with_tag(n: usize, tag: u64) -> Float32Array { + let mut rng = rng_for(tag, n); + let values = (0..n).map(|_| rng.random::()); + Float32Array::from_iter_values(values) +} + +#[inline] +fn make_f64_array_with_tag(n: usize, tag: u64) -> Float64Array { + let mut rng = rng_for(tag, n); + let values = (0..n).map(|_| rng.random::()); + Float64Array::from_iter_values(values) +} + +#[inline] +fn make_binary_array_with_tag(n: usize, tag: u64) -> BinaryArray { + let mut rng = rng_for(tag, n); + let mut payloads: Vec> = Vec::with_capacity(n); + for _ in 0..n { + let len = rng.random_range(1..=16); + let mut p = vec![0u8; len]; + rng.fill(&mut p[..]); + payloads.push(p); + } + let views: Vec<&[u8]> = payloads.iter().map(|p| &p[..]).collect(); + // This repo exposes a simple `from_vec` for BinaryArray + BinaryArray::from_vec(views) +} + +#[inline] +fn make_fixed16_array_with_tag(n: usize, tag: u64) -> FixedSizeBinaryArray { + let mut rng = rng_for(tag, n); + let payloads = (0..n) + .map(|_| { + let mut b = [0u8; 16]; + rng.fill(&mut b); + b + }) + .collect::>(); + // Fixed-size constructor available in this repo + FixedSizeBinaryArray::try_from_iter(payloads.into_iter()).expect("build FixedSizeBinaryArray") +} + +/// Make an Arrow `Interval(IntervalUnit::MonthDayNano)` array with **non-negative** +/// (months, days, nanos) values, and nanos as **multiples of 1_000_000** (whole ms), +/// per Avro `duration` constraints used by the writer. +#[inline] +fn make_interval_mdn_array_with_tag( + n: usize, + tag: u64, +) -> PrimitiveArray { + let mut rng = rng_for(tag, n); + let values = (0..n).map(|_| { + let months: i32 = rng.random_range(0..=120); + let days: i32 = rng.random_range(0..=31); + // pick millis within a day (safe within u32::MAX and realistic) + let millis: u32 = rng.random_range(0..=86_400_000); + let nanos: i64 = (millis as i64) * 1_000_000; + IntervalMonthDayNanoType::make_value(months, days, nanos) + }); + PrimitiveArray::::from_iter_values(values) +} + +#[inline] +fn make_ts_micros_array_with_tag(n: usize, tag: u64) -> PrimitiveArray { + let mut rng = rng_for(tag, n); + let base: i64 = 1_600_000_000_000_000; + let year_us: i64 = 31_536_000_000_000; + let values = (0..n).map(|_| base + sample_in::(&mut rng, 0..year_us)); + PrimitiveArray::::from_iter_values(values) +} + +// === Decimal helpers & generators === + +#[inline] +#[cfg(feature = "small_decimals")] +fn pow10_i32(p: u8) -> i32 { + (0..p).fold(1i32, |acc, _| acc.saturating_mul(10)) +} + +#[inline] +#[cfg(feature = "small_decimals")] +fn pow10_i64(p: u8) -> i64 { + (0..p).fold(1i64, |acc, _| acc.saturating_mul(10)) +} + +#[inline] +fn pow10_i128(p: u8) -> i128 { + (0..p).fold(1i128, |acc, _| acc.saturating_mul(10)) +} + +#[inline] +#[cfg(feature = "small_decimals")] +fn make_decimal32_array_with_tag(n: usize, tag: u64, precision: u8, scale: i8) -> Decimal32Array { + let mut rng = rng_for(tag, n); + let max = pow10_i32(precision).saturating_sub(1); + let values = (0..n).map(|_| rng.random_range(-max..=max)); + Decimal32Array::from_iter_values(values) + .with_precision_and_scale(precision, scale) + .expect("set precision/scale on Decimal32Array") +} + +#[inline] +#[cfg(feature = "small_decimals")] +fn make_decimal64_array_with_tag(n: usize, tag: u64, precision: u8, scale: i8) -> Decimal64Array { + let mut rng = rng_for(tag, n); + let max = pow10_i64(precision).saturating_sub(1); + let values = (0..n).map(|_| rng.random_range(-max..=max)); + Decimal64Array::from_iter_values(values) + .with_precision_and_scale(precision, scale) + .expect("set precision/scale on Decimal64Array") +} + +#[inline] +fn make_decimal128_array_with_tag(n: usize, tag: u64, precision: u8, scale: i8) -> Decimal128Array { + let mut rng = rng_for(tag, n); + let max = pow10_i128(precision).saturating_sub(1); + let values = (0..n).map(|_| rng.random_range(-max..=max)); + Decimal128Array::from_iter_values(values) + .with_precision_and_scale(precision, scale) + .expect("set precision/scale on Decimal128Array") +} + +#[inline] +fn make_decimal256_array_with_tag(n: usize, tag: u64, precision: u8, scale: i8) -> Decimal256Array { + // Generate within i128 range and widen to i256 to keep generation cheap and portable + let mut rng = rng_for(tag, n); + let max128 = pow10_i128(30).saturating_sub(1); + let values = (0..n).map(|_| { + let v: i128 = rng.random_range(-max128..=max128); + i256::from_i128(v) + }); + Decimal256Array::from_iter_values(values) + .with_precision_and_scale(precision, scale) + .expect("set precision/scale on Decimal256Array") +} + +#[inline] +fn make_fixed16_array(n: usize) -> FixedSizeBinaryArray { + make_fixed16_array_with_tag(n, 0xF15E_D016) +} + +#[inline] +fn make_interval_mdn_array(n: usize) -> PrimitiveArray { + make_interval_mdn_array_with_tag(n, 0xD0_1E_AD) +} + +#[inline] +fn make_bool_array(n: usize) -> BooleanArray { + make_bool_array_with_tag(n, 0xB001) +} +#[inline] +fn make_i32_array(n: usize) -> PrimitiveArray { + make_i32_array_with_tag(n, 0x1337_0032) +} +#[inline] +fn make_i64_array(n: usize) -> PrimitiveArray { + make_i64_array_with_tag(n, 0x1337_0064) +} +#[inline] +fn make_f32_array(n: usize) -> Float32Array { + make_f32_array_with_tag(n, 0xF0_0032) +} +#[inline] +fn make_f64_array(n: usize) -> Float64Array { + make_f64_array_with_tag(n, 0xF0_0064) +} +#[inline] +fn make_binary_array(n: usize) -> BinaryArray { + make_binary_array_with_tag(n, 0xB1_0001) +} +#[inline] +fn make_ts_micros_array(n: usize) -> PrimitiveArray { + make_ts_micros_array_with_tag(n, 0x7157_0001) +} +#[inline] +fn make_utf8_array(n: usize) -> StringArray { + make_utf8_array_with_tag(n, 0x5712_07F8) +} +#[inline] +fn make_list_utf8_array(n: usize) -> ListArray { + make_list_utf8_array_with_tag(n, 0x0A11_57ED) +} +#[inline] +fn make_struct_array(n: usize) -> StructArray { + make_struct_array_with_tag(n, 0x57_AB_C7) +} + +#[inline] +fn make_list_utf8_array_with_tag(n: usize, tag: u64) -> ListArray { + let mut rng = rng_for(tag, n); + let mut builder = ListBuilder::new(StringBuilder::new()); + for _ in 0..n { + let items = rng.random_range(0..=5); + for _ in 0..items { + let s = rand_ascii_string(&mut rng, 1, 12); + builder.values().append_value(s.as_str()); + } + builder.append(true); + } + builder.finish() +} + +#[inline] +fn make_struct_array_with_tag(n: usize, tag: u64) -> StructArray { + let s_tag = tag ^ 0x5u64; + let i_tag = tag ^ 0x6u64; + let f_tag = tag ^ 0x7u64; + let s_col: ArrayRef = Arc::new(make_utf8_array_with_tag(n, s_tag)); + let i_col: ArrayRef = Arc::new(make_i32_array_with_tag(n, i_tag)); + let f_col: ArrayRef = Arc::new(make_f64_array_with_tag(n, f_tag)); + StructArray::from(vec![ + ( + Arc::new(Field::new("s1", DataType::Utf8, false)), + s_col.clone(), + ), + ( + Arc::new(Field::new("s2", DataType::Int32, false)), + i_col.clone(), + ), + ( + Arc::new(Field::new("s3", DataType::Float64, false)), + f_col.clone(), + ), + ]) +} + +#[inline] +fn schema_single(name: &str, dt: DataType) -> Arc { + Arc::new(Schema::new(vec![Field::new(name, dt, false)])) +} + +#[inline] +fn schema_mixed() -> Arc { + Arc::new(Schema::new(vec![ + Field::new("f1", DataType::Int32, false), + Field::new("f2", DataType::Int64, false), + Field::new("f3", DataType::Binary, false), + Field::new("f4", DataType::Float64, false), + ])) +} + +#[inline] +fn schema_fixed16() -> Arc { + schema_single("field1", DataType::FixedSizeBinary(16)) +} + +#[inline] +fn schema_uuid16() -> Arc { + let mut md = HashMap::new(); + md.insert("logicalType".to_string(), "uuid".to_string()); + let field = Field::new("uuid", DataType::FixedSizeBinary(16), false).with_metadata(md); + Arc::new(Schema::new(vec![field])) +} + +#[inline] +fn schema_interval_mdn() -> Arc { + schema_single("duration", DataType::Interval(IntervalUnit::MonthDayNano)) +} + +#[inline] +fn schema_decimal_with_size(name: &str, dt: DataType, size_meta: Option) -> Arc { + let field = if let Some(size) = size_meta { + let mut md = HashMap::new(); + md.insert("size".to_string(), size.to_string()); + Field::new(name, dt, false).with_metadata(md) + } else { + Field::new(name, dt, false) + }; + Arc::new(Schema::new(vec![field])) +} + +static BOOLEAN_DATA: Lazy> = Lazy::new(|| { + let schema = schema_single("field1", DataType::Boolean); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_bool_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static INT32_DATA: Lazy> = Lazy::new(|| { + let schema = schema_single("field1", DataType::Int32); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_i32_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static INT64_DATA: Lazy> = Lazy::new(|| { + let schema = schema_single("field1", DataType::Int64); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_i64_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static FLOAT32_DATA: Lazy> = Lazy::new(|| { + let schema = schema_single("field1", DataType::Float32); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_f32_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static FLOAT64_DATA: Lazy> = Lazy::new(|| { + let schema = schema_single("field1", DataType::Float64); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_f64_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static BINARY_DATA: Lazy> = Lazy::new(|| { + let schema = schema_single("field1", DataType::Binary); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_binary_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static FIXED16_DATA: Lazy> = Lazy::new(|| { + let schema = schema_fixed16(); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_fixed16_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static UUID16_DATA: Lazy> = Lazy::new(|| { + let schema = schema_uuid16(); + SIZES + .iter() + .map(|&n| { + // Same values as Fixed16; writer path differs because of field metadata + let col: ArrayRef = Arc::new(make_fixed16_array_with_tag(n, 0x7575_6964_7575_6964)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static INTERVAL_MDN_DATA: Lazy> = Lazy::new(|| { + let schema = schema_interval_mdn(); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_interval_mdn_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static TIMESTAMP_US_DATA: Lazy> = Lazy::new(|| { + let schema = schema_single("field1", DataType::Timestamp(TimeUnit::Microsecond, None)); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_ts_micros_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static MIXED_DATA: Lazy> = Lazy::new(|| { + let schema = schema_mixed(); + SIZES + .iter() + .map(|&n| { + let f1: ArrayRef = Arc::new(make_i32_array_with_tag(n, 0xA1)); + let f2: ArrayRef = Arc::new(make_i64_array_with_tag(n, 0xA2)); + let f3: ArrayRef = Arc::new(make_binary_array_with_tag(n, 0xA3)); + let f4: ArrayRef = Arc::new(make_f64_array_with_tag(n, 0xA4)); + RecordBatch::try_new(schema.clone(), vec![f1, f2, f3, f4]).unwrap() + }) + .collect() +}); + +static UTF8_DATA: Lazy> = Lazy::new(|| { + let schema = schema_single("field1", DataType::Utf8); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_utf8_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static LIST_UTF8_DATA: Lazy> = Lazy::new(|| { + // IMPORTANT: ListBuilder creates a child field named "item" that is nullable by default. + // Make the schema's list item nullable to match the array we construct. + let item_field = Arc::new(Field::new("item", DataType::Utf8, true)); + let schema = schema_single("field1", DataType::List(item_field)); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_list_utf8_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static STRUCT_DATA: Lazy> = Lazy::new(|| { + let struct_dt = DataType::Struct( + vec![ + Field::new("s1", DataType::Utf8, false), + Field::new("s2", DataType::Int32, false), + Field::new("s3", DataType::Float64, false), + ] + .into(), + ); + let schema = schema_single("field1", struct_dt); + SIZES + .iter() + .map(|&n| { + let col: ArrayRef = Arc::new(make_struct_array(n)); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +#[cfg(feature = "small_decimals")] +static DECIMAL32_DATA: Lazy> = Lazy::new(|| { + // Choose a representative precision/scale within Decimal32 limits + let precision: u8 = 7; + let scale: i8 = 2; + let schema = schema_single("amount", DataType::Decimal32(precision, scale)); + SIZES + .iter() + .map(|&n| { + let arr = make_decimal32_array_with_tag(n, 0xDEC_0032, precision, scale); + let col: ArrayRef = Arc::new(arr); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +#[cfg(feature = "small_decimals")] +static DECIMAL64_DATA: Lazy> = Lazy::new(|| { + let precision: u8 = 13; + let scale: i8 = 3; + let schema = schema_single("amount", DataType::Decimal64(precision, scale)); + SIZES + .iter() + .map(|&n| { + let arr = make_decimal64_array_with_tag(n, 0xDEC_0064, precision, scale); + let col: ArrayRef = Arc::new(arr); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static DECIMAL128_BYTES_DATA: Lazy> = Lazy::new(|| { + let precision: u8 = 25; + let scale: i8 = 6; + let schema = schema_single("amount", DataType::Decimal128(precision, scale)); + SIZES + .iter() + .map(|&n| { + let arr = make_decimal128_array_with_tag(n, 0xDEC_0128, precision, scale); + let col: ArrayRef = Arc::new(arr); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static DECIMAL128_FIXED16_DATA: Lazy> = Lazy::new(|| { + // Same logical type as above but force Avro fixed(16) via metadata "size": "16" + let precision: u8 = 25; + let scale: i8 = 6; + let schema = + schema_decimal_with_size("amount", DataType::Decimal128(precision, scale), Some(16)); + SIZES + .iter() + .map(|&n| { + let arr = make_decimal128_array_with_tag(n, 0xDEC_F128, precision, scale); + let col: ArrayRef = Arc::new(arr); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static DECIMAL256_DATA: Lazy> = Lazy::new(|| { + // Use a higher precision typical of 256-bit decimals + let precision: u8 = 50; + let scale: i8 = 10; + let schema = schema_single("amount", DataType::Decimal256(precision, scale)); + SIZES + .iter() + .map(|&n| { + let arr = make_decimal256_array_with_tag(n, 0xDEC_0256, precision, scale); + let col: ArrayRef = Arc::new(arr); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static MAP_DATA: Lazy> = Lazy::new(|| { + use arrow_array::builder::{MapBuilder, StringBuilder}; + + let key_field = Arc::new(Field::new("keys", DataType::Utf8, false)); + let value_field = Arc::new(Field::new("values", DataType::Utf8, true)); + let entry_struct = Field::new( + "entries", + DataType::Struct(vec![key_field.as_ref().clone(), value_field.as_ref().clone()].into()), + false, + ); + let map_dt = DataType::Map(Arc::new(entry_struct), false); + let schema = schema_single("field1", map_dt); + + SIZES + .iter() + .map(|&n| { + // Build a MapArray with n rows + let mut builder = MapBuilder::new(None, StringBuilder::new(), StringBuilder::new()); + let mut rng = rng_for(0x00D0_0D1A, n); + for _ in 0..n { + let entries = rng.random_range(0..=5); + for _ in 0..entries { + let k = rand_ascii_string(&mut rng, 3, 10); + let v = rand_ascii_string(&mut rng, 0, 12); + // keys non-nullable, values nullable allowed but we provide non-null here + builder.keys().append_value(k); + builder.values().append_value(v); + } + builder.append(true).expect("Error building MapArray"); + } + let col: ArrayRef = Arc::new(builder.finish()); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static ENUM_DATA: Lazy> = Lazy::new(|| { + // To represent an Avro enum, the Arrow writer expects a Dictionary + // field with metadata specifying the enum symbols. + let enum_symbols = r#"["RED", "GREEN", "BLUE"]"#; + let mut metadata = HashMap::new(); + metadata.insert("avro.enum.symbols".to_string(), enum_symbols.to_string()); + + let dict_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let field = Field::new("color_enum", dict_type, false).with_metadata(metadata); + let schema = Arc::new(Schema::new(vec![field])); + + let dict_values: ArrayRef = Arc::new(StringArray::from(vec!["RED", "GREEN", "BLUE"])); + + SIZES + .iter() + .map(|&n| { + use arrow_array::DictionaryArray; + let mut rng = rng_for(0x3A7A, n); + let keys_vec: Vec = (0..n).map(|_| rng.random_range(0..=2)).collect(); + let keys = PrimitiveArray::::from(keys_vec); + + let dict_array = + DictionaryArray::::try_new(keys, dict_values.clone()).unwrap(); + let col: ArrayRef = Arc::new(dict_array); + + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +static UNION_DATA: Lazy> = Lazy::new(|| { + // Basic Dense Union of three types: Utf8, Int32, Float64 + let union_fields = UnionFields::try_new( + vec![0, 1, 2], + vec![ + Field::new("u_str", DataType::Utf8, true), + Field::new("u_int", DataType::Int32, true), + Field::new("u_f64", DataType::Float64, true), + ], + ) + .expect("UnionFields should be valid"); + let union_dt = DataType::Union(union_fields.clone(), UnionMode::Dense); + let schema = schema_single("field1", union_dt); + + SIZES + .iter() + .map(|&n| { + // Cycle type ids 0 -> 1 -> 2 ... for determinism + let mut type_ids: Vec = Vec::with_capacity(n); + let mut offsets: Vec = Vec::with_capacity(n); + let (mut c0, mut c1, mut c2) = (0i32, 0i32, 0i32); + for i in 0..n { + let tid = (i % 3) as i8; + type_ids.push(tid); + match tid { + 0 => { + offsets.push(c0); + c0 += 1; + } + 1 => { + offsets.push(c1); + c1 += 1; + } + _ => { + offsets.push(c2); + c2 += 1; + } + } + } + + // Build children arrays with lengths equal to counts per type id + let mut rng = rng_for(0xDEAD_0003, n); + let strings: Vec = (0..c0) + .map(|_| rand_ascii_string(&mut rng, 3, 12)) + .collect(); + let ints = 0..c1; + let floats = (0..c2).map(|_| rng.random::()); + + let str_arr = StringArray::from_iter_values(strings); + let int_arr: PrimitiveArray = PrimitiveArray::from_iter_values(ints); + let f_arr = Float64Array::from_iter_values(floats); + + let type_ids_buf = Buffer::from_slice_ref(type_ids.as_slice()); + let offsets_buf = Buffer::from_slice_ref(offsets.as_slice()); + + let union_array = arrow_array::UnionArray::try_new( + union_fields.clone(), + type_ids_buf.into(), + Some(offsets_buf.into()), + vec![ + Arc::new(str_arr) as ArrayRef, + Arc::new(int_arr) as ArrayRef, + Arc::new(f_arr) as ArrayRef, + ], + ) + .unwrap(); + + let col: ArrayRef = Arc::new(union_array); + RecordBatch::try_new(schema.clone(), vec![col]).unwrap() + }) + .collect() +}); + +fn ocf_size_for_batch(batch: &RecordBatch) -> usize { + let schema_owned: Schema = (*batch.schema()).clone(); + let cursor = Cursor::new(Vec::::with_capacity(1024)); + let mut writer = AvroWriter::new(cursor, schema_owned).expect("create writer"); + writer.write(batch).expect("write batch"); + writer.finish().expect("finish writer"); + let inner = writer.into_inner(); + inner.into_inner().len() +} + +fn bench_writer_scenario(c: &mut Criterion, name: &str, data_sets: &[RecordBatch]) { + let mut group = c.benchmark_group(name); + let schema_owned: Schema = (*data_sets[0].schema()).clone(); + for (idx, &rows) in SIZES.iter().enumerate() { + let batch = &data_sets[idx]; + let bytes = ocf_size_for_batch(batch); + group.throughput(Throughput::Bytes(bytes as u64)); + match rows { + 4_096 | 8_192 => { + group + .sample_size(40) + .measurement_time(Duration::from_secs(10)) + .warm_up_time(Duration::from_secs(3)); + } + 100_000 => { + group + .sample_size(20) + .measurement_time(Duration::from_secs(10)) + .warm_up_time(Duration::from_secs(3)); + } + 1_000_000 => { + group + .sample_size(10) + .measurement_time(Duration::from_secs(10)) + .warm_up_time(Duration::from_secs(3)); + } + _ => {} + } + group.bench_function(BenchmarkId::from_parameter(rows), |b| { + b.iter_batched_ref( + || { + let file = tempfile().expect("create temp file"); + AvroWriter::new(file, schema_owned.clone()).expect("create writer") + }, + |writer| { + writer.write(batch).unwrap(); + writer.finish().unwrap(); + }, + BatchSize::SmallInput, + ) + }); + } + group.finish(); +} + +fn criterion_benches(c: &mut Criterion) { + bench_writer_scenario(c, "write-Boolean", &BOOLEAN_DATA); + bench_writer_scenario(c, "write-Int32", &INT32_DATA); + bench_writer_scenario(c, "write-Int64", &INT64_DATA); + bench_writer_scenario(c, "write-Float32", &FLOAT32_DATA); + bench_writer_scenario(c, "write-Float64", &FLOAT64_DATA); + bench_writer_scenario(c, "write-Binary(Bytes)", &BINARY_DATA); + bench_writer_scenario(c, "write-TimestampMicros", &TIMESTAMP_US_DATA); + bench_writer_scenario(c, "write-Mixed", &MIXED_DATA); + bench_writer_scenario(c, "write-Utf8", &UTF8_DATA); + bench_writer_scenario(c, "write-List", &LIST_UTF8_DATA); + bench_writer_scenario(c, "write-Struct", &STRUCT_DATA); + bench_writer_scenario(c, "write-FixedSizeBinary16", &FIXED16_DATA); + bench_writer_scenario(c, "write-UUID(logicalType)", &UUID16_DATA); + bench_writer_scenario(c, "write-IntervalMonthDayNanoDuration", &INTERVAL_MDN_DATA); + #[cfg(feature = "small_decimals")] + bench_writer_scenario(c, "write-Decimal32(bytes)", &DECIMAL32_DATA); + #[cfg(feature = "small_decimals")] + bench_writer_scenario(c, "write-Decimal64(bytes)", &DECIMAL64_DATA); + bench_writer_scenario(c, "write-Decimal128(bytes)", &DECIMAL128_BYTES_DATA); + bench_writer_scenario(c, "write-Decimal128(fixed16)", &DECIMAL128_FIXED16_DATA); + bench_writer_scenario(c, "write-Decimal256(bytes)", &DECIMAL256_DATA); + bench_writer_scenario(c, "write-Map", &MAP_DATA); + bench_writer_scenario(c, "write-Enum", &ENUM_DATA); + bench_writer_scenario(c, "write-Union", &UNION_DATA); +} + +criterion_group! { + name = avro_writer; + config = Criterion::default().configure_from_args(); + targets = criterion_benches +} +criterion_main!(avro_writer); diff --git a/arrow-avro/benches/decoder.rs b/arrow-avro/benches/decoder.rs new file mode 100644 index 000000000000..7180826b7b7d --- /dev/null +++ b/arrow-avro/benches/decoder.rs @@ -0,0 +1,600 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Benchmarks for `arrow‑avro` **Decoder** +//! + +extern crate apache_avro; +extern crate arrow_avro; +extern crate criterion; +extern crate num_bigint; +extern crate once_cell; +extern crate uuid; + +use apache_avro::types::Value; +use apache_avro::{Decimal, Schema as ApacheSchema, to_avro_datum}; +use arrow_avro::schema::{CONFLUENT_MAGIC, Fingerprint, FingerprintAlgorithm, SINGLE_OBJECT_MAGIC}; +use arrow_avro::{reader::ReaderBuilder, schema::AvroSchema}; +use criterion::{BatchSize, BenchmarkId, Criterion, Throughput, criterion_group, criterion_main}; +use once_cell::sync::Lazy; +use std::{hint::black_box, time::Duration}; +use uuid::Uuid; + +fn make_prefix(fp: Fingerprint) -> Vec { + match fp { + Fingerprint::Rabin(val) => { + let mut buf = Vec::with_capacity(SINGLE_OBJECT_MAGIC.len() + size_of::()); + buf.extend_from_slice(&SINGLE_OBJECT_MAGIC); // C3 01 + buf.extend_from_slice(&val.to_le_bytes()); // little-endian + buf + } + Fingerprint::Id(id) => { + let mut buf = Vec::with_capacity(CONFLUENT_MAGIC.len() + size_of::()); + buf.extend_from_slice(&CONFLUENT_MAGIC); // 00 + buf.extend_from_slice(&id.to_be_bytes()); // big-endian + buf + } + Fingerprint::Id64(id) => { + let mut buf = Vec::with_capacity(CONFLUENT_MAGIC.len() + size_of::()); + buf.extend_from_slice(&CONFLUENT_MAGIC); // 00 + buf.extend_from_slice(&id.to_be_bytes()); // big-endian + buf + } + #[cfg(feature = "md5")] + Fingerprint::MD5(val) => { + let mut buf = Vec::with_capacity(SINGLE_OBJECT_MAGIC.len() + size_of_val(&val)); + buf.extend_from_slice(&SINGLE_OBJECT_MAGIC); // C3 01 + buf.extend_from_slice(&val); + buf + } + #[cfg(feature = "sha256")] + Fingerprint::SHA256(val) => { + let mut buf = Vec::with_capacity(SINGLE_OBJECT_MAGIC.len() + size_of_val(&val)); + buf.extend_from_slice(&SINGLE_OBJECT_MAGIC); // C3 01 + buf.extend_from_slice(&val); + buf + } + } +} + +fn encode_records_with_prefix( + schema: &ApacheSchema, + prefix: &[u8], + rows: impl Iterator, +) -> Vec { + let mut out = Vec::new(); + for v in rows { + out.extend_from_slice(prefix); + out.extend_from_slice(&to_avro_datum(schema, v).expect("encode datum failed")); + } + out +} + +fn gen_int(sc: &ApacheSchema, n: usize, prefix: &[u8]) -> Vec { + encode_records_with_prefix( + sc, + prefix, + (0..n).map(|i| Value::Record(vec![("field1".into(), Value::Int(i as i32))])), + ) +} + +fn gen_long(sc: &ApacheSchema, n: usize, prefix: &[u8]) -> Vec { + encode_records_with_prefix( + sc, + prefix, + (0..n).map(|i| Value::Record(vec![("field1".into(), Value::Long(i as i64))])), + ) +} + +fn gen_float(sc: &ApacheSchema, n: usize, prefix: &[u8]) -> Vec { + encode_records_with_prefix( + sc, + prefix, + (0..n).map(|i| Value::Record(vec![("field1".into(), Value::Float(i as f32 + 0.5678))])), + ) +} + +fn gen_bool(sc: &ApacheSchema, n: usize, prefix: &[u8]) -> Vec { + encode_records_with_prefix( + sc, + prefix, + (0..n).map(|i| Value::Record(vec![("field1".into(), Value::Boolean(i % 2 == 0))])), + ) +} + +fn gen_double(sc: &ApacheSchema, n: usize, prefix: &[u8]) -> Vec { + encode_records_with_prefix( + sc, + prefix, + (0..n).map(|i| Value::Record(vec![("field1".into(), Value::Double(i as f64 + 0.1234))])), + ) +} + +fn gen_bytes(sc: &ApacheSchema, n: usize, prefix: &[u8]) -> Vec { + encode_records_with_prefix( + sc, + prefix, + (0..n).map(|i| { + let payload = vec![(i & 0xFF) as u8; 16]; + Value::Record(vec![("field1".into(), Value::Bytes(payload))]) + }), + ) +} + +fn gen_string(sc: &ApacheSchema, n: usize, prefix: &[u8]) -> Vec { + encode_records_with_prefix( + sc, + prefix, + (0..n).map(|i| { + let s = if i % 3 == 0 { + format!("value-{i}") + } else { + "abcdefghij".into() + }; + Value::Record(vec![("field1".into(), Value::String(s))]) + }), + ) +} + +fn gen_date(sc: &ApacheSchema, n: usize, prefix: &[u8]) -> Vec { + encode_records_with_prefix( + sc, + prefix, + (0..n).map(|i| Value::Record(vec![("field1".into(), Value::Int(i as i32))])), + ) +} + +fn gen_timemillis(sc: &ApacheSchema, n: usize, prefix: &[u8]) -> Vec { + encode_records_with_prefix( + sc, + prefix, + (0..n).map(|i| Value::Record(vec![("field1".into(), Value::Int((i * 37) as i32))])), + ) +} + +fn gen_timemicros(sc: &ApacheSchema, n: usize, prefix: &[u8]) -> Vec { + encode_records_with_prefix( + sc, + prefix, + (0..n).map(|i| Value::Record(vec![("field1".into(), Value::Long((i * 1_001) as i64))])), + ) +} + +fn gen_ts_millis(sc: &ApacheSchema, n: usize, prefix: &[u8]) -> Vec { + encode_records_with_prefix( + sc, + prefix, + (0..n).map(|i| { + Value::Record(vec![( + "field1".into(), + Value::Long(1_600_000_000_000 + i as i64), + )]) + }), + ) +} + +fn gen_ts_micros(sc: &ApacheSchema, n: usize, prefix: &[u8]) -> Vec { + encode_records_with_prefix( + sc, + prefix, + (0..n).map(|i| { + Value::Record(vec![( + "field1".into(), + Value::Long(1_600_000_000_000_000 + i as i64), + )]) + }), + ) +} + +fn gen_map(sc: &ApacheSchema, n: usize, prefix: &[u8]) -> Vec { + use std::collections::HashMap; + encode_records_with_prefix( + sc, + prefix, + (0..n).map(|i| { + let mut m = HashMap::new(); + let int_val = |v: i32| Value::Union(0, Box::new(Value::Int(v))); + m.insert("key1".into(), int_val(i as i32)); + let key2_val = if i % 5 == 0 { + Value::Union(1, Box::new(Value::Null)) + } else { + int_val(i as i32 + 1) + }; + m.insert("key2".into(), key2_val); + m.insert("key3".into(), int_val(42)); + Value::Record(vec![("field1".into(), Value::Map(m))]) + }), + ) +} + +fn gen_array(sc: &ApacheSchema, n: usize, prefix: &[u8]) -> Vec { + encode_records_with_prefix( + sc, + prefix, + (0..n).map(|i| { + let items = (0..5).map(|j| Value::Int(i as i32 + j)).collect(); + Value::Record(vec![("field1".into(), Value::Array(items))]) + }), + ) +} + +fn trim_i128_be(v: i128) -> Vec { + let full = v.to_be_bytes(); + let first = full + .iter() + .enumerate() + .take_while(|(i, b)| { + *i < 15 + && ((**b == 0x00 && full[i + 1] & 0x80 == 0) + || (**b == 0xFF && full[i + 1] & 0x80 != 0)) + }) + .count(); + full[first..].to_vec() +} + +fn gen_decimal(sc: &ApacheSchema, n: usize, prefix: &[u8]) -> Vec { + encode_records_with_prefix( + sc, + prefix, + (0..n).map(|i| { + let unscaled = if i % 2 == 0 { i as i128 } else { -(i as i128) }; + Value::Record(vec![( + "field1".into(), + Value::Decimal(Decimal::from(trim_i128_be(unscaled))), + )]) + }), + ) +} + +fn gen_uuid(sc: &ApacheSchema, n: usize, prefix: &[u8]) -> Vec { + encode_records_with_prefix( + sc, + prefix, + (0..n).map(|i| { + let mut raw = (i as u128).to_be_bytes(); + raw[6] = (raw[6] & 0x0F) | 0x40; + raw[8] = (raw[8] & 0x3F) | 0x80; + Value::Record(vec![("field1".into(), Value::Uuid(Uuid::from_bytes(raw)))]) + }), + ) +} + +fn gen_fixed(sc: &ApacheSchema, n: usize, prefix: &[u8]) -> Vec { + encode_records_with_prefix( + sc, + prefix, + (0..n).map(|i| { + let mut buf = vec![0u8; 16]; + buf[..8].copy_from_slice(&(i as u64).to_be_bytes()); + Value::Record(vec![("field1".into(), Value::Fixed(16, buf))]) + }), + ) +} + +fn gen_interval(sc: &ApacheSchema, n: usize, prefix: &[u8]) -> Vec { + encode_records_with_prefix( + sc, + prefix, + (0..n).map(|i| { + let months = (i % 24) as u32; + let days = (i % 32) as u32; + let millis = (i * 10) as u32; + let mut buf = Vec::with_capacity(12); + buf.extend_from_slice(&months.to_le_bytes()); + buf.extend_from_slice(&days.to_le_bytes()); + buf.extend_from_slice(&millis.to_le_bytes()); + Value::Record(vec![("field1".into(), Value::Fixed(12, buf))]) + }), + ) +} + +fn gen_enum(sc: &ApacheSchema, n: usize, prefix: &[u8]) -> Vec { + const SYMBOLS: [&str; 3] = ["A", "B", "C"]; + encode_records_with_prefix( + sc, + prefix, + (0..n).map(|i| { + let idx = i % 3; + Value::Record(vec![( + "field1".into(), + Value::Enum(idx as u32, SYMBOLS[idx].into()), + )]) + }), + ) +} + +fn gen_mixed(sc: &ApacheSchema, n: usize, prefix: &[u8]) -> Vec { + encode_records_with_prefix( + sc, + prefix, + (0..n).map(|i| { + Value::Record(vec![ + ("f1".into(), Value::Int(i as i32)), + ("f2".into(), Value::Long(i as i64)), + ("f3".into(), Value::String(format!("name-{i}"))), + ("f4".into(), Value::Double(i as f64 * 1.5)), + ]) + }), + ) +} + +fn gen_nested(sc: &ApacheSchema, n: usize, prefix: &[u8]) -> Vec { + encode_records_with_prefix( + sc, + prefix, + (0..n).map(|i| { + let sub = Value::Record(vec![ + ("x".into(), Value::Int(i as i32)), + ("y".into(), Value::String("constant".into())), + ]); + Value::Record(vec![("sub".into(), sub)]) + }), + ) +} + +const LARGE_BATCH: usize = 65_536; +const SMALL_BATCH: usize = 4096; + +fn new_decoder( + schema_json: &'static str, + batch_size: usize, + utf8view: bool, +) -> arrow_avro::reader::Decoder { + let schema = AvroSchema::new(schema_json.parse().unwrap()); + let mut store = arrow_avro::schema::SchemaStore::new(); + store.register(schema.clone()).unwrap(); + ReaderBuilder::new() + .with_writer_schema_store(store) + .with_batch_size(batch_size) + .with_utf8_view(utf8view) + .build_decoder() + .expect("failed to build decoder") +} + +fn new_decoder_id( + schema_json: &'static str, + batch_size: usize, + utf8view: bool, + id: u32, +) -> arrow_avro::reader::Decoder { + let schema = AvroSchema::new(schema_json.parse().unwrap()); + let mut store = arrow_avro::schema::SchemaStore::new_with_type(FingerprintAlgorithm::Id); + // Register the schema with a provided Confluent-style ID + store + .set(Fingerprint::Id(id), schema.clone()) + .expect("failed to set schema with id"); + ReaderBuilder::new() + .with_writer_schema_store(store) + .with_active_fingerprint(Fingerprint::Id(id)) + .with_batch_size(batch_size) + .with_utf8_view(utf8view) + .build_decoder() + .expect("failed to build decoder for id") +} + +const SIZES: [usize; 3] = [100, 10_000, 1_000_000]; + +const INT_SCHEMA: &str = + r#"{"type":"record","name":"IntRec","fields":[{"name":"field1","type":"int"}]}"#; +const LONG_SCHEMA: &str = + r#"{"type":"record","name":"LongRec","fields":[{"name":"field1","type":"long"}]}"#; +const FLOAT_SCHEMA: &str = + r#"{"type":"record","name":"FloatRec","fields":[{"name":"field1","type":"float"}]}"#; +const BOOL_SCHEMA: &str = + r#"{"type":"record","name":"BoolRec","fields":[{"name":"field1","type":"boolean"}]}"#; +const DOUBLE_SCHEMA: &str = + r#"{"type":"record","name":"DoubleRec","fields":[{"name":"field1","type":"double"}]}"#; +const BYTES_SCHEMA: &str = + r#"{"type":"record","name":"BytesRec","fields":[{"name":"field1","type":"bytes"}]}"#; +const STRING_SCHEMA: &str = + r#"{"type":"record","name":"StrRec","fields":[{"name":"field1","type":"string"}]}"#; +const DATE_SCHEMA: &str = r#"{"type":"record","name":"DateRec","fields":[{"name":"field1","type":{"type":"int","logicalType":"date"}}]}"#; +const TMILLIS_SCHEMA: &str = r#"{"type":"record","name":"TimeMsRec","fields":[{"name":"field1","type":{"type":"int","logicalType":"time-millis"}}]}"#; +const TMICROS_SCHEMA: &str = r#"{"type":"record","name":"TimeUsRec","fields":[{"name":"field1","type":{"type":"long","logicalType":"time-micros"}}]}"#; +const TSMILLIS_SCHEMA: &str = r#"{"type":"record","name":"TsMsRec","fields":[{"name":"field1","type":{"type":"long","logicalType":"timestamp-millis"}}]}"#; +const TSMICROS_SCHEMA: &str = r#"{"type":"record","name":"TsUsRec","fields":[{"name":"field1","type":{"type":"long","logicalType":"timestamp-micros"}}]}"#; +const MAP_SCHEMA: &str = r#"{"type":"record","name":"MapRec","fields":[{"name":"field1","type":{"type":"map","values":["int","null"]}}]}"#; +const ARRAY_SCHEMA: &str = r#"{"type":"record","name":"ArrRec","fields":[{"name":"field1","type":{"type":"array","items":"int"}}]}"#; +const DECIMAL_SCHEMA: &str = r#"{"type":"record","name":"DecRec","fields":[{"name":"field1","type":{"type":"bytes","logicalType":"decimal","precision":10,"scale":3}}]}"#; +const UUID_SCHEMA: &str = r#"{"type":"record","name":"UuidRec","fields":[{"name":"field1","type":{"type":"string","logicalType":"uuid"}}]}"#; +const FIXED_SCHEMA: &str = r#"{"type":"record","name":"FixRec","fields":[{"name":"field1","type":{"type":"fixed","name":"Fixed16","size":16}}]}"#; +const INTERVAL_SCHEMA: &str = r#"{"type":"record","name":"DurRec","fields":[{"name":"field1","type":{"type":"fixed","name":"Duration12","size":12,"logicalType":"duration"}}]}"#; +const INTERVAL_SCHEMA_ENCODE: &str = r#"{"type":"record","name":"DurRec","fields":[{"name":"field1","type":{"type":"fixed","name":"Duration12","size":12}}]}"#; +const ENUM_SCHEMA: &str = r#"{"type":"record","name":"EnumRec","fields":[{"name":"field1","type":{"type":"enum","name":"MyEnum","symbols":["A","B","C"]}}]}"#; +const MIX_SCHEMA: &str = r#"{"type":"record","name":"MixRec","fields":[{"name":"f1","type":"int"},{"name":"f2","type":"long"},{"name":"f3","type":"string"},{"name":"f4","type":"double"}]}"#; +const NEST_SCHEMA: &str = r#"{"type":"record","name":"NestRec","fields":[{"name":"sub","type":{"type":"record","name":"Sub","fields":[{"name":"x","type":"int"},{"name":"y","type":"string"}]}}]}"#; + +macro_rules! dataset { + ($name:ident, $schema_json:expr, $gen_fn:ident) => { + static $name: Lazy>> = Lazy::new(|| { + let schema = + ApacheSchema::parse_str($schema_json).expect("invalid schema for generator"); + let arrow_schema = AvroSchema::new($schema_json.parse().unwrap()); + let fingerprint = arrow_schema + .fingerprint(FingerprintAlgorithm::Rabin) + .expect("fingerprint failed"); + let prefix = make_prefix(fingerprint); + SIZES + .iter() + .map(|&n| $gen_fn(&schema, n, &prefix)) + .collect() + }); + }; +} + +/// Additional helper for Confluent's ID-based wire format (00 + BE u32). +macro_rules! dataset_id { + ($name:ident, $schema_json:expr, $gen_fn:ident, $id:expr) => { + static $name: Lazy>> = Lazy::new(|| { + let schema = + ApacheSchema::parse_str($schema_json).expect("invalid schema for generator"); + let prefix = make_prefix(Fingerprint::Id($id)); + SIZES + .iter() + .map(|&n| $gen_fn(&schema, n, &prefix)) + .collect() + }); + }; +} + +const ID_BENCH_ID: u32 = 7; + +dataset_id!(INT_DATA_ID, INT_SCHEMA, gen_int, ID_BENCH_ID); +dataset!(INT_DATA, INT_SCHEMA, gen_int); +dataset!(LONG_DATA, LONG_SCHEMA, gen_long); +dataset!(FLOAT_DATA, FLOAT_SCHEMA, gen_float); +dataset!(BOOL_DATA, BOOL_SCHEMA, gen_bool); +dataset!(DOUBLE_DATA, DOUBLE_SCHEMA, gen_double); +dataset!(BYTES_DATA, BYTES_SCHEMA, gen_bytes); +dataset!(STRING_DATA, STRING_SCHEMA, gen_string); +dataset!(DATE_DATA, DATE_SCHEMA, gen_date); +dataset!(TMILLIS_DATA, TMILLIS_SCHEMA, gen_timemillis); +dataset!(TMICROS_DATA, TMICROS_SCHEMA, gen_timemicros); +dataset!(TSMILLIS_DATA, TSMILLIS_SCHEMA, gen_ts_millis); +dataset!(TSMICROS_DATA, TSMICROS_SCHEMA, gen_ts_micros); +dataset!(MAP_DATA, MAP_SCHEMA, gen_map); +dataset!(ARRAY_DATA, ARRAY_SCHEMA, gen_array); +dataset!(DECIMAL_DATA, DECIMAL_SCHEMA, gen_decimal); +dataset!(UUID_DATA, UUID_SCHEMA, gen_uuid); +dataset!(FIXED_DATA, FIXED_SCHEMA, gen_fixed); +dataset!(INTERVAL_DATA, INTERVAL_SCHEMA_ENCODE, gen_interval); +dataset!(ENUM_DATA, ENUM_SCHEMA, gen_enum); +dataset!(MIX_DATA, MIX_SCHEMA, gen_mixed); +dataset!(NEST_DATA, NEST_SCHEMA, gen_nested); + +fn bench_with_decoder( + c: &mut Criterion, + name: &str, + data_sets: &[Vec], + rows: &[usize], + mut new_decoder: F, +) where + F: FnMut() -> arrow_avro::reader::Decoder, +{ + let mut group = c.benchmark_group(name); + for (idx, &row_count) in rows.iter().enumerate() { + let datum = &data_sets[idx]; + group.throughput(Throughput::Bytes(datum.len() as u64)); + match row_count { + 10_000 => { + group + .sample_size(25) + .measurement_time(Duration::from_secs(10)) + .warm_up_time(Duration::from_secs(3)); + } + 1_000_000 => { + group + .sample_size(10) + .measurement_time(Duration::from_secs(10)) + .warm_up_time(Duration::from_secs(3)); + } + _ => {} + } + group.bench_function(BenchmarkId::from_parameter(row_count), |b| { + b.iter_batched_ref( + &mut new_decoder, + |decoder| { + black_box(decoder.decode(datum).unwrap()); + black_box(decoder.flush().unwrap().unwrap()); + }, + BatchSize::SmallInput, + ) + }); + } + group.finish(); +} + +fn criterion_benches(c: &mut Criterion) { + for &batch_size in &[SMALL_BATCH, LARGE_BATCH] { + bench_with_decoder(c, "Interval", &INTERVAL_DATA, &SIZES, || { + new_decoder(INTERVAL_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Int32", &INT_DATA, &SIZES, || { + new_decoder(INT_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Int32_Id", &INT_DATA_ID, &SIZES, || { + new_decoder_id(INT_SCHEMA, batch_size, false, ID_BENCH_ID) + }); + bench_with_decoder(c, "Int64", &LONG_DATA, &SIZES, || { + new_decoder(LONG_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Float32", &FLOAT_DATA, &SIZES, || { + new_decoder(FLOAT_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Boolean", &BOOL_DATA, &SIZES, || { + new_decoder(BOOL_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Float64", &DOUBLE_DATA, &SIZES, || { + new_decoder(DOUBLE_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Binary(Bytes)", &BYTES_DATA, &SIZES, || { + new_decoder(BYTES_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "String", &STRING_DATA, &SIZES, || { + new_decoder(STRING_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "StringView", &STRING_DATA, &SIZES, || { + new_decoder(STRING_SCHEMA, batch_size, true) + }); + bench_with_decoder(c, "Date32", &DATE_DATA, &SIZES, || { + new_decoder(DATE_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "TimeMillis", &TMILLIS_DATA, &SIZES, || { + new_decoder(TMILLIS_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "TimeMicros", &TMICROS_DATA, &SIZES, || { + new_decoder(TMICROS_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "TimestampMillis", &TSMILLIS_DATA, &SIZES, || { + new_decoder(TSMILLIS_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "TimestampMicros", &TSMICROS_DATA, &SIZES, || { + new_decoder(TSMICROS_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Map", &MAP_DATA, &SIZES, || { + new_decoder(MAP_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Array", &ARRAY_DATA, &SIZES, || { + new_decoder(ARRAY_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Decimal128", &DECIMAL_DATA, &SIZES, || { + new_decoder(DECIMAL_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "UUID", &UUID_DATA, &SIZES, || { + new_decoder(UUID_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "FixedSizeBinary", &FIXED_DATA, &SIZES, || { + new_decoder(FIXED_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Enum(Dictionary)", &ENUM_DATA, &SIZES, || { + new_decoder(ENUM_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Mixed", &MIX_DATA, &SIZES, || { + new_decoder(MIX_SCHEMA, batch_size, false) + }); + bench_with_decoder(c, "Nested(Struct)", &NEST_DATA, &SIZES, || { + new_decoder(NEST_SCHEMA, batch_size, false) + }); + } +} + +criterion_group! { + name = avro_decoder; + config = Criterion::default().configure_from_args(); + targets = criterion_benches +} +criterion_main!(avro_decoder); diff --git a/arrow-avro/examples/decode_kafka_stream.rs b/arrow-avro/examples/decode_kafka_stream.rs new file mode 100644 index 000000000000..46309ecd0cb9 --- /dev/null +++ b/arrow-avro/examples/decode_kafka_stream.rs @@ -0,0 +1,233 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Decode **Confluent Schema Registry - framed** Avro messages into Arrow [`RecordBatch`]es, +//! resolving **older writer schemas** against a **current reader schema** without adding +//! any new reader‑only fields. +//! +//! What this example shows: +//! * A **reader schema** for the current topic version with fields: `{ id: long, name: string }`. +//! * Two older **writer schemas** (Confluent IDs **0** and **1**): +//! - v0: `{ id: int, name: string }` (older type for `id`) +//! - v1: `{ id: long, name: string, email: ["null","string"] }` (extra writer field `email`) +//! * Streaming decode with `ReaderBuilder::with_reader_schema(...)` so that: +//! - v0's `id:int` is **promoted** to `long` for the reader +//! - v1's extra `email` field is **ignored** by the reader (projection) +//! +//! Wire format reminder (message value bytes): +//! `0x00` magic byte + 4‑byte **big‑endian** schema ID + Avro **binary** body. +//! + +use arrow_array::{Int64Array, RecordBatch, StringArray}; +use arrow_avro::reader::ReaderBuilder; +use arrow_avro::schema::{ + AvroSchema, CONFLUENT_MAGIC, Fingerprint, FingerprintAlgorithm, SchemaStore, +}; +use arrow_schema::ArrowError; + +fn encode_long(value: i64, out: &mut Vec) { + let mut n = ((value << 1) ^ (value >> 63)) as u64; + while (n & !0x7F) != 0 { + out.push(((n as u8) & 0x7F) | 0x80); + n >>= 7; + } + out.push(n as u8); +} + +fn encode_len(len: usize, out: &mut Vec) { + encode_long(len as i64, out) +} + +fn encode_string(s: &str, out: &mut Vec) { + encode_len(s.len(), out); + out.extend_from_slice(s.as_bytes()); +} + +fn encode_union_index(index: i64, out: &mut Vec) { + encode_long(index, out); +} + +// Writer v0 (ID=0): +// {"type":"record","name":"User","fields":[ +// {"name":"id","type":"int"}, +// {"name":"name","type":"string"}]} +fn encode_user_v0_body(id: i32, name: &str) -> Vec { + let mut v = Vec::with_capacity(16 + name.len()); + encode_long(id as i64, &mut v); + encode_string(name, &mut v); + v +} + +// Writer v1 (ID=1): +// {"type":"record","name":"User","fields":[ +// {"name":"id","type":"long"}, +// {"name":"name","type":"string"}, +// {"name":"email","type":["null","string"],"default":null}]} +fn encode_user_v1_body(id: i64, name: &str, email: Option<&str>) -> Vec { + let mut v = Vec::with_capacity(24 + name.len() + email.map(|s| s.len()).unwrap_or(0)); + encode_long(id, &mut v); // id: long + encode_string(name, &mut v); // name: string + match email { + None => { + // union index 0 => null + encode_union_index(0, &mut v); + // no value bytes follow for null + } + Some(s) => { + // union index 1 => string, followed by the string payload + encode_union_index(1, &mut v); + encode_string(s, &mut v); + } + } + v +} + +fn frame_confluent(id_be: u32, body: &[u8]) -> Vec { + let mut out = Vec::with_capacity(5 + body.len()); + out.extend_from_slice(&CONFLUENT_MAGIC); // 0x00 + out.extend_from_slice(&id_be.to_be_bytes()); + out.extend_from_slice(body); + out +} + +fn print_arrow_schema(schema: &arrow_schema::Schema) { + println!("Resolved Arrow schema (via reader schema):"); + for (i, f) in schema.fields().iter().enumerate() { + println!( + " {i:>2}: {}: {:?} (nullable: {})", + f.name(), + f.data_type(), + f.is_nullable() + ); + } + if !schema.metadata.is_empty() { + println!(" metadata: {:?}", schema.metadata()); + } +} + +fn print_rows(batch: &RecordBatch) -> Result<(), ArrowError> { + let ids = batch + .column(0) + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::ComputeError("col 0 not Int64".into()))?; + let names = batch + .column(1) + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::ComputeError("col 1 not Utf8".into()))?; + for row in 0..batch.num_rows() { + let id = ids.value(row); + let name = names.value(row); + println!(" row {row}: id={id}, name={name}"); + } + Ok(()) +} + +fn main() -> Result<(), Box> { + // The current topic schema as a READER schema + let reader_schema = AvroSchema::new( + r#"{ + "type":"record","name":"User","fields":[ + {"name":"id","type":"long"}, + {"name":"name","type":"string"} + ]}"# + .to_string(), + ); + + // Two prior WRITER schemas versions under Confluent IDs 0 and 1 + let writer_v0 = AvroSchema::new( + r#"{ + "type":"record","name":"User","fields":[ + {"name":"id","type":"int"}, + {"name":"name","type":"string"} + ]}"# + .to_string(), + ); + let writer_v1 = AvroSchema::new( + r#"{ + "type":"record","name":"User","fields":[ + {"name":"id","type":"long"}, + {"name":"name","type":"string"}, + {"name":"email","type":["null","string"],"default":null} + ]}"# + .to_string(), + ); + + let id_v0: u32 = 0; + let id_v1: u32 = 1; + + // Confluent SchemaStore keyed by integer IDs (FingerprintAlgorithm::Id) + let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::Id); + store.set(Fingerprint::Id(id_v0), writer_v0.clone())?; + store.set(Fingerprint::Id(id_v1), writer_v1.clone())?; + + // Build a streaming Decoder with the READER schema + let mut decoder = ReaderBuilder::new() + .with_reader_schema(reader_schema) + .with_writer_schema_store(store) + .with_batch_size(8) // small batches for demo output + .build_decoder()?; + + // Print the resolved Arrow schema (derived from reader and writer) + let resolved = decoder.schema(); + print_arrow_schema(resolved.as_ref()); + println!(); + + // Simulate an interleaved Kafka stream (IDs 0 and 1) + // - v0: {id:int, name:string} --> reader: id promoted to long + // - v1: {id:long, name:string, email: ...} --> reader ignores extra field + let mut frames: Vec<(u32, Vec)> = Vec::new(); + + // Some v0 messages + for (i, name) in ["v0-alice", "v0-bob", "v0-carol"].iter().enumerate() { + let body = encode_user_v0_body(1000 + i as i32, name); + frames.push((id_v0, frame_confluent(id_v0, &body))); + } + + // Some v1 messages (may include optional email on the writer side) + let v1_rows = [ + (2001_i64, "v1-dave", Some("dave@example.com")), + (2002_i64, "v1-erin", None), + (2003_i64, "v1-frank", Some("frank@example.com")), + ]; + for (id, name, email) in v1_rows { + let body = encode_user_v1_body(id, name, email); + frames.push((id_v1, frame_confluent(id_v1, &body))); + } + + // Interleave to show mid-stream schema ID changes (0,1,0,1, ...) + frames.swap(1, 3); // crude interleave for demo + + // Decode frames as if they were Kafka record values + for (schema_id, frame) in frames { + println!("Decoding record framed with Confluent schema id = {schema_id}"); + let _consumed = decoder.decode(&frame)?; + while let Some(batch) = decoder.flush()? { + println!( + " -> Emitted batch: rows = {}, cols = {}", + batch.num_rows(), + batch.num_columns() + ); + print_rows(&batch)?; + } + println!(); + } + + println!("Done decoding Kafka-style stream with schema resolution (no reader-added fields)."); + Ok(()) +} diff --git a/arrow-avro/examples/read_ocf_with_resolution.rs b/arrow-avro/examples/read_ocf_with_resolution.rs new file mode 100644 index 000000000000..7367ba3cd5b0 --- /dev/null +++ b/arrow-avro/examples/read_ocf_with_resolution.rs @@ -0,0 +1,96 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Read an Avro **Object Container File (OCF)** using an inline **reader schema** +//! that differs from the writer schema, demonstrating Avro **schema resolution** +//! (field projection and legal type promotion) without ever fetching the writer +//! schema from the file. +//! +//! What this example does: +//! 1. Locates `/test/data/skippable_types.avro` (portable path). +//! 2. Defines an inline **reader schema** JSON: +//! * Projects a subset of fields from the writer schema, and +//! * Promotes `"int"` to `"long"` where applicable. +//! 3. Builds a `Reader` with `ReaderBuilder::with_reader_schema(...)` and prints batches. + +use std::fs::File; +use std::io::BufReader; +use std::path::PathBuf; + +use arrow_array::RecordBatch; +use arrow_avro::reader::ReaderBuilder; +use arrow_avro::schema::AvroSchema; + +fn default_ocf_path() -> PathBuf { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("test") + .join("data") + .join("skippable_types.avro") +} + +// A minimal reader schema compatible with the provided writer schema +const READER_SCHEMA_JSON: &str = r#" +{ + "type": "record", + "name": "SkippableTypesRecord", + "fields": [ + { "name": "boolean_field", "type": "boolean" }, + { "name": "int_field", "type": "long" }, + { "name": "long_field", "type": "long" }, + { "name": "string_field", "type": "string" }, + { "name": "nullable_nullfirst_field", "type": ["null", "long"] } + ] +} +"#; + +fn main() -> Result<(), Box> { + let ocf_path = default_ocf_path(); + let file = File::open(&ocf_path)?; + let reader_schema = AvroSchema::new(READER_SCHEMA_JSON.to_string()); + + let reader = ReaderBuilder::new() + .with_reader_schema(reader_schema) + .build(BufReader::new(file))?; + + let resolved_schema = reader.schema(); + println!( + "Reader-based decode: resolved Arrow schema with {} fields", + resolved_schema.fields().len() + ); + + // Iterate batches and print a brief summary + let mut total_batches = 0usize; + let mut total_rows = 0usize; + for next in reader { + let batch: RecordBatch = next?; + total_batches += 1; + total_rows += batch.num_rows(); + println!( + " Batch {:>3}: rows = {:>6}, cols = {:>2}", + total_batches, + batch.num_rows(), + batch.num_columns() + ); + } + + println!(); + println!("Done (with reader/writer schema resolution)."); + println!(" Batches : {total_batches}"); + println!(" Rows : {total_rows}"); + + Ok(()) +} diff --git a/arrow-avro/examples/read_with_utf8view.rs b/arrow-avro/examples/read_with_utf8view.rs index 2fa47820346b..85b07c8d033c 100644 --- a/arrow-avro/examples/read_with_utf8view.rs +++ b/arrow-avro/examples/read_with_utf8view.rs @@ -23,12 +23,10 @@ use std::env; use std::fs::File; use std::io::{BufReader, Seek, SeekFrom}; -use std::sync::Arc; use std::time::Instant; -use arrow_array::{ArrayRef, Int32Array, RecordBatch, StringArray, StringViewArray}; -use arrow_avro::reader::ReadOptions; -use arrow_schema::{ArrowError, DataType, Field, Schema}; +use arrow_array::{RecordBatch, StringArray, StringViewArray}; +use arrow_avro::reader::ReaderBuilder; fn main() -> Result<(), Box> { let args: Vec = env::args().collect(); @@ -41,22 +39,29 @@ fn main() -> Result<(), Box> { }; let file = File::open(file_path)?; - let mut reader = BufReader::new(file); + let mut file_for_view = file.try_clone()?; let start = Instant::now(); - let batch = read_avro_with_options(&mut reader, &ReadOptions::default())?; + let reader = BufReader::new(file); + let avro_reader = ReaderBuilder::new().build(reader)?; + let schema = avro_reader.schema(); + let batches: Vec = avro_reader.collect::>()?; let regular_duration = start.elapsed(); - reader.seek(SeekFrom::Start(0))?; - + file_for_view.seek(SeekFrom::Start(0))?; let start = Instant::now(); - let options = ReadOptions::default().with_utf8view(true); - let batch_view = read_avro_with_options(&mut reader, &options)?; + let reader_view = BufReader::new(file_for_view); + let avro_reader_view = ReaderBuilder::new() + .with_utf8_view(true) + .build(reader_view)?; + let batches_view: Vec = avro_reader_view.collect::>()?; let view_duration = start.elapsed(); - println!("Read {} rows from {}", batch.num_rows(), file_path); - println!("Reading with StringArray: {:?}", regular_duration); - println!("Reading with StringViewArray: {:?}", view_duration); + let num_rows = batches.iter().map(|b| b.num_rows()).sum::(); + + println!("Read {num_rows} rows from {file_path}"); + println!("Reading with StringArray: {regular_duration:?}"); + println!("Reading with StringViewArray: {view_duration:?}"); if regular_duration > view_duration { println!( @@ -70,7 +75,16 @@ fn main() -> Result<(), Box> { ); } - for (i, field) in batch.schema().fields().iter().enumerate() { + if batches.is_empty() { + println!("No data read from file."); + return Ok(()); + } + + // Inspect the first batch from each run to show the array types + let batch = &batches[0]; + let batch_view = &batches_view[0]; + + for (i, field) in schema.fields().iter().enumerate() { let col = batch.column(i); let col_view = batch_view.column(i); @@ -93,29 +107,3 @@ fn main() -> Result<(), Box> { Ok(()) } - -fn read_avro_with_options( - reader: &mut BufReader, - options: &ReadOptions, -) -> Result { - reader.get_mut().seek(SeekFrom::Start(0))?; - - let mock_schema = Schema::new(vec![ - Field::new("string_field", DataType::Utf8, false), - Field::new("int_field", DataType::Int32, false), - ]); - - let string_data = vec!["avro1", "avro2", "avro3", "avro4", "avro5"]; - let int_data = vec![1, 2, 3, 4, 5]; - - let string_array: ArrayRef = if options.use_utf8view() { - Arc::new(StringViewArray::from(string_data)) - } else { - Arc::new(StringArray::from(string_data)) - }; - - let int_array: ArrayRef = Arc::new(Int32Array::from(int_data)); - - RecordBatch::try_new(Arc::new(mock_schema), vec![string_array, int_array]) - .map_err(|e| ArrowError::ComputeError(format!("Failed to create record batch: {}", e))) -} diff --git a/arrow-avro/examples/write_avro_ocf.rs b/arrow-avro/examples/write_avro_ocf.rs new file mode 100644 index 000000000000..5bdca0de7a3d --- /dev/null +++ b/arrow-avro/examples/write_avro_ocf.rs @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! # Write an Avro Object Container File (OCF) from an Arrow `RecordBatch` +//! +//! This example builds a small Arrow `RecordBatch` and persists it to an +//! **Avro Object Container File (OCF)** using +//! `arrow_avro::writer::{Writer, WriterBuilder}`. +//! +//! ## What this example does +//! - Define an Arrow schema with supported types (`Int64`, `Utf8`, `Boolean`, +//! `Float64`, `Binary`, and `Timestamp (Microsecond, "UTC")`). +//! - Constructs arrays and a `RecordBatch`, ensuring each column’s data type +//! **exactly** matches the schema (timestamps include the `"UTC"` timezone). +//! - Writes a single batch to `target/write_avro_ocf_example.avro` as an OCF, +//! using Snappy block compression (you can disable or change the codec). +//! - Prints the file’s 16‑byte sync marker (used by OCF to delimit blocks). + +use std::fs::File; +use std::io::BufWriter; +use std::sync::Arc; + +use arrow_array::{ + ArrayRef, BinaryArray, BooleanArray, Float64Array, Int64Array, RecordBatch, StringArray, + TimestampMicrosecondArray, +}; +use arrow_avro::compression::CompressionCodec; +use arrow_avro::writer::format::AvroOcfFormat; +use arrow_avro::writer::{Writer, WriterBuilder}; +use arrow_schema::{DataType, Field, Schema, TimeUnit}; + +fn main() -> Result<(), Box> { + // Arrow schema + // id: Int64 (non-null) + // name: Utf8 (nullable) + // active: Boolean (non-null) + // score: Float64 (nullable) + // payload: Binary (nullable) + // created_at: Timestamp(Microsecond, Some("UTC")) (non-null) + let schema = Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, true), + Field::new("active", DataType::Boolean, false), + Field::new("score", DataType::Float64, true), + Field::new("payload", DataType::Binary, true), + Field::new( + "created_at", + DataType::Timestamp(TimeUnit::Microsecond, Some(Arc::from("UTC".to_string()))), + false, + ), + ]); + + let schema_ref = Arc::new(schema.clone()); + let ids = Int64Array::from(vec![1_i64, 2, 3]); + let names = StringArray::from(vec![Some("alpha"), None, Some("gamma")]); + let active = BooleanArray::from(vec![true, false, true]); + let scores = Float64Array::from(vec![Some(1.5_f64), None, Some(3.0)]); + + // BinaryArray: include a null + let payload = BinaryArray::from_opt_vec(vec![Some(&b"abc"[..]), None, Some(&[0u8, 1, 2][..])]); + + // Timestamp in microseconds since UNIX epoch + let created_at = TimestampMicrosecondArray::from(vec![ + Some(1_722_000_000_000_000_i64), + Some(1_722_000_123_456_000_i64), + Some(1_722_000_999_999_000_i64), + ]) + .with_timezone("UTC".to_string()); + + let columns: Vec = vec![ + Arc::new(ids), + Arc::new(names), + Arc::new(active), + Arc::new(scores), + Arc::new(payload), + Arc::new(created_at), + ]; + + let batch = RecordBatch::try_new(schema_ref, columns)?; + + // Build an OCF writer with optional compression + let out_path = "target/write_avro_ocf_example.avro"; + let file = File::create(out_path)?; + let mut writer: Writer<_, AvroOcfFormat> = WriterBuilder::new(schema) + .with_compression(Some(CompressionCodec::Snappy)) + .build(BufWriter::new(file))?; + + // Write a single batch (use `write_batches` for multiple) + writer.write(&batch)?; + writer.finish()?; // flush and finalize + + if let Some(sync) = writer.sync_marker() { + println!("Wrote OCF to {out_path} (sync marker: {:02x?})", &sync[..]); + } else { + println!("Wrote OCF to {out_path}"); + } + + Ok(()) +} diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 70f162f1471d..04ef87d7ef20 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -15,38 +15,171 @@ // specific language governing permissions and limitations // under the License. -use crate::schema::{Attributes, ComplexType, PrimitiveType, Record, Schema, TypeName}; +//! Codec for Mapping Avro and Arrow types. + +use crate::schema::{ + AVRO_ENUM_SYMBOLS_METADATA_KEY, AVRO_FIELD_DEFAULT_METADATA_KEY, AVRO_NAME_METADATA_KEY, + AVRO_NAMESPACE_METADATA_KEY, Array, Attributes, ComplexType, Enum, Fixed, Map, Nullability, + PrimitiveType, Record, Schema, Type, TypeName, make_full_name, +}; use arrow_schema::{ - ArrowError, DataType, Field, FieldRef, Fields, IntervalUnit, SchemaBuilder, SchemaRef, TimeUnit, + ArrowError, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DataType, Field, Fields, + IntervalUnit, TimeUnit, UnionFields, UnionMode, }; -use std::borrow::Cow; -use std::collections::HashMap; +#[cfg(feature = "small_decimals")] +use arrow_schema::{DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION}; +use indexmap::IndexMap; +use serde_json::Value; +use std::collections::hash_map::Entry; +use std::collections::{HashMap, HashSet}; +use std::fmt; +use std::fmt::Display; use std::sync::Arc; +use strum_macros::AsRefStr; + +/// Contains information about how to resolve differences between a writer's and a reader's schema. +#[derive(Debug, Clone, PartialEq)] +pub(crate) enum ResolutionInfo { + /// Indicates that the writer's type should be promoted to the reader's type. + Promotion(Promotion), + /// Indicates that a default value should be used for a field. + DefaultValue(AvroLiteral), + /// Provides mapping information for resolving enums. + EnumMapping(EnumMapping), + /// Provides resolution information for record fields. + Record(ResolvedRecord), + /// Provides mapping and shape info for resolving unions. + Union(ResolvedUnion), +} + +/// Represents a literal Avro value. +/// +/// This is used to represent default values in an Avro schema. +#[derive(Debug, Clone, PartialEq)] +pub(crate) enum AvroLiteral { + /// Represents a null value. + Null, + /// Represents a boolean value. + Boolean(bool), + /// Represents an integer value. + Int(i32), + /// Represents a long value. + Long(i64), + /// Represents a float value. + Float(f32), + /// Represents a double value. + Double(f64), + /// Represents a bytes value. + Bytes(Vec), + /// Represents a string value. + String(String), + /// Represents an enum symbol. + Enum(String), + /// Represents a JSON array default for an Avro array, containing element literals. + Array(Vec), + /// Represents a JSON object default for an Avro map/struct, mapping string keys to value literals. + Map(IndexMap), +} + +/// Contains the necessary information to resolve a writer's record against a reader's record schema. +#[derive(Debug, Clone, PartialEq)] +pub(crate) struct ResolvedRecord { + /// Maps a writer's field index to the corresponding reader's field index. + /// `None` if the writer's field is not present in the reader's schema. + pub(crate) writer_to_reader: Arc<[Option]>, + /// A list of indices in the reader's schema for fields that have a default value. + pub(crate) default_fields: Arc<[usize]>, + /// For fields present in the writer's schema but not the reader's, this stores their data type. + /// This is needed to correctly skip over these fields during deserialization. + pub(crate) skip_fields: Arc<[Option]>, +} + +/// Defines the type of promotion to be applied during schema resolution. +/// +/// Schema resolution may require promoting a writer's data type to a reader's data type. +/// For example, an `int` can be promoted to a `long`, `float`, or `double`. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub(crate) enum Promotion { + /// Direct read with no data type promotion. + Direct, + /// Promotes an `int` to a `long`. + IntToLong, + /// Promotes an `int` to a `float`. + IntToFloat, + /// Promotes an `int` to a `double`. + IntToDouble, + /// Promotes a `long` to a `float`. + LongToFloat, + /// Promotes a `long` to a `double`. + LongToDouble, + /// Promotes a `float` to a `double`. + FloatToDouble, + /// Promotes a `string` to `bytes`. + StringToBytes, + /// Promotes `bytes` to a `string`. + BytesToString, +} + +impl Display for Promotion { + fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Direct => write!(formatter, "Direct"), + Self::IntToLong => write!(formatter, "Int->Long"), + Self::IntToFloat => write!(formatter, "Int->Float"), + Self::IntToDouble => write!(formatter, "Int->Double"), + Self::LongToFloat => write!(formatter, "Long->Float"), + Self::LongToDouble => write!(formatter, "Long->Double"), + Self::FloatToDouble => write!(formatter, "Float->Double"), + Self::StringToBytes => write!(formatter, "String->Bytes"), + Self::BytesToString => write!(formatter, "Bytes->String"), + } + } +} + +/// Information required to resolve a writer union against a reader union (or single type). +#[derive(Debug, Clone, PartialEq)] +pub(crate) struct ResolvedUnion { + /// For each writer branch index, the reader branch index and how to read it. + /// `None` means the writer branch doesn't resolve against the reader. + pub(crate) writer_to_reader: Arc<[Option<(usize, Promotion)>]>, + /// Whether the writer schema at this site is a union + pub(crate) writer_is_union: bool, + /// Whether the reader schema at this site is a union + pub(crate) reader_is_union: bool, +} -/// Avro types are not nullable, with nullability instead encoded as a union -/// where one of the variants is the null type. +/// Holds the mapping information for resolving Avro enums. /// -/// To accommodate this we special case two-variant unions where one of the -/// variants is the null type, and use this to derive arrow's notion of nullability -#[derive(Debug, Copy, Clone)] -pub enum Nullability { - /// The nulls are encoded as the first union variant - NullFirst, - /// The nulls are encoded as the second union variant - NullSecond, +/// When resolving schemas, the writer's enum symbols must be mapped to the reader's symbols. +#[derive(Debug, Clone, PartialEq, Eq)] +pub(crate) struct EnumMapping { + /// A mapping from the writer's symbol index to the reader's symbol index. + pub(crate) mapping: Arc<[i32]>, + /// The index to use for a writer's symbol that is not present in the reader's enum + /// and a default value is specified in the reader's schema. + pub(crate) default_index: i32, +} + +#[cfg(feature = "canonical_extension_types")] +fn with_extension_type(codec: &Codec, field: Field) -> Field { + match codec { + Codec::Uuid => field.with_extension_type(arrow_schema::extension::Uuid), + _ => field, + } } /// An Avro datatype mapped to the arrow data model -#[derive(Debug, Clone)] -pub struct AvroDataType { +#[derive(Debug, Clone, PartialEq)] +pub(crate) struct AvroDataType { nullability: Option, metadata: HashMap, codec: Codec, + pub(crate) resolution: Option, } impl AvroDataType { /// Create a new [`AvroDataType`] with the given parts. - pub fn new( + pub(crate) fn new( codec: Codec, metadata: HashMap, nullability: Option, @@ -55,20 +188,49 @@ impl AvroDataType { codec, metadata, nullability, + resolution: None, + } + } + + #[inline] + fn new_with_resolution( + codec: Codec, + metadata: HashMap, + nullability: Option, + resolution: Option, + ) -> Self { + Self { + codec, + metadata, + nullability, + resolution, } } /// Returns an arrow [`Field`] with the given name - pub fn field_with_name(&self, name: &str) -> Field { - let d = self.codec.data_type(); - Field::new(name, d, self.nullability.is_some()).with_metadata(self.metadata.clone()) + pub(crate) fn field_with_name(&self, name: &str) -> Field { + let mut nullable = self.nullability.is_some(); + if !nullable { + if let Codec::Union(children, _, _) = self.codec() { + // If any encoded branch is `null`, mark field as nullable + if children.iter().any(|c| matches!(c.codec(), Codec::Null)) { + nullable = true; + } + } + } + let data_type = self.codec.data_type(); + let field = Field::new(name, data_type, nullable).with_metadata(self.metadata.clone()); + #[cfg(feature = "canonical_extension_types")] + return with_extension_type(&self.codec, field); + #[cfg(not(feature = "canonical_extension_types"))] + field } /// Returns a reference to the codec used by this data type /// /// The codec determines how Avro data is encoded and mapped to Arrow data types. /// This is useful when we need to inspect or use the specific encoding of a field. - pub fn codec(&self) -> &Codec { + pub(crate) fn codec(&self) -> &Codec { &self.codec } @@ -79,26 +241,266 @@ impl AvroDataType { /// - `Some(Nullability::NullFirst)` - Nulls are encoded as the first union variant /// - `Some(Nullability::NullSecond)` - Nulls are encoded as the second union variant /// - `None` - The type is not nullable - pub fn nullability(&self) -> Option { + pub(crate) fn nullability(&self) -> Option { self.nullability } + + #[inline] + fn parse_default_literal(&self, default_json: &Value) -> Result { + fn expect_string<'v>( + default_json: &'v Value, + data_type: &str, + ) -> Result<&'v str, ArrowError> { + match default_json { + Value::String(s) => Ok(s.as_str()), + _ => Err(ArrowError::SchemaError(format!( + "Default value must be a JSON string for {data_type}" + ))), + } + } + + fn parse_bytes_default( + default_json: &Value, + expected_len: Option, + ) -> Result, ArrowError> { + let s = expect_string(default_json, "bytes/fixed logical types")?; + let mut out = Vec::with_capacity(s.len()); + for ch in s.chars() { + let cp = ch as u32; + if cp > 0xFF { + return Err(ArrowError::SchemaError(format!( + "Invalid codepoint U+{cp:04X} in bytes/fixed default; must be ≤ 0xFF" + ))); + } + out.push(cp as u8); + } + if let Some(len) = expected_len { + if out.len() != len { + return Err(ArrowError::SchemaError(format!( + "Default length {} does not match expected fixed size {len}", + out.len(), + ))); + } + } + Ok(out) + } + + fn parse_json_i64(default_json: &Value, data_type: &str) -> Result { + match default_json { + Value::Number(n) => n.as_i64().ok_or_else(|| { + ArrowError::SchemaError(format!("Default {data_type} must be an integer")) + }), + _ => Err(ArrowError::SchemaError(format!( + "Default {data_type} must be a JSON integer" + ))), + } + } + + fn parse_json_f64(default_json: &Value, data_type: &str) -> Result { + match default_json { + Value::Number(n) => n.as_f64().ok_or_else(|| { + ArrowError::SchemaError(format!("Default {data_type} must be a number")) + }), + _ => Err(ArrowError::SchemaError(format!( + "Default {data_type} must be a JSON number" + ))), + } + } + + // Handle JSON nulls per-spec: allowed only for `null` type or unions with null FIRST + if default_json.is_null() { + return match self.codec() { + Codec::Null => Ok(AvroLiteral::Null), + Codec::Union(encodings, _, _) if !encodings.is_empty() + && matches!(encodings[0].codec(), Codec::Null) => + { + Ok(AvroLiteral::Null) + } + _ if self.nullability() == Some(Nullability::NullFirst) => Ok(AvroLiteral::Null), + _ => Err(ArrowError::SchemaError( + "JSON null default is only valid for `null` type or for a union whose first branch is `null`" + .to_string(), + )), + }; + } + let lit = match self.codec() { + Codec::Null => { + return Err(ArrowError::SchemaError( + "Default for `null` type must be JSON null".to_string(), + )); + } + Codec::Boolean => match default_json { + Value::Bool(b) => AvroLiteral::Boolean(*b), + _ => { + return Err(ArrowError::SchemaError( + "Boolean default must be a JSON boolean".to_string(), + )); + } + }, + Codec::Int32 | Codec::Date32 | Codec::TimeMillis => { + let i = parse_json_i64(default_json, "int")?; + if i < i32::MIN as i64 || i > i32::MAX as i64 { + return Err(ArrowError::SchemaError(format!( + "Default int {i} out of i32 range" + ))); + } + AvroLiteral::Int(i as i32) + } + Codec::Int64 + | Codec::TimeMicros + | Codec::TimestampMillis(_) + | Codec::TimestampMicros(_) + | Codec::TimestampNanos(_) => AvroLiteral::Long(parse_json_i64(default_json, "long")?), + #[cfg(feature = "avro_custom_types")] + Codec::DurationNanos + | Codec::DurationMicros + | Codec::DurationMillis + | Codec::DurationSeconds => AvroLiteral::Long(parse_json_i64(default_json, "long")?), + Codec::Float32 => { + let f = parse_json_f64(default_json, "float")?; + if !f.is_finite() || f < f32::MIN as f64 || f > f32::MAX as f64 { + return Err(ArrowError::SchemaError(format!( + "Default float {f} out of f32 range or not finite" + ))); + } + AvroLiteral::Float(f as f32) + } + Codec::Float64 => AvroLiteral::Double(parse_json_f64(default_json, "double")?), + Codec::Utf8 | Codec::Utf8View | Codec::Uuid => { + AvroLiteral::String(expect_string(default_json, "string/uuid")?.to_string()) + } + Codec::Binary => AvroLiteral::Bytes(parse_bytes_default(default_json, None)?), + Codec::Fixed(sz) => { + AvroLiteral::Bytes(parse_bytes_default(default_json, Some(*sz as usize))?) + } + Codec::Decimal(_, _, fixed_size) => { + AvroLiteral::Bytes(parse_bytes_default(default_json, *fixed_size)?) + } + Codec::Enum(symbols) => { + let s = expect_string(default_json, "enum")?; + if symbols.iter().any(|sym| sym == s) { + AvroLiteral::Enum(s.to_string()) + } else { + return Err(ArrowError::SchemaError(format!( + "Default enum symbol {s:?} not found in reader enum symbols" + ))); + } + } + Codec::Interval => AvroLiteral::Bytes(parse_bytes_default(default_json, Some(12))?), + Codec::List(item_dt) => match default_json { + Value::Array(items) => AvroLiteral::Array( + items + .iter() + .map(|v| item_dt.parse_default_literal(v)) + .collect::>()?, + ), + _ => { + return Err(ArrowError::SchemaError( + "Default value must be a JSON array for Avro array type".to_string(), + )); + } + }, + Codec::Map(val_dt) => match default_json { + Value::Object(map) => { + let mut out = IndexMap::with_capacity(map.len()); + for (k, v) in map { + out.insert(k.clone(), val_dt.parse_default_literal(v)?); + } + AvroLiteral::Map(out) + } + _ => { + return Err(ArrowError::SchemaError( + "Default value must be a JSON object for Avro map type".to_string(), + )); + } + }, + Codec::Struct(fields) => match default_json { + Value::Object(obj) => { + let mut out: IndexMap = + IndexMap::with_capacity(fields.len()); + for f in fields.as_ref() { + let name = f.name().to_string(); + if let Some(sub) = obj.get(&name) { + out.insert(name, f.data_type().parse_default_literal(sub)?); + } else { + // Cache metadata lookup once + let stored_default = + f.data_type().metadata.get(AVRO_FIELD_DEFAULT_METADATA_KEY); + if stored_default.is_none() + && f.data_type().nullability() == Some(Nullability::default()) + { + out.insert(name, AvroLiteral::Null); + } else if let Some(default_json) = stored_default { + let v: Value = + serde_json::from_str(default_json).map_err(|e| { + ArrowError::SchemaError(format!( + "Failed to parse stored subfield default JSON for '{}': {e}", + f.name(), + )) + })?; + out.insert(name, f.data_type().parse_default_literal(&v)?); + } else { + return Err(ArrowError::SchemaError(format!( + "Record default missing required subfield '{}' with non-nullable type {:?}", + f.name(), + f.data_type().codec() + ))); + } + } + } + AvroLiteral::Map(out) + } + _ => { + return Err(ArrowError::SchemaError( + "Default value for record/struct must be a JSON object".to_string(), + )); + } + }, + Codec::Union(encodings, _, _) => { + let Some(default_encoding) = encodings.first() else { + return Err(ArrowError::SchemaError( + "Union with no branches cannot have a default".to_string(), + )); + }; + default_encoding.parse_default_literal(default_json)? + } + #[cfg(feature = "avro_custom_types")] + Codec::RunEndEncoded(values, _) => values.parse_default_literal(default_json)?, + }; + Ok(lit) + } + + fn store_default(&mut self, default_json: &Value) -> Result<(), ArrowError> { + let json_text = serde_json::to_string(default_json).map_err(|e| { + ArrowError::ParseError(format!("Failed to serialize default to JSON: {e}")) + })?; + self.metadata + .insert(AVRO_FIELD_DEFAULT_METADATA_KEY.to_string(), json_text); + Ok(()) + } + + fn parse_and_store_default(&mut self, default_json: &Value) -> Result { + let lit = self.parse_default_literal(default_json)?; + self.store_default(default_json)?; + Ok(lit) + } } /// A named [`AvroDataType`] -#[derive(Debug, Clone)] -pub struct AvroField { +#[derive(Debug, Clone, PartialEq)] +pub(crate) struct AvroField { name: String, data_type: AvroDataType, } impl AvroField { /// Returns the arrow [`Field`] - pub fn field(&self) -> Field { + pub(crate) fn field(&self) -> Field { self.data_type.field_with_name(&self.name) } /// Returns the [`AvroDataType`] - pub fn data_type(&self) -> &AvroDataType { + pub(crate) fn data_type(&self) -> &AvroDataType { &self.data_type } @@ -110,7 +512,7 @@ impl AvroField { /// /// Returns a new `AvroField` with the same structure, but with string types /// converted to use `Utf8View` instead of `Utf8`. - pub fn with_utf8view(&self) -> Self { + pub(crate) fn with_utf8view(&self) -> Self { let mut field = self.clone(); if let Codec::Utf8 = field.data_type.codec { field.data_type.codec = Codec::Utf8View; @@ -122,7 +524,7 @@ impl AvroField { /// /// This is the field name as defined in the Avro schema. /// It's used to identify fields within a record structure. - pub fn name(&self) -> &str { + pub(crate) fn name(&self) -> &str { &self.name } } @@ -133,8 +535,8 @@ impl<'a> TryFrom<&Schema<'a>> for AvroField { fn try_from(schema: &Schema<'a>) -> Result { match schema { Schema::Complex(ComplexType::Record(r)) => { - let mut resolver = Resolver::default(); - let data_type = make_data_type(schema, None, &mut resolver, false)?; + let mut resolver = Maker::new(false, false); + let data_type = resolver.make_data_type(schema, None, None)?; Ok(AvroField { data_type, name: r.name.to_string(), @@ -147,11 +549,73 @@ impl<'a> TryFrom<&Schema<'a>> for AvroField { } } +/// Builder for an [`AvroField`] +#[derive(Debug)] +pub(crate) struct AvroFieldBuilder<'a> { + writer_schema: &'a Schema<'a>, + reader_schema: Option<&'a Schema<'a>>, + use_utf8view: bool, + strict_mode: bool, +} + +impl<'a> AvroFieldBuilder<'a> { + /// Creates a new [`AvroFieldBuilder`] for a given writer schema. + pub(crate) fn new(writer_schema: &'a Schema<'a>) -> Self { + Self { + writer_schema, + reader_schema: None, + use_utf8view: false, + strict_mode: false, + } + } + + /// Sets the reader schema for schema resolution. + /// + /// If a reader schema is provided, the builder will produce a resolved `AvroField` + /// that can handle differences between the writer's and reader's schemas. + #[inline] + pub(crate) fn with_reader_schema(mut self, reader_schema: &'a Schema<'a>) -> Self { + self.reader_schema = Some(reader_schema); + self + } + + /// Enable or disable Utf8View support + pub(crate) fn with_utf8view(mut self, use_utf8view: bool) -> Self { + self.use_utf8view = use_utf8view; + self + } + + /// Enable or disable strict mode. + pub(crate) fn with_strict_mode(mut self, strict_mode: bool) -> Self { + self.strict_mode = strict_mode; + self + } + + /// Build an [`AvroField`] from the builder + pub(crate) fn build(self) -> Result { + match self.writer_schema { + Schema::Complex(ComplexType::Record(r)) => { + let mut resolver = Maker::new(self.use_utf8view, self.strict_mode); + let data_type = + resolver.make_data_type(self.writer_schema, self.reader_schema, None)?; + Ok(AvroField { + name: r.name.to_string(), + data_type, + }) + } + _ => Err(ArrowError::ParseError(format!( + "Expected a Record schema to build an AvroField, but got {:?}", + self.writer_schema + ))), + } + } +} + /// An Avro encoding /// /// -#[derive(Debug, Clone)] -pub enum Codec { +#[derive(Debug, Clone, PartialEq)] +pub(crate) enum Codec { /// Represents Avro null type, maps to Arrow's Null data type Null, /// Represents Avro boolean type, maps to Arrow's Boolean data type @@ -189,9 +653,27 @@ pub enum Codec { /// Maps to Arrow's Timestamp(TimeUnit::Microsecond) data type /// The boolean parameter indicates whether the timestamp has a UTC timezone (true) or is local time (false) TimestampMicros(bool), + /// Represents Avro timestamp-nanos or local-timestamp-nanos logical type + /// + /// Maps to Arrow's Timestamp(TimeUnit::Nanosecond) data type + /// The boolean parameter indicates whether the timestamp has a UTC timezone (true) or is local time (false) + TimestampNanos(bool), /// Represents Avro fixed type, maps to Arrow's FixedSizeBinary data type /// The i32 parameter indicates the fixed binary size Fixed(i32), + /// Represents Avro decimal type, maps to Arrow's Decimal32, Decimal64, Decimal128, or Decimal256 data types + /// + /// The fields are `(precision, scale, fixed_size)`. + /// - `precision` (`usize`): Total number of digits. + /// - `scale` (`Option`): Number of fractional digits. + /// - `fixed_size` (`Option`): Size in bytes if backed by a `fixed` type, otherwise `None`. + Decimal(usize, Option, Option), + /// Represents Avro Uuid type, a FixedSizeBinary with a length of 16. + Uuid, + /// Represents an Avro enum, maps to Arrow's Dictionary(Int32, Utf8) type. + /// + /// The enclosed value contains the enum's symbols. + Enum(Arc<[String]>), /// Represents Avro array type, maps to Arrow's List data type List(Arc), /// Represents Avro record type, maps to Arrow's Struct data type @@ -200,6 +682,22 @@ pub enum Codec { Map(Arc), /// Represents Avro duration logical type, maps to Arrow's Interval(IntervalUnit::MonthDayNano) data type Interval, + /// Represents Avro union type, maps to Arrow's Union data type + Union(Arc<[AvroDataType]>, UnionFields, UnionMode), + /// Represents Avro custom logical type to map to Arrow Duration(TimeUnit::Nanosecond) + #[cfg(feature = "avro_custom_types")] + DurationNanos, + /// Represents Avro custom logical type to map to Arrow Duration(TimeUnit::Microsecond) + #[cfg(feature = "avro_custom_types")] + DurationMicros, + /// Represents Avro custom logical type to map to Arrow Duration(TimeUnit::Millisecond) + #[cfg(feature = "avro_custom_types")] + DurationMillis, + /// Represents Avro custom logical type to map to Arrow Duration(TimeUnit::Second) + #[cfg(feature = "avro_custom_types")] + DurationSeconds, + #[cfg(feature = "avro_custom_types")] + RunEndEncoded(Arc, u8), } impl Codec { @@ -223,16 +721,45 @@ impl Codec { Self::TimestampMicros(is_utc) => { DataType::Timestamp(TimeUnit::Microsecond, is_utc.then(|| "+00:00".into())) } + Self::TimestampNanos(is_utc) => { + DataType::Timestamp(TimeUnit::Nanosecond, is_utc.then(|| "+00:00".into())) + } Self::Interval => DataType::Interval(IntervalUnit::MonthDayNano), Self::Fixed(size) => DataType::FixedSizeBinary(*size), + Self::Decimal(precision, scale, _size) => { + let p = *precision as u8; + let s = scale.unwrap_or(0) as i8; + #[cfg(feature = "small_decimals")] + { + if *precision <= DECIMAL32_MAX_PRECISION as usize { + DataType::Decimal32(p, s) + } else if *precision <= DECIMAL64_MAX_PRECISION as usize { + DataType::Decimal64(p, s) + } else if *precision <= DECIMAL128_MAX_PRECISION as usize { + DataType::Decimal128(p, s) + } else { + DataType::Decimal256(p, s) + } + } + #[cfg(not(feature = "small_decimals"))] + { + if *precision <= DECIMAL128_MAX_PRECISION as usize { + DataType::Decimal128(p, s) + } else { + DataType::Decimal256(p, s) + } + } + } + Self::Uuid => DataType::FixedSizeBinary(16), + Self::Enum(_) => { + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) + } Self::List(f) => { DataType::List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME))) } Self::Struct(f) => DataType::Struct(f.iter().map(|x| x.field()).collect()), Self::Map(value_type) => { - let val_dt = value_type.codec.data_type(); - let val_field = Field::new("value", val_dt, value_type.nullability.is_some()) - .with_metadata(value_type.metadata.clone()); + let val_field = value_type.field_with_name("value"); DataType::Map( Arc::new(Field::new( "entries", @@ -245,8 +772,48 @@ impl Codec { false, ) } + Self::Union(_, fields, mode) => DataType::Union(fields.clone(), *mode), + #[cfg(feature = "avro_custom_types")] + Self::DurationNanos => DataType::Duration(TimeUnit::Nanosecond), + #[cfg(feature = "avro_custom_types")] + Self::DurationMicros => DataType::Duration(TimeUnit::Microsecond), + #[cfg(feature = "avro_custom_types")] + Self::DurationMillis => DataType::Duration(TimeUnit::Millisecond), + #[cfg(feature = "avro_custom_types")] + Self::DurationSeconds => DataType::Duration(TimeUnit::Second), + #[cfg(feature = "avro_custom_types")] + Self::RunEndEncoded(values, bits) => { + let run_ends_dt = match *bits { + 16 => DataType::Int16, + 32 => DataType::Int32, + 64 => DataType::Int64, + _ => unreachable!(), + }; + DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", run_ends_dt, false)), + Arc::new(Field::new("values", values.codec().data_type(), true)), + ) + } + } + } + + /// Converts a string codec to use Utf8View if requested + /// + /// The conversion only happens if both: + /// 1. `use_utf8view` is true + /// 2. The codec is currently `Utf8` + pub(crate) fn with_utf8view(self, use_utf8view: bool) -> Self { + if use_utf8view && matches!(self, Self::Utf8) { + Self::Utf8View + } else { + self } } + + #[inline] + fn union_field_name(&self) -> String { + UnionFieldKind::from(self).as_ref().to_owned() + } } impl From for Codec { @@ -264,34 +831,175 @@ impl From for Codec { } } -impl Codec { - /// Converts a string codec to use Utf8View if requested - /// - /// The conversion only happens if both: - /// 1. `use_utf8view` is true - /// 2. The codec is currently `Utf8` - /// - /// # Example - /// ``` - /// # use arrow_avro::codec::Codec; - /// let utf8_codec1 = Codec::Utf8; - /// let utf8_codec2 = Codec::Utf8; - /// - /// // Convert to Utf8View - /// let view_codec = utf8_codec1.with_utf8view(true); - /// assert!(matches!(view_codec, Codec::Utf8View)); - /// - /// // Don't convert if use_utf8view is false - /// let unchanged_codec = utf8_codec2.with_utf8view(false); - /// assert!(matches!(unchanged_codec, Codec::Utf8)); - /// ``` - pub fn with_utf8view(self, use_utf8view: bool) -> Self { - if use_utf8view && matches!(self, Self::Utf8) { - Self::Utf8View - } else { - self +/// Compute the exact maximum base‑10 precision that fits in `n` bytes for Avro +/// `fixed` decimals stored as two's‑complement unscaled integers (big‑endian). +/// +/// Per Avro spec (Decimal logical type), for a fixed length `n`: +/// max precision = ⌊log₁₀(2^(8n − 1) − 1)⌋. +/// +/// This function returns `None` if `n` is 0 or greater than 32 (Arrow supports +/// Decimal256, which is 32 bytes and has max precision 76). +const fn max_precision_for_fixed_bytes(n: usize) -> Option { + // Precomputed exact table for n = 1..=32 + // 1:2, 2:4, 3:6, 4:9, 5:11, 6:14, 7:16, 8:18, 9:21, 10:23, 11:26, 12:28, + // 13:31, 14:33, 15:35, 16:38, 17:40, 18:43, 19:45, 20:47, 21:50, 22:52, + // 23:55, 24:57, 25:59, 26:62, 27:64, 28:67, 29:69, 30:71, 31:74, 32:76 + const MAX_P: [usize; 32] = [ + 2, 4, 6, 9, 11, 14, 16, 18, 21, 23, 26, 28, 31, 33, 35, 38, 40, 43, 45, 47, 50, 52, 55, 57, + 59, 62, 64, 67, 69, 71, 74, 76, + ]; + match n { + 1..=32 => Some(MAX_P[n - 1]), + _ => None, + } +} + +fn parse_decimal_attributes( + attributes: &Attributes, + fallback_size: Option, + precision_required: bool, +) -> Result<(usize, usize, Option), ArrowError> { + let precision = attributes + .additional + .get("precision") + .and_then(|v| v.as_u64()) + .or(if precision_required { None } else { Some(10) }) + .ok_or_else(|| ArrowError::ParseError("Decimal requires precision".to_string()))? + as usize; + let scale = attributes + .additional + .get("scale") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize; + let size = attributes + .additional + .get("size") + .and_then(|v| v.as_u64()) + .map(|s| s as usize) + .or(fallback_size); + if precision == 0 { + return Err(ArrowError::ParseError( + "Decimal requires precision > 0".to_string(), + )); + } + if scale > precision { + return Err(ArrowError::ParseError(format!( + "Decimal has invalid scale > precision: scale={scale}, precision={precision}" + ))); + } + if precision > DECIMAL256_MAX_PRECISION as usize { + return Err(ArrowError::ParseError(format!( + "Decimal precision {precision} exceeds maximum supported by Arrow ({})", + DECIMAL256_MAX_PRECISION + ))); + } + if let Some(sz) = size { + let max_p = max_precision_for_fixed_bytes(sz).ok_or_else(|| { + ArrowError::ParseError(format!( + "Invalid fixed size for decimal: {sz}, must be between 1 and 32 bytes" + )) + })?; + if precision > max_p { + return Err(ArrowError::ParseError(format!( + "Decimal precision {precision} exceeds capacity of fixed size {sz} bytes (max {max_p})" + ))); + } + } + Ok((precision, scale, size)) +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, AsRefStr)] +#[strum(serialize_all = "snake_case")] +enum UnionFieldKind { + Null, + Boolean, + Int, + Long, + Float, + Double, + Bytes, + String, + Date, + TimeMillis, + TimeMicros, + TimestampMillisUtc, + TimestampMillisLocal, + TimestampMicrosUtc, + TimestampMicrosLocal, + TimestampNanosUtc, + TimestampNanosLocal, + Duration, + Fixed, + Decimal, + Enum, + Array, + Record, + Map, + Uuid, + Union, +} + +impl From<&Codec> for UnionFieldKind { + fn from(c: &Codec) -> Self { + match c { + Codec::Null => Self::Null, + Codec::Boolean => Self::Boolean, + Codec::Int32 => Self::Int, + Codec::Int64 => Self::Long, + Codec::Float32 => Self::Float, + Codec::Float64 => Self::Double, + Codec::Binary => Self::Bytes, + Codec::Utf8 | Codec::Utf8View => Self::String, + Codec::Date32 => Self::Date, + Codec::TimeMillis => Self::TimeMillis, + Codec::TimeMicros => Self::TimeMicros, + Codec::TimestampMillis(true) => Self::TimestampMillisUtc, + Codec::TimestampMillis(false) => Self::TimestampMillisLocal, + Codec::TimestampMicros(true) => Self::TimestampMicrosUtc, + Codec::TimestampMicros(false) => Self::TimestampMicrosLocal, + Codec::TimestampNanos(true) => Self::TimestampNanosUtc, + Codec::TimestampNanos(false) => Self::TimestampNanosLocal, + Codec::Interval => Self::Duration, + Codec::Fixed(_) => Self::Fixed, + Codec::Decimal(..) => Self::Decimal, + Codec::Enum(_) => Self::Enum, + Codec::List(_) => Self::Array, + Codec::Struct(_) => Self::Record, + Codec::Map(_) => Self::Map, + Codec::Uuid => Self::Uuid, + Codec::Union(..) => Self::Union, + #[cfg(feature = "avro_custom_types")] + Codec::RunEndEncoded(values, _) => UnionFieldKind::from(values.codec()), + #[cfg(feature = "avro_custom_types")] + Codec::DurationNanos + | Codec::DurationMicros + | Codec::DurationMillis + | Codec::DurationSeconds => Self::Duration, + } + } +} + +fn union_branch_name(dt: &AvroDataType) -> String { + if let Some(name) = dt.metadata.get(AVRO_NAME_METADATA_KEY) { + if name.contains(".") { + // Full name + return name.to_string(); + } + if let Some(ns) = dt.metadata.get(AVRO_NAMESPACE_METADATA_KEY) { + return format!("{ns}.{name}"); } + return name.to_string(); } + dt.codec.union_field_name() +} + +fn build_union_fields(encodings: &[AvroDataType]) -> Result { + let arrow_fields: Vec = encodings + .iter() + .map(|encoding| encoding.field_with_name(&union_branch_name(encoding))) + .collect(); + let type_ids: Vec = (0..arrow_fields.len()).map(|i| i as i8).collect(); + UnionFields::try_new(type_ids, arrow_fields) } /// Resolves Avro type names to [`AvroDataType`] @@ -304,14 +1012,13 @@ struct Resolver<'a> { impl<'a> Resolver<'a> { fn register(&mut self, name: &'a str, namespace: Option<&'a str>, schema: AvroDataType) { - self.map.insert((name, namespace.unwrap_or("")), schema); + self.map.insert((namespace.unwrap_or(""), name), schema); } fn resolve(&self, name: &str, namespace: Option<&'a str>) -> Result { let (namespace, name) = name .rsplit_once('.') .unwrap_or_else(|| (namespace.unwrap_or(""), name)); - self.map .get(&(namespace, name)) .ok_or_else(|| ArrowError::ParseError(format!("Failed to resolve {namespace}.{name}"))) @@ -319,169 +1026,973 @@ impl<'a> Resolver<'a> { } } -/// Parses a [`AvroDataType`] from the provided [`Schema`] and the given `name` and `namespace` -/// -/// `name`: is name used to refer to `schema` in its parent -/// `namespace`: an optional qualifier used as part of a type hierarchy -/// If the data type is a string, convert to use Utf8View if requested -/// -/// This function is used during the schema conversion process to determine whether -/// string data should be represented as StringArray (default) or StringViewArray. -/// -/// `use_utf8view`: if true, use Utf8View instead of Utf8 for string types +fn full_name_set(name: &str, ns: Option<&str>, aliases: &[&str]) -> HashSet { + let mut out = HashSet::with_capacity(1 + aliases.len()); + let (full, _) = make_full_name(name, ns, None); + out.insert(full); + for a in aliases { + let (fa, _) = make_full_name(a, None, ns); + out.insert(fa); + } + out +} + +fn names_match( + writer_name: &str, + writer_namespace: Option<&str>, + writer_aliases: &[&str], + reader_name: &str, + reader_namespace: Option<&str>, + reader_aliases: &[&str], +) -> bool { + let writer_set = full_name_set(writer_name, writer_namespace, writer_aliases); + let reader_set = full_name_set(reader_name, reader_namespace, reader_aliases); + // If the canonical full names match, or any alias matches cross-wise. + !writer_set.is_disjoint(&reader_set) +} + +fn ensure_names_match( + data_type: &str, + writer_name: &str, + writer_namespace: Option<&str>, + writer_aliases: &[&str], + reader_name: &str, + reader_namespace: Option<&str>, + reader_aliases: &[&str], +) -> Result<(), ArrowError> { + if names_match( + writer_name, + writer_namespace, + writer_aliases, + reader_name, + reader_namespace, + reader_aliases, + ) { + Ok(()) + } else { + Err(ArrowError::ParseError(format!( + "{data_type} name mismatch writer={writer_name}, reader={reader_name}" + ))) + } +} + +fn primitive_of(schema: &Schema) -> Option { + match schema { + Schema::TypeName(TypeName::Primitive(primitive)) => Some(*primitive), + Schema::Type(Type { + r#type: TypeName::Primitive(primitive), + .. + }) => Some(*primitive), + _ => None, + } +} + +fn nullable_union_variants<'x, 'y>( + variant: &'y [Schema<'x>], +) -> Option<(Nullability, &'y Schema<'x>)> { + if variant.len() != 2 { + return None; + } + let is_null = |schema: &Schema<'x>| { + matches!( + schema, + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)) + ) + }; + match (is_null(&variant[0]), is_null(&variant[1])) { + (true, false) => Some((Nullability::NullFirst, &variant[1])), + (false, true) => Some((Nullability::NullSecond, &variant[0])), + _ => None, + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +enum UnionBranchKey { + Named(String), + Primitive(PrimitiveType), + Array, + Map, +} + +fn branch_key_of<'a>(s: &Schema<'a>, enclosing_ns: Option<&'a str>) -> Option { + let (name, namespace) = match s { + Schema::TypeName(TypeName::Primitive(p)) + | Schema::Type(Type { + r#type: TypeName::Primitive(p), + .. + }) => return Some(UnionBranchKey::Primitive(*p)), + Schema::TypeName(TypeName::Ref(name)) + | Schema::Type(Type { + r#type: TypeName::Ref(name), + .. + }) => (name, None), + Schema::Complex(ComplexType::Array(_)) => return Some(UnionBranchKey::Array), + Schema::Complex(ComplexType::Map(_)) => return Some(UnionBranchKey::Map), + Schema::Complex(ComplexType::Record(r)) => (&r.name, r.namespace), + Schema::Complex(ComplexType::Enum(e)) => (&e.name, e.namespace), + Schema::Complex(ComplexType::Fixed(f)) => (&f.name, f.namespace), + Schema::Union(_) => return None, + }; + let (full, _) = make_full_name(name, namespace, enclosing_ns); + Some(UnionBranchKey::Named(full)) +} + +fn union_first_duplicate<'a>( + branches: &'a [Schema<'a>], + enclosing_ns: Option<&'a str>, +) -> Option { + let mut seen = HashSet::with_capacity(branches.len()); + for schema in branches { + if let Some(key) = branch_key_of(schema, enclosing_ns) { + if !seen.insert(key.clone()) { + let msg = match key { + UnionBranchKey::Named(full) => format!("named type {full}"), + UnionBranchKey::Primitive(p) => format!("primitive {}", p.as_ref()), + UnionBranchKey::Array => "array".to_string(), + UnionBranchKey::Map => "map".to_string(), + }; + return Some(msg); + } + } + } + None +} + +/// Resolves Avro type names to [`AvroDataType`] /// -/// See [`Resolver`] for more information -fn make_data_type<'a>( - schema: &Schema<'a>, - namespace: Option<&'a str>, - resolver: &mut Resolver<'a>, +/// See +struct Maker<'a> { + resolver: Resolver<'a>, use_utf8view: bool, -) -> Result { - match schema { - Schema::TypeName(TypeName::Primitive(p)) => { - let codec: Codec = (*p).into(); - let codec = codec.with_utf8view(use_utf8view); - Ok(AvroDataType { - nullability: None, - metadata: Default::default(), - codec, - }) + strict_mode: bool, +} + +impl<'a> Maker<'a> { + fn new(use_utf8view: bool, strict_mode: bool) -> Self { + Self { + resolver: Default::default(), + use_utf8view, + strict_mode, } - Schema::TypeName(TypeName::Ref(name)) => resolver.resolve(name, namespace), - Schema::Union(f) => { - // Special case the common case of nullable primitives - let null = f - .iter() - .position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null))); - match (f.len() == 2, null) { - (true, Some(0)) => { - let mut field = make_data_type(&f[1], namespace, resolver, use_utf8view)?; - field.nullability = Some(Nullability::NullFirst); - Ok(field) + } + + #[cfg(feature = "avro_custom_types")] + #[inline] + fn propagate_nullability_into_ree(dt: &mut AvroDataType, nb: Nullability) { + if let Codec::RunEndEncoded(values, bits) = dt.codec.clone() { + let mut inner = (*values).clone(); + inner.nullability = Some(nb); + dt.codec = Codec::RunEndEncoded(Arc::new(inner), bits); + } + } + + fn make_data_type<'s>( + &mut self, + writer_schema: &'s Schema<'a>, + reader_schema: Option<&'s Schema<'a>>, + namespace: Option<&'a str>, + ) -> Result { + match reader_schema { + Some(reader_schema) => self.resolve_type(writer_schema, reader_schema, namespace), + None => self.parse_type(writer_schema, namespace), + } + } + + /// Parses a [`AvroDataType`] from the provided `Schema` and the given `name` and `namespace` + /// + /// `name`: is the name used to refer to `schema` in its parent + /// `namespace`: an optional qualifier used as part of a type hierarchy + /// If the data type is a string, convert to use Utf8View if requested + /// + /// This function is used during the schema conversion process to determine whether + /// string data should be represented as StringArray (default) or StringViewArray. + /// + /// `use_utf8view`: if true, use Utf8View instead of Utf8 for string types + /// + /// See [`Resolver`] for more information + fn parse_type<'s>( + &mut self, + schema: &'s Schema<'a>, + namespace: Option<&'a str>, + ) -> Result { + match schema { + Schema::TypeName(TypeName::Primitive(p)) => Ok(AvroDataType::new( + Codec::from(*p).with_utf8view(self.use_utf8view), + Default::default(), + None, + )), + Schema::TypeName(TypeName::Ref(name)) => self.resolver.resolve(name, namespace), + Schema::Union(f) => { + let null = f + .iter() + .position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null))); + match (f.len() == 2, null) { + (true, Some(0)) => { + let mut field = self.parse_type(&f[1], namespace)?; + field.nullability = Some(Nullability::NullFirst); + #[cfg(feature = "avro_custom_types")] + Self::propagate_nullability_into_ree(&mut field, Nullability::NullFirst); + return Ok(field); + } + (true, Some(1)) => { + if self.strict_mode { + return Err(ArrowError::SchemaError( + "Found Avro union of the form ['T','null'], which is disallowed in strict_mode" + .to_string(), + )); + } + let mut field = self.parse_type(&f[0], namespace)?; + field.nullability = Some(Nullability::NullSecond); + #[cfg(feature = "avro_custom_types")] + Self::propagate_nullability_into_ree(&mut field, Nullability::NullSecond); + return Ok(field); + } + _ => {} } - (true, Some(1)) => { - let mut field = make_data_type(&f[0], namespace, resolver, use_utf8view)?; - field.nullability = Some(Nullability::NullSecond); - Ok(field) + // Validate: unions may not immediately contain unions + if f.iter().any(|s| matches!(s, Schema::Union(_))) { + return Err(ArrowError::SchemaError( + "Avro unions may not immediately contain other unions".to_string(), + )); } - _ => Err(ArrowError::NotYetImplemented(format!( - "Union of {f:?} not currently supported" - ))), - } - } - Schema::Complex(c) => match c { - ComplexType::Record(r) => { - let namespace = r.namespace.or(namespace); - let fields = r - .fields + // Validate: duplicates (named by full name; non-named by kind) + if let Some(dup) = union_first_duplicate(f, namespace) { + return Err(ArrowError::SchemaError(format!( + "Avro union contains duplicate branch type: {dup}" + ))); + } + // Parse all branches + let children: Vec = f .iter() - .map(|field| { - Ok(AvroField { - name: field.name.to_string(), - data_type: make_data_type( - &field.r#type, - namespace, - resolver, - use_utf8view, - )?, + .map(|s| self.parse_type(s, namespace)) + .collect::>()?; + // Build Arrow layout once here + let union_fields = build_union_fields(&children)?; + Ok(AvroDataType::new( + Codec::Union(Arc::from(children), union_fields, UnionMode::Dense), + Default::default(), + None, + )) + } + Schema::Complex(c) => match c { + ComplexType::Record(r) => { + let namespace = r.namespace.or(namespace); + let mut metadata = r.attributes.field_metadata(); + let fields = r + .fields + .iter() + .map(|field| { + Ok(AvroField { + name: field.name.to_string(), + data_type: self.parse_type(&field.r#type, namespace)?, + }) }) + .collect::>()?; + metadata.insert(AVRO_NAME_METADATA_KEY.to_string(), r.name.to_string()); + if let Some(ns) = namespace { + metadata.insert(AVRO_NAMESPACE_METADATA_KEY.to_string(), ns.to_string()); + } + let field = AvroDataType { + nullability: None, + codec: Codec::Struct(fields), + metadata, + resolution: None, + }; + self.resolver.register(r.name, namespace, field.clone()); + Ok(field) + } + ComplexType::Array(a) => { + let field = self.parse_type(a.items.as_ref(), namespace)?; + Ok(AvroDataType { + nullability: None, + metadata: a.attributes.field_metadata(), + codec: Codec::List(Arc::new(field)), + resolution: None, }) - .collect::>()?; - - let field = AvroDataType { - nullability: None, - codec: Codec::Struct(fields), - metadata: r.attributes.field_metadata(), - }; - resolver.register(r.name, namespace, field.clone()); - Ok(field) - } - ComplexType::Array(a) => { - let mut field = - make_data_type(a.items.as_ref(), namespace, resolver, use_utf8view)?; - Ok(AvroDataType { - nullability: None, - metadata: a.attributes.field_metadata(), - codec: Codec::List(Arc::new(field)), - }) - } - ComplexType::Fixed(f) => { - let size = f.size.try_into().map_err(|e| { - ArrowError::ParseError(format!("Overflow converting size to i32: {e}")) - })?; - - let field = AvroDataType { - nullability: None, - metadata: f.attributes.field_metadata(), - codec: Codec::Fixed(size), - }; - resolver.register(f.name, namespace, field.clone()); - Ok(field) - } - ComplexType::Enum(e) => Err(ArrowError::NotYetImplemented(format!( - "Enum of {e:?} not currently supported" - ))), - ComplexType::Map(m) => { - let val = make_data_type(&m.values, namespace, resolver, use_utf8view)?; - Ok(AvroDataType { - nullability: None, - metadata: m.attributes.field_metadata(), - codec: Codec::Map(Arc::new(val)), - }) - } - }, - Schema::Type(t) => { - let mut field = make_data_type( - &Schema::TypeName(t.r#type.clone()), - namespace, - resolver, - use_utf8view, - )?; - - // https://avro.apache.org/docs/1.11.1/specification/#logical-types - match (t.attributes.logical_type, &mut field.codec) { - (Some("decimal"), c @ Codec::Fixed(_)) => { - return Err(ArrowError::NotYetImplemented( - "Decimals are not currently supported".to_string(), - )) } - (Some("date"), c @ Codec::Int32) => *c = Codec::Date32, - (Some("time-millis"), c @ Codec::Int32) => *c = Codec::TimeMillis, - (Some("time-micros"), c @ Codec::Int64) => *c = Codec::TimeMicros, - (Some("timestamp-millis"), c @ Codec::Int64) => *c = Codec::TimestampMillis(true), - (Some("timestamp-micros"), c @ Codec::Int64) => *c = Codec::TimestampMicros(true), - (Some("local-timestamp-millis"), c @ Codec::Int64) => { - *c = Codec::TimestampMillis(false) + ComplexType::Fixed(f) => { + let size = f.size.try_into().map_err(|e| { + ArrowError::ParseError(format!("Overflow converting size to i32: {e}")) + })?; + let namespace = f.namespace.or(namespace); + let mut metadata = f.attributes.field_metadata(); + metadata.insert(AVRO_NAME_METADATA_KEY.to_string(), f.name.to_string()); + if let Some(ns) = namespace { + metadata.insert(AVRO_NAMESPACE_METADATA_KEY.to_string(), ns.to_string()); + } + let field = match f.attributes.logical_type { + Some("decimal") => { + let (precision, scale, _) = + parse_decimal_attributes(&f.attributes, Some(size as usize), true)?; + AvroDataType { + nullability: None, + metadata, + codec: Codec::Decimal(precision, Some(scale), Some(size as usize)), + resolution: None, + } + } + Some("duration") => { + if size != 12 { + return Err(ArrowError::ParseError(format!( + "Invalid fixed size for Duration: {size}, must be 12" + ))); + }; + AvroDataType { + nullability: None, + metadata, + codec: Codec::Interval, + resolution: None, + } + } + _ => AvroDataType { + nullability: None, + metadata, + codec: Codec::Fixed(size), + resolution: None, + }, + }; + self.resolver.register(f.name, namespace, field.clone()); + Ok(field) + } + ComplexType::Enum(e) => { + let namespace = e.namespace.or(namespace); + let symbols = e + .symbols + .iter() + .map(|s| s.to_string()) + .collect::>(); + let mut metadata = e.attributes.field_metadata(); + let symbols_json = serde_json::to_string(&e.symbols).map_err(|e| { + ArrowError::ParseError(format!("Failed to serialize enum symbols: {e}")) + })?; + metadata.insert(AVRO_ENUM_SYMBOLS_METADATA_KEY.to_string(), symbols_json); + metadata.insert(AVRO_NAME_METADATA_KEY.to_string(), e.name.to_string()); + if let Some(ns) = namespace { + metadata.insert(AVRO_NAMESPACE_METADATA_KEY.to_string(), ns.to_string()); + } + let field = AvroDataType { + nullability: None, + metadata, + codec: Codec::Enum(symbols), + resolution: None, + }; + self.resolver.register(e.name, namespace, field.clone()); + Ok(field) } - (Some("local-timestamp-micros"), c @ Codec::Int64) => { - *c = Codec::TimestampMicros(false) + ComplexType::Map(m) => { + let val = self.parse_type(&m.values, namespace)?; + Ok(AvroDataType { + nullability: None, + metadata: m.attributes.field_metadata(), + codec: Codec::Map(Arc::new(val)), + resolution: None, + }) } - (Some("duration"), c @ Codec::Fixed(12)) => *c = Codec::Interval, - (Some(logical), _) => { - // Insert unrecognized logical type into metadata map - field.metadata.insert("logicalType".into(), logical.into()); + }, + Schema::Type(t) => { + let mut field = self.parse_type(&Schema::TypeName(t.r#type.clone()), namespace)?; + // https://avro.apache.org/docs/1.11.1/specification/#logical-types + match (t.attributes.logical_type, &mut field.codec) { + (Some("decimal"), c @ Codec::Binary) => { + let (prec, sc, _) = parse_decimal_attributes(&t.attributes, None, false)?; + *c = Codec::Decimal(prec, Some(sc), None); + } + (Some("date"), c @ Codec::Int32) => *c = Codec::Date32, + (Some("time-millis"), c @ Codec::Int32) => *c = Codec::TimeMillis, + (Some("time-micros"), c @ Codec::Int64) => *c = Codec::TimeMicros, + (Some("timestamp-millis"), c @ Codec::Int64) => { + *c = Codec::TimestampMillis(true) + } + (Some("timestamp-micros"), c @ Codec::Int64) => { + *c = Codec::TimestampMicros(true) + } + (Some("local-timestamp-millis"), c @ Codec::Int64) => { + *c = Codec::TimestampMillis(false) + } + (Some("local-timestamp-micros"), c @ Codec::Int64) => { + *c = Codec::TimestampMicros(false) + } + (Some("timestamp-nanos"), c @ Codec::Int64) => *c = Codec::TimestampNanos(true), + (Some("local-timestamp-nanos"), c @ Codec::Int64) => { + *c = Codec::TimestampNanos(false) + } + (Some("uuid"), c @ Codec::Utf8) => { + // Map Avro string+logicalType=uuid into the UUID Codec, + // and preserve the logicalType in Arrow field metadata + // so writers can round-trip it correctly. + *c = Codec::Uuid; + field.metadata.insert("logicalType".into(), "uuid".into()); + } + #[cfg(feature = "avro_custom_types")] + (Some("arrow.duration-nanos"), c @ Codec::Int64) => *c = Codec::DurationNanos, + #[cfg(feature = "avro_custom_types")] + (Some("arrow.duration-micros"), c @ Codec::Int64) => *c = Codec::DurationMicros, + #[cfg(feature = "avro_custom_types")] + (Some("arrow.duration-millis"), c @ Codec::Int64) => *c = Codec::DurationMillis, + #[cfg(feature = "avro_custom_types")] + (Some("arrow.duration-seconds"), c @ Codec::Int64) => { + *c = Codec::DurationSeconds + } + #[cfg(feature = "avro_custom_types")] + (Some("arrow.run-end-encoded"), _) => { + let bits_u8: u8 = t + .attributes + .additional + .get("arrow.runEndIndexBits") + .and_then(|v| v.as_u64()) + .and_then(|n| u8::try_from(n).ok()) + .ok_or_else(|| ArrowError::ParseError( + "arrow.run-end-encoded requires 'arrow.runEndIndexBits' (one of 16, 32, or 64)" + .to_string(), + ))?; + if bits_u8 != 16 && bits_u8 != 32 && bits_u8 != 64 { + return Err(ArrowError::ParseError(format!( + "Invalid 'arrow.runEndIndexBits' value {bits_u8}; must be 16, 32, or 64" + ))); + } + // Wrap the parsed underlying site as REE + let values_site = field.clone(); + field.codec = Codec::RunEndEncoded(Arc::new(values_site), bits_u8); + } + (Some(logical), _) => { + // Insert unrecognized logical type into metadata map + field.metadata.insert("logicalType".into(), logical.into()); + } + (None, _) => {} } - (None, _) => {} - } - - if !t.attributes.additional.is_empty() { - for (k, v) in &t.attributes.additional { - field.metadata.insert(k.to_string(), v.to_string()); + if matches!(field.codec, Codec::Int64) { + if let Some(unit) = t + .attributes + .additional + .get("arrowTimeUnit") + .and_then(|v| v.as_str()) + { + if unit == "nanosecond" { + field.codec = Codec::TimestampNanos(false); + } + } } + if !t.attributes.additional.is_empty() { + for (k, v) in &t.attributes.additional { + field.metadata.insert(k.to_string(), v.to_string()); + } + } + Ok(field) } - Ok(field) } } -} -#[cfg(test)] + fn resolve_type<'s>( + &mut self, + writer_schema: &'s Schema<'a>, + reader_schema: &'s Schema<'a>, + namespace: Option<&'a str>, + ) -> Result { + if let (Some(write_primitive), Some(read_primitive)) = + (primitive_of(writer_schema), primitive_of(reader_schema)) + { + return self.resolve_primitives(write_primitive, read_primitive, reader_schema); + } + match (writer_schema, reader_schema) { + (Schema::Union(writer_variants), Schema::Union(reader_variants)) => { + let writer_variants = writer_variants.as_slice(); + let reader_variants = reader_variants.as_slice(); + match ( + nullable_union_variants(writer_variants), + nullable_union_variants(reader_variants), + ) { + (Some((w_nb, w_nonnull)), Some((_r_nb, r_nonnull))) => { + let mut dt = self.make_data_type(w_nonnull, Some(r_nonnull), namespace)?; + dt.nullability = Some(w_nb); + #[cfg(feature = "avro_custom_types")] + Self::propagate_nullability_into_ree(&mut dt, w_nb); + Ok(dt) + } + _ => self.resolve_unions(writer_variants, reader_variants, namespace), + } + } + (Schema::Union(writer_variants), reader_non_union) => { + let writer_to_reader: Vec> = writer_variants + .iter() + .map(|writer| { + self.resolve_type(writer, reader_non_union, namespace) + .ok() + .map(|tmp| (0usize, Self::coercion_from(&tmp))) + }) + .collect(); + let mut dt = self.parse_type(reader_non_union, namespace)?; + dt.resolution = Some(ResolutionInfo::Union(ResolvedUnion { + writer_to_reader: Arc::from(writer_to_reader), + writer_is_union: true, + reader_is_union: false, + })); + Ok(dt) + } + (writer_non_union, Schema::Union(reader_variants)) => { + let promo = self.find_best_promotion( + writer_non_union, + reader_variants.as_slice(), + namespace, + ); + let Some((reader_index, promotion)) = promo else { + return Err(ArrowError::SchemaError( + "Writer schema does not match any reader union branch".to_string(), + )); + }; + let mut dt = self.parse_type(reader_schema, namespace)?; + dt.resolution = Some(ResolutionInfo::Union(ResolvedUnion { + writer_to_reader: Arc::from(vec![Some((reader_index, promotion))]), + writer_is_union: false, + reader_is_union: true, + })); + Ok(dt) + } + ( + Schema::Complex(ComplexType::Array(writer_array)), + Schema::Complex(ComplexType::Array(reader_array)), + ) => self.resolve_array(writer_array, reader_array, namespace), + ( + Schema::Complex(ComplexType::Map(writer_map)), + Schema::Complex(ComplexType::Map(reader_map)), + ) => self.resolve_map(writer_map, reader_map, namespace), + ( + Schema::Complex(ComplexType::Fixed(writer_fixed)), + Schema::Complex(ComplexType::Fixed(reader_fixed)), + ) => self.resolve_fixed(writer_fixed, reader_fixed, reader_schema, namespace), + ( + Schema::Complex(ComplexType::Record(writer_record)), + Schema::Complex(ComplexType::Record(reader_record)), + ) => self.resolve_records(writer_record, reader_record, namespace), + ( + Schema::Complex(ComplexType::Enum(writer_enum)), + Schema::Complex(ComplexType::Enum(reader_enum)), + ) => self.resolve_enums(writer_enum, reader_enum, reader_schema, namespace), + (Schema::TypeName(TypeName::Ref(_)), _) => self.parse_type(reader_schema, namespace), + (_, Schema::TypeName(TypeName::Ref(_))) => self.parse_type(reader_schema, namespace), + _ => Err(ArrowError::NotYetImplemented( + "Other resolutions not yet implemented".to_string(), + )), + } + } + + #[inline] + fn coercion_from(dt: &AvroDataType) -> Promotion { + match dt.resolution.as_ref() { + Some(ResolutionInfo::Promotion(promotion)) => *promotion, + _ => Promotion::Direct, + } + } + + fn find_best_promotion( + &mut self, + writer: &Schema<'a>, + reader_variants: &[Schema<'a>], + namespace: Option<&'a str>, + ) -> Option<(usize, Promotion)> { + let mut first_promotion: Option<(usize, Promotion)> = None; + for (reader_index, reader) in reader_variants.iter().enumerate() { + if let Ok(tmp) = self.resolve_type(writer, reader, namespace) { + let promotion = Self::coercion_from(&tmp); + if promotion == Promotion::Direct { + // An exact match is best, return immediately. + return Some((reader_index, promotion)); + } else if first_promotion.is_none() { + // Store the first valid promotion but keep searching for a direct match. + first_promotion = Some((reader_index, promotion)); + } + } + } + first_promotion + } + + fn resolve_unions<'s>( + &mut self, + writer_variants: &'s [Schema<'a>], + reader_variants: &'s [Schema<'a>], + namespace: Option<&'a str>, + ) -> Result { + let reader_encodings: Vec = reader_variants + .iter() + .map(|reader_schema| self.parse_type(reader_schema, namespace)) + .collect::>()?; + let mut writer_to_reader: Vec> = + Vec::with_capacity(writer_variants.len()); + for writer in writer_variants { + writer_to_reader.push(self.find_best_promotion(writer, reader_variants, namespace)); + } + let union_fields = build_union_fields(&reader_encodings)?; + let mut dt = AvroDataType::new( + Codec::Union(reader_encodings.into(), union_fields, UnionMode::Dense), + Default::default(), + None, + ); + dt.resolution = Some(ResolutionInfo::Union(ResolvedUnion { + writer_to_reader: Arc::from(writer_to_reader), + writer_is_union: true, + reader_is_union: true, + })); + Ok(dt) + } + + fn resolve_array( + &mut self, + writer_array: &Array<'a>, + reader_array: &Array<'a>, + namespace: Option<&'a str>, + ) -> Result { + Ok(AvroDataType { + nullability: None, + metadata: reader_array.attributes.field_metadata(), + codec: Codec::List(Arc::new(self.make_data_type( + writer_array.items.as_ref(), + Some(reader_array.items.as_ref()), + namespace, + )?)), + resolution: None, + }) + } + + fn resolve_map( + &mut self, + writer_map: &Map<'a>, + reader_map: &Map<'a>, + namespace: Option<&'a str>, + ) -> Result { + Ok(AvroDataType { + nullability: None, + metadata: reader_map.attributes.field_metadata(), + codec: Codec::Map(Arc::new(self.make_data_type( + &writer_map.values, + Some(&reader_map.values), + namespace, + )?)), + resolution: None, + }) + } + + fn resolve_fixed<'s>( + &mut self, + writer_fixed: &Fixed<'a>, + reader_fixed: &Fixed<'a>, + reader_schema: &'s Schema<'a>, + namespace: Option<&'a str>, + ) -> Result { + ensure_names_match( + "Fixed", + writer_fixed.name, + writer_fixed.namespace, + &writer_fixed.aliases, + reader_fixed.name, + reader_fixed.namespace, + &reader_fixed.aliases, + )?; + if writer_fixed.size != reader_fixed.size { + return Err(ArrowError::SchemaError(format!( + "Fixed size mismatch for {}: writer={}, reader={}", + reader_fixed.name, writer_fixed.size, reader_fixed.size + ))); + } + self.parse_type(reader_schema, namespace) + } + + fn resolve_primitives( + &mut self, + write_primitive: PrimitiveType, + read_primitive: PrimitiveType, + reader_schema: &Schema<'a>, + ) -> Result { + if write_primitive == read_primitive { + return self.parse_type(reader_schema, None); + } + let promotion = match (write_primitive, read_primitive) { + (PrimitiveType::Int, PrimitiveType::Long) => Promotion::IntToLong, + (PrimitiveType::Int, PrimitiveType::Float) => Promotion::IntToFloat, + (PrimitiveType::Int, PrimitiveType::Double) => Promotion::IntToDouble, + (PrimitiveType::Long, PrimitiveType::Float) => Promotion::LongToFloat, + (PrimitiveType::Long, PrimitiveType::Double) => Promotion::LongToDouble, + (PrimitiveType::Float, PrimitiveType::Double) => Promotion::FloatToDouble, + (PrimitiveType::String, PrimitiveType::Bytes) => Promotion::StringToBytes, + (PrimitiveType::Bytes, PrimitiveType::String) => Promotion::BytesToString, + _ => { + return Err(ArrowError::ParseError(format!( + "Illegal promotion {write_primitive:?} to {read_primitive:?}" + ))); + } + }; + let mut datatype = self.parse_type(reader_schema, None)?; + datatype.resolution = Some(ResolutionInfo::Promotion(promotion)); + Ok(datatype) + } + + // Resolve writer vs. reader enum schemas according to Avro 1.11.1. + // + // # How enums resolve (writer to reader) + // Per “Schema Resolution”: + // * The two schemas must refer to the same (unqualified) enum name (or match + // via alias rewriting). + // * If the writer’s symbol is not present in the reader’s enum and the reader + // enum has a `default`, that `default` symbol must be used; otherwise, + // error. + // https://avro.apache.org/docs/1.11.1/specification/#schema-resolution + // * Avro “Aliases” are applied from the reader side to rewrite the writer’s + // names during resolution. For robustness across ecosystems, we also accept + // symmetry here (see note below). + // https://avro.apache.org/docs/1.11.1/specification/#aliases + // + // # Rationale for this code path + // 1. Do the work once at schema‑resolution time. Avro serializes an enum as a + // writer‑side position. Mapping positions on the hot decoder path is expensive + // if done with string lookups. This method builds a `writer_index to reader_index` + // vector once, so decoding just does an O(1) table lookup. + // 2. Adopt the reader’s symbol set and order. We return an Arrow + // `Dictionary(Int32, Utf8)` whose dictionary values are the reader enum + // symbols. This makes downstream semantics match the reader schema, including + // Avro’s sort order rule that orders enums by symbol position in the schema. + // https://avro.apache.org/docs/1.11.1/specification/#sort-order + // 3. Honor Avro’s `default` for enums. Avro 1.9+ allows a type‑level default + // on the enum. When the writer emits a symbol unknown to the reader, we map it + // to the reader’s validated `default` symbol if present; otherwise we signal an + // error at decoding time. + // https://avro.apache.org/docs/1.11.1/specification/#enums + // + // # Implementation notes + // * We first check that enum names match or are*alias‑equivalent. The Avro + // spec describes alias rewriting using reader aliases; this implementation + // additionally treats writer aliases as acceptable for name matching to be + // resilient with schemas produced by different tooling. + // * We build `EnumMapping`: + // - `mapping[i]` = reader index of the writer symbol at writer index `i`. + // - If the writer symbol is absent and the reader has a default, we store the + // reader index of that default. + // - Otherwise we store `-1` as a sentinel meaning unresolvable; the decoder + // must treat encountering such a value as an error, per the spec. + // * We persist the reader symbol list in field metadata under + // `AVRO_ENUM_SYMBOLS_METADATA_KEY`, so consumers can inspect the dictionary + // without needing the original Avro schema. + // * The Arrow representation is `Dictionary(Int32, Utf8)`, which aligns with + // Avro’s integer index encoding for enums. + // + // # Examples + // * Writer `["A","B","C"]`, Reader `["A","B"]`, Reader default `"A"` + // `mapping = [0, 1, 0]`, `default_index = 0`. + // * Writer `["A","B"]`, Reader `["B","A"]` (no default) + // `mapping = [1, 0]`, `default_index = -1`. + // * Writer `["A","B","C"]`, Reader `["A","B"]` (no default) + // `mapping = [0, 1, -1]` (decode must error on `"C"`). + fn resolve_enums( + &mut self, + writer_enum: &Enum<'a>, + reader_enum: &Enum<'a>, + reader_schema: &Schema<'a>, + namespace: Option<&'a str>, + ) -> Result { + ensure_names_match( + "Enum", + writer_enum.name, + writer_enum.namespace, + &writer_enum.aliases, + reader_enum.name, + reader_enum.namespace, + &reader_enum.aliases, + )?; + if writer_enum.symbols == reader_enum.symbols { + return self.parse_type(reader_schema, namespace); + } + let reader_index: HashMap<&str, i32> = reader_enum + .symbols + .iter() + .enumerate() + .map(|(index, &symbol)| (symbol, index as i32)) + .collect(); + let default_index: i32 = match reader_enum.default { + Some(symbol) => *reader_index.get(symbol).ok_or_else(|| { + ArrowError::SchemaError(format!( + "Reader enum '{}' default symbol '{symbol}' not found in symbols list", + reader_enum.name, + )) + })?, + None => -1, + }; + let mapping: Vec = writer_enum + .symbols + .iter() + .map(|&write_symbol| { + reader_index + .get(write_symbol) + .copied() + .unwrap_or(default_index) + }) + .collect(); + if self.strict_mode && mapping.iter().any(|&m| m < 0) { + return Err(ArrowError::SchemaError(format!( + "Reader enum '{}' does not cover all writer symbols and no default is provided", + reader_enum.name + ))); + } + let mut dt = self.parse_type(reader_schema, namespace)?; + dt.resolution = Some(ResolutionInfo::EnumMapping(EnumMapping { + mapping: Arc::from(mapping), + default_index, + })); + let reader_ns = reader_enum.namespace.or(namespace); + self.resolver + .register(reader_enum.name, reader_ns, dt.clone()); + Ok(dt) + } + + #[inline] + fn build_writer_lookup( + writer_record: &Record<'a>, + ) -> (HashMap<&'a str, usize>, HashSet<&'a str>) { + let mut map: HashMap<&str, usize> = HashMap::with_capacity(writer_record.fields.len() * 2); + for (idx, wf) in writer_record.fields.iter().enumerate() { + // Avro field names are unique; last-in wins are acceptable and match previous behavior. + map.insert(wf.name, idx); + } + // Track ambiguous writer aliases (alias used by multiple writer fields) + let mut ambiguous: HashSet<&str> = HashSet::new(); + for (idx, wf) in writer_record.fields.iter().enumerate() { + for &alias in &wf.aliases { + match map.entry(alias) { + Entry::Occupied(e) if *e.get() != idx => { + ambiguous.insert(alias); + } + Entry::Vacant(e) => { + e.insert(idx); + } + _ => {} + } + } + } + (map, ambiguous) + } + + fn resolve_records( + &mut self, + writer_record: &Record<'a>, + reader_record: &Record<'a>, + namespace: Option<&'a str>, + ) -> Result { + ensure_names_match( + "Record", + writer_record.name, + writer_record.namespace, + &writer_record.aliases, + reader_record.name, + reader_record.namespace, + &reader_record.aliases, + )?; + let writer_ns = writer_record.namespace.or(namespace); + let reader_ns = reader_record.namespace.or(namespace); + let reader_md = reader_record.attributes.field_metadata(); + // Build writer lookup and ambiguous alias set. + let (writer_lookup, ambiguous_writer_aliases) = Self::build_writer_lookup(writer_record); + let mut writer_to_reader: Vec> = vec![None; writer_record.fields.len()]; + let mut reader_fields: Vec = Vec::with_capacity(reader_record.fields.len()); + // Capture default field indices during the main loop (one pass). + let mut default_fields: Vec = Vec::new(); + for (reader_idx, r_field) in reader_record.fields.iter().enumerate() { + // Direct name match, then reader aliases (a writer alias map is pre-populated). + let mut match_idx = writer_lookup.get(r_field.name).copied(); + let mut matched_via_alias: Option<&str> = None; + if match_idx.is_none() { + for &alias in &r_field.aliases { + if let Some(i) = writer_lookup.get(alias).copied() { + if self.strict_mode && ambiguous_writer_aliases.contains(alias) { + return Err(ArrowError::SchemaError(format!( + "Ambiguous alias '{alias}' on reader field '{}' matches multiple writer fields", + r_field.name + ))); + } + match_idx = Some(i); + matched_via_alias = Some(alias); + break; + } + } + } + if let Some(wi) = match_idx { + if writer_to_reader[wi].is_none() { + let w_schema = &writer_record.fields[wi].r#type; + let dt = self.make_data_type(w_schema, Some(&r_field.r#type), reader_ns)?; + writer_to_reader[wi] = Some(reader_idx); + reader_fields.push(AvroField { + name: r_field.name.to_owned(), + data_type: dt, + }); + continue; + } else if self.strict_mode { + // Writer field already mapped and strict_mode => error + let existing_reader = writer_to_reader[wi].unwrap(); + let via = matched_via_alias + .map(|a| format!("alias '{a}'")) + .unwrap_or_else(|| "name match".to_string()); + return Err(ArrowError::SchemaError(format!( + "Multiple reader fields map to the same writer field '{}' via {via} (existing reader index {existing_reader}, new reader index {reader_idx})", + writer_record.fields[wi].name + ))); + } + // Non-strict and already mapped -> fall through to defaulting logic + } + // No match (or conflicted in non-strict mode): attach default per Avro spec. + let mut dt = self.parse_type(&r_field.r#type, reader_ns)?; + if let Some(default_json) = r_field.default.as_ref() { + dt.resolution = Some(ResolutionInfo::DefaultValue( + dt.parse_and_store_default(default_json)?, + )); + default_fields.push(reader_idx); + } else if dt.nullability() == Some(Nullability::NullFirst) { + // The only valid implicit default for a union is the first branch (null-first case). + dt.resolution = Some(ResolutionInfo::DefaultValue( + dt.parse_and_store_default(&Value::Null)?, + )); + default_fields.push(reader_idx); + } else { + return Err(ArrowError::SchemaError(format!( + "Reader field '{}' not present in writer schema must have a default value", + r_field.name + ))); + } + reader_fields.push(AvroField { + name: r_field.name.to_owned(), + data_type: dt, + }); + } + // Build skip_fields in writer order; pre-size and push. + let mut skip_fields: Vec> = + Vec::with_capacity(writer_record.fields.len()); + for (writer_index, writer_field) in writer_record.fields.iter().enumerate() { + if writer_to_reader[writer_index].is_some() { + skip_fields.push(None); + } else { + skip_fields.push(Some(self.parse_type(&writer_field.r#type, writer_ns)?)); + } + } + let resolved = AvroDataType::new_with_resolution( + Codec::Struct(Arc::from(reader_fields)), + reader_md, + None, + Some(ResolutionInfo::Record(ResolvedRecord { + writer_to_reader: Arc::from(writer_to_reader), + default_fields: Arc::from(default_fields), + skip_fields: Arc::from(skip_fields), + })), + ); + // Register a resolved record by reader name+namespace for potential named type refs. + self.resolver + .register(reader_record.name, reader_ns, resolved.clone()); + Ok(resolved) + } +} + +#[cfg(test)] mod tests { use super::*; use crate::schema::{ - Attributes, ComplexType, Fixed, PrimitiveType, Record, Schema, Type, TypeName, + AVRO_ROOT_RECORD_DEFAULT_NAME, Array, Attributes, ComplexType, Field as AvroFieldSchema, + Fixed, PrimitiveType, Record, Schema, Type, TypeName, }; - use serde_json; - use std::collections::HashMap; + use indexmap::IndexMap; + use serde_json::{self, Value}; fn create_schema_with_logical_type( primitive_type: PrimitiveType, @@ -498,27 +2009,28 @@ mod tests { }) } - fn create_fixed_schema(size: usize, logical_type: &'static str) -> Schema<'static> { - let attributes = Attributes { - logical_type: Some(logical_type), - additional: Default::default(), - }; + fn resolve_promotion(writer: PrimitiveType, reader: PrimitiveType) -> AvroDataType { + let writer_schema = Schema::TypeName(TypeName::Primitive(writer)); + let reader_schema = Schema::TypeName(TypeName::Primitive(reader)); + let mut maker = Maker::new(false, false); + maker + .make_data_type(&writer_schema, Some(&reader_schema), None) + .expect("promotion should resolve") + } - Schema::Complex(ComplexType::Fixed(Fixed { - name: "fixed_type", - namespace: None, - aliases: Vec::new(), - size, - attributes, - })) + fn mk_primitive(pt: PrimitiveType) -> Schema<'static> { + Schema::TypeName(TypeName::Primitive(pt)) + } + fn mk_union(branches: Vec>) -> Schema<'_> { + Schema::Union(branches) } #[test] fn test_date_logical_type() { let schema = create_schema_with_logical_type(PrimitiveType::Int, "date"); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::Date32)); } @@ -527,8 +2039,8 @@ mod tests { fn test_time_millis_logical_type() { let schema = create_schema_with_logical_type(PrimitiveType::Int, "time-millis"); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::TimeMillis)); } @@ -537,8 +2049,8 @@ mod tests { fn test_time_micros_logical_type() { let schema = create_schema_with_logical_type(PrimitiveType::Long, "time-micros"); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::TimeMicros)); } @@ -547,8 +2059,8 @@ mod tests { fn test_timestamp_millis_logical_type() { let schema = create_schema_with_logical_type(PrimitiveType::Long, "timestamp-millis"); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::TimestampMillis(true))); } @@ -557,8 +2069,8 @@ mod tests { fn test_timestamp_micros_logical_type() { let schema = create_schema_with_logical_type(PrimitiveType::Long, "timestamp-micros"); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::TimestampMicros(true))); } @@ -567,8 +2079,8 @@ mod tests { fn test_local_timestamp_millis_logical_type() { let schema = create_schema_with_logical_type(PrimitiveType::Long, "local-timestamp-millis"); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::TimestampMillis(false))); } @@ -577,12 +2089,21 @@ mod tests { fn test_local_timestamp_micros_logical_type() { let schema = create_schema_with_logical_type(PrimitiveType::Long, "local-timestamp-micros"); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::TimestampMicros(false))); } + #[test] + fn test_uuid_type() { + let mut codec = Codec::Fixed(16); + if let c @ Codec::Fixed(16) = &mut codec { + *c = Codec::Uuid; + } + assert!(matches!(codec, Codec::Uuid)); + } + #[test] fn test_duration_logical_type() { let mut codec = Codec::Fixed(12); @@ -596,7 +2117,7 @@ mod tests { #[test] fn test_decimal_logical_type_not_implemented() { - let mut codec = Codec::Fixed(16); + let codec = Codec::Fixed(16); let process_decimal = || -> Result<(), ArrowError> { if let Codec::Fixed(_) = codec { @@ -616,13 +2137,12 @@ mod tests { panic!("Expected NotYetImplemented error"); } } - #[test] fn test_unknown_logical_type_added_to_metadata() { let schema = create_schema_with_logical_type(PrimitiveType::Int, "custom-type"); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert_eq!( result.metadata.get("logicalType"), @@ -634,8 +2154,8 @@ mod tests { fn test_string_with_utf8view_enabled() { let schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::String)); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, true).unwrap(); + let mut maker = Maker::new(true, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::Utf8View)); } @@ -644,8 +2164,8 @@ mod tests { fn test_string_without_utf8view_enabled() { let schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::String)); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, false).unwrap(); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); assert!(matches!(result.codec, Codec::Utf8)); } @@ -659,6 +2179,7 @@ mod tests { r#type: field_schema, default: None, doc: None, + aliases: vec![], }; let record = Record { @@ -672,8 +2193,8 @@ mod tests { let schema = Schema::Complex(ComplexType::Record(record)); - let mut resolver = Resolver::default(); - let result = make_data_type(&schema, None, &mut resolver, true).unwrap(); + let mut maker = Maker::new(true, false); + let result = maker.make_data_type(&schema, None, None).unwrap(); if let Codec::Struct(fields) = &result.codec { let first_field_codec = &fields[0].data_type().codec; @@ -682,4 +2203,957 @@ mod tests { panic!("Expected Struct codec"); } } + + #[test] + fn test_union_with_strict_mode() { + let schema = Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + ]); + + let mut maker = Maker::new(false, true); + let result = maker.make_data_type(&schema, None, None); + + assert!(result.is_err()); + match result { + Err(ArrowError::SchemaError(msg)) => { + assert!(msg.contains( + "Found Avro union of the form ['T','null'], which is disallowed in strict_mode" + )); + } + _ => panic!("Expected SchemaError"), + } + } + + #[test] + fn test_resolve_int_to_float_promotion() { + let result = resolve_promotion(PrimitiveType::Int, PrimitiveType::Float); + assert!(matches!(result.codec, Codec::Float32)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::IntToFloat)) + ); + } + + #[test] + fn test_resolve_int_to_double_promotion() { + let result = resolve_promotion(PrimitiveType::Int, PrimitiveType::Double); + assert!(matches!(result.codec, Codec::Float64)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::IntToDouble)) + ); + } + + #[test] + fn test_resolve_long_to_float_promotion() { + let result = resolve_promotion(PrimitiveType::Long, PrimitiveType::Float); + assert!(matches!(result.codec, Codec::Float32)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::LongToFloat)) + ); + } + + #[test] + fn test_resolve_long_to_double_promotion() { + let result = resolve_promotion(PrimitiveType::Long, PrimitiveType::Double); + assert!(matches!(result.codec, Codec::Float64)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::LongToDouble)) + ); + } + + #[test] + fn test_resolve_float_to_double_promotion() { + let result = resolve_promotion(PrimitiveType::Float, PrimitiveType::Double); + assert!(matches!(result.codec, Codec::Float64)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::FloatToDouble)) + ); + } + + #[test] + fn test_resolve_string_to_bytes_promotion() { + let result = resolve_promotion(PrimitiveType::String, PrimitiveType::Bytes); + assert!(matches!(result.codec, Codec::Binary)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::StringToBytes)) + ); + } + + #[test] + fn test_resolve_bytes_to_string_promotion() { + let result = resolve_promotion(PrimitiveType::Bytes, PrimitiveType::String); + assert!(matches!(result.codec, Codec::Utf8)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::BytesToString)) + ); + } + + #[test] + fn test_resolve_illegal_promotion_double_to_float_errors() { + let writer_schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::Double)); + let reader_schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::Float)); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&writer_schema, Some(&reader_schema), None); + assert!(result.is_err()); + match result { + Err(ArrowError::ParseError(msg)) => { + assert!(msg.contains("Illegal promotion")); + } + _ => panic!("Expected ParseError for illegal promotion Double -> Float"), + } + } + + #[test] + fn test_promotion_within_nullable_union_keeps_writer_null_ordering() { + let writer = Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + ]); + let reader = Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Double)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + ]); + let mut maker = Maker::new(false, false); + let result = maker.make_data_type(&writer, Some(&reader), None).unwrap(); + assert!(matches!(result.codec, Codec::Float64)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::IntToDouble)) + ); + assert_eq!(result.nullability, Some(Nullability::NullFirst)); + } + + #[test] + fn test_resolve_writer_union_to_reader_non_union_partial_coverage() { + let writer = mk_union(vec![ + mk_primitive(PrimitiveType::String), + mk_primitive(PrimitiveType::Long), + ]); + let reader = mk_primitive(PrimitiveType::Bytes); + let mut maker = Maker::new(false, false); + let dt = maker.make_data_type(&writer, Some(&reader), None).unwrap(); + assert!(matches!(dt.codec(), Codec::Binary)); + let resolved = match dt.resolution { + Some(ResolutionInfo::Union(u)) => u, + other => panic!("expected union resolution info, got {other:?}"), + }; + assert!(resolved.writer_is_union && !resolved.reader_is_union); + assert_eq!( + resolved.writer_to_reader.as_ref(), + &[Some((0, Promotion::StringToBytes)), None] + ); + } + + #[test] + fn test_resolve_writer_non_union_to_reader_union_prefers_direct_over_promotion() { + let writer = mk_primitive(PrimitiveType::Long); + let reader = mk_union(vec![ + mk_primitive(PrimitiveType::Long), + mk_primitive(PrimitiveType::Double), + ]); + let mut maker = Maker::new(false, false); + let dt = maker.make_data_type(&writer, Some(&reader), None).unwrap(); + let resolved = match dt.resolution { + Some(ResolutionInfo::Union(u)) => u, + other => panic!("expected union resolution info, got {other:?}"), + }; + assert!(!resolved.writer_is_union && resolved.reader_is_union); + assert_eq!( + resolved.writer_to_reader.as_ref(), + &[Some((0, Promotion::Direct))] + ); + } + + #[test] + fn test_resolve_writer_non_union_to_reader_union_uses_promotion_when_needed() { + let writer = mk_primitive(PrimitiveType::Int); + let reader = mk_union(vec![ + mk_primitive(PrimitiveType::Null), + mk_primitive(PrimitiveType::Long), + mk_primitive(PrimitiveType::String), + ]); + let mut maker = Maker::new(false, false); + let dt = maker.make_data_type(&writer, Some(&reader), None).unwrap(); + let resolved = match dt.resolution { + Some(ResolutionInfo::Union(u)) => u, + other => panic!("expected union resolution info, got {other:?}"), + }; + assert_eq!( + resolved.writer_to_reader.as_ref(), + &[Some((1, Promotion::IntToLong))] + ); + } + + #[test] + fn test_resolve_both_nullable_unions_direct_match() { + let writer = mk_union(vec![ + mk_primitive(PrimitiveType::Null), + mk_primitive(PrimitiveType::String), + ]); + let reader = mk_union(vec![ + mk_primitive(PrimitiveType::String), + mk_primitive(PrimitiveType::Null), + ]); + let mut maker = Maker::new(false, false); + let dt = maker.make_data_type(&writer, Some(&reader), None).unwrap(); + assert!(matches!(dt.codec(), Codec::Utf8)); + assert_eq!(dt.nullability, Some(Nullability::NullFirst)); + assert!(dt.resolution.is_none()); + } + + #[test] + fn test_resolve_both_nullable_unions_with_promotion() { + let writer = mk_union(vec![ + mk_primitive(PrimitiveType::Null), + mk_primitive(PrimitiveType::Int), + ]); + let reader = mk_union(vec![ + mk_primitive(PrimitiveType::Double), + mk_primitive(PrimitiveType::Null), + ]); + let mut maker = Maker::new(false, false); + let dt = maker.make_data_type(&writer, Some(&reader), None).unwrap(); + assert!(matches!(dt.codec(), Codec::Float64)); + assert_eq!(dt.nullability, Some(Nullability::NullFirst)); + assert_eq!( + dt.resolution, + Some(ResolutionInfo::Promotion(Promotion::IntToDouble)) + ); + } + + #[test] + fn test_resolve_type_promotion() { + let writer_schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)); + let reader_schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)); + let mut maker = Maker::new(false, false); + let result = maker + .make_data_type(&writer_schema, Some(&reader_schema), None) + .unwrap(); + assert!(matches!(result.codec, Codec::Int64)); + assert_eq!( + result.resolution, + Some(ResolutionInfo::Promotion(Promotion::IntToLong)) + ); + } + + #[test] + fn test_nested_record_type_reuse_without_namespace() { + let schema_str = r#" + { + "type": "record", + "name": "Record", + "fields": [ + { + "name": "nested", + "type": { + "type": "record", + "name": "Nested", + "fields": [ + { "name": "nested_int", "type": "int" } + ] + } + }, + { "name": "nestedRecord", "type": "Nested" }, + { "name": "nestedArray", "type": { "type": "array", "items": "Nested" } }, + { "name": "nestedMap", "type": { "type": "map", "values": "Nested" } } + ] + } + "#; + + let schema: Schema = serde_json::from_str(schema_str).unwrap(); + + let mut maker = Maker::new(false, false); + let avro_data_type = maker.make_data_type(&schema, None, None).unwrap(); + + if let Codec::Struct(fields) = avro_data_type.codec() { + assert_eq!(fields.len(), 4); + + // nested + assert_eq!(fields[0].name(), "nested"); + let nested_data_type = fields[0].data_type(); + if let Codec::Struct(nested_fields) = nested_data_type.codec() { + assert_eq!(nested_fields.len(), 1); + assert_eq!(nested_fields[0].name(), "nested_int"); + assert!(matches!(nested_fields[0].data_type().codec(), Codec::Int32)); + } else { + panic!( + "'nested' field is not a struct but {:?}", + nested_data_type.codec() + ); + } + + // nestedRecord + assert_eq!(fields[1].name(), "nestedRecord"); + let nested_record_data_type = fields[1].data_type(); + assert_eq!( + nested_record_data_type.codec().data_type(), + nested_data_type.codec().data_type() + ); + + // nestedArray + assert_eq!(fields[2].name(), "nestedArray"); + if let Codec::List(item_type) = fields[2].data_type().codec() { + assert_eq!( + item_type.codec().data_type(), + nested_data_type.codec().data_type() + ); + } else { + panic!("'nestedArray' field is not a list"); + } + + // nestedMap + assert_eq!(fields[3].name(), "nestedMap"); + if let Codec::Map(value_type) = fields[3].data_type().codec() { + assert_eq!( + value_type.codec().data_type(), + nested_data_type.codec().data_type() + ); + } else { + panic!("'nestedMap' field is not a map"); + } + } else { + panic!("Top-level schema is not a struct"); + } + } + + #[test] + fn test_nested_enum_type_reuse_with_namespace() { + let schema_str = r#" + { + "type": "record", + "name": "Record", + "namespace": "record_ns", + "fields": [ + { + "name": "status", + "type": { + "type": "enum", + "name": "Status", + "namespace": "enum_ns", + "symbols": ["ACTIVE", "INACTIVE", "PENDING"] + } + }, + { "name": "backupStatus", "type": "enum_ns.Status" }, + { "name": "statusHistory", "type": { "type": "array", "items": "enum_ns.Status" } }, + { "name": "statusMap", "type": { "type": "map", "values": "enum_ns.Status" } } + ] + } + "#; + + let schema: Schema = serde_json::from_str(schema_str).unwrap(); + + let mut maker = Maker::new(false, false); + let avro_data_type = maker.make_data_type(&schema, None, None).unwrap(); + + if let Codec::Struct(fields) = avro_data_type.codec() { + assert_eq!(fields.len(), 4); + + // status + assert_eq!(fields[0].name(), "status"); + let status_data_type = fields[0].data_type(); + if let Codec::Enum(symbols) = status_data_type.codec() { + assert_eq!(symbols.as_ref(), &["ACTIVE", "INACTIVE", "PENDING"]); + } else { + panic!( + "'status' field is not an enum but {:?}", + status_data_type.codec() + ); + } + + // backupStatus + assert_eq!(fields[1].name(), "backupStatus"); + let backup_status_data_type = fields[1].data_type(); + assert_eq!( + backup_status_data_type.codec().data_type(), + status_data_type.codec().data_type() + ); + + // statusHistory + assert_eq!(fields[2].name(), "statusHistory"); + if let Codec::List(item_type) = fields[2].data_type().codec() { + assert_eq!( + item_type.codec().data_type(), + status_data_type.codec().data_type() + ); + } else { + panic!("'statusHistory' field is not a list"); + } + + // statusMap + assert_eq!(fields[3].name(), "statusMap"); + if let Codec::Map(value_type) = fields[3].data_type().codec() { + assert_eq!( + value_type.codec().data_type(), + status_data_type.codec().data_type() + ); + } else { + panic!("'statusMap' field is not a map"); + } + } else { + panic!("Top-level schema is not a struct"); + } + } + + #[test] + fn test_resolve_from_writer_and_reader_defaults_root_name_for_non_record_reader() { + let writer_schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::String)); + let reader_schema = Schema::TypeName(TypeName::Primitive(PrimitiveType::String)); + let mut maker = Maker::new(false, false); + let data_type = maker + .make_data_type(&writer_schema, Some(&reader_schema), None) + .expect("resolution should succeed"); + let field = AvroField { + name: AVRO_ROOT_RECORD_DEFAULT_NAME.to_string(), + data_type, + }; + assert_eq!(field.name(), AVRO_ROOT_RECORD_DEFAULT_NAME); + assert!(matches!(field.data_type().codec(), Codec::Utf8)); + } + + fn json_string(s: &str) -> Value { + Value::String(s.to_string()) + } + + fn assert_default_stored(dt: &AvroDataType, default_json: &Value) { + let stored = dt + .metadata + .get(AVRO_FIELD_DEFAULT_METADATA_KEY) + .cloned() + .unwrap_or_default(); + let expected = serde_json::to_string(default_json).unwrap(); + assert_eq!(stored, expected, "stored default metadata should match"); + } + + #[test] + fn test_validate_and_store_default_null_and_nullability_rules() { + let mut dt_null = AvroDataType::new(Codec::Null, HashMap::new(), None); + let lit = dt_null.parse_and_store_default(&Value::Null).unwrap(); + assert_eq!(lit, AvroLiteral::Null); + assert_default_stored(&dt_null, &Value::Null); + let mut dt_int = AvroDataType::new(Codec::Int32, HashMap::new(), None); + let err = dt_int.parse_and_store_default(&Value::Null).unwrap_err(); + assert!( + err.to_string() + .contains("JSON null default is only valid for `null` type"), + "unexpected error: {err}" + ); + let mut dt_int_nf = + AvroDataType::new(Codec::Int32, HashMap::new(), Some(Nullability::NullFirst)); + let lit2 = dt_int_nf.parse_and_store_default(&Value::Null).unwrap(); + assert_eq!(lit2, AvroLiteral::Null); + assert_default_stored(&dt_int_nf, &Value::Null); + let mut dt_int_ns = + AvroDataType::new(Codec::Int32, HashMap::new(), Some(Nullability::NullSecond)); + let err2 = dt_int_ns.parse_and_store_default(&Value::Null).unwrap_err(); + assert!( + err2.to_string() + .contains("JSON null default is only valid for `null` type"), + "unexpected error: {err2}" + ); + } + + #[test] + fn test_validate_and_store_default_primitives_and_temporal() { + let mut dt_bool = AvroDataType::new(Codec::Boolean, HashMap::new(), None); + let lit = dt_bool.parse_and_store_default(&Value::Bool(true)).unwrap(); + assert_eq!(lit, AvroLiteral::Boolean(true)); + assert_default_stored(&dt_bool, &Value::Bool(true)); + let mut dt_i32 = AvroDataType::new(Codec::Int32, HashMap::new(), None); + let lit = dt_i32 + .parse_and_store_default(&serde_json::json!(123)) + .unwrap(); + assert_eq!(lit, AvroLiteral::Int(123)); + assert_default_stored(&dt_i32, &serde_json::json!(123)); + let err = dt_i32 + .parse_and_store_default(&serde_json::json!(i64::from(i32::MAX) + 1)) + .unwrap_err(); + assert!(format!("{err}").contains("out of i32 range")); + let mut dt_i64 = AvroDataType::new(Codec::Int64, HashMap::new(), None); + let lit = dt_i64 + .parse_and_store_default(&serde_json::json!(1234567890)) + .unwrap(); + assert_eq!(lit, AvroLiteral::Long(1234567890)); + assert_default_stored(&dt_i64, &serde_json::json!(1234567890)); + let mut dt_f32 = AvroDataType::new(Codec::Float32, HashMap::new(), None); + let lit = dt_f32 + .parse_and_store_default(&serde_json::json!(1.25)) + .unwrap(); + assert_eq!(lit, AvroLiteral::Float(1.25)); + assert_default_stored(&dt_f32, &serde_json::json!(1.25)); + let err = dt_f32 + .parse_and_store_default(&serde_json::json!(1e39)) + .unwrap_err(); + assert!(format!("{err}").contains("out of f32 range")); + let mut dt_f64 = AvroDataType::new(Codec::Float64, HashMap::new(), None); + let lit = dt_f64 + .parse_and_store_default(&serde_json::json!(std::f64::consts::PI)) + .unwrap(); + assert_eq!(lit, AvroLiteral::Double(std::f64::consts::PI)); + assert_default_stored(&dt_f64, &serde_json::json!(std::f64::consts::PI)); + let mut dt_str = AvroDataType::new(Codec::Utf8, HashMap::new(), None); + let l = dt_str + .parse_and_store_default(&json_string("hello")) + .unwrap(); + assert_eq!(l, AvroLiteral::String("hello".into())); + assert_default_stored(&dt_str, &json_string("hello")); + let mut dt_strv = AvroDataType::new(Codec::Utf8View, HashMap::new(), None); + let l = dt_strv + .parse_and_store_default(&json_string("view")) + .unwrap(); + assert_eq!(l, AvroLiteral::String("view".into())); + assert_default_stored(&dt_strv, &json_string("view")); + let mut dt_uuid = AvroDataType::new(Codec::Uuid, HashMap::new(), None); + let l = dt_uuid + .parse_and_store_default(&json_string("00000000-0000-0000-0000-000000000000")) + .unwrap(); + assert_eq!( + l, + AvroLiteral::String("00000000-0000-0000-0000-000000000000".into()) + ); + let mut dt_bin = AvroDataType::new(Codec::Binary, HashMap::new(), None); + let l = dt_bin.parse_and_store_default(&json_string("ABC")).unwrap(); + assert_eq!(l, AvroLiteral::Bytes(vec![65, 66, 67])); + let err = dt_bin + .parse_and_store_default(&json_string("€")) // U+20AC + .unwrap_err(); + assert!(format!("{err}").contains("Invalid codepoint")); + let mut dt_date = AvroDataType::new(Codec::Date32, HashMap::new(), None); + let ld = dt_date + .parse_and_store_default(&serde_json::json!(1)) + .unwrap(); + assert_eq!(ld, AvroLiteral::Int(1)); + let mut dt_tmill = AvroDataType::new(Codec::TimeMillis, HashMap::new(), None); + let lt = dt_tmill + .parse_and_store_default(&serde_json::json!(86_400_000)) + .unwrap(); + assert_eq!(lt, AvroLiteral::Int(86_400_000)); + let mut dt_tmicros = AvroDataType::new(Codec::TimeMicros, HashMap::new(), None); + let ltm = dt_tmicros + .parse_and_store_default(&serde_json::json!(1_000_000)) + .unwrap(); + assert_eq!(ltm, AvroLiteral::Long(1_000_000)); + let mut dt_ts_milli = AvroDataType::new(Codec::TimestampMillis(true), HashMap::new(), None); + let l1 = dt_ts_milli + .parse_and_store_default(&serde_json::json!(123)) + .unwrap(); + assert_eq!(l1, AvroLiteral::Long(123)); + let mut dt_ts_micro = + AvroDataType::new(Codec::TimestampMicros(false), HashMap::new(), None); + let l2 = dt_ts_micro + .parse_and_store_default(&serde_json::json!(456)) + .unwrap(); + assert_eq!(l2, AvroLiteral::Long(456)); + } + + #[test] + fn test_validate_and_store_default_fixed_decimal_interval() { + let mut dt_fixed = AvroDataType::new(Codec::Fixed(4), HashMap::new(), None); + let l = dt_fixed + .parse_and_store_default(&json_string("WXYZ")) + .unwrap(); + assert_eq!(l, AvroLiteral::Bytes(vec![87, 88, 89, 90])); + let err = dt_fixed + .parse_and_store_default(&json_string("TOO LONG")) + .unwrap_err(); + assert!(err.to_string().contains("Default length")); + let mut dt_dec_fixed = + AvroDataType::new(Codec::Decimal(10, Some(2), Some(3)), HashMap::new(), None); + let l = dt_dec_fixed + .parse_and_store_default(&json_string("abc")) + .unwrap(); + assert_eq!(l, AvroLiteral::Bytes(vec![97, 98, 99])); + let err = dt_dec_fixed + .parse_and_store_default(&json_string("toolong")) + .unwrap_err(); + assert!(err.to_string().contains("Default length")); + let mut dt_dec_bytes = + AvroDataType::new(Codec::Decimal(10, Some(2), None), HashMap::new(), None); + let l = dt_dec_bytes + .parse_and_store_default(&json_string("freeform")) + .unwrap(); + assert_eq!( + l, + AvroLiteral::Bytes("freeform".bytes().collect::>()) + ); + let mut dt_interval = AvroDataType::new(Codec::Interval, HashMap::new(), None); + let l = dt_interval + .parse_and_store_default(&json_string("ABCDEFGHIJKL")) + .unwrap(); + assert_eq!( + l, + AvroLiteral::Bytes("ABCDEFGHIJKL".bytes().collect::>()) + ); + let err = dt_interval + .parse_and_store_default(&json_string("short")) + .unwrap_err(); + assert!(err.to_string().contains("Default length")); + } + + #[test] + fn test_validate_and_store_default_enum_list_map_struct() { + let symbols: Arc<[String]> = ["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()] + .into_iter() + .collect(); + let mut dt_enum = AvroDataType::new(Codec::Enum(symbols), HashMap::new(), None); + let l = dt_enum + .parse_and_store_default(&json_string("GREEN")) + .unwrap(); + assert_eq!(l, AvroLiteral::Enum("GREEN".into())); + let err = dt_enum + .parse_and_store_default(&json_string("YELLOW")) + .unwrap_err(); + assert!(err.to_string().contains("Default enum symbol")); + let item = AvroDataType::new(Codec::Int64, HashMap::new(), None); + let mut dt_list = AvroDataType::new(Codec::List(Arc::new(item)), HashMap::new(), None); + let val = serde_json::json!([1, 2, 3]); + let l = dt_list.parse_and_store_default(&val).unwrap(); + assert_eq!( + l, + AvroLiteral::Array(vec![ + AvroLiteral::Long(1), + AvroLiteral::Long(2), + AvroLiteral::Long(3) + ]) + ); + let err = dt_list + .parse_and_store_default(&serde_json::json!({"not":"array"})) + .unwrap_err(); + assert!(err.to_string().contains("JSON array")); + let val_dt = AvroDataType::new(Codec::Float64, HashMap::new(), None); + let mut dt_map = AvroDataType::new(Codec::Map(Arc::new(val_dt)), HashMap::new(), None); + let mv = serde_json::json!({"x": 1.5, "y": 2.5}); + let l = dt_map.parse_and_store_default(&mv).unwrap(); + let mut expected = IndexMap::new(); + expected.insert("x".into(), AvroLiteral::Double(1.5)); + expected.insert("y".into(), AvroLiteral::Double(2.5)); + assert_eq!(l, AvroLiteral::Map(expected)); + // Not object -> error + let err = dt_map + .parse_and_store_default(&serde_json::json!(123)) + .unwrap_err(); + assert!(err.to_string().contains("JSON object")); + let mut field_a = AvroField { + name: "a".into(), + data_type: AvroDataType::new(Codec::Int32, HashMap::new(), None), + }; + let field_b = AvroField { + name: "b".into(), + data_type: AvroDataType::new( + Codec::Int64, + HashMap::new(), + Some(Nullability::NullFirst), + ), + }; + let mut c_md = HashMap::new(); + c_md.insert(AVRO_FIELD_DEFAULT_METADATA_KEY.into(), "\"xyz\"".into()); + let field_c = AvroField { + name: "c".into(), + data_type: AvroDataType::new(Codec::Utf8, c_md, None), + }; + field_a.data_type.metadata.insert("doc".into(), "na".into()); + let struct_fields: Arc<[AvroField]> = Arc::from(vec![field_a, field_b, field_c]); + let mut dt_struct = AvroDataType::new(Codec::Struct(struct_fields), HashMap::new(), None); + let default_obj = serde_json::json!({"a": 7}); + let l = dt_struct.parse_and_store_default(&default_obj).unwrap(); + let mut expected = IndexMap::new(); + expected.insert("a".into(), AvroLiteral::Int(7)); + expected.insert("b".into(), AvroLiteral::Null); + expected.insert("c".into(), AvroLiteral::String("xyz".into())); + assert_eq!(l, AvroLiteral::Map(expected)); + assert_default_stored(&dt_struct, &default_obj); + let req_field = AvroField { + name: "req".into(), + data_type: AvroDataType::new(Codec::Boolean, HashMap::new(), None), + }; + let mut dt_bad = AvroDataType::new( + Codec::Struct(Arc::from(vec![req_field])), + HashMap::new(), + None, + ); + let err = dt_bad + .parse_and_store_default(&serde_json::json!({})) + .unwrap_err(); + assert!( + err.to_string().contains("missing required subfield 'req'"), + "unexpected error: {err}" + ); + let err = dt_struct + .parse_and_store_default(&serde_json::json!(10)) + .unwrap_err(); + err.to_string().contains("must be a JSON object"); + } + + #[test] + fn test_resolve_array_promotion_and_reader_metadata() { + let mut w_add: HashMap<&str, Value> = HashMap::new(); + w_add.insert("who", json_string("writer")); + let mut r_add: HashMap<&str, Value> = HashMap::new(); + r_add.insert("who", json_string("reader")); + let writer_schema = Schema::Complex(ComplexType::Array(Array { + items: Box::new(Schema::TypeName(TypeName::Primitive(PrimitiveType::Int))), + attributes: Attributes { + logical_type: None, + additional: w_add, + }, + })); + let reader_schema = Schema::Complex(ComplexType::Array(Array { + items: Box::new(Schema::TypeName(TypeName::Primitive(PrimitiveType::Long))), + attributes: Attributes { + logical_type: None, + additional: r_add, + }, + })); + let mut maker = Maker::new(false, false); + let dt = maker + .make_data_type(&writer_schema, Some(&reader_schema), None) + .unwrap(); + assert_eq!(dt.metadata.get("who"), Some(&"\"reader\"".to_string())); + if let Codec::List(inner) = dt.codec() { + assert!(matches!(inner.codec(), Codec::Int64)); + assert_eq!( + inner.resolution, + Some(ResolutionInfo::Promotion(Promotion::IntToLong)) + ); + } else { + panic!("expected list codec"); + } + } + + #[test] + fn test_resolve_fixed_success_name_and_size_match_and_alias() { + let writer_schema = Schema::Complex(ComplexType::Fixed(Fixed { + name: "MD5", + namespace: None, + aliases: vec!["Hash16"], + size: 16, + attributes: Attributes::default(), + })); + let reader_schema = Schema::Complex(ComplexType::Fixed(Fixed { + name: "Hash16", + namespace: None, + aliases: vec![], + size: 16, + attributes: Attributes::default(), + })); + let mut maker = Maker::new(false, false); + let dt = maker + .make_data_type(&writer_schema, Some(&reader_schema), None) + .unwrap(); + assert!(matches!(dt.codec(), Codec::Fixed(16))); + } + + #[test] + fn test_resolve_records_mapping_default_fields_and_skip_fields() { + let writer = Schema::Complex(ComplexType::Record(Record { + name: "R", + namespace: None, + doc: None, + aliases: vec![], + fields: vec![ + crate::schema::Field { + name: "a", + doc: None, + r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + default: None, + aliases: vec![], + }, + crate::schema::Field { + name: "skipme", + doc: None, + r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), + default: None, + aliases: vec![], + }, + crate::schema::Field { + name: "b", + doc: None, + r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)), + default: None, + aliases: vec![], + }, + ], + attributes: Attributes::default(), + })); + let reader = Schema::Complex(ComplexType::Record(Record { + name: "R", + namespace: None, + doc: None, + aliases: vec![], + fields: vec![ + crate::schema::Field { + name: "b", + doc: None, + r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)), + default: None, + aliases: vec![], + }, + crate::schema::Field { + name: "a", + doc: None, + r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)), + default: None, + aliases: vec![], + }, + crate::schema::Field { + name: "name", + doc: None, + r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), + default: Some(json_string("anon")), + aliases: vec![], + }, + crate::schema::Field { + name: "opt", + doc: None, + r#type: Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + ]), + default: None, // should default to null because NullFirst + aliases: vec![], + }, + ], + attributes: Attributes::default(), + })); + let mut maker = Maker::new(false, false); + let dt = maker + .make_data_type(&writer, Some(&reader), None) + .expect("record resolution"); + let fields = match dt.codec() { + Codec::Struct(f) => f, + other => panic!("expected struct, got {other:?}"), + }; + assert_eq!(fields.len(), 4); + assert_eq!(fields[0].name(), "b"); + assert_eq!(fields[1].name(), "a"); + assert_eq!(fields[2].name(), "name"); + assert_eq!(fields[3].name(), "opt"); + assert!(matches!( + fields[1].data_type().resolution, + Some(ResolutionInfo::Promotion(Promotion::IntToLong)) + )); + let rec = match dt.resolution { + Some(ResolutionInfo::Record(ref r)) => r.clone(), + other => panic!("expected record resolution, got {other:?}"), + }; + assert_eq!(rec.writer_to_reader.as_ref(), &[Some(1), None, Some(0)]); + assert_eq!(rec.default_fields.as_ref(), &[2usize, 3usize]); + assert!(rec.skip_fields[0].is_none()); + assert!(rec.skip_fields[2].is_none()); + let skip1 = rec.skip_fields[1].as_ref().expect("skip field present"); + assert!(matches!(skip1.codec(), Codec::Utf8)); + let name_md = &fields[2].data_type().metadata; + assert_eq!( + name_md.get(AVRO_FIELD_DEFAULT_METADATA_KEY), + Some(&"\"anon\"".to_string()) + ); + let opt_md = &fields[3].data_type().metadata; + assert_eq!( + opt_md.get(AVRO_FIELD_DEFAULT_METADATA_KEY), + Some(&"null".to_string()) + ); + } + + #[test] + fn test_named_type_alias_resolution_record_cross_namespace() { + let writer_record = Record { + name: "PersonV2", + namespace: Some("com.example.v2"), + doc: None, + aliases: vec!["com.example.Person"], + fields: vec![ + AvroFieldSchema { + name: "name", + doc: None, + r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), + default: None, + aliases: vec![], + }, + AvroFieldSchema { + name: "age", + doc: None, + r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + default: None, + aliases: vec![], + }, + ], + attributes: Attributes::default(), + }; + let reader_record = Record { + name: "Person", + namespace: Some("com.example"), + doc: None, + aliases: vec![], + fields: writer_record.fields.clone(), + attributes: Attributes::default(), + }; + let writer_schema = Schema::Complex(ComplexType::Record(writer_record)); + let reader_schema = Schema::Complex(ComplexType::Record(reader_record)); + let mut maker = Maker::new(false, false); + let result = maker + .make_data_type(&writer_schema, Some(&reader_schema), None) + .expect("record alias resolution should succeed"); + match result.codec { + Codec::Struct(ref fields) => assert_eq!(fields.len(), 2), + other => panic!("expected struct, got {other:?}"), + } + } + + #[test] + fn test_named_type_alias_resolution_enum_cross_namespace() { + let writer_enum = Enum { + name: "ColorV2", + namespace: Some("org.example.v2"), + doc: None, + aliases: vec!["org.example.Color"], + symbols: vec!["RED", "GREEN", "BLUE"], + default: None, + attributes: Attributes::default(), + }; + let reader_enum = Enum { + name: "Color", + namespace: Some("org.example"), + doc: None, + aliases: vec![], + symbols: vec!["RED", "GREEN", "BLUE"], + default: None, + attributes: Attributes::default(), + }; + let writer_schema = Schema::Complex(ComplexType::Enum(writer_enum)); + let reader_schema = Schema::Complex(ComplexType::Enum(reader_enum)); + let mut maker = Maker::new(false, false); + maker + .make_data_type(&writer_schema, Some(&reader_schema), None) + .expect("enum alias resolution should succeed"); + } + + #[test] + fn test_named_type_alias_resolution_fixed_cross_namespace() { + let writer_fixed = Fixed { + name: "Fx10V2", + namespace: Some("ns.v2"), + aliases: vec!["ns.Fx10"], + size: 10, + attributes: Attributes::default(), + }; + let reader_fixed = Fixed { + name: "Fx10", + namespace: Some("ns"), + aliases: vec![], + size: 10, + attributes: Attributes::default(), + }; + let writer_schema = Schema::Complex(ComplexType::Fixed(writer_fixed)); + let reader_schema = Schema::Complex(ComplexType::Fixed(reader_fixed)); + let mut maker = Maker::new(false, false); + maker + .make_data_type(&writer_schema, Some(&reader_schema), None) + .expect("fixed alias resolution should succeed"); + } } diff --git a/arrow-avro/src/compression.rs b/arrow-avro/src/compression.rs index 69aee634977a..0cb2878a132d 100644 --- a/arrow-avro/src/compression.rs +++ b/arrow-avro/src/compression.rs @@ -16,8 +16,13 @@ // under the License. use arrow_schema::ArrowError; -use std::io; -use std::io::Read; +#[cfg(any( + feature = "deflate", + feature = "zstd", + feature = "bzip2", + feature = "xz" +))] +use std::io::{Read, Write}; /// The metadata key used for storing the JSON encoded [`CompressionCodec`] pub const CODEC_METADATA_KEY: &str = "avro.codec"; @@ -34,9 +39,14 @@ pub enum CompressionCodec { Snappy, /// ZStandard compression ZStandard, + /// Bzip2 compression + Bzip2, + /// Xz compression + Xz, } impl CompressionCodec { + #[allow(unused_variables)] pub(crate) fn decompress(&self, block: &[u8]) -> Result, ArrowError> { match self { #[cfg(feature = "deflate")] @@ -84,6 +94,102 @@ impl CompressionCodec { CompressionCodec::ZStandard => Err(ArrowError::ParseError( "ZStandard codec requires zstd feature".to_string(), )), + #[cfg(feature = "bzip2")] + CompressionCodec::Bzip2 => { + let mut decoder = bzip2::read::BzDecoder::new(block); + let mut out = Vec::new(); + decoder.read_to_end(&mut out)?; + Ok(out) + } + #[cfg(not(feature = "bzip2"))] + CompressionCodec::Bzip2 => Err(ArrowError::ParseError( + "Bzip2 codec requires bzip2 feature".to_string(), + )), + #[cfg(feature = "xz")] + CompressionCodec::Xz => { + let mut decoder = xz::read::XzDecoder::new(block); + let mut out = Vec::new(); + decoder.read_to_end(&mut out)?; + Ok(out) + } + #[cfg(not(feature = "xz"))] + CompressionCodec::Xz => Err(ArrowError::ParseError( + "XZ codec requires xz feature".to_string(), + )), + } + } + + #[allow(unused_variables)] + pub(crate) fn compress(&self, data: &[u8]) -> Result, ArrowError> { + match self { + #[cfg(feature = "deflate")] + CompressionCodec::Deflate => { + let mut encoder = + flate2::write::DeflateEncoder::new(Vec::new(), flate2::Compression::default()); + encoder.write_all(data)?; + let compressed = encoder.finish()?; + Ok(compressed) + } + #[cfg(not(feature = "deflate"))] + CompressionCodec::Deflate => Err(ArrowError::ParseError( + "Deflate codec requires deflate feature".to_string(), + )), + + #[cfg(feature = "snappy")] + CompressionCodec::Snappy => { + let mut encoder = snap::raw::Encoder::new(); + // Allocate and compress in one step for efficiency + let mut compressed = encoder + .compress_vec(data) + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + // Compute CRC32 (ISO‑HDLC poly) of **uncompressed** data + let crc_val = crc::Crc::::new(&crc::CRC_32_ISO_HDLC).checksum(data); + compressed.extend_from_slice(&crc_val.to_be_bytes()); + Ok(compressed) + } + #[cfg(not(feature = "snappy"))] + CompressionCodec::Snappy => Err(ArrowError::ParseError( + "Snappy codec requires snappy feature".to_string(), + )), + + #[cfg(feature = "zstd")] + CompressionCodec::ZStandard => { + let mut encoder = zstd::Encoder::new(Vec::new(), 0) + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + encoder.write_all(data)?; + let compressed = encoder + .finish() + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + Ok(compressed) + } + #[cfg(not(feature = "zstd"))] + CompressionCodec::ZStandard => Err(ArrowError::ParseError( + "ZStandard codec requires zstd feature".to_string(), + )), + + #[cfg(feature = "bzip2")] + CompressionCodec::Bzip2 => { + let mut encoder = + bzip2::write::BzEncoder::new(Vec::new(), bzip2::Compression::default()); + encoder.write_all(data)?; + let compressed = encoder.finish()?; + Ok(compressed) + } + #[cfg(not(feature = "bzip2"))] + CompressionCodec::Bzip2 => Err(ArrowError::ParseError( + "Bzip2 codec requires bzip2 feature".to_string(), + )), + #[cfg(feature = "xz")] + CompressionCodec::Xz => { + let mut encoder = xz::write::XzEncoder::new(Vec::new(), 6); + encoder.write_all(data)?; + let compressed = encoder.finish()?; + Ok(compressed) + } + #[cfg(not(feature = "xz"))] + CompressionCodec::Xz => Err(ArrowError::ParseError( + "XZ codec requires xz feature".to_string(), + )), } } } diff --git a/arrow-avro/src/lib.rs b/arrow-avro/src/lib.rs index e413e0aa9173..032ad683ff77 100644 --- a/arrow-avro/src/lib.rs +++ b/arrow-avro/src/lib.rs @@ -15,28 +15,173 @@ // specific language governing permissions and limitations // under the License. -//! Convert data to / from the [Apache Arrow] memory format and [Apache Avro] +//! Convert data to / from the [Apache Arrow] memory format and [Apache Avro]. //! -//! [Apache Arrow]: https://arrow.apache.org +//! This crate provides: +//! - a [`reader`] that decodes Avro (Object Container Files, Avro Single‑Object encoding, +//! and Confluent Schema Registry wire format) into Arrow `RecordBatch`es, +//! - and a [`writer`] that encodes Arrow `RecordBatch`es into Avro (OCF or SOE). +//! +//! If you’re new to Arrow or Avro, see: +//! - Arrow project site: +//! - Avro 1.11.1 specification: +//! +//! ## Example: OCF (Object Container File) round‑trip *(runnable)* +//! +//! The example below creates an Arrow table, writes an **Avro OCF** fully in memory, +//! and then reads it back. OCF is a self‑describing file format that embeds the Avro +//! schema in a header with optional compression and block sync markers. +//! Spec: +//! +//! ``` +//! use std::io::Cursor; +//! use std::sync::Arc; +//! use arrow_array::{ArrayRef, Int32Array, RecordBatch}; +//! use arrow_schema::{DataType, Field, Schema}; +//! use arrow_avro::writer::AvroWriter; +//! use arrow_avro::reader::ReaderBuilder; +//! +//! # fn main() -> Result<(), Box> { +//! // Build a tiny Arrow batch +//! let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); +//! let batch = RecordBatch::try_new( +//! Arc::new(schema.clone()), +//! vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef], +//! )?; +//! +//! // Write an Avro **Object Container File** (OCF) to a Vec +//! let sink: Vec = Vec::new(); +//! let mut w = AvroWriter::new(sink, schema.clone())?; +//! w.write(&batch)?; +//! w.finish()?; +//! let bytes = w.into_inner(); +//! assert!(!bytes.is_empty()); +//! +//! // Read it back +//! let mut r = ReaderBuilder::new().build(Cursor::new(bytes))?; +//! let out = r.next().unwrap()?; +//! assert_eq!(out.num_rows(), 3); +//! # Ok(()) } +//! ``` +//! +//! ## Quickstart: SOE (Single‑Object Encoding) round‑trip *(runnable)* +//! +//! Avro **Single‑Object Encoding (SOE)** wraps an Avro body with a 2‑byte marker +//! `0xC3 0x01` and an **8‑byte little‑endian CRC‑64‑AVRO Rabin fingerprint** of the +//! writer schema, then the Avro body. Spec: +//! +//! +//! This example registers the writer schema (computing a Rabin fingerprint), writes a +//! single‑row Avro body (using `AvroStreamWriter`), constructs the SOE frame, and decodes it back to Arrow. +//! +//! ``` +//! use std::collections::HashMap; +//! use std::sync::Arc; +//! use arrow_array::{ArrayRef, Int64Array, RecordBatch}; +//! use arrow_schema::{DataType, Field, Schema}; +//! use arrow_avro::writer::{AvroStreamWriter, WriterBuilder}; +//! use arrow_avro::reader::ReaderBuilder; +//! use arrow_avro::schema::{AvroSchema, SchemaStore, FingerprintStrategy, SCHEMA_METADATA_KEY}; +//! +//! # fn main() -> Result<(), Box> { +//! // Writer schema: { "type":"record","name":"User","fields":[{"name":"x","type":"long"}] } +//! let writer_json = r#"{"type":"record","name":"User","fields":[{"name":"x","type":"long"}]}"#; +//! let mut store = SchemaStore::new(); // Rabin CRC‑64‑AVRO by default +//! let _fp = store.register(AvroSchema::new(writer_json.to_string()))?; +//! +//! // Build an Arrow schema that references the same Avro JSON +//! let mut md = HashMap::new(); +//! md.insert(SCHEMA_METADATA_KEY.to_string(), writer_json.to_string()); +//! let schema = Schema::new_with_metadata( +//! vec![Field::new("x", DataType::Int64, false)], +//! md, +//! ); +//! +//! // One‑row batch: { x: 7 } +//! let batch = RecordBatch::try_new( +//! Arc::new(schema.clone()), +//! vec![Arc::new(Int64Array::from(vec![7])) as ArrayRef], +//! )?; +//! +//! // Stream‑write a single record; the writer adds **SOE** (C3 01 + Rabin) automatically. +//! let sink: Vec = Vec::new(); +//! let mut w: AvroStreamWriter> = WriterBuilder::new(schema.clone()) +//! .with_fingerprint_strategy(FingerprintStrategy::Rabin) +//! .build(sink)?; +//! w.write(&batch)?; +//! w.finish()?; +//! let frame = w.into_inner(); // already: C3 01 + 8B LE Rabin + Avro body +//! assert!(frame.len() > 10); +//! +//! // Decode +//! let mut dec = ReaderBuilder::new() +//! .with_writer_schema_store(store) +//! .build_decoder()?; +//! dec.decode(&frame)?; +//! let out = dec.flush()?.expect("one row"); +//! assert_eq!(out.num_rows(), 1); +//! # Ok(()) } +//! ``` +//! +//! --- +//! +//! ### Modules +//! +//! - [`reader`]: read Avro (OCF, SOE, Confluent) into Arrow `RecordBatch`es. +//! - [`writer`]: write Arrow `RecordBatch`es as Avro (OCF, SOE, Confluent, Apicurio). +//! - [`schema`]: Avro schema parsing / fingerprints / registries. +//! - [`compression`]: codecs used for **OCF block compression** (i.e., Deflate, Snappy, Zstandard, BZip2, and XZ). +//! - [`codec`]: internal Avro-Arrow type conversion and row decode/encode plans. +//! +//! ### Features +//! +//! **OCF compression (enabled by default)** +//! - `deflate` — enable DEFLATE block compression (via `flate2`). +//! - `snappy` — enable Snappy block compression with 4‑byte BE CRC32 (per Avro). +//! - `zstd` — enable Zstandard block compression. +//! - `bzip2` — enable BZip2 block compression. +//! - `xz` — enable XZ/LZMA block compression. +//! +//! **Schema fingerprints & helpers (opt‑in)** +//! - `md5` — enable MD5 writer‑schema fingerprints. +//! - `sha256` — enable SHA‑256 writer‑schema fingerprints. +//! - `small_decimals` — support for compact Arrow representations of small Avro decimals (`Decimal32` and `Decimal64`). +//! - `avro_custom_types` — interpret Avro fields annotated with Arrow‑specific logical +//! types such as `arrow.duration-nanos`, `arrow.duration-micros`, +//! `arrow.duration-millis`, or `arrow.duration-seconds` as Arrow `Duration(TimeUnit)`. +//! - `canonical_extension_types` — enable support for Arrow [canonical extension types] +//! from `arrow-schema` so `arrow-avro` can respect them during Avro↔Arrow mapping. +//! +//! **Notes** +//! - OCF compression codecs apply only to **Object Container Files**; they do not affect Avro +//! single object encodings. +//! +//! [canonical extension types]: https://arrow.apache.org/docs/format/CanonicalExtensions.html +//! +//! [Apache Arrow]: https://arrow.apache.org/ //! [Apache Avro]: https://avro.apache.org/ #![doc( html_logo_url = "https://arrow.apache.org/img/arrow-logo_chevrons_black-txt_white-bg.svg", html_favicon_url = "https://arrow.apache.org/img/arrow-logo_chevrons_black-txt_transparent-bg.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] #![warn(missing_docs)] -#![allow(unused)] // Temporary /// Core functionality for reading Avro data into Arrow arrays /// /// Implements the primary reader interface and record decoding logic. pub mod reader; -// Avro schema parsing and representation -// -// Provides types for parsing and representing Avro schema definitions. -mod schema; +/// Core functionality for writing Arrow arrays as Avro data +/// +/// Implements the primary writer interface and record encoding logic. +pub mod writer; + +/// Avro schema parsing and representation +/// +/// Provides types for parsing and representing Avro schema definitions. +pub mod schema; /// Compression codec implementations for Avro /// @@ -50,8 +195,6 @@ pub mod compression; /// Avro data types and Arrow data types. pub mod codec; -pub use reader::ReadOptions; - /// Extension trait for AvroField to add Utf8View support /// /// This trait adds methods for working with Utf8View support to the AvroField struct. diff --git a/arrow-avro/src/reader/cursor.rs b/arrow-avro/src/reader/cursor.rs index 4b6a5a4d65db..23d9e503339d 100644 --- a/arrow-avro/src/reader/cursor.rs +++ b/arrow-avro/src/reader/cursor.rs @@ -85,7 +85,7 @@ impl<'a> AvroCursor<'a> { ArrowError::ParseError("offset overflow reading avro bytes".to_string()) })?; - if (self.buf.len() < len) { + if self.buf.len() < len { return Err(ArrowError::ParseError( "Unexpected EOF reading bytes".to_string(), )); @@ -97,7 +97,7 @@ impl<'a> AvroCursor<'a> { #[inline] pub(crate) fn get_float(&mut self) -> Result { - if (self.buf.len() < 4) { + if self.buf.len() < 4 { return Err(ArrowError::ParseError( "Unexpected EOF reading float".to_string(), )); @@ -109,7 +109,7 @@ impl<'a> AvroCursor<'a> { #[inline] pub(crate) fn get_double(&mut self) -> Result { - if (self.buf.len() < 8) { + if self.buf.len() < 8 { return Err(ArrowError::ParseError( "Unexpected EOF reading float".to_string(), )); @@ -118,4 +118,16 @@ impl<'a> AvroCursor<'a> { self.buf = &self.buf[8..]; Ok(ret) } + + /// Read exactly `n` bytes from the buffer (e.g. for Avro `fixed`). + pub(crate) fn get_fixed(&mut self, n: usize) -> Result<&'a [u8], ArrowError> { + if self.buf.len() < n { + return Err(ArrowError::ParseError( + "Unexpected EOF reading fixed".to_string(), + )); + } + let ret = &self.buf[..n]; + self.buf = &self.buf[n..]; + Ok(ret) + } } diff --git a/arrow-avro/src/reader/header.rs b/arrow-avro/src/reader/header.rs index 98c285171bf3..aac267f50e9e 100644 --- a/arrow-avro/src/reader/header.rs +++ b/arrow-avro/src/reader/header.rs @@ -17,10 +17,31 @@ //! Decoder for [`Header`] -use crate::compression::{CompressionCodec, CODEC_METADATA_KEY}; +use crate::compression::{CODEC_METADATA_KEY, CompressionCodec}; use crate::reader::vlq::VLQDecoder; -use crate::schema::{Schema, SCHEMA_METADATA_KEY}; +use crate::schema::{SCHEMA_METADATA_KEY, Schema}; use arrow_schema::ArrowError; +use std::io::BufRead; + +/// Read the Avro file header (magic, metadata, sync marker) from `reader`. +pub(crate) fn read_header(mut reader: R) -> Result { + let mut decoder = HeaderDecoder::default(); + loop { + let buf = reader.fill_buf()?; + if buf.is_empty() { + break; + } + let read = buf.len(); + let decoded = decoder.decode(buf)?; + reader.consume(decoded); + if decoded != read { + break; + } + } + decoder.flush().ok_or_else(|| { + ArrowError::ParseError("Unexpected EOF while reading Avro header".to_string()) + }) +} #[derive(Debug)] enum HeaderDecoderState { @@ -77,12 +98,13 @@ impl Header { /// Returns the [`CompressionCodec`] if any pub fn compression(&self) -> Result, ArrowError> { let v = self.get(CODEC_METADATA_KEY); - match v { None | Some(b"null") => Ok(None), Some(b"deflate") => Ok(Some(CompressionCodec::Deflate)), Some(b"snappy") => Ok(Some(CompressionCodec::Snappy)), Some(b"zstandard") => Ok(Some(CompressionCodec::ZStandard)), + Some(b"bzip2") => Ok(Some(CompressionCodec::Bzip2)), + Some(b"xz") => Ok(Some(CompressionCodec::Xz)), Some(v) => Err(ArrowError::ParseError(format!( "Unrecognized compression codec \'{}\'", String::from_utf8_lossy(v) @@ -90,8 +112,8 @@ impl Header { } } - /// Returns the [`Schema`] if any - pub fn schema(&self) -> Result>, ArrowError> { + /// Returns the `Schema` if any + pub(crate) fn schema(&self) -> Result>, ArrowError> { self.get(SCHEMA_METADATA_KEY) .map(|x| { serde_json::from_slice(x).map_err(|e| { @@ -264,13 +286,16 @@ impl HeaderDecoder { #[cfg(test)] mod test { use super::*; - use crate::codec::{AvroDataType, AvroField}; + use crate::codec::AvroField; use crate::reader::read_header; - use crate::schema::SCHEMA_METADATA_KEY; + use crate::schema::{ + AVRO_NAME_METADATA_KEY, AVRO_ROOT_RECORD_DEFAULT_NAME, SCHEMA_METADATA_KEY, + }; use crate::test_util::arrow_test_data; use arrow_schema::{DataType, Field, Fields, TimeUnit}; + use std::collections::HashMap; use std::fs::File; - use std::io::{BufRead, BufReader}; + use std::io::BufReader; #[test] fn test_header_decode() { @@ -290,7 +315,7 @@ mod test { fn decode_file(file: &str) -> Header { let file = File::open(file).unwrap(); - read_header(BufReader::with_capacity(100, file)).unwrap() + read_header(BufReader::with_capacity(1000, file)).unwrap() } #[test] @@ -325,6 +350,10 @@ mod test { ])), false ) + .with_metadata(HashMap::from([( + AVRO_NAME_METADATA_KEY.to_string(), + AVRO_ROOT_RECORD_DEFAULT_NAME.to_string() + )])) ); assert_eq!( diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index 61e3e8511caa..546650faf568 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -15,11 +15,482 @@ // specific language governing permissions and limitations // under the License. -//! Read Avro data to Arrow - -use crate::reader::block::{Block, BlockDecoder}; -use crate::reader::header::{Header, HeaderDecoder}; -use arrow_schema::ArrowError; +//! Avro reader +//! +//! Facilities to read Apache Avro–encoded data into Arrow's `RecordBatch` format. +//! +//! ### Limitations +//! +//!- **Avro unions with > 127 branches are not supported.** +//! When decoding Avro unions to Arrow `UnionArray`, Arrow stores the union +//! type identifiers in an **8‑bit signed** buffer (`i8`). This implies a +//! practical limit of **127** distinct branch ids. Inputs that resolve to +//! more than 127 branches will return an error. If you truly need more, +//! model the schema as a **union of unions**, per the Arrow format spec. +//! +//! See: Arrow Columnar Format — Dense Union (“types buffer: 8‑bit signed; +//! a union with more than 127 possible types can be modeled as a union of +//! unions”). +//! +//! This module exposes three layers of the API surface, from highest to lowest-level: +//! +//! * [`ReaderBuilder`](crate::reader::ReaderBuilder): configures how Avro is read (batch size, strict union handling, +//! string representation, reader schema, etc.) and produces either: +//! * a `Reader` for **Avro Object Container Files (OCF)** read from any `BufRead`, or +//! * a low-level `Decoder` for **single‑object encoded** Avro bytes and Confluent +//! **Schema Registry** framed messages. +//! * [`Reader`](crate::reader::Reader): a convenient, synchronous iterator over `RecordBatch` decoded from an OCF +//! input. Implements [`Iterator>`] and +//! `RecordBatchReader`. +//! * [`Decoder`](crate::reader::Decoder): a push‑based row decoder that consumes SOE framed Avro bytes and yields ready +//! `RecordBatch` values when batches fill. This is suitable for integrating with async +//! byte streams, network protocols, or other custom data sources. +//! +//! ## Encodings and when to use which type +//! +//! * **Object Container File (OCF)**: A self‑describing file format with a header containing +//! the writer schema, optional compression codec, and a sync marker, followed by one or +//! more data blocks. Use `Reader` for this format. See the Avro 1.11.1 specification +//! (“Object Container Files”). +//! * **Single‑Object Encoding**: A stream‑friendly framing that prefixes each record body with +//! the 2‑byte marker `0xC3 0x01` followed by the **8‑byte little‑endian CRC‑64‑AVRO Rabin +//! fingerprint** of the writer schema, then the Avro binary body. Use `Decoder` with a +//! populated `SchemaStore` to resolve fingerprints to full schemas. +//! See “Single object encoding” in the Avro 1.11.1 spec. +//! +//! * **Confluent Schema Registry wire format**: A 1‑byte magic `0x00`, a **4‑byte big‑endian** +//! schema ID, then the Avro‑encoded body. Use `Decoder` with a `SchemaStore` configured +//! for `FingerprintAlgorithm::Id` and entries keyed by `Fingerprint::Id`. See +//! Confluent’s “Wire format” documentation. +//! +//! * **Apicurio Schema Registry wire format**: A 1‑byte magic `0x00`, a **8‑byte big‑endian** +//! global schema ID, then the Avro‑encoded body. Use `Decoder` with a `SchemaStore` configured +//! for `FingerprintAlgorithm::Id64` and entries keyed by `Fingerprint::Id64`. See +//! Apicurio’s “Avro SerDe” documentation. +//! +//! +//! ## Basic file usage (OCF) +//! +//! Use `ReaderBuilder::build` to construct a `Reader` from any `BufRead`. The doctest below +//! creates a tiny OCF in memory using `AvroWriter` and then reads it back. +//! +//! ``` +//! use std::io::Cursor; +//! use std::sync::Arc; +//! use arrow_array::{ArrayRef, Int32Array, RecordBatch}; +//! use arrow_schema::{DataType, Field, Schema}; +//! use arrow_avro::writer::AvroWriter; +//! use arrow_avro::reader::ReaderBuilder; +//! +//! # fn main() -> Result<(), Box> { +//! // Build a minimal Arrow schema and batch +//! let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); +//! let batch = RecordBatch::try_new( +//! Arc::new(schema.clone()), +//! vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef], +//! )?; +//! +//! // Write an Avro OCF to memory +//! let buffer: Vec = Vec::new(); +//! let mut writer = AvroWriter::new(buffer, schema.clone())?; +//! writer.write(&batch)?; +//! writer.finish()?; +//! let bytes = writer.into_inner(); +//! +//! // Read it back with ReaderBuilder +//! let mut reader = ReaderBuilder::new().build(Cursor::new(bytes))?; +//! let out = reader.next().unwrap()?; +//! assert_eq!(out.num_rows(), 3); +//! # Ok(()) } +//! ``` +//! +//! ## Streaming usage (single‑object / Confluent / Apicurio) +//! +//! The `Decoder` lets you integrate Avro decoding with **any** source of bytes by +//! periodically calling `Decoder::decode` with new data and calling `Decoder::flush` +//! to get a `RecordBatch` once at least one row is complete. +//! +//! The example below shows how to decode from an arbitrary stream of `bytes::Bytes` using +//! `futures` utilities. Note: this is illustrative and keeps a single in‑memory `Bytes` +//! buffer for simplicity—real applications typically maintain a rolling buffer. +//! +//! ``` +//! use bytes::{Buf, Bytes}; +//! use futures::{Stream, StreamExt}; +//! use std::task::{Poll, ready}; +//! use arrow_array::RecordBatch; +//! use arrow_schema::ArrowError; +//! use arrow_avro::reader::Decoder; +//! +//! /// Decode a stream of Avro-framed bytes into RecordBatch values. +//! fn decode_stream + Unpin>( +//! mut decoder: Decoder, +//! mut input: S, +//! ) -> impl Stream> { +//! let mut buffered = Bytes::new(); +//! futures::stream::poll_fn(move |cx| { +//! loop { +//! if buffered.is_empty() { +//! buffered = match ready!(input.poll_next_unpin(cx)) { +//! Some(b) => b, +//! None => break, // EOF +//! }; +//! } +//! // Feed as much as possible +//! let decoded = match decoder.decode(buffered.as_ref()) { +//! Ok(n) => n, +//! Err(e) => return Poll::Ready(Some(Err(e))), +//! }; +//! let read = buffered.len(); +//! buffered.advance(decoded); +//! if decoded != read { +//! // decoder made partial progress; request more bytes +//! break +//! } +//! } +//! // Return a batch if one or more rows are complete +//! Poll::Ready(decoder.flush().transpose()) +//! }) +//! } +//! ``` +//! +//! ### Building and using a `Decoder` for **single‑object encoding** (Rabin fingerprints) +//! +//! The doctest below **writes** a single‑object framed record using the Avro writer +//! (no manual varints) for the writer schema +//! (`{"type":"record","name":"User","fields":[{"name":"id","type":"long"}]}`) +//! and then decodes it into a `RecordBatch`. +//! +//! ``` +//! use std::sync::Arc; +//! use std::collections::HashMap; +//! use arrow_array::{ArrayRef, Int64Array, RecordBatch}; +//! use arrow_schema::{DataType, Field, Schema}; +//! use arrow_avro::schema::{AvroSchema, SchemaStore, SCHEMA_METADATA_KEY, FingerprintStrategy}; +//! use arrow_avro::writer::{WriterBuilder, format::AvroSoeFormat}; +//! use arrow_avro::reader::ReaderBuilder; +//! +//! # fn main() -> Result<(), Box> { +//! // Register the writer schema (Rabin fingerprint by default). +//! let mut store = SchemaStore::new(); +//! let avro_schema = AvroSchema::new(r#"{"type":"record","name":"User","fields":[ +//! {"name":"id","type":"long"}]}"#.to_string()); +//! let _fp = store.register(avro_schema.clone())?; +//! +//! // Create a single-object framed record { id: 42 } with the Avro writer. +//! let mut md = HashMap::new(); +//! md.insert(SCHEMA_METADATA_KEY.to_string(), avro_schema.json_string.clone()); +//! let arrow = Schema::new_with_metadata(vec![Field::new("id", DataType::Int64, false)], md); +//! let batch = RecordBatch::try_new( +//! Arc::new(arrow.clone()), +//! vec![Arc::new(Int64Array::from(vec![42])) as ArrayRef], +//! )?; +//! let mut w = WriterBuilder::new(arrow) +//! .with_fingerprint_strategy(FingerprintStrategy::Rabin) // SOE prefix +//! .build::<_, AvroSoeFormat>(Vec::new())?; +//! w.write(&batch)?; +//! w.finish()?; +//! let frame = w.into_inner(); // C3 01 + fp + Avro body +//! +//! // Decode with a `Decoder` +//! let mut dec = ReaderBuilder::new() +//! .with_writer_schema_store(store) +//! .with_batch_size(1024) +//! .build_decoder()?; +//! +//! dec.decode(&frame)?; +//! let out = dec.flush()?.expect("one batch"); +//! assert_eq!(out.num_rows(), 1); +//! # Ok(()) } +//! ``` +//! +//! See Avro 1.11.1 “Single object encoding” for details of the 2‑byte marker +//! and little‑endian CRC‑64‑AVRO fingerprint: +//! +//! +//! ### Building and using a `Decoder` for **Confluent Schema Registry** framing +//! +//! The Confluent wire format is: 1‑byte magic `0x00`, then a **4‑byte big‑endian** schema ID, +//! then the Avro body. The doctest below crafts two messages for the same schema ID and +//! decodes them into a single `RecordBatch` with two rows. +//! +//! ``` +//! use std::sync::Arc; +//! use std::collections::HashMap; +//! use arrow_array::{ArrayRef, Int64Array, StringArray, RecordBatch}; +//! use arrow_schema::{DataType, Field, Schema}; +//! use arrow_avro::schema::{AvroSchema, SchemaStore, Fingerprint, FingerprintAlgorithm, SCHEMA_METADATA_KEY, FingerprintStrategy}; +//! use arrow_avro::writer::{WriterBuilder, format::AvroSoeFormat}; +//! use arrow_avro::reader::ReaderBuilder; +//! +//! # fn main() -> Result<(), Box> { +//! // Set up a store keyed by numeric IDs (Confluent). +//! let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::Id); +//! let schema_id = 7u32; +//! let avro_schema = AvroSchema::new(r#"{"type":"record","name":"User","fields":[ +//! {"name":"id","type":"long"}, {"name":"name","type":"string"}]}"#.to_string()); +//! store.set(Fingerprint::Id(schema_id), avro_schema.clone())?; +//! +//! // Write two Confluent-framed messages {id:1,name:"a"} and {id:2,name:"b"}. +//! fn msg(id: i64, name: &str, schema: &AvroSchema, schema_id: u32) -> Result, Box> { +//! let mut md = HashMap::new(); +//! md.insert(SCHEMA_METADATA_KEY.to_string(), schema.json_string.clone()); +//! let arrow = Schema::new_with_metadata( +//! vec![Field::new("id", DataType::Int64, false), Field::new("name", DataType::Utf8, false)], +//! md, +//! ); +//! let batch = RecordBatch::try_new( +//! Arc::new(arrow.clone()), +//! vec![ +//! Arc::new(Int64Array::from(vec![id])) as ArrayRef, +//! Arc::new(StringArray::from(vec![name])) as ArrayRef, +//! ], +//! )?; +//! let mut w = WriterBuilder::new(arrow) +//! .with_fingerprint_strategy(FingerprintStrategy::Id(schema_id)) // 0x00 + ID + body +//! .build::<_, AvroSoeFormat>(Vec::new())?; +//! w.write(&batch)?; w.finish()?; +//! Ok(w.into_inner()) +//! } +//! let m1 = msg(1, "a", &avro_schema, schema_id)?; +//! let m2 = msg(2, "b", &avro_schema, schema_id)?; +//! +//! // Decode both into a single batch. +//! let mut dec = ReaderBuilder::new() +//! .with_writer_schema_store(store) +//! .with_batch_size(1024) +//! .build_decoder()?; +//! dec.decode(&m1)?; +//! dec.decode(&m2)?; +//! let batch = dec.flush()?.expect("batch"); +//! assert_eq!(batch.num_rows(), 2); +//! # Ok(()) } +//! ``` +//! +//! See Confluent’s “Wire format” notes: magic byte `0x00`, 4‑byte **big‑endian** schema ID, +//! then the Avro‑encoded payload. +//! +//! +//! ## Schema resolution (reader vs. writer schemas) +//! +//! Avro supports resolving data written with one schema (“writer”) into another (“reader”) +//! using rules like **field aliases**, **default values**, and **numeric promotions**. +//! In practice this lets you evolve schemas over time while remaining compatible with old data. +//! +//! *Spec background:* See Avro’s **Schema Resolution** (aliases, defaults) and the Confluent +//! **Wire format** (magic `0x00` + big‑endian schema id + Avro body). +//! +//! +//! +//! ### OCF example: rename a field and add a default via a reader schema +//! +//! Below we write an OCF with a *writer schema* having fields `id: long`, `name: string`. +//! We then read it with a *reader schema* that: +//! - **renames** `name` to `full_name` via `aliases`, and +//! - **adds** `is_active: boolean` with a **default** value `true`. +//! +//! ``` +//! use std::io::Cursor; +//! use std::sync::Arc; +//! use arrow_array::{ArrayRef, Int64Array, StringArray, RecordBatch}; +//! use arrow_schema::{DataType, Field, Schema}; +//! use arrow_avro::writer::AvroWriter; +//! use arrow_avro::reader::ReaderBuilder; +//! use arrow_avro::schema::AvroSchema; +//! +//! # fn main() -> Result<(), Box> { +//! // Writer (past version): { id: long, name: string } +//! let writer_arrow = Schema::new(vec![ +//! Field::new("id", DataType::Int64, false), +//! Field::new("name", DataType::Utf8, false), +//! ]); +//! let batch = RecordBatch::try_new( +//! Arc::new(writer_arrow.clone()), +//! vec![ +//! Arc::new(Int64Array::from(vec![1, 2])) as ArrayRef, +//! Arc::new(StringArray::from(vec!["a", "b"])) as ArrayRef, +//! ], +//! )?; +//! +//! // Write an OCF entirely in memory +//! let mut w = AvroWriter::new(Vec::::new(), writer_arrow)?; +//! w.write(&batch)?; +//! w.finish()?; +//! let bytes = w.into_inner(); +//! +//! // Reader (current version): +//! // - record name "topLevelRecord" matches the crate's default for OCF +//! // - rename `name` -> `full_name` using aliases (optional) +//! let reader_json = r#" +//! { +//! "type": "record", +//! "name": "topLevelRecord", +//! "fields": [ +//! { "name": "id", "type": "long" }, +//! { "name": "full_name", "type": ["null","string"], "aliases": ["name"], "default": null }, +//! { "name": "is_active", "type": "boolean", "default": true } +//! ] +//! }"#; +//! +//! let mut reader = ReaderBuilder::new() +//! .with_reader_schema(AvroSchema::new(reader_json.to_string())) +//! .build(Cursor::new(bytes))?; +//! +//! let out = reader.next().unwrap()?; +//! assert_eq!(out.num_rows(), 2); +//! # Ok(()) } +//! ``` +//! +//! ### Confluent single‑object example: resolve *past* writer versions to the topic’s **current** reader schema +//! +//! In this scenario, the **reader schema** is the topic’s *current* schema, while the two +//! **writer schemas** registered under Confluent IDs **1** and **2** represent *past versions*. +//! The decoder uses the reader schema to resolve both versions. +//! +//! ``` +//! use std::sync::Arc; +//! use std::collections::HashMap; +//! use arrow_avro::reader::ReaderBuilder; +//! use arrow_avro::schema::{ +//! AvroSchema, Fingerprint, FingerprintAlgorithm, SchemaStore, +//! SCHEMA_METADATA_KEY, FingerprintStrategy, +//! }; +//! use arrow_array::{ArrayRef, Int32Array, Int64Array, StringArray, RecordBatch}; +//! use arrow_schema::{DataType, Field, Schema}; +//! +//! fn main() -> Result<(), Box> { +//! // Reader: current topic schema (no reader-added fields) +//! // {"type":"record","name":"User","fields":[ +//! // {"name":"id","type":"long"}, +//! // {"name":"name","type":"string"}]} +//! let reader_schema = AvroSchema::new( +//! r#"{"type":"record","name":"User", +//! "fields":[{"name":"id","type":"long"},{"name":"name","type":"string"}]}"# +//! .to_string(), +//! ); +//! +//! // Register two *writer* schemas under Confluent IDs 0 and 1 +//! let writer_v0 = AvroSchema::new( +//! r#"{"type":"record","name":"User", +//! "fields":[{"name":"id","type":"int"},{"name":"name","type":"string"}]}"# +//! .to_string(), +//! ); +//! let writer_v1 = AvroSchema::new( +//! r#"{"type":"record","name":"User", +//! "fields":[{"name":"id","type":"long"},{"name":"name","type":"string"}, +//! {"name":"email","type":["null","string"],"default":null}]}"# +//! .to_string(), +//! ); +//! +//! let id_v0: u32 = 0; +//! let id_v1: u32 = 1; +//! +//! let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::Id); // integer IDs +//! store.set(Fingerprint::Id(id_v0), writer_v0.clone())?; +//! store.set(Fingerprint::Id(id_v1), writer_v1.clone())?; +//! +//! // Write two Confluent-framed messages using each writer version +//! // frame0: writer v0 body {id:1001_i32, name:"v0-alice"} +//! let mut md0 = HashMap::new(); +//! md0.insert(SCHEMA_METADATA_KEY.to_string(), writer_v0.json_string.clone()); +//! let arrow0 = Schema::new_with_metadata( +//! vec![Field::new("id", DataType::Int32, false), +//! Field::new("name", DataType::Utf8, false)], md0); +//! let batch0 = RecordBatch::try_new( +//! Arc::new(arrow0.clone()), +//! vec![Arc::new(Int32Array::from(vec![1001])) as ArrayRef, +//! Arc::new(StringArray::from(vec!["v0-alice"])) as ArrayRef])?; +//! let mut w0 = arrow_avro::writer::WriterBuilder::new(arrow0) +//! .with_fingerprint_strategy(FingerprintStrategy::Id(id_v0)) +//! .build::<_, arrow_avro::writer::format::AvroSoeFormat>(Vec::new())?; +//! w0.write(&batch0)?; w0.finish()?; +//! let frame0 = w0.into_inner(); // 0x00 + id_v0 + body +//! +//! // frame1: writer v1 body {id:2002_i64, name:"v1-bob", email: Some("bob@example.com")} +//! let mut md1 = HashMap::new(); +//! md1.insert(SCHEMA_METADATA_KEY.to_string(), writer_v1.json_string.clone()); +//! let arrow1 = Schema::new_with_metadata( +//! vec![Field::new("id", DataType::Int64, false), +//! Field::new("name", DataType::Utf8, false), +//! Field::new("email", DataType::Utf8, true)], md1); +//! let batch1 = RecordBatch::try_new( +//! Arc::new(arrow1.clone()), +//! vec![Arc::new(Int64Array::from(vec![2002])) as ArrayRef, +//! Arc::new(StringArray::from(vec!["v1-bob"])) as ArrayRef, +//! Arc::new(StringArray::from(vec![Some("bob@example.com")])) as ArrayRef])?; +//! let mut w1 = arrow_avro::writer::WriterBuilder::new(arrow1) +//! .with_fingerprint_strategy(FingerprintStrategy::Id(id_v1)) +//! .build::<_, arrow_avro::writer::format::AvroSoeFormat>(Vec::new())?; +//! w1.write(&batch1)?; w1.finish()?; +//! let frame1 = w1.into_inner(); // 0x00 + id_v1 + body +//! +//! // Build a streaming Decoder that understands Confluent framing +//! let mut decoder = ReaderBuilder::new() +//! .with_reader_schema(reader_schema) +//! .with_writer_schema_store(store) +//! .with_batch_size(8) // small demo batches +//! .build_decoder()?; +//! +//! // Decode each whole frame, then drain completed rows with flush() +//! let mut total_rows = 0usize; +//! +//! let consumed0 = decoder.decode(&frame0)?; +//! assert_eq!(consumed0, frame0.len(), "decoder must consume the whole frame"); +//! while let Some(batch) = decoder.flush()? { total_rows += batch.num_rows(); } +//! +//! let consumed1 = decoder.decode(&frame1)?; +//! assert_eq!(consumed1, frame1.len(), "decoder must consume the whole frame"); +//! while let Some(batch) = decoder.flush()? { total_rows += batch.num_rows(); } +//! +//! // We sent 2 records so we should get 2 rows (possibly one per flush) +//! assert_eq!(total_rows, 2); +//! Ok(()) +//! } +//! ``` +//! +//! ## Schema evolution and batch boundaries +//! +//! `Decoder` supports mid‑stream schema changes when the input framing carries a schema +//! fingerprint (single‑object or Confluent). When a new fingerprint is observed: +//! +//! * If the current `RecordBatch` is **empty**, the decoder switches to the new schema +//! immediately. +//! * If not, the decoder finishes the current batch first and only then switches. +//! +//! Consequently, the schema of batches produced by `Decoder::flush` may change over time, +//! and `Decoder` intentionally does **not** implement `RecordBatchReader`. In contrast, +//! `Reader` (OCF) has a single writer schema for the entire file and therefore implements +//! `RecordBatchReader`. +//! +//! ## Performance & memory +//! +//! * `batch_size` controls the maximum number of rows per `RecordBatch`. Larger batches +//! amortize per‑batch overhead; smaller batches reduce peak memory usage and latency. +//! * When `utf8_view` is enabled, string columns use Arrow’s `StringViewArray`, which can +//! reduce allocations for short strings. +//! * For OCF, blocks may be compressed; `Reader` will decompress using the codec specified +//! in the file header and feed uncompressed bytes to the row `Decoder`. +//! +//! ## Error handling +//! +//! * Incomplete inputs return parse errors with "Unexpected EOF"; callers typically provide +//! more bytes and try again. +//! * If a fingerprint is unknown to the provided `SchemaStore`, decoding fails with a +//! descriptive error. Populate the store up front to avoid this. +//! +//! --- +use crate::codec::AvroFieldBuilder; +use crate::reader::header::read_header; +use crate::schema::{ + AvroSchema, CONFLUENT_MAGIC, Fingerprint, FingerprintAlgorithm, SINGLE_OBJECT_MAGIC, Schema, + SchemaStore, +}; +use arrow_array::{RecordBatch, RecordBatchReader}; +use arrow_schema::{ArrowError, SchemaRef}; +use block::BlockDecoder; +use header::Header; +use indexmap::IndexMap; +use record::RecordDecoder; use std::io::BufRead; mod block; @@ -28,259 +499,2284 @@ mod header; mod record; mod vlq; -/// Configuration options for reading Avro data into Arrow arrays +fn is_incomplete_data(err: &ArrowError) -> bool { + matches!( + err, + ArrowError::ParseError(msg) + if msg.contains("Unexpected EOF") + ) +} + +/// A low‑level, push‑based decoder from Avro bytes to Arrow `RecordBatch`. +/// +/// `Decoder` is designed for **streaming** scenarios: +/// +/// * You *feed* freshly received bytes using `Self::decode`, potentially multiple times, +/// until at least one row is complete. +/// * You then *drain* completed rows with `Self::flush`, which yields a `RecordBatch` +/// if any rows were finished since the last flush. +/// +/// Unlike `Reader`, which is specialized for Avro **Object Container Files**, `Decoder` +/// understands **framed single‑object** inputs and **Confluent Schema Registry** messages, +/// switching schemas mid‑stream when the framing indicates a new fingerprint. +/// +/// ### Supported prefixes +/// +/// On each new row boundary, `Decoder` tries to match one of the following "prefixes": +/// +/// * **Single‑Object encoding**: magic `0xC3 0x01` + schema fingerprint (length depends on +/// the configured `FingerprintAlgorithm`); see `SINGLE_OBJECT_MAGIC`. +/// * **Confluent wire format**: magic `0x00` + 4‑byte big‑endian schema id; see +/// `CONFLUENT_MAGIC`. +/// +/// The active fingerprint determines which cached row decoder is used to decode the following +/// record body bytes. +/// +/// ### Schema switching semantics +/// +/// When a new fingerprint is observed: +/// +/// * If the current batch is empty, the decoder switches immediately; +/// * Otherwise, the current batch is finalized on the next `flush` and only then +/// does the decoder switch to the new schema. This guarantees that a single `RecordBatch` +/// never mixes rows with different schemas. +/// +/// ### Examples +/// +/// Build and use a `Decoder` for single‑object encoding: +/// +/// ``` +/// use arrow_avro::schema::{AvroSchema, SchemaStore}; +/// use arrow_avro::reader::ReaderBuilder; +/// +/// # fn main() -> Result<(), Box> { +/// // Use a record schema at the top level so we can build an Arrow RecordBatch +/// let mut store = SchemaStore::new(); // Rabin fingerprinting by default +/// let avro = AvroSchema::new( +/// r#"{"type":"record","name":"E","fields":[{"name":"x","type":"long"}]}"#.to_string() +/// ); +/// let fp = store.register(avro)?; +/// +/// // --- Hidden: write a single-object framed row {x:7} --- +/// # use std::sync::Arc; +/// # use std::collections::HashMap; +/// # use arrow_array::{ArrayRef, Int64Array, RecordBatch}; +/// # use arrow_schema::{DataType, Field, Schema}; +/// # use arrow_avro::schema::{SCHEMA_METADATA_KEY, FingerprintStrategy}; +/// # use arrow_avro::writer::{WriterBuilder, format::AvroSoeFormat}; +/// # let mut md = HashMap::new(); +/// # md.insert(SCHEMA_METADATA_KEY.to_string(), +/// # r#"{"type":"record","name":"E","fields":[{"name":"x","type":"long"}]}"#.to_string()); +/// # let arrow = Schema::new_with_metadata(vec![Field::new("x", DataType::Int64, false)], md); +/// # let batch = RecordBatch::try_new(Arc::new(arrow.clone()), vec![Arc::new(Int64Array::from(vec![7])) as ArrayRef])?; +/// # let mut w = WriterBuilder::new(arrow) +/// # .with_fingerprint_strategy(fp.into()) +/// # .build::<_, AvroSoeFormat>(Vec::new())?; +/// # w.write(&batch)?; w.finish()?; let frame = w.into_inner(); +/// +/// let mut decoder = ReaderBuilder::new() +/// .with_writer_schema_store(store) +/// .with_batch_size(16) +/// .build_decoder()?; +/// +/// # decoder.decode(&frame)?; +/// let batch = decoder.flush()?.expect("one row"); +/// assert_eq!(batch.num_rows(), 1); +/// # Ok(()) } +/// ``` +/// +/// *Background:* Avro's single‑object encoding is defined as `0xC3 0x01` + 8‑byte +/// little‑endian CRC‑64‑AVRO fingerprint of the **writer schema** + Avro binary body. +/// See the Avro 1.11.1 spec for details. +/// +/// Build and use a `Decoder` for Confluent Registry messages: +/// +/// ``` +/// use arrow_avro::schema::{AvroSchema, SchemaStore, Fingerprint, FingerprintAlgorithm}; +/// use arrow_avro::reader::ReaderBuilder; +/// +/// # fn main() -> Result<(), Box> { +/// let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::Id); +/// store.set(Fingerprint::Id(1234), AvroSchema::new(r#"{"type":"record","name":"E","fields":[{"name":"x","type":"long"}]}"#.to_string()))?; +/// +/// // --- Hidden: encode two Confluent-framed messages {x:1} and {x:2} --- +/// # use std::sync::Arc; +/// # use std::collections::HashMap; +/// # use arrow_array::{ArrayRef, Int64Array, RecordBatch}; +/// # use arrow_schema::{DataType, Field, Schema}; +/// # use arrow_avro::schema::{SCHEMA_METADATA_KEY, FingerprintStrategy}; +/// # use arrow_avro::writer::{WriterBuilder, format::AvroSoeFormat}; +/// # fn msg(x: i64) -> Result, Box> { +/// # let mut md = HashMap::new(); +/// # md.insert(SCHEMA_METADATA_KEY.to_string(), +/// # r#"{"type":"record","name":"E","fields":[{"name":"x","type":"long"}]}"#.to_string()); +/// # let arrow = Schema::new_with_metadata(vec![Field::new("x", DataType::Int64, false)], md); +/// # let batch = RecordBatch::try_new(Arc::new(arrow.clone()), vec![Arc::new(Int64Array::from(vec![x])) as ArrayRef])?; +/// # let mut w = WriterBuilder::new(arrow) +/// # .with_fingerprint_strategy(FingerprintStrategy::Id(1234)) +/// # .build::<_, AvroSoeFormat>(Vec::new())?; +/// # w.write(&batch)?; w.finish()?; Ok(w.into_inner()) +/// # } +/// # let m1 = msg(1)?; +/// # let m2 = msg(2)?; +/// +/// let mut decoder = ReaderBuilder::new() +/// .with_writer_schema_store(store) +/// .build_decoder()?; +/// # decoder.decode(&m1)?; +/// # decoder.decode(&m2)?; +/// let batch = decoder.flush()?.expect("two rows"); +/// assert_eq!(batch.num_rows(), 2); +/// # Ok(()) } +/// ``` +#[derive(Debug)] +pub struct Decoder { + active_decoder: RecordDecoder, + active_fingerprint: Option, + batch_size: usize, + remaining_capacity: usize, + cache: IndexMap, + fingerprint_algorithm: FingerprintAlgorithm, + pending_schema: Option<(Fingerprint, RecordDecoder)>, + awaiting_body: bool, +} + +impl Decoder { + /// Returns the Arrow schema for the rows decoded by this decoder. + /// + /// **Note:** With single‑object or Confluent framing, the schema may change + /// at a row boundary when the input indicates a new fingerprint. + pub fn schema(&self) -> SchemaRef { + self.active_decoder.schema().clone() + } + + /// Returns the configured maximum number of rows per batch. + pub fn batch_size(&self) -> usize { + self.batch_size + } + + /// Feed a chunk of bytes into the decoder. + /// + /// This will: + /// + /// * Decode at most `Self::batch_size` rows; + /// * Return the number of input bytes **consumed** from `data` (which may be 0 if more + /// bytes are required, or less than `data.len()` if a prefix/body straddles the + /// chunk boundary); + /// * Defer producing a `RecordBatch` until you call `Self::flush`. + /// + /// # Returns + /// The number of bytes consumed from `data`. + /// + /// # Errors + /// Returns an error if: + /// + /// * The input indicates an unknown fingerprint (not present in the provided + /// `SchemaStore`; + /// * The Avro body is malformed; + /// * A strict‑mode union rule is violated (see `ReaderBuilder::with_strict_mode`). + pub fn decode(&mut self, data: &[u8]) -> Result { + let mut total_consumed = 0usize; + while total_consumed < data.len() && self.remaining_capacity > 0 { + if self.awaiting_body { + match self.active_decoder.decode(&data[total_consumed..], 1) { + Ok(n) => { + self.remaining_capacity -= 1; + total_consumed += n; + self.awaiting_body = false; + continue; + } + Err(ref e) if is_incomplete_data(e) => break, + err => return err, + }; + } + match self.handle_prefix(&data[total_consumed..])? { + Some(0) => break, // Insufficient bytes + Some(n) => { + total_consumed += n; + self.apply_pending_schema_if_batch_empty(); + self.awaiting_body = true; + } + None => { + return Err(ArrowError::ParseError( + "Missing magic bytes and fingerprint".to_string(), + )); + } + } + } + Ok(total_consumed) + } + + // Attempt to handle a prefix at the current position. + // * Ok(None) – buffer does not start with the prefix. + // * Ok(Some(0)) – prefix detected, but the buffer is too short; caller should await more bytes. + // * Ok(Some(n)) – consumed `n > 0` bytes of a complete prefix (magic and fingerprint). + fn handle_prefix(&mut self, buf: &[u8]) -> Result, ArrowError> { + match self.fingerprint_algorithm { + FingerprintAlgorithm::Rabin => { + self.handle_prefix_common(buf, &SINGLE_OBJECT_MAGIC, |bytes| { + Fingerprint::Rabin(u64::from_le_bytes(bytes)) + }) + } + FingerprintAlgorithm::Id => self.handle_prefix_common(buf, &CONFLUENT_MAGIC, |bytes| { + Fingerprint::Id(u32::from_be_bytes(bytes)) + }), + FingerprintAlgorithm::Id64 => { + self.handle_prefix_common(buf, &CONFLUENT_MAGIC, |bytes| { + Fingerprint::Id64(u64::from_be_bytes(bytes)) + }) + } + #[cfg(feature = "md5")] + FingerprintAlgorithm::MD5 => { + self.handle_prefix_common(buf, &SINGLE_OBJECT_MAGIC, |bytes| { + Fingerprint::MD5(bytes) + }) + } + #[cfg(feature = "sha256")] + FingerprintAlgorithm::SHA256 => { + self.handle_prefix_common(buf, &SINGLE_OBJECT_MAGIC, |bytes| { + Fingerprint::SHA256(bytes) + }) + } + } + } + + /// This method checks for the provided `magic` bytes at the start of `buf` and, if present, + /// attempts to read the following fingerprint of `N` bytes, converting it to a + /// `Fingerprint` using `fingerprint_from`. + fn handle_prefix_common( + &mut self, + buf: &[u8], + magic: &[u8; MAGIC_LEN], + fingerprint_from: impl FnOnce([u8; N]) -> Fingerprint, + ) -> Result, ArrowError> { + // Need at least the magic bytes to decide + // 2 bytes for Avro Spec and 1 byte for Confluent Wire Protocol. + if buf.len() < MAGIC_LEN { + return Ok(Some(0)); + } + // Bail out early if the magic does not match. + if &buf[..MAGIC_LEN] != magic { + return Ok(None); + } + // Try to parse the fingerprint that follows the magic. + let consumed_fp = self.handle_fingerprint(&buf[MAGIC_LEN..], fingerprint_from)?; + // Convert the inner result into a “bytes consumed” count. + // NOTE: Incomplete fingerprint consumes no bytes. + Ok(Some(consumed_fp.map_or(0, |n| n + MAGIC_LEN))) + } + + // Attempts to read and install a new fingerprint of `N` bytes. + // + // * Ok(None) – insufficient bytes (`buf.len() < `N`). + // * Ok(Some(N)) – fingerprint consumed (always `N`). + fn handle_fingerprint( + &mut self, + buf: &[u8], + fingerprint_from: impl FnOnce([u8; N]) -> Fingerprint, + ) -> Result, ArrowError> { + // Need enough bytes to get fingerprint (next N bytes) + let Some(fingerprint_bytes) = buf.get(..N) else { + return Ok(None); // insufficient bytes + }; + // SAFETY: length checked above. + let new_fingerprint = fingerprint_from(fingerprint_bytes.try_into().unwrap()); + // If the fingerprint indicates a schema change, prepare to switch decoders. + if self.active_fingerprint != Some(new_fingerprint) { + let Some(new_decoder) = self.cache.shift_remove(&new_fingerprint) else { + return Err(ArrowError::ParseError(format!( + "Unknown fingerprint: {new_fingerprint:?}" + ))); + }; + self.pending_schema = Some((new_fingerprint, new_decoder)); + // If there are already decoded rows, we must flush them first. + // Reducing `remaining_capacity` to 0 ensures `flush` is called next. + if self.remaining_capacity < self.batch_size { + self.remaining_capacity = 0; + } + } + Ok(Some(N)) + } + + fn apply_pending_schema(&mut self) { + if let Some((new_fingerprint, new_decoder)) = self.pending_schema.take() { + if let Some(old_fingerprint) = self.active_fingerprint.replace(new_fingerprint) { + let old_decoder = std::mem::replace(&mut self.active_decoder, new_decoder); + self.cache.shift_remove(&old_fingerprint); + self.cache.insert(old_fingerprint, old_decoder); + } else { + self.active_decoder = new_decoder; + } + } + } + + fn apply_pending_schema_if_batch_empty(&mut self) { + if self.batch_is_empty() { + self.apply_pending_schema(); + } + } + + fn flush_and_reset(&mut self) -> Result, ArrowError> { + if self.batch_is_empty() { + return Ok(None); + } + let batch = self.active_decoder.flush()?; + self.remaining_capacity = self.batch_size; + Ok(Some(batch)) + } + + /// Produce a `RecordBatch` if at least one row is fully decoded, returning + /// `Ok(None)` if no new rows are available. + /// + /// If a schema change was detected while decoding rows for the current batch, the + /// schema switch is applied **after** flushing this batch, so the **next** batch + /// (if any) may have a different schema. + pub fn flush(&mut self) -> Result, ArrowError> { + // We must flush the active decoder before switching to the pending one. + let batch = self.flush_and_reset(); + self.apply_pending_schema(); + batch + } + + /// Returns the number of rows that can be added to this decoder before it is full. + pub fn capacity(&self) -> usize { + self.remaining_capacity + } + + /// Returns true if the decoder has reached its capacity for the current batch. + pub fn batch_is_full(&self) -> bool { + self.remaining_capacity == 0 + } + + /// Returns true if the decoder has not decoded any batches yet (i.e., the current batch is empty). + pub fn batch_is_empty(&self) -> bool { + self.remaining_capacity == self.batch_size + } + + // Decode either the block count or remaining capacity from `data` (an OCF block payload). + // + // Returns the number of bytes consumed from `data` along with the number of records decoded. + fn decode_block(&mut self, data: &[u8], count: usize) -> Result<(usize, usize), ArrowError> { + // OCF decoding never interleaves records across blocks, so no chunking. + let to_decode = std::cmp::min(count, self.remaining_capacity); + if to_decode == 0 { + return Ok((0, 0)); + } + let consumed = self.active_decoder.decode(data, to_decode)?; + self.remaining_capacity -= to_decode; + Ok((consumed, to_decode)) + } + + // Produce a `RecordBatch` if at least one row is fully decoded, returning + // `Ok(None)` if no new rows are available. + fn flush_block(&mut self) -> Result, ArrowError> { + self.flush_and_reset() + } +} + +/// A builder that configures and constructs Avro readers and decoders. +/// +/// `ReaderBuilder` is the primary entry point for this module. It supports: +/// +/// * OCF reading via `Self::build`, returning a `Reader` over any `BufRead`; +/// * streaming decoding via `Self::build_decoder`, returning a `Decoder`. +/// +/// ### Options +/// +/// * **`batch_size`**: Max rows per `RecordBatch` (default: `1024`). See `Self::with_batch_size`. +/// * **`utf8_view`**: Use Arrow `StringViewArray` for string columns (default: `false`). +/// See `Self::with_utf8_view`. +/// * **`strict_mode`**: Opt‑in to stricter union handling (default: `false`). +/// See `Self::with_strict_mode`. +/// * **`reader_schema`**: Optional reader schema (projection / evolution) used when decoding +/// values (default: `None`). See `Self::with_reader_schema`. +/// * **`writer_schema_store`**: Required for building a `Decoder` for single‑object or +/// Confluent framing. Maps fingerprints to Avro schemas. See `Self::with_writer_schema_store`. +/// * **`active_fingerprint`**: Optional starting fingerprint for streaming decode when the +/// first frame omits one (rare). See `Self::with_active_fingerprint`. +/// +/// ### Examples +/// +/// Read an OCF file in batches of 4096 rows: /// -/// This struct contains configuration options that control how Avro data is -/// converted into Arrow arrays. It allows customizing various aspects of the -/// data conversion process. +/// ```no_run +/// use std::fs::File; +/// use std::io::BufReader; +/// use arrow_avro::reader::ReaderBuilder; /// -/// # Examples +/// let file = File::open("data.avro")?; +/// let mut reader = ReaderBuilder::new() +/// .with_batch_size(4096) +/// .build(BufReader::new(file))?; +/// # Ok::<(), Box>(()) +/// ``` +/// +/// Build a `Decoder` for Confluent messages: /// /// ``` -/// # use arrow_avro::reader::ReadOptions; -/// // Use default options (regular StringArray for strings) -/// let default_options = ReadOptions::default(); +/// use arrow_avro::schema::{AvroSchema, SchemaStore, Fingerprint, FingerprintAlgorithm}; +/// use arrow_avro::reader::ReaderBuilder; /// -/// // Enable Utf8View support for better string performance -/// let options = ReadOptions::default() -/// .with_utf8view(true); +/// let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::Id); +/// store.set(Fingerprint::Id(1234), AvroSchema::new(r#"{"type":"record","name":"E","fields":[]}"#.to_string()))?; +/// +/// let decoder = ReaderBuilder::new() +/// .with_writer_schema_store(store) +/// .build_decoder()?; +/// # Ok::<(), Box>(()) /// ``` -#[derive(Default, Debug, Clone)] -pub struct ReadOptions { - use_utf8view: bool, +#[derive(Debug)] +pub struct ReaderBuilder { + batch_size: usize, + strict_mode: bool, + utf8_view: bool, + reader_schema: Option, + writer_schema_store: Option, + active_fingerprint: Option, +} + +impl Default for ReaderBuilder { + fn default() -> Self { + Self { + batch_size: 1024, + strict_mode: false, + utf8_view: false, + reader_schema: None, + writer_schema_store: None, + active_fingerprint: None, + } + } } -impl ReadOptions { - /// Create a new `ReadOptions` with default values +impl ReaderBuilder { + /// Creates a new `ReaderBuilder` with defaults: + /// + /// * `batch_size = 1024` + /// * `strict_mode = false` + /// * `utf8_view = false` + /// * `reader_schema = None` + /// * `writer_schema_store = None` + /// * `active_fingerprint = None` pub fn new() -> Self { Self::default() } - /// Set whether to use StringViewArray for string data + fn make_record_decoder( + &self, + writer_schema: &Schema, + reader_schema: Option<&Schema>, + ) -> Result { + let mut builder = AvroFieldBuilder::new(writer_schema); + if let Some(reader_schema) = reader_schema { + builder = builder.with_reader_schema(reader_schema); + } + let root = builder + .with_utf8view(self.utf8_view) + .with_strict_mode(self.strict_mode) + .build()?; + RecordDecoder::try_new_with_options(root.data_type()) + } + + fn make_record_decoder_from_schemas( + &self, + writer_schema: &Schema, + reader_schema: Option<&AvroSchema>, + ) -> Result { + let reader_schema_raw = reader_schema.map(|s| s.schema()).transpose()?; + self.make_record_decoder(writer_schema, reader_schema_raw.as_ref()) + } + + fn make_decoder_with_parts( + &self, + active_decoder: RecordDecoder, + active_fingerprint: Option, + cache: IndexMap, + fingerprint_algorithm: FingerprintAlgorithm, + ) -> Decoder { + Decoder { + batch_size: self.batch_size, + remaining_capacity: self.batch_size, + active_fingerprint, + active_decoder, + cache, + fingerprint_algorithm, + pending_schema: None, + awaiting_body: false, + } + } + + fn make_decoder( + &self, + header: Option<&Header>, + reader_schema: Option<&AvroSchema>, + ) -> Result { + if let Some(hdr) = header { + let writer_schema = hdr + .schema() + .map_err(|e| ArrowError::ExternalError(Box::new(e)))? + .ok_or_else(|| { + ArrowError::ParseError("No Avro schema present in file header".into()) + })?; + let record_decoder = + self.make_record_decoder_from_schemas(&writer_schema, reader_schema)?; + return Ok(self.make_decoder_with_parts( + record_decoder, + None, + IndexMap::new(), + FingerprintAlgorithm::Rabin, + )); + } + let store = self.writer_schema_store.as_ref().ok_or_else(|| { + ArrowError::ParseError("Writer schema store required for raw Avro".into()) + })?; + let fingerprints = store.fingerprints(); + if fingerprints.is_empty() { + return Err(ArrowError::ParseError( + "Writer schema store must contain at least one schema".into(), + )); + } + let start_fingerprint = self + .active_fingerprint + .or_else(|| fingerprints.first().copied()) + .ok_or_else(|| { + ArrowError::ParseError("Could not determine initial schema fingerprint".into()) + })?; + let mut cache = IndexMap::with_capacity(fingerprints.len().saturating_sub(1)); + let mut active_decoder: Option = None; + for fingerprint in store.fingerprints() { + let avro_schema = match store.lookup(&fingerprint) { + Some(schema) => schema, + None => { + return Err(ArrowError::ComputeError(format!( + "Fingerprint {fingerprint:?} not found in schema store", + ))); + } + }; + let writer_schema = avro_schema.schema()?; + let record_decoder = + self.make_record_decoder_from_schemas(&writer_schema, reader_schema)?; + if fingerprint == start_fingerprint { + active_decoder = Some(record_decoder); + } else { + cache.insert(fingerprint, record_decoder); + } + } + let active_decoder = active_decoder.ok_or_else(|| { + ArrowError::ComputeError(format!( + "Initial fingerprint {start_fingerprint:?} not found in schema store" + )) + })?; + Ok(self.make_decoder_with_parts( + active_decoder, + Some(start_fingerprint), + cache, + store.fingerprint_algorithm(), + )) + } + + /// Sets the **row‑based batch size**. + /// + /// Each call to `Decoder::flush` or each iteration of `Reader` yields a batch with + /// *up to* this many rows. Larger batches can reduce overhead; smaller batches can + /// reduce peak memory usage and latency. + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = batch_size; + self + } + + /// Choose Arrow's `StringViewArray` for UTF‑8 string data. /// - /// When enabled, string data from Avro files will be loaded into - /// Arrow's StringViewArray instead of the standard StringArray. - pub fn with_utf8view(mut self, use_utf8view: bool) -> Self { - self.use_utf8view = use_utf8view; + /// When enabled, textual Avro fields are loaded into Arrow’s **StringViewArray** + /// instead of the standard `StringArray`. This can improve performance for workloads + /// with many short strings by reducing allocations. + pub fn with_utf8_view(mut self, utf8_view: bool) -> Self { + self.utf8_view = utf8_view; self } - /// Get whether StringViewArray is enabled for string data + /// Returns whether `StringViewArray` is enabled for string data. pub fn use_utf8view(&self) -> bool { - self.use_utf8view + self.utf8_view } -} -/// Read a [`Header`] from the provided [`BufRead`] -fn read_header(mut reader: R) -> Result { - let mut decoder = HeaderDecoder::default(); - loop { - let buf = reader.fill_buf()?; - if buf.is_empty() { - break; - } - let read = buf.len(); - let decoded = decoder.decode(buf)?; - reader.consume(decoded); - if decoded != read { - break; + /// Enable stricter behavior for certain Avro unions (e.g., `[T, "null"]`). + /// + /// When `true`, ambiguous or lossy unions that would otherwise be coerced may instead + /// produce a descriptive error. Use this to catch schema issues early during ingestion. + pub fn with_strict_mode(mut self, strict_mode: bool) -> Self { + self.strict_mode = strict_mode; + self + } + + /// Sets the **reader schema** used during decoding. + /// + /// If not provided, the writer schema from the OCF header (for `Reader`) or the + /// schema looked up from the fingerprint (for `Decoder`) is used directly. + /// + /// A reader schema can be used for **schema evolution** or **projection**. + pub fn with_reader_schema(mut self, schema: AvroSchema) -> Self { + self.reader_schema = Some(schema); + self + } + + /// Sets the `SchemaStore` used to resolve writer schemas by fingerprint. + /// + /// This is required when building a `Decoder` for **single‑object encoding** or the + /// **Confluent** wire format. The store maps a fingerprint (Rabin / MD5 / SHA‑256 / + /// ID) to a full Avro schema. + /// + /// Defaults to `None`. + pub fn with_writer_schema_store(mut self, store: SchemaStore) -> Self { + self.writer_schema_store = Some(store); + self + } + + /// Sets the initial schema fingerprint for stream decoding. + /// + /// This can be useful for streams that **do not include** a fingerprint before the first + /// record body (uncommon). If not set, the first observed fingerprint is used. + pub fn with_active_fingerprint(mut self, fp: Fingerprint) -> Self { + self.active_fingerprint = Some(fp); + self + } + + /// Build a `Reader` (OCF) from this builder and a `BufRead`. + /// + /// This reads and validates the OCF header, initializes an internal row decoder from + /// the discovered writer (and optional reader) schema, and prepares to iterate blocks, + /// decompressing if necessary. + pub fn build(self, mut reader: R) -> Result, ArrowError> { + let header = read_header(&mut reader)?; + let decoder = self.make_decoder(Some(&header), self.reader_schema.as_ref())?; + Ok(Reader { + reader, + header, + decoder, + block_decoder: BlockDecoder::default(), + block_data: Vec::new(), + block_count: 0, + block_cursor: 0, + finished: false, + }) + } + + /// Build a streaming `Decoder` from this builder. + /// + /// # Requirements + /// * `SchemaStore` **must** be provided via `Self::with_writer_schema_store`. + /// * The store should contain **all** fingerprints that may appear on the stream. + /// + /// # Errors + /// * Returns [`ArrowError::InvalidArgumentError`] if the schema store is missing + pub fn build_decoder(self) -> Result { + if self.writer_schema_store.is_none() { + return Err(ArrowError::InvalidArgumentError( + "Building a decoder requires a writer schema store".to_string(), + )); } + self.make_decoder(None, self.reader_schema.as_ref()) } +} - decoder - .flush() - .ok_or_else(|| ArrowError::ParseError("Unexpected EOF".to_string())) +/// A high‑level Avro **Object Container File** reader. +/// +/// `Reader` pulls blocks from a `BufRead` source, handles optional block compression, +/// and decodes them row‑by‑row into Arrow `RecordBatch` values using an internal +/// `Decoder`. It implements both: +/// +/// * [`Iterator>`], and +/// * `RecordBatchReader`, guaranteeing a consistent schema across all produced batches. +/// +#[derive(Debug)] +pub struct Reader { + reader: R, + header: Header, + decoder: Decoder, + block_decoder: BlockDecoder, + block_data: Vec, + block_count: usize, + block_cursor: usize, + finished: bool, } -/// Return an iterator of [`Block`] from the provided [`BufRead`] -fn read_blocks(mut reader: R) -> impl Iterator> { - let mut decoder = BlockDecoder::default(); +impl Reader { + /// Returns the Arrow schema discovered from the Avro file header (or derived via + /// the optional reader schema). + pub fn schema(&self) -> SchemaRef { + self.decoder.schema() + } - let mut try_next = move || { - loop { - let buf = reader.fill_buf()?; - if buf.is_empty() { - break; + /// Returns a reference to the parsed Avro container‑file header (magic, metadata, codec, sync). + pub fn avro_header(&self) -> &Header { + &self.header + } + + /// Reads the next `RecordBatch` from the Avro file, or `Ok(None)` on EOF. + /// + /// Batches are bounded by `batch_size`; a single OCF block may yield multiple batches, + /// and a batch may also span multiple blocks. + fn read(&mut self) -> Result, ArrowError> { + 'outer: while !self.finished && !self.decoder.batch_is_full() { + while self.block_cursor == self.block_data.len() { + let buf = self.reader.fill_buf()?; + if buf.is_empty() { + self.finished = true; + break 'outer; + } + // Try to decode another block from the buffered reader. + let consumed = self.block_decoder.decode(buf)?; + self.reader.consume(consumed); + if let Some(block) = self.block_decoder.flush() { + // Successfully decoded a block. + self.block_data = if let Some(ref codec) = self.header.compression()? { + codec.decompress(&block.data)? + } else { + block.data + }; + self.block_count = block.count; + self.block_cursor = 0; + } else if consumed == 0 { + // The block decoder made no progress on a non-empty buffer. + return Err(ArrowError::ParseError( + "Could not decode next Avro block from partial data".to_string(), + )); + } } - let read = buf.len(); - let decoded = decoder.decode(buf)?; - reader.consume(decoded); - if decoded != read { - break; + // Decode as many rows as will fit in the current batch + if self.block_cursor < self.block_data.len() { + let (consumed, records_decoded) = self + .decoder + .decode_block(&self.block_data[self.block_cursor..], self.block_count)?; + self.block_cursor += consumed; + self.block_count -= records_decoded; } } - Ok(decoder.flush()) - }; - std::iter::from_fn(move || try_next().transpose()) + self.decoder.flush_block() + } +} + +impl Iterator for Reader { + type Item = Result; + + fn next(&mut self) -> Option { + self.read().transpose() + } +} + +impl RecordBatchReader for Reader { + fn schema(&self) -> SchemaRef { + self.schema() + } } #[cfg(test)] mod test { - use crate::codec::{AvroDataType, AvroField, Codec}; - use crate::compression::CompressionCodec; + use crate::codec::AvroFieldBuilder; use crate::reader::record::RecordDecoder; - use crate::reader::{read_blocks, read_header}; + use crate::reader::{Decoder, Reader, ReaderBuilder}; + use crate::schema::{ + AVRO_ENUM_SYMBOLS_METADATA_KEY, AVRO_NAME_METADATA_KEY, AVRO_NAMESPACE_METADATA_KEY, + AvroSchema, CONFLUENT_MAGIC, Fingerprint, FingerprintAlgorithm, PrimitiveType, + SINGLE_OBJECT_MAGIC, SchemaStore, + }; use crate::test_util::arrow_test_data; + use crate::writer::AvroWriter; + use arrow_array::builder::{ + ArrayBuilder, BooleanBuilder, Float32Builder, Int32Builder, Int64Builder, ListBuilder, + MapBuilder, StringBuilder, StructBuilder, + }; + #[cfg(feature = "snappy")] + use arrow_array::builder::{Float64Builder, MapFieldNames}; + use arrow_array::cast::AsArray; + #[cfg(not(feature = "avro_custom_types"))] + use arrow_array::types::Int64Type; + #[cfg(feature = "avro_custom_types")] + use arrow_array::types::{ + DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType, + DurationSecondType, + }; + use arrow_array::types::{Int32Type, IntervalMonthDayNanoType}; use arrow_array::*; - use arrow_schema::{DataType, Field}; + #[cfg(feature = "snappy")] + use arrow_buffer::{Buffer, NullBuffer}; + use arrow_buffer::{IntervalMonthDayNano, OffsetBuffer, ScalarBuffer, i256}; + #[cfg(feature = "avro_custom_types")] + use arrow_schema::{ + ArrowError, DataType, Field, FieldRef, Fields, IntervalUnit, Schema, TimeUnit, UnionFields, + UnionMode, + }; + #[cfg(not(feature = "avro_custom_types"))] + use arrow_schema::{ + ArrowError, DataType, Field, FieldRef, Fields, IntervalUnit, Schema, UnionFields, UnionMode, + }; + use bytes::Bytes; + use futures::executor::block_on; + use futures::{Stream, StreamExt, TryStreamExt, stream}; + use serde_json::{Value, json}; use std::collections::HashMap; use std::fs::File; - use std::io::BufReader; + use std::io::{BufReader, Cursor}; use std::sync::Arc; - fn read_file(file: &str, batch_size: usize) -> RecordBatch { - read_file_with_options(file, batch_size, &crate::ReadOptions::default()) + fn files() -> impl Iterator { + [ + // TODO: avoid requiring snappy for this file + #[cfg(feature = "snappy")] + "avro/alltypes_plain.avro", + #[cfg(feature = "snappy")] + "avro/alltypes_plain.snappy.avro", + #[cfg(feature = "zstd")] + "avro/alltypes_plain.zstandard.avro", + #[cfg(feature = "bzip2")] + "avro/alltypes_plain.bzip2.avro", + #[cfg(feature = "xz")] + "avro/alltypes_plain.xz.avro", + ] + .into_iter() + } + + fn read_file(path: &str, batch_size: usize, utf8_view: bool) -> RecordBatch { + let file = File::open(path).unwrap(); + let reader = ReaderBuilder::new() + .with_batch_size(batch_size) + .with_utf8_view(utf8_view) + .build(BufReader::new(file)) + .unwrap(); + let schema = reader.schema(); + let batches = reader.collect::, _>>().unwrap(); + arrow::compute::concat_batches(&schema, &batches).unwrap() } - fn read_file_with_options( - file: &str, + fn read_file_strict( + path: &str, batch_size: usize, - options: &crate::ReadOptions, - ) -> RecordBatch { - let file = File::open(file).unwrap(); - let mut reader = BufReader::new(file); - let header = read_header(&mut reader).unwrap(); - let compression = header.compression().unwrap(); - let schema = header.schema().unwrap().unwrap(); - let root = AvroField::try_from(&schema).unwrap(); - - let mut decoder = - RecordDecoder::try_new_with_options(root.data_type(), options.clone()).unwrap(); - - for result in read_blocks(reader) { - let block = result.unwrap(); - assert_eq!(block.sync, header.sync()); - if let Some(c) = compression { - let decompressed = c.decompress(&block.data).unwrap(); - - let mut offset = 0; - let mut remaining = block.count; - while remaining > 0 { - let to_read = remaining.max(batch_size); - offset += decoder - .decode(&decompressed[offset..], block.count) - .unwrap(); + utf8_view: bool, + ) -> Result>, ArrowError> { + let file = File::open(path)?; + ReaderBuilder::new() + .with_batch_size(batch_size) + .with_utf8_view(utf8_view) + .with_strict_mode(true) + .build(BufReader::new(file)) + } - remaining -= to_read; + fn decode_stream + Unpin>( + mut decoder: Decoder, + mut input: S, + ) -> impl Stream> { + async_stream::try_stream! { + if let Some(data) = input.next().await { + let consumed = decoder.decode(&data)?; + if consumed < data.len() { + Err(ArrowError::ParseError( + "did not consume all bytes".to_string(), + ))?; } - assert_eq!(offset, decompressed.len()); + } + if let Some(batch) = decoder.flush()? { + yield batch } } - decoder.flush().unwrap() } - #[test] - fn test_utf8view_support() { - let schema_json = r#"{ - "type": "record", - "name": "test", - "fields": [{ - "name": "str_field", - "type": "string" - }] - }"#; + fn make_record_schema(pt: PrimitiveType) -> AvroSchema { + let js = format!( + r#"{{"type":"record","name":"TestRecord","fields":[{{"name":"a","type":"{}"}}]}}"#, + pt.as_ref() + ); + AvroSchema::new(js) + } + + fn make_two_schema_store() -> ( + SchemaStore, + Fingerprint, + Fingerprint, + AvroSchema, + AvroSchema, + ) { + let schema_int = make_record_schema(PrimitiveType::Int); + let schema_long = make_record_schema(PrimitiveType::Long); + let mut store = SchemaStore::new(); + let fp_int = store + .register(schema_int.clone()) + .expect("register int schema"); + let fp_long = store + .register(schema_long.clone()) + .expect("register long schema"); + (store, fp_int, fp_long, schema_int, schema_long) + } + + fn make_prefix(fp: Fingerprint) -> Vec { + match fp { + Fingerprint::Rabin(v) => { + let mut out = Vec::with_capacity(2 + 8); + out.extend_from_slice(&SINGLE_OBJECT_MAGIC); + out.extend_from_slice(&v.to_le_bytes()); + out + } + Fingerprint::Id(v) => { + panic!("make_prefix expects a Rabin fingerprint, got ({v})"); + } + Fingerprint::Id64(v) => { + panic!("make_prefix expects a Rabin fingerprint, got ({v})"); + } + #[cfg(feature = "md5")] + Fingerprint::MD5(v) => { + panic!("make_prefix expects a Rabin fingerprint, got ({v:?})"); + } + #[cfg(feature = "sha256")] + Fingerprint::SHA256(id) => { + panic!("make_prefix expects a Rabin fingerprint, got ({id:?})"); + } + } + } - let schema: crate::schema::Schema = serde_json::from_str(schema_json).unwrap(); - let avro_field = AvroField::try_from(&schema).unwrap(); + fn make_decoder(store: &SchemaStore, fp: Fingerprint, reader_schema: &AvroSchema) -> Decoder { + ReaderBuilder::new() + .with_batch_size(8) + .with_reader_schema(reader_schema.clone()) + .with_writer_schema_store(store.clone()) + .with_active_fingerprint(fp) + .build_decoder() + .expect("decoder") + } - let data_type = avro_field.data_type(); + fn make_id_prefix(id: u32, additional: usize) -> Vec { + let capacity = CONFLUENT_MAGIC.len() + size_of::() + additional; + let mut out = Vec::with_capacity(capacity); + out.extend_from_slice(&CONFLUENT_MAGIC); + out.extend_from_slice(&id.to_be_bytes()); + out + } - struct TestHelper; - impl TestHelper { - fn with_utf8view(field: &Field) -> Field { - match field.data_type() { - DataType::Utf8 => { - Field::new(field.name(), DataType::Utf8View, field.is_nullable()) - .with_metadata(field.metadata().clone()) + fn make_message_id(id: u32, value: i64) -> Vec { + let encoded_value = encode_zigzag(value); + let mut msg = make_id_prefix(id, encoded_value.len()); + msg.extend_from_slice(&encoded_value); + msg + } + + fn make_id64_prefix(id: u64, additional: usize) -> Vec { + let capacity = CONFLUENT_MAGIC.len() + size_of::() + additional; + let mut out = Vec::with_capacity(capacity); + out.extend_from_slice(&CONFLUENT_MAGIC); + out.extend_from_slice(&id.to_be_bytes()); + out + } + + fn make_message_id64(id: u64, value: i64) -> Vec { + let encoded_value = encode_zigzag(value); + let mut msg = make_id64_prefix(id, encoded_value.len()); + msg.extend_from_slice(&encoded_value); + msg + } + + fn make_value_schema(pt: PrimitiveType) -> AvroSchema { + let json_schema = format!( + r#"{{"type":"record","name":"S","fields":[{{"name":"v","type":"{}"}}]}}"#, + pt.as_ref() + ); + AvroSchema::new(json_schema) + } + + fn encode_zigzag(value: i64) -> Vec { + let mut n = ((value << 1) ^ (value >> 63)) as u64; + let mut out = Vec::new(); + loop { + if (n & !0x7F) == 0 { + out.push(n as u8); + break; + } else { + out.push(((n & 0x7F) | 0x80) as u8); + n >>= 7; + } + } + out + } + + fn make_message(fp: Fingerprint, value: i64) -> Vec { + let mut msg = make_prefix(fp); + msg.extend_from_slice(&encode_zigzag(value)); + msg + } + + fn load_writer_schema_json(path: &str) -> Value { + let file = File::open(path).unwrap(); + let header = super::read_header(BufReader::new(file)).unwrap(); + let schema = header.schema().unwrap().unwrap(); + serde_json::to_value(&schema).unwrap() + } + + fn make_reader_schema_with_promotions( + path: &str, + promotions: &HashMap<&str, &str>, + ) -> AvroSchema { + let mut root = load_writer_schema_json(path); + assert_eq!(root["type"], "record", "writer schema must be a record"); + let fields = root + .get_mut("fields") + .and_then(|f| f.as_array_mut()) + .expect("record has fields"); + for f in fields.iter_mut() { + let Some(name) = f.get("name").and_then(|n| n.as_str()) else { + continue; + }; + if let Some(new_ty) = promotions.get(name) { + let ty = f.get_mut("type").expect("field has a type"); + match ty { + Value::String(_) => { + *ty = Value::String((*new_ty).to_string()); } - _ => field.clone(), + // Union + Value::Array(arr) => { + for b in arr.iter_mut() { + match b { + Value::String(s) if s != "null" => { + *b = Value::String((*new_ty).to_string()); + break; + } + Value::Object(_) => { + *b = Value::String((*new_ty).to_string()); + break; + } + _ => {} + } + } + } + Value::Object(_) => { + *ty = Value::String((*new_ty).to_string()); + } + _ => {} } } } + AvroSchema::new(root.to_string()) + } - let field = TestHelper::with_utf8view(&Field::new("str_field", DataType::Utf8, false)); + fn make_reader_schema_with_enum_remap( + path: &str, + remap: &HashMap<&str, Vec<&str>>, + ) -> AvroSchema { + let mut root = load_writer_schema_json(path); + assert_eq!(root["type"], "record", "writer schema must be a record"); + let fields = root + .get_mut("fields") + .and_then(|f| f.as_array_mut()) + .expect("record has fields"); - assert_eq!(field.data_type(), &DataType::Utf8View); + fn to_symbols_array(symbols: &[&str]) -> Value { + Value::Array(symbols.iter().map(|s| Value::String((*s).into())).collect()) + } - let array = StringViewArray::from(vec!["test1", "test2"]); - let batch = - RecordBatch::try_from_iter(vec![("str_field", Arc::new(array) as ArrayRef)]).unwrap(); + fn update_enum_symbols(ty: &mut Value, symbols: &Value) { + match ty { + Value::Object(map) => { + if matches!(map.get("type"), Some(Value::String(t)) if t == "enum") { + map.insert("symbols".to_string(), symbols.clone()); + } + } + Value::Array(arr) => { + for b in arr.iter_mut() { + if let Value::Object(map) = b { + if matches!(map.get("type"), Some(Value::String(t)) if t == "enum") { + map.insert("symbols".to_string(), symbols.clone()); + } + } + } + } + _ => {} + } + } + for f in fields.iter_mut() { + let Some(name) = f.get("name").and_then(|n| n.as_str()) else { + continue; + }; + if let Some(new_symbols) = remap.get(name) { + let symbols_val = to_symbols_array(new_symbols); + let ty = f.get_mut("type").expect("field has a type"); + update_enum_symbols(ty, &symbols_val); + } + } + AvroSchema::new(root.to_string()) + } - assert!(batch.column(0).as_any().is::()); + fn read_alltypes_with_reader_schema(path: &str, reader_schema: AvroSchema) -> RecordBatch { + let file = File::open(path).unwrap(); + let reader = ReaderBuilder::new() + .with_batch_size(1024) + .with_utf8_view(false) + .with_reader_schema(reader_schema) + .build(BufReader::new(file)) + .unwrap(); + let schema = reader.schema(); + let batches = reader.collect::, _>>().unwrap(); + arrow::compute::concat_batches(&schema, &batches).unwrap() } - #[test] - fn test_alltypes() { - let files = [ - "avro/alltypes_plain.avro", - "avro/alltypes_plain.snappy.avro", - "avro/alltypes_plain.zstandard.avro", - ]; + fn make_reader_schema_with_selected_fields_in_order( + path: &str, + selected: &[&str], + ) -> AvroSchema { + let mut root = load_writer_schema_json(path); + assert_eq!(root["type"], "record", "writer schema must be a record"); + let writer_fields = root + .get("fields") + .and_then(|f| f.as_array()) + .expect("record has fields"); + let mut field_map: HashMap = HashMap::with_capacity(writer_fields.len()); + for f in writer_fields { + if let Some(name) = f.get("name").and_then(|n| n.as_str()) { + field_map.insert(name.to_string(), f.clone()); + } + } + let mut new_fields = Vec::with_capacity(selected.len()); + for name in selected { + let f = field_map + .get(*name) + .unwrap_or_else(|| panic!("field '{name}' not found in writer schema")) + .clone(); + new_fields.push(f); + } + root["fields"] = Value::Array(new_fields); + AvroSchema::new(root.to_string()) + } - let expected = RecordBatch::try_from_iter_with_nullable([ - ( - "id", - Arc::new(Int32Array::from(vec![4, 5, 6, 7, 2, 3, 0, 1])) as _, - true, - ), - ( - "bool_col", - Arc::new(BooleanArray::from_iter((0..8).map(|x| Some(x % 2 == 0)))) as _, - true, - ), - ( - "tinyint_col", - Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, - true, - ), - ( - "smallint_col", - Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, - true, - ), - ( - "int_col", - Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, - true, - ), - ( - "bigint_col", - Arc::new(Int64Array::from_iter_values((0..8).map(|x| (x % 2) * 10))) as _, - true, - ), - ( - "float_col", - Arc::new(Float32Array::from_iter_values( - (0..8).map(|x| (x % 2) as f32 * 1.1), - )) as _, - true, - ), - ( - "double_col", - Arc::new(Float64Array::from_iter_values( - (0..8).map(|x| (x % 2) as f64 * 10.1), - )) as _, - true, - ), - ( - "date_string_col", - Arc::new(BinaryArray::from_iter_values([ - [48, 51, 47, 48, 49, 47, 48, 57], - [48, 51, 47, 48, 49, 47, 48, 57], - [48, 52, 47, 48, 49, 47, 48, 57], - [48, 52, 47, 48, 49, 47, 48, 57], - [48, 50, 47, 48, 49, 47, 48, 57], - [48, 50, 47, 48, 49, 47, 48, 57], - [48, 49, 47, 48, 49, 47, 48, 57], - [48, 49, 47, 48, 49, 47, 48, 57], - ])) as _, - true, - ), - ( - "string_col", - Arc::new(BinaryArray::from_iter_values((0..8).map(|x| [48 + x % 2]))) as _, - true, - ), + fn write_ocf(schema: &Schema, batches: &[RecordBatch]) -> Vec { + let mut w = AvroWriter::new(Vec::::new(), schema.clone()).expect("writer"); + for b in batches { + w.write(b).expect("write"); + } + w.finish().expect("finish"); + w.into_inner() + } + + #[test] + fn writer_string_reader_nullable_with_alias() -> Result<(), Box> { + // Writer: { id: long, name: string } + let writer_schema = Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + ]); + let batch = RecordBatch::try_new( + Arc::new(writer_schema.clone()), + vec![ + Arc::new(Int64Array::from(vec![1, 2])) as ArrayRef, + Arc::new(StringArray::from(vec!["a", "b"])) as ArrayRef, + ], + )?; + let bytes = write_ocf(&writer_schema, &[batch]); + let reader_json = r#" + { + "type": "record", + "name": "topLevelRecord", + "fields": [ + { "name": "id", "type": "long" }, + { "name": "full_name", "type": ["null","string"], "aliases": ["name"], "default": null }, + { "name": "is_active", "type": "boolean", "default": true } + ] + }"#; + let mut reader = ReaderBuilder::new() + .with_reader_schema(AvroSchema::new(reader_json.to_string())) + .build(Cursor::new(bytes))?; + let out = reader.next().unwrap()?; + // Evolved aliased field should be non-null and match original writer values + let full_name = out.column(1).as_string::(); + assert_eq!(full_name.value(0), "a"); + assert_eq!(full_name.value(1), "b"); + + Ok(()) + } + + #[test] + fn writer_string_reader_string_null_order_second() -> Result<(), Box> { + // Writer: { name: string } + let writer_schema = Schema::new(vec![Field::new("name", DataType::Utf8, false)]); + let batch = RecordBatch::try_new( + Arc::new(writer_schema.clone()), + vec![Arc::new(StringArray::from(vec!["x", "y"])) as ArrayRef], + )?; + let bytes = write_ocf(&writer_schema, &[batch]); + + // Reader: ["string","null"] (NullSecond) + let reader_json = r#" + { + "type":"record", "name":"topLevelRecord", + "fields":[ { "name":"name", "type":["string","null"], "default":"x" } ] + }"#; + + let mut reader = ReaderBuilder::new() + .with_reader_schema(AvroSchema::new(reader_json.to_string())) + .build(Cursor::new(bytes))?; + + let out = reader.next().unwrap()?; + assert_eq!(out.num_rows(), 2); + + // Should decode as non-null strings (writer non-union -> reader union) + let name = out.column(0).as_string::(); + assert_eq!(name.value(0), "x"); + assert_eq!(name.value(1), "y"); + + Ok(()) + } + + #[test] + fn promotion_writer_int_reader_nullable_long() -> Result<(), Box> { + // Writer: { v: int } + let writer_schema = Schema::new(vec![Field::new("v", DataType::Int32, false)]); + let batch = RecordBatch::try_new( + Arc::new(writer_schema.clone()), + vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef], + )?; + let bytes = write_ocf(&writer_schema, &[batch]); + + // Reader: { v: ["null","long"] } + let reader_json = r#" + { + "type":"record", "name":"topLevelRecord", + "fields":[ { "name":"v", "type":["null","long"], "default": null } ] + }"#; + + let mut reader = ReaderBuilder::new() + .with_reader_schema(AvroSchema::new(reader_json.to_string())) + .build(Cursor::new(bytes))?; + + let out = reader.next().unwrap()?; + assert_eq!(out.num_rows(), 3); + + // Should have promoted to Int64 and be non-null (no union tag in writer) + let v = out + .column(0) + .as_primitive::(); + assert_eq!(v.values(), &[1, 2, 3]); + assert!( + out.column(0).nulls().is_none(), + "expected no validity bitmap for all-valid column" + ); + + Ok(()) + } + + #[test] + fn test_alltypes_schema_promotion_mixed() { + for file in files() { + let file = arrow_test_data(file); + let mut promotions: HashMap<&str, &str> = HashMap::new(); + promotions.insert("id", "long"); + promotions.insert("tinyint_col", "float"); + promotions.insert("smallint_col", "double"); + promotions.insert("int_col", "double"); + promotions.insert("bigint_col", "double"); + promotions.insert("float_col", "double"); + promotions.insert("date_string_col", "string"); + promotions.insert("string_col", "string"); + let reader_schema = make_reader_schema_with_promotions(&file, &promotions); + let batch = read_alltypes_with_reader_schema(&file, reader_schema); + let expected = RecordBatch::try_from_iter_with_nullable([ + ( + "id", + Arc::new(Int64Array::from(vec![4i64, 5, 6, 7, 2, 3, 0, 1])) as _, + true, + ), + ( + "bool_col", + Arc::new(BooleanArray::from_iter((0..8).map(|x| Some(x % 2 == 0)))) as _, + true, + ), + ( + "tinyint_col", + Arc::new(Float32Array::from_iter_values( + (0..8).map(|x| (x % 2) as f32), + )) as _, + true, + ), + ( + "smallint_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| (x % 2) as f64), + )) as _, + true, + ), + ( + "int_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| (x % 2) as f64), + )) as _, + true, + ), + ( + "bigint_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| ((x % 2) * 10) as f64), + )) as _, + true, + ), + ( + "float_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| ((x % 2) as f32 * 1.1f32) as f64), + )) as _, + true, + ), + ( + "double_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| (x % 2) as f64 * 10.1), + )) as _, + true, + ), + ( + "date_string_col", + Arc::new(StringArray::from(vec![ + "03/01/09", "03/01/09", "04/01/09", "04/01/09", "02/01/09", "02/01/09", + "01/01/09", "01/01/09", + ])) as _, + true, + ), + ( + "string_col", + Arc::new(StringArray::from( + (0..8) + .map(|x| if x % 2 == 0 { "0" } else { "1" }) + .collect::>(), + )) as _, + true, + ), + ( + "timestamp_col", + Arc::new( + TimestampMicrosecondArray::from_iter_values([ + 1235865600000000, // 2009-03-01T00:00:00.000 + 1235865660000000, // 2009-03-01T00:01:00.000 + 1238544000000000, // 2009-04-01T00:00:00.000 + 1238544060000000, // 2009-04-01T00:01:00.000 + 1233446400000000, // 2009-02-01T00:00:00.000 + 1233446460000000, // 2009-02-01T00:01:00.000 + 1230768000000000, // 2009-01-01T00:00:00.000 + 1230768060000000, // 2009-01-01T00:01:00.000 + ]) + .with_timezone("+00:00"), + ) as _, + true, + ), + ]) + .unwrap(); + assert_eq!(batch, expected, "mismatch for file {file}"); + } + } + + #[test] + fn test_alltypes_schema_promotion_long_to_float_only() { + for file in files() { + let file = arrow_test_data(file); + let mut promotions: HashMap<&str, &str> = HashMap::new(); + promotions.insert("bigint_col", "float"); + let reader_schema = make_reader_schema_with_promotions(&file, &promotions); + let batch = read_alltypes_with_reader_schema(&file, reader_schema); + let expected = RecordBatch::try_from_iter_with_nullable([ + ( + "id", + Arc::new(Int32Array::from(vec![4, 5, 6, 7, 2, 3, 0, 1])) as _, + true, + ), + ( + "bool_col", + Arc::new(BooleanArray::from_iter((0..8).map(|x| Some(x % 2 == 0)))) as _, + true, + ), + ( + "tinyint_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "smallint_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "int_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "bigint_col", + Arc::new(Float32Array::from_iter_values( + (0..8).map(|x| ((x % 2) * 10) as f32), + )) as _, + true, + ), + ( + "float_col", + Arc::new(Float32Array::from_iter_values( + (0..8).map(|x| (x % 2) as f32 * 1.1), + )) as _, + true, + ), + ( + "double_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| (x % 2) as f64 * 10.1), + )) as _, + true, + ), + ( + "date_string_col", + Arc::new(BinaryArray::from_iter_values([ + [48, 51, 47, 48, 49, 47, 48, 57], + [48, 51, 47, 48, 49, 47, 48, 57], + [48, 52, 47, 48, 49, 47, 48, 57], + [48, 52, 47, 48, 49, 47, 48, 57], + [48, 50, 47, 48, 49, 47, 48, 57], + [48, 50, 47, 48, 49, 47, 48, 57], + [48, 49, 47, 48, 49, 47, 48, 57], + [48, 49, 47, 48, 49, 47, 48, 57], + ])) as _, + true, + ), + ( + "string_col", + Arc::new(BinaryArray::from_iter_values((0..8).map(|x| [48 + x % 2]))) as _, + true, + ), + ( + "timestamp_col", + Arc::new( + TimestampMicrosecondArray::from_iter_values([ + 1235865600000000, // 2009-03-01T00:00:00.000 + 1235865660000000, // 2009-03-01T00:01:00.000 + 1238544000000000, // 2009-04-01T00:00:00.000 + 1238544060000000, // 2009-04-01T00:01:00.000 + 1233446400000000, // 2009-02-01T00:00:00.000 + 1233446460000000, // 2009-02-01T00:01:00.000 + 1230768000000000, // 2009-01-01T00:00:00.000 + 1230768060000000, // 2009-01-01T00:01:00.000 + ]) + .with_timezone("+00:00"), + ) as _, + true, + ), + ]) + .unwrap(); + assert_eq!(batch, expected, "mismatch for file {file}"); + } + } + + #[test] + fn test_alltypes_schema_promotion_bytes_to_string_only() { + for file in files() { + let file = arrow_test_data(file); + let mut promotions: HashMap<&str, &str> = HashMap::new(); + promotions.insert("date_string_col", "string"); + promotions.insert("string_col", "string"); + let reader_schema = make_reader_schema_with_promotions(&file, &promotions); + let batch = read_alltypes_with_reader_schema(&file, reader_schema); + let expected = RecordBatch::try_from_iter_with_nullable([ + ( + "id", + Arc::new(Int32Array::from(vec![4, 5, 6, 7, 2, 3, 0, 1])) as _, + true, + ), + ( + "bool_col", + Arc::new(BooleanArray::from_iter((0..8).map(|x| Some(x % 2 == 0)))) as _, + true, + ), + ( + "tinyint_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "smallint_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "int_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "bigint_col", + Arc::new(Int64Array::from_iter_values((0..8).map(|x| (x % 2) * 10))) as _, + true, + ), + ( + "float_col", + Arc::new(Float32Array::from_iter_values( + (0..8).map(|x| (x % 2) as f32 * 1.1), + )) as _, + true, + ), + ( + "double_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| (x % 2) as f64 * 10.1), + )) as _, + true, + ), + ( + "date_string_col", + Arc::new(StringArray::from(vec![ + "03/01/09", "03/01/09", "04/01/09", "04/01/09", "02/01/09", "02/01/09", + "01/01/09", "01/01/09", + ])) as _, + true, + ), + ( + "string_col", + Arc::new(StringArray::from( + (0..8) + .map(|x| if x % 2 == 0 { "0" } else { "1" }) + .collect::>(), + )) as _, + true, + ), + ( + "timestamp_col", + Arc::new( + TimestampMicrosecondArray::from_iter_values([ + 1235865600000000, // 2009-03-01T00:00:00.000 + 1235865660000000, // 2009-03-01T00:01:00.000 + 1238544000000000, // 2009-04-01T00:00:00.000 + 1238544060000000, // 2009-04-01T00:01:00.000 + 1233446400000000, // 2009-02-01T00:00:00.000 + 1233446460000000, // 2009-02-01T00:01:00.000 + 1230768000000000, // 2009-01-01T00:00:00.000 + 1230768060000000, // 2009-01-01T00:01:00.000 + ]) + .with_timezone("+00:00"), + ) as _, + true, + ), + ]) + .unwrap(); + assert_eq!(batch, expected, "mismatch for file {file}"); + } + } + + #[test] + // TODO: avoid requiring snappy for this file + #[cfg(feature = "snappy")] + fn test_alltypes_illegal_promotion_bool_to_double_errors() { + let file = arrow_test_data("avro/alltypes_plain.avro"); + let mut promotions: HashMap<&str, &str> = HashMap::new(); + promotions.insert("bool_col", "double"); // illegal + let reader_schema = make_reader_schema_with_promotions(&file, &promotions); + let file_handle = File::open(&file).unwrap(); + let result = ReaderBuilder::new() + .with_reader_schema(reader_schema) + .build(BufReader::new(file_handle)); + let err = result.expect_err("expected illegal promotion to error"); + let msg = err.to_string(); + assert!( + msg.contains("Illegal promotion") || msg.contains("illegal promotion"), + "unexpected error: {msg}" + ); + } + + #[test] + fn test_simple_enum_with_reader_schema_mapping() { + let file = arrow_test_data("avro/simple_enum.avro"); + let mut remap: HashMap<&str, Vec<&str>> = HashMap::new(); + remap.insert("f1", vec!["d", "c", "b", "a"]); + remap.insert("f2", vec!["h", "g", "f", "e"]); + remap.insert("f3", vec!["k", "i", "j"]); + let reader_schema = make_reader_schema_with_enum_remap(&file, &remap); + let actual = read_alltypes_with_reader_schema(&file, reader_schema); + let dict_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + // f1 + let f1_keys = Int32Array::from(vec![3, 2, 1, 0]); + let f1_vals = StringArray::from(vec!["d", "c", "b", "a"]); + let f1 = DictionaryArray::::try_new(f1_keys, Arc::new(f1_vals)).unwrap(); + let mut md_f1 = HashMap::new(); + md_f1.insert( + AVRO_ENUM_SYMBOLS_METADATA_KEY.to_string(), + r#"["d","c","b","a"]"#.to_string(), + ); + // New named-type metadata + md_f1.insert("avro.name".to_string(), "enum1".to_string()); + md_f1.insert("avro.namespace".to_string(), "ns1".to_string()); + let f1_field = Field::new("f1", dict_type.clone(), false).with_metadata(md_f1); + // f2 + let f2_keys = Int32Array::from(vec![1, 0, 3, 2]); + let f2_vals = StringArray::from(vec!["h", "g", "f", "e"]); + let f2 = DictionaryArray::::try_new(f2_keys, Arc::new(f2_vals)).unwrap(); + let mut md_f2 = HashMap::new(); + md_f2.insert( + AVRO_ENUM_SYMBOLS_METADATA_KEY.to_string(), + r#"["h","g","f","e"]"#.to_string(), + ); + // New named-type metadata + md_f2.insert("avro.name".to_string(), "enum2".to_string()); + md_f2.insert("avro.namespace".to_string(), "ns2".to_string()); + let f2_field = Field::new("f2", dict_type.clone(), false).with_metadata(md_f2); + // f3 + let f3_keys = Int32Array::from(vec![Some(2), Some(0), None, Some(1)]); + let f3_vals = StringArray::from(vec!["k", "i", "j"]); + let f3 = DictionaryArray::::try_new(f3_keys, Arc::new(f3_vals)).unwrap(); + let mut md_f3 = HashMap::new(); + md_f3.insert( + AVRO_ENUM_SYMBOLS_METADATA_KEY.to_string(), + r#"["k","i","j"]"#.to_string(), + ); + // New named-type metadata + md_f3.insert("avro.name".to_string(), "enum3".to_string()); + md_f3.insert("avro.namespace".to_string(), "ns1".to_string()); + let f3_field = Field::new("f3", dict_type.clone(), true).with_metadata(md_f3); + let expected_schema = Arc::new(Schema::new(vec![f1_field, f2_field, f3_field])); + let expected = RecordBatch::try_new( + expected_schema, + vec![Arc::new(f1) as ArrayRef, Arc::new(f2), Arc::new(f3)], + ) + .unwrap(); + assert_eq!(actual, expected); + } + + #[test] + fn test_schema_store_register_lookup() { + let schema_int = make_record_schema(PrimitiveType::Int); + let schema_long = make_record_schema(PrimitiveType::Long); + let mut store = SchemaStore::new(); + let fp_int = store.register(schema_int.clone()).unwrap(); + let fp_long = store.register(schema_long.clone()).unwrap(); + assert_eq!(store.lookup(&fp_int).cloned(), Some(schema_int)); + assert_eq!(store.lookup(&fp_long).cloned(), Some(schema_long)); + assert_eq!(store.fingerprint_algorithm(), FingerprintAlgorithm::Rabin); + } + + #[test] + fn test_unknown_fingerprint_is_error() { + let (store, fp_int, _fp_long, _schema_int, schema_long) = make_two_schema_store(); + let unknown_fp = Fingerprint::Rabin(0xDEAD_BEEF_DEAD_BEEF); + let prefix = make_prefix(unknown_fp); + let mut decoder = make_decoder(&store, fp_int, &schema_long); + let err = decoder.decode(&prefix).expect_err("decode should error"); + let msg = err.to_string(); + assert!( + msg.contains("Unknown fingerprint"), + "unexpected message: {msg}" + ); + } + + #[test] + fn test_handle_prefix_incomplete_magic() { + let (store, fp_int, _fp_long, _schema_int, schema_long) = make_two_schema_store(); + let mut decoder = make_decoder(&store, fp_int, &schema_long); + let buf = &SINGLE_OBJECT_MAGIC[..1]; + let res = decoder.handle_prefix(buf).unwrap(); + assert_eq!(res, Some(0)); + assert!(decoder.pending_schema.is_none()); + } + + #[test] + fn test_handle_prefix_magic_mismatch() { + let (store, fp_int, _fp_long, _schema_int, schema_long) = make_two_schema_store(); + let mut decoder = make_decoder(&store, fp_int, &schema_long); + let buf = [0xFFu8, 0x00u8, 0x01u8]; + let res = decoder.handle_prefix(&buf).unwrap(); + assert!(res.is_none()); + } + + #[test] + fn test_handle_prefix_incomplete_fingerprint() { + let (store, fp_int, fp_long, _schema_int, schema_long) = make_two_schema_store(); + let mut decoder = make_decoder(&store, fp_int, &schema_long); + let long_bytes = match fp_long { + Fingerprint::Rabin(v) => v.to_le_bytes(), + Fingerprint::Id(id) => panic!("expected Rabin fingerprint, got ({id})"), + Fingerprint::Id64(id) => panic!("expected Rabin fingerprint, got ({id})"), + #[cfg(feature = "md5")] + Fingerprint::MD5(v) => panic!("expected Rabin fingerprint, got ({v:?})"), + #[cfg(feature = "sha256")] + Fingerprint::SHA256(v) => panic!("expected Rabin fingerprint, got ({v:?})"), + }; + let mut buf = Vec::from(SINGLE_OBJECT_MAGIC); + buf.extend_from_slice(&long_bytes[..4]); + let res = decoder.handle_prefix(&buf).unwrap(); + assert_eq!(res, Some(0)); + assert!(decoder.pending_schema.is_none()); + } + + #[test] + fn test_handle_prefix_valid_prefix_switches_schema() { + let (store, fp_int, fp_long, _schema_int, schema_long) = make_two_schema_store(); + let mut decoder = make_decoder(&store, fp_int, &schema_long); + let writer_schema_long = schema_long.schema().unwrap(); + let root_long = AvroFieldBuilder::new(&writer_schema_long).build().unwrap(); + let long_decoder = RecordDecoder::try_new_with_options(root_long.data_type()).unwrap(); + let _ = decoder.cache.insert(fp_long, long_decoder); + let mut buf = Vec::from(SINGLE_OBJECT_MAGIC); + match fp_long { + Fingerprint::Rabin(v) => buf.extend_from_slice(&v.to_le_bytes()), + Fingerprint::Id(id) => panic!("expected Rabin fingerprint, got ({id})"), + Fingerprint::Id64(id) => panic!("expected Rabin fingerprint, got ({id})"), + #[cfg(feature = "md5")] + Fingerprint::MD5(v) => panic!("expected Rabin fingerprint, got ({v:?})"), + #[cfg(feature = "sha256")] + Fingerprint::SHA256(v) => panic!("expected Rabin fingerprint, got ({v:?})"), + } + let consumed = decoder.handle_prefix(&buf).unwrap().unwrap(); + assert_eq!(consumed, buf.len()); + assert!(decoder.pending_schema.is_some()); + assert_eq!(decoder.pending_schema.as_ref().unwrap().0, fp_long); + } + + #[test] + fn test_two_messages_same_schema() { + let writer_schema = make_value_schema(PrimitiveType::Int); + let reader_schema = writer_schema.clone(); + let mut store = SchemaStore::new(); + let fp = store.register(writer_schema).unwrap(); + let msg1 = make_message(fp, 42); + let msg2 = make_message(fp, 11); + let input = [msg1.clone(), msg2.clone()].concat(); + let mut decoder = ReaderBuilder::new() + .with_batch_size(8) + .with_reader_schema(reader_schema.clone()) + .with_writer_schema_store(store) + .with_active_fingerprint(fp) + .build_decoder() + .unwrap(); + let _ = decoder.decode(&input).unwrap(); + let batch = decoder.flush().unwrap().expect("batch"); + assert_eq!(batch.num_rows(), 2); + let col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.value(0), 42); + assert_eq!(col.value(1), 11); + } + + #[test] + fn test_two_messages_schema_switch() { + let w_int = make_value_schema(PrimitiveType::Int); + let w_long = make_value_schema(PrimitiveType::Long); + let mut store = SchemaStore::new(); + let fp_int = store.register(w_int).unwrap(); + let fp_long = store.register(w_long).unwrap(); + let msg_int = make_message(fp_int, 1); + let msg_long = make_message(fp_long, 123456789_i64); + let mut decoder = ReaderBuilder::new() + .with_batch_size(8) + .with_writer_schema_store(store) + .with_active_fingerprint(fp_int) + .build_decoder() + .unwrap(); + let _ = decoder.decode(&msg_int).unwrap(); + let batch1 = decoder.flush().unwrap().expect("batch1"); + assert_eq!(batch1.num_rows(), 1); + assert_eq!( + batch1 + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + 1 + ); + let _ = decoder.decode(&msg_long).unwrap(); + let batch2 = decoder.flush().unwrap().expect("batch2"); + assert_eq!(batch2.num_rows(), 1); + assert_eq!( + batch2 + .column(0) + .as_any() + .downcast_ref::() + .unwrap() + .value(0), + 123456789_i64 + ); + } + + #[test] + fn test_two_messages_same_schema_id() { + let writer_schema = make_value_schema(PrimitiveType::Int); + let reader_schema = writer_schema.clone(); + let id = 100u32; + // Set up store with None fingerprint algorithm and register schema by id + let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::Id); + let _ = store + .set(Fingerprint::Id(id), writer_schema.clone()) + .expect("set id schema"); + let msg1 = make_message_id(id, 21); + let msg2 = make_message_id(id, 22); + let input = [msg1.clone(), msg2.clone()].concat(); + let mut decoder = ReaderBuilder::new() + .with_batch_size(8) + .with_reader_schema(reader_schema) + .with_writer_schema_store(store) + .with_active_fingerprint(Fingerprint::Id(id)) + .build_decoder() + .unwrap(); + let _ = decoder.decode(&input).unwrap(); + let batch = decoder.flush().unwrap().expect("batch"); + assert_eq!(batch.num_rows(), 2); + let col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.value(0), 21); + assert_eq!(col.value(1), 22); + } + + #[test] + fn test_unknown_id_fingerprint_is_error() { + let writer_schema = make_value_schema(PrimitiveType::Int); + let id_known = 7u32; + let id_unknown = 9u32; + let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::Id); + let _ = store + .set(Fingerprint::Id(id_known), writer_schema.clone()) + .expect("set id schema"); + let mut decoder = ReaderBuilder::new() + .with_batch_size(8) + .with_reader_schema(writer_schema) + .with_writer_schema_store(store) + .with_active_fingerprint(Fingerprint::Id(id_known)) + .build_decoder() + .unwrap(); + let prefix = make_id_prefix(id_unknown, 0); + let err = decoder.decode(&prefix).expect_err("decode should error"); + let msg = err.to_string(); + assert!( + msg.contains("Unknown fingerprint"), + "unexpected message: {msg}" + ); + } + + #[test] + fn test_handle_prefix_id_incomplete_magic() { + let writer_schema = make_value_schema(PrimitiveType::Int); + let id = 5u32; + let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::Id); + let _ = store + .set(Fingerprint::Id(id), writer_schema.clone()) + .expect("set id schema"); + let mut decoder = ReaderBuilder::new() + .with_batch_size(8) + .with_reader_schema(writer_schema) + .with_writer_schema_store(store) + .with_active_fingerprint(Fingerprint::Id(id)) + .build_decoder() + .unwrap(); + let buf = &CONFLUENT_MAGIC[..0]; // empty incomplete magic + let res = decoder.handle_prefix(buf).unwrap(); + assert_eq!(res, Some(0)); + assert!(decoder.pending_schema.is_none()); + } + + #[test] + fn test_two_messages_same_schema_id64() { + let writer_schema = make_value_schema(PrimitiveType::Int); + let reader_schema = writer_schema.clone(); + let id = 100u64; + // Set up store with None fingerprint algorithm and register schema by id + let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::Id64); + let _ = store + .set(Fingerprint::Id64(id), writer_schema.clone()) + .expect("set id schema"); + let msg1 = make_message_id64(id, 21); + let msg2 = make_message_id64(id, 22); + let input = [msg1.clone(), msg2.clone()].concat(); + let mut decoder = ReaderBuilder::new() + .with_batch_size(8) + .with_reader_schema(reader_schema) + .with_writer_schema_store(store) + .with_active_fingerprint(Fingerprint::Id64(id)) + .build_decoder() + .unwrap(); + let _ = decoder.decode(&input).unwrap(); + let batch = decoder.flush().unwrap().expect("batch"); + assert_eq!(batch.num_rows(), 2); + let col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col.value(0), 21); + assert_eq!(col.value(1), 22); + } + + #[test] + fn test_decode_stream_with_schema() { + struct TestCase<'a> { + name: &'a str, + schema: &'a str, + expected_error: Option<&'a str>, + } + let tests = vec![ + TestCase { + name: "success", + schema: r#"{"type":"record","name":"test","fields":[{"name":"f2","type":"string"}]}"#, + expected_error: None, + }, + TestCase { + name: "valid schema invalid data", + schema: r#"{"type":"record","name":"test","fields":[{"name":"f2","type":"long"}]}"#, + expected_error: Some("did not consume all bytes"), + }, + ]; + for test in tests { + let avro_schema = AvroSchema::new(test.schema.to_string()); + let mut store = SchemaStore::new(); + let fp = store.register(avro_schema.clone()).unwrap(); + let prefix = make_prefix(fp); + let record_val = "some_string"; + let mut body = prefix; + body.push((record_val.len() as u8) << 1); + body.extend_from_slice(record_val.as_bytes()); + let decoder_res = ReaderBuilder::new() + .with_batch_size(1) + .with_writer_schema_store(store) + .with_active_fingerprint(fp) + .build_decoder(); + let decoder = match decoder_res { + Ok(d) => d, + Err(e) => { + if let Some(expected) = test.expected_error { + assert!( + e.to_string().contains(expected), + "Test '{}' failed at build – expected '{expected}', got '{e}'", + test.name + ); + continue; + } else { + panic!("Test '{}' failed during build: {e}", test.name); + } + } + }; + let stream = Box::pin(stream::once(async { Bytes::from(body) })); + let decoded_stream = decode_stream(decoder, stream); + let batches_result: Result, ArrowError> = + block_on(decoded_stream.try_collect()); + match (batches_result, test.expected_error) { + (Ok(batches), None) => { + let batch = + arrow::compute::concat_batches(&batches[0].schema(), &batches).unwrap(); + let expected_field = Field::new("f2", DataType::Utf8, false); + let expected_schema = Arc::new(Schema::new(vec![expected_field])); + let expected_array = Arc::new(StringArray::from(vec![record_val])); + let expected_batch = + RecordBatch::try_new(expected_schema, vec![expected_array]).unwrap(); + assert_eq!(batch, expected_batch, "Test '{}'", test.name); + } + (Err(e), Some(expected)) => { + assert!( + e.to_string().contains(expected), + "Test '{}' – expected error containing '{expected}', got '{e}'", + test.name + ); + } + (Ok(_), Some(expected)) => { + panic!( + "Test '{}' expected failure ('{expected}') but succeeded", + test.name + ); + } + (Err(e), None) => { + panic!("Test '{}' unexpectedly failed with '{e}'", test.name); + } + } + } + } + + #[test] + fn test_utf8view_support() { + struct TestHelper; + impl TestHelper { + fn with_utf8view(field: &Field) -> Field { + match field.data_type() { + DataType::Utf8 => { + Field::new(field.name(), DataType::Utf8View, field.is_nullable()) + .with_metadata(field.metadata().clone()) + } + _ => field.clone(), + } + } + } + + let field = TestHelper::with_utf8view(&Field::new("str_field", DataType::Utf8, false)); + + assert_eq!(field.data_type(), &DataType::Utf8View); + + let array = StringViewArray::from(vec!["test1", "test2"]); + let batch = + RecordBatch::try_from_iter(vec![("str_field", Arc::new(array) as ArrayRef)]).unwrap(); + + assert!(batch.column(0).as_any().is::()); + } + + fn make_reader_schema_with_default_fields( + path: &str, + default_fields: Vec, + ) -> AvroSchema { + let mut root = load_writer_schema_json(path); + assert_eq!(root["type"], "record", "writer schema must be a record"); + root.as_object_mut() + .expect("schema is a JSON object") + .insert("fields".to_string(), Value::Array(default_fields)); + AvroSchema::new(root.to_string()) + } + + #[test] + fn test_schema_resolution_defaults_all_supported_types() { + let path = "test/data/skippable_types.avro"; + let duration_default = "\u{0000}".repeat(12); + let reader_schema = make_reader_schema_with_default_fields( + path, + vec![ + serde_json::json!({"name":"d_bool","type":"boolean","default":true}), + serde_json::json!({"name":"d_int","type":"int","default":42}), + serde_json::json!({"name":"d_long","type":"long","default":12345}), + serde_json::json!({"name":"d_float","type":"float","default":1.5}), + serde_json::json!({"name":"d_double","type":"double","default":2.25}), + serde_json::json!({"name":"d_bytes","type":"bytes","default":"XYZ"}), + serde_json::json!({"name":"d_string","type":"string","default":"hello"}), + serde_json::json!({"name":"d_date","type":{"type":"int","logicalType":"date"},"default":0}), + serde_json::json!({"name":"d_time_ms","type":{"type":"int","logicalType":"time-millis"},"default":1000}), + serde_json::json!({"name":"d_time_us","type":{"type":"long","logicalType":"time-micros"},"default":2000}), + serde_json::json!({"name":"d_ts_ms","type":{"type":"long","logicalType":"local-timestamp-millis"},"default":0}), + serde_json::json!({"name":"d_ts_us","type":{"type":"long","logicalType":"local-timestamp-micros"},"default":0}), + serde_json::json!({"name":"d_decimal","type":{"type":"bytes","logicalType":"decimal","precision":10,"scale":2},"default":""}), + serde_json::json!({"name":"d_fixed","type":{"type":"fixed","name":"F4","size":4},"default":"ABCD"}), + serde_json::json!({"name":"d_enum","type":{"type":"enum","name":"E","symbols":["A","B","C"]},"default":"A"}), + serde_json::json!({"name":"d_duration","type":{"type":"fixed","name":"Dur","size":12,"logicalType":"duration"},"default":duration_default}), + serde_json::json!({"name":"d_uuid","type":{"type":"string","logicalType":"uuid"},"default":"00000000-0000-0000-0000-000000000000"}), + serde_json::json!({"name":"d_array","type":{"type":"array","items":"int"},"default":[1,2,3]}), + serde_json::json!({"name":"d_map","type":{"type":"map","values":"long"},"default":{"a":1,"b":2}}), + serde_json::json!({"name":"d_record","type":{ + "type":"record","name":"DefaultRec","fields":[ + {"name":"x","type":"int"}, + {"name":"y","type":["null","string"],"default":null} + ] + },"default":{"x":7}}), + serde_json::json!({"name":"d_nullable_null","type":["null","int"],"default":null}), + serde_json::json!({"name":"d_nullable_value","type":["int","null"],"default":123}), + ], + ); + let actual = read_alltypes_with_reader_schema(path, reader_schema); + let num_rows = actual.num_rows(); + assert!(num_rows > 0, "skippable_types.avro should contain rows"); + assert_eq!( + actual.num_columns(), + 22, + "expected exactly our defaulted fields" + ); + let mut arrays: Vec> = Vec::with_capacity(22); + arrays.push(Arc::new(BooleanArray::from_iter(std::iter::repeat_n( + Some(true), + num_rows, + )))); + arrays.push(Arc::new(Int32Array::from_iter_values(std::iter::repeat_n( + 42, num_rows, + )))); + arrays.push(Arc::new(Int64Array::from_iter_values(std::iter::repeat_n( + 12345, num_rows, + )))); + arrays.push(Arc::new(Float32Array::from_iter_values( + std::iter::repeat_n(1.5f32, num_rows), + ))); + arrays.push(Arc::new(Float64Array::from_iter_values( + std::iter::repeat_n(2.25f64, num_rows), + ))); + arrays.push(Arc::new(BinaryArray::from_iter_values( + std::iter::repeat_n(b"XYZ".as_ref(), num_rows), + ))); + arrays.push(Arc::new(StringArray::from_iter_values( + std::iter::repeat_n("hello", num_rows), + ))); + arrays.push(Arc::new(Date32Array::from_iter_values( + std::iter::repeat_n(0, num_rows), + ))); + arrays.push(Arc::new(Time32MillisecondArray::from_iter_values( + std::iter::repeat_n(1_000, num_rows), + ))); + arrays.push(Arc::new(Time64MicrosecondArray::from_iter_values( + std::iter::repeat_n(2_000i64, num_rows), + ))); + arrays.push(Arc::new(TimestampMillisecondArray::from_iter_values( + std::iter::repeat_n(0i64, num_rows), + ))); + arrays.push(Arc::new(TimestampMicrosecondArray::from_iter_values( + std::iter::repeat_n(0i64, num_rows), + ))); + #[cfg(feature = "small_decimals")] + let decimal = Decimal64Array::from_iter_values(std::iter::repeat_n(0i64, num_rows)) + .with_precision_and_scale(10, 2) + .unwrap(); + #[cfg(not(feature = "small_decimals"))] + let decimal = Decimal128Array::from_iter_values(std::iter::repeat_n(0i128, num_rows)) + .with_precision_and_scale(10, 2) + .unwrap(); + arrays.push(Arc::new(decimal)); + let fixed_iter = std::iter::repeat_n(Some(*b"ABCD"), num_rows); + arrays.push(Arc::new( + FixedSizeBinaryArray::try_from_sparse_iter_with_size(fixed_iter, 4).unwrap(), + )); + let enum_keys = Int32Array::from_iter_values(std::iter::repeat_n(0, num_rows)); + let enum_values = StringArray::from_iter_values(["A", "B", "C"]); + let enum_arr = + DictionaryArray::::try_new(enum_keys, Arc::new(enum_values)).unwrap(); + arrays.push(Arc::new(enum_arr)); + let duration_values = std::iter::repeat_n( + Some(IntervalMonthDayNanoType::make_value(0, 0, 0)), + num_rows, + ); + let duration_arr: IntervalMonthDayNanoArray = duration_values.collect(); + arrays.push(Arc::new(duration_arr)); + let uuid_bytes = [0u8; 16]; + let uuid_iter = std::iter::repeat_n(Some(uuid_bytes), num_rows); + arrays.push(Arc::new( + FixedSizeBinaryArray::try_from_sparse_iter_with_size(uuid_iter, 16).unwrap(), + )); + let item_field = Arc::new(Field::new( + Field::LIST_FIELD_DEFAULT_NAME, + DataType::Int32, + false, + )); + let mut list_builder = ListBuilder::new(Int32Builder::new()).with_field(item_field); + for _ in 0..num_rows { + list_builder.values().append_value(1); + list_builder.values().append_value(2); + list_builder.values().append_value(3); + list_builder.append(true); + } + arrays.push(Arc::new(list_builder.finish())); + let values_field = Arc::new(Field::new("value", DataType::Int64, false)); + let mut map_builder = MapBuilder::new( + Some(builder::MapFieldNames { + entry: "entries".to_string(), + key: "key".to_string(), + value: "value".to_string(), + }), + StringBuilder::new(), + Int64Builder::new(), + ) + .with_values_field(values_field); + for _ in 0..num_rows { + let (keys, vals) = map_builder.entries(); + keys.append_value("a"); + vals.append_value(1); + keys.append_value("b"); + vals.append_value(2); + map_builder.append(true).unwrap(); + } + arrays.push(Arc::new(map_builder.finish())); + let rec_fields: Fields = Fields::from(vec![ + Field::new("x", DataType::Int32, false), + Field::new("y", DataType::Utf8, true), + ]); + let mut sb = StructBuilder::new( + rec_fields.clone(), + vec![ + Box::new(Int32Builder::new()), + Box::new(StringBuilder::new()), + ], + ); + for _ in 0..num_rows { + sb.field_builder::(0).unwrap().append_value(7); + sb.field_builder::(1).unwrap().append_null(); + sb.append(true); + } + arrays.push(Arc::new(sb.finish())); + arrays.push(Arc::new(Int32Array::from_iter(std::iter::repeat_n( + None::, + num_rows, + )))); + arrays.push(Arc::new(Int32Array::from_iter_values(std::iter::repeat_n( + 123, num_rows, + )))); + let expected = RecordBatch::try_new(actual.schema(), arrays).unwrap(); + assert_eq!( + actual, expected, + "defaults should materialize correctly for all fields" + ); + } + + #[test] + fn test_schema_resolution_default_enum_invalid_symbol_errors() { + let path = "test/data/skippable_types.avro"; + let bad_schema = make_reader_schema_with_default_fields( + path, + vec![serde_json::json!({ + "name":"bad_enum", + "type":{"type":"enum","name":"E","symbols":["A","B","C"]}, + "default":"Z" + })], + ); + let file = File::open(path).unwrap(); + let res = ReaderBuilder::new() + .with_reader_schema(bad_schema) + .build(BufReader::new(file)); + let err = res.expect_err("expected enum default validation to fail"); + let msg = err.to_string(); + let lower_msg = msg.to_lowercase(); + assert!( + lower_msg.contains("enum") + && (lower_msg.contains("symbol") || lower_msg.contains("default")), + "unexpected error: {msg}" + ); + } + + #[test] + fn test_schema_resolution_default_fixed_size_mismatch_errors() { + let path = "test/data/skippable_types.avro"; + let bad_schema = make_reader_schema_with_default_fields( + path, + vec![serde_json::json!({ + "name":"bad_fixed", + "type":{"type":"fixed","name":"F","size":4}, + "default":"ABC" + })], + ); + let file = File::open(path).unwrap(); + let res = ReaderBuilder::new() + .with_reader_schema(bad_schema) + .build(BufReader::new(file)); + let err = res.expect_err("expected fixed default validation to fail"); + let msg = err.to_string(); + let lower_msg = msg.to_lowercase(); + assert!( + lower_msg.contains("fixed") + && (lower_msg.contains("size") + || lower_msg.contains("length") + || lower_msg.contains("does not match")), + "unexpected error: {msg}" + ); + } + + #[test] + // TODO: avoid requiring snappy for this file + #[cfg(feature = "snappy")] + fn test_alltypes_skip_writer_fields_keep_double_only() { + let file = arrow_test_data("avro/alltypes_plain.avro"); + let reader_schema = + make_reader_schema_with_selected_fields_in_order(&file, &["double_col"]); + let batch = read_alltypes_with_reader_schema(&file, reader_schema); + let expected = RecordBatch::try_from_iter_with_nullable([( + "double_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| (x % 2) as f64 * 10.1), + )) as _, + true, + )]) + .unwrap(); + assert_eq!(batch, expected); + } + + #[test] + // TODO: avoid requiring snappy for this file + #[cfg(feature = "snappy")] + fn test_alltypes_skip_writer_fields_reorder_and_skip_many() { + let file = arrow_test_data("avro/alltypes_plain.avro"); + let reader_schema = + make_reader_schema_with_selected_fields_in_order(&file, &["timestamp_col", "id"]); + let batch = read_alltypes_with_reader_schema(&file, reader_schema); + let expected = RecordBatch::try_from_iter_with_nullable([ ( "timestamp_col", Arc::new( @@ -298,14 +2794,5920 @@ mod test { ) as _, true, ), + ( + "id", + Arc::new(Int32Array::from(vec![4, 5, 6, 7, 2, 3, 0, 1])) as _, + true, + ), ]) .unwrap(); + assert_eq!(batch, expected); + } - for file in files { - let file = arrow_test_data(file); + #[test] + fn test_skippable_types_project_each_field_individually() { + let path = "test/data/skippable_types.avro"; + let full = read_file(path, 1024, false); + let schema_full = full.schema(); + let num_rows = full.num_rows(); + let writer_json = load_writer_schema_json(path); + assert_eq!( + writer_json["type"], "record", + "writer schema must be a record" + ); + let fields_json = writer_json + .get("fields") + .and_then(|f| f.as_array()) + .expect("record has fields"); + assert_eq!( + schema_full.fields().len(), + fields_json.len(), + "full read column count vs writer fields" + ); + fn rebuild_list_array_with_element( + col: &ArrayRef, + new_elem: Arc, + is_large: bool, + ) -> ArrayRef { + if is_large { + let list = col + .as_any() + .downcast_ref::() + .expect("expected LargeListArray"); + let offsets = list.offsets().clone(); + let values = list.values().clone(); + let validity = list.nulls().cloned(); + Arc::new(LargeListArray::try_new(new_elem, offsets, values, validity).unwrap()) + } else { + let list = col + .as_any() + .downcast_ref::() + .expect("expected ListArray"); + let offsets = list.offsets().clone(); + let values = list.values().clone(); + let validity = list.nulls().cloned(); + Arc::new(ListArray::try_new(new_elem, offsets, values, validity).unwrap()) + } + } + for (idx, f) in fields_json.iter().enumerate() { + let name = f + .get("name") + .and_then(|n| n.as_str()) + .unwrap_or_else(|| panic!("field at index {idx} has no name")); + let reader_schema = make_reader_schema_with_selected_fields_in_order(path, &[name]); + let projected = read_alltypes_with_reader_schema(path, reader_schema); + assert_eq!( + projected.num_columns(), + 1, + "projected batch should contain exactly the selected column '{name}'" + ); + assert_eq!( + projected.num_rows(), + num_rows, + "row count mismatch for projected column '{name}'" + ); + let col_full = full.column(idx).clone(); + let full_field = schema_full.field(idx).as_ref().clone(); + let proj_field_ref = projected.schema().field(0).clone(); + let proj_field = proj_field_ref.as_ref(); + let top_meta = proj_field.metadata().clone(); + let (expected_field_ref, expected_col): (Arc, ArrayRef) = + match (full_field.data_type(), proj_field.data_type()) { + (&DataType::List(_), DataType::List(proj_elem)) => { + let new_col = + rebuild_list_array_with_element(&col_full, proj_elem.clone(), false); + let nf = Field::new( + full_field.name().clone(), + proj_field.data_type().clone(), + full_field.is_nullable(), + ) + .with_metadata(top_meta); + (Arc::new(nf), new_col) + } + (&DataType::LargeList(_), DataType::LargeList(proj_elem)) => { + let new_col = + rebuild_list_array_with_element(&col_full, proj_elem.clone(), true); + let nf = Field::new( + full_field.name().clone(), + proj_field.data_type().clone(), + full_field.is_nullable(), + ) + .with_metadata(top_meta); + (Arc::new(nf), new_col) + } + _ => { + let nf = full_field.with_metadata(top_meta); + (Arc::new(nf), col_full) + } + }; + + let expected = RecordBatch::try_new( + Arc::new(Schema::new(vec![expected_field_ref])), + vec![expected_col], + ) + .unwrap(); + assert_eq!( + projected, expected, + "projected column '{name}' mismatch vs full read column" + ); + } + } + + #[test] + fn test_union_fields_avro_nullable_and_general_unions() { + let path = "test/data/union_fields.avro"; + let batch = read_file(path, 1024, false); + let schema = batch.schema(); + let idx = schema.index_of("nullable_int_nullfirst").unwrap(); + let a = batch.column(idx).as_primitive::(); + assert_eq!(a.len(), 4); + assert!(a.is_null(0)); + assert_eq!(a.value(1), 42); + assert!(a.is_null(2)); + assert_eq!(a.value(3), 0); + let idx = schema.index_of("nullable_string_nullsecond").unwrap(); + let s = batch + .column(idx) + .as_any() + .downcast_ref::() + .expect("nullable_string_nullsecond should be Utf8"); + assert_eq!(s.len(), 4); + assert_eq!(s.value(0), "s1"); + assert!(s.is_null(1)); + assert_eq!(s.value(2), "s3"); + assert!(s.is_valid(3)); // empty string, not null + assert_eq!(s.value(3), ""); + let idx = schema.index_of("union_prim").unwrap(); + let u = batch + .column(idx) + .as_any() + .downcast_ref::() + .expect("union_prim should be Union"); + let fields = match u.data_type() { + DataType::Union(fields, mode) => { + assert!(matches!(mode, UnionMode::Dense), "expect dense unions"); + fields + } + other => panic!("expected Union, got {other:?}"), + }; + let tid_by_name = |name: &str| -> i8 { + for (tid, f) in fields.iter() { + if f.name() == name { + return tid; + } + } + panic!("union child '{name}' not found"); + }; + let expected_type_ids = vec![ + tid_by_name("long"), + tid_by_name("int"), + tid_by_name("float"), + tid_by_name("double"), + ]; + let type_ids: Vec = u.type_ids().iter().copied().collect(); + assert_eq!( + type_ids, expected_type_ids, + "branch selection for union_prim rows" + ); + let longs = u + .child(tid_by_name("long")) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(longs.len(), 1); + let ints = u + .child(tid_by_name("int")) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(ints.len(), 1); + let floats = u + .child(tid_by_name("float")) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(floats.len(), 1); + let doubles = u + .child(tid_by_name("double")) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(doubles.len(), 1); + let idx = schema.index_of("union_bytes_vs_string").unwrap(); + let u = batch + .column(idx) + .as_any() + .downcast_ref::() + .expect("union_bytes_vs_string should be Union"); + let fields = match u.data_type() { + DataType::Union(fields, _) => fields, + other => panic!("expected Union, got {other:?}"), + }; + let tid_by_name = |name: &str| -> i8 { + for (tid, f) in fields.iter() { + if f.name() == name { + return tid; + } + } + panic!("union child '{name}' not found"); + }; + let tid_bytes = tid_by_name("bytes"); + let tid_string = tid_by_name("string"); + let type_ids: Vec = u.type_ids().iter().copied().collect(); + assert_eq!( + type_ids, + vec![tid_bytes, tid_string, tid_string, tid_bytes], + "branch selection for bytes/string union" + ); + let s_child = u + .child(tid_string) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(s_child.len(), 2); + assert_eq!(s_child.value(0), "hello"); + assert_eq!(s_child.value(1), "world"); + let b_child = u + .child(tid_bytes) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(b_child.len(), 2); + assert_eq!(b_child.value(0), &[0x00, 0xFF, 0x7F]); + assert_eq!(b_child.value(1), b""); // previously: &[] + let idx = schema.index_of("union_enum_records_array_map").unwrap(); + let u = batch + .column(idx) + .as_any() + .downcast_ref::() + .expect("union_enum_records_array_map should be Union"); + let fields = match u.data_type() { + DataType::Union(fields, _) => fields, + other => panic!("expected Union, got {other:?}"), + }; + let mut tid_enum: Option = None; + let mut tid_rec_a: Option = None; + let mut tid_rec_b: Option = None; + let mut tid_array: Option = None; + for (tid, f) in fields.iter() { + match f.data_type() { + DataType::Dictionary(_, _) => tid_enum = Some(tid), + DataType::Struct(childs) => { + if childs.len() == 2 && childs[0].name() == "a" && childs[1].name() == "b" { + tid_rec_a = Some(tid); + } else if childs.len() == 2 + && childs[0].name() == "x" + && childs[1].name() == "y" + { + tid_rec_b = Some(tid); + } + } + DataType::List(_) => tid_array = Some(tid), + _ => {} + } + } + let (tid_enum, tid_rec_a, tid_rec_b, tid_array) = ( + tid_enum.expect("enum child"), + tid_rec_a.expect("RecA child"), + tid_rec_b.expect("RecB child"), + tid_array.expect("array child"), + ); + let type_ids: Vec = u.type_ids().iter().copied().collect(); + assert_eq!( + type_ids, + vec![tid_enum, tid_rec_a, tid_rec_b, tid_array], + "branch selection for complex union" + ); + let dict = u + .child(tid_enum) + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(dict.len(), 1); + assert!(dict.is_valid(0)); + let rec_a = u + .child(tid_rec_a) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(rec_a.len(), 1); + let a_val = rec_a + .column_by_name("a") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(a_val.value(0), 7); + let b_val = rec_a + .column_by_name("b") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(b_val.value(0), "x"); + // RecB row: {"x": 123456789, "y": b"\xFF\x00"} + let rec_b = u + .child(tid_rec_b) + .as_any() + .downcast_ref::() + .unwrap(); + let x_val = rec_b + .column_by_name("x") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(x_val.value(0), 123_456_789_i64); + let y_val = rec_b + .column_by_name("y") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(y_val.value(0), &[0xFF, 0x00]); + let arr = u + .child(tid_array) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(arr.len(), 1); + let first_values = arr.value(0); + let longs = first_values.as_any().downcast_ref::().unwrap(); + assert_eq!(longs.len(), 3); + assert_eq!(longs.value(0), 1); + assert_eq!(longs.value(1), 2); + assert_eq!(longs.value(2), 3); + let idx = schema.index_of("union_date_or_fixed4").unwrap(); + let u = batch + .column(idx) + .as_any() + .downcast_ref::() + .expect("union_date_or_fixed4 should be Union"); + let fields = match u.data_type() { + DataType::Union(fields, _) => fields, + other => panic!("expected Union, got {other:?}"), + }; + let mut tid_date: Option = None; + let mut tid_fixed: Option = None; + for (tid, f) in fields.iter() { + match f.data_type() { + DataType::Date32 => tid_date = Some(tid), + DataType::FixedSizeBinary(4) => tid_fixed = Some(tid), + _ => {} + } + } + let (tid_date, tid_fixed) = (tid_date.expect("date"), tid_fixed.expect("fixed(4)")); + let type_ids: Vec = u.type_ids().iter().copied().collect(); + assert_eq!( + type_ids, + vec![tid_date, tid_fixed, tid_date, tid_fixed], + "branch selection for date/fixed4 union" + ); + let dates = u + .child(tid_date) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(dates.len(), 2); + assert_eq!(dates.value(0), 19_000); // ~2022‑01‑15 + assert_eq!(dates.value(1), 0); // epoch + let fixed = u + .child(tid_fixed) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(fixed.len(), 2); + assert_eq!(fixed.value(0), b"ABCD"); + assert_eq!(fixed.value(1), &[0x00, 0x11, 0x22, 0x33]); + } - assert_eq!(read_file(&file, 8), expected); - assert_eq!(read_file(&file, 3), expected); + #[test] + fn test_union_schema_resolution_all_type_combinations() { + let path = "test/data/union_fields.avro"; + let baseline = read_file(path, 1024, false); + let baseline_schema = baseline.schema(); + let mut root = load_writer_schema_json(path); + assert_eq!(root["type"], "record", "writer schema must be a record"); + let fields = root + .get_mut("fields") + .and_then(|f| f.as_array_mut()) + .expect("record has fields"); + fn is_named_type(obj: &Value, ty: &str, nm: &str) -> bool { + obj.get("type").and_then(|v| v.as_str()) == Some(ty) + && obj.get("name").and_then(|v| v.as_str()) == Some(nm) + } + fn is_logical(obj: &Value, prim: &str, lt: &str) -> bool { + obj.get("type").and_then(|v| v.as_str()) == Some(prim) + && obj.get("logicalType").and_then(|v| v.as_str()) == Some(lt) + } + fn find_first(arr: &[Value], pred: impl Fn(&Value) -> bool) -> Option { + arr.iter().find(|v| pred(v)).cloned() + } + fn prim(s: &str) -> Value { + Value::String(s.to_string()) + } + for f in fields.iter_mut() { + let Some(name) = f.get("name").and_then(|n| n.as_str()) else { + continue; + }; + match name { + // Flip null ordering – should not affect values + "nullable_int_nullfirst" => { + f["type"] = json!(["int", "null"]); + } + "nullable_string_nullsecond" => { + f["type"] = json!(["null", "string"]); + } + "union_prim" => { + let orig = f["type"].as_array().unwrap().clone(); + let long = prim("long"); + let double = prim("double"); + let string = prim("string"); + let bytes = prim("bytes"); + let boolean = prim("boolean"); + assert!(orig.contains(&long)); + assert!(orig.contains(&double)); + assert!(orig.contains(&string)); + assert!(orig.contains(&bytes)); + assert!(orig.contains(&boolean)); + f["type"] = json!([long, double, string, bytes, boolean]); + } + "union_bytes_vs_string" => { + f["type"] = json!(["string", "bytes"]); + } + "union_fixed_dur_decfix" => { + let orig = f["type"].as_array().unwrap().clone(); + let fx8 = find_first(&orig, |o| is_named_type(o, "fixed", "Fx8")).unwrap(); + let dur12 = find_first(&orig, |o| is_named_type(o, "fixed", "Dur12")).unwrap(); + let decfix16 = + find_first(&orig, |o| is_named_type(o, "fixed", "DecFix16")).unwrap(); + f["type"] = json!([decfix16, dur12, fx8]); + } + "union_enum_records_array_map" => { + let orig = f["type"].as_array().unwrap().clone(); + let enum_color = find_first(&orig, |o| { + o.get("type").and_then(|v| v.as_str()) == Some("enum") + }) + .unwrap(); + let rec_a = find_first(&orig, |o| is_named_type(o, "record", "RecA")).unwrap(); + let rec_b = find_first(&orig, |o| is_named_type(o, "record", "RecB")).unwrap(); + let arr = find_first(&orig, |o| { + o.get("type").and_then(|v| v.as_str()) == Some("array") + }) + .unwrap(); + let map = find_first(&orig, |o| { + o.get("type").and_then(|v| v.as_str()) == Some("map") + }) + .unwrap(); + f["type"] = json!([arr, map, rec_b, rec_a, enum_color]); + } + "union_date_or_fixed4" => { + let orig = f["type"].as_array().unwrap().clone(); + let date = find_first(&orig, |o| is_logical(o, "int", "date")).unwrap(); + let fx4 = find_first(&orig, |o| is_named_type(o, "fixed", "Fx4")).unwrap(); + f["type"] = json!([fx4, date]); + } + "union_time_millis_or_enum" => { + let orig = f["type"].as_array().unwrap().clone(); + let time_ms = + find_first(&orig, |o| is_logical(o, "int", "time-millis")).unwrap(); + let en = find_first(&orig, |o| { + o.get("type").and_then(|v| v.as_str()) == Some("enum") + }) + .unwrap(); + f["type"] = json!([en, time_ms]); + } + "union_time_micros_or_string" => { + let orig = f["type"].as_array().unwrap().clone(); + let time_us = + find_first(&orig, |o| is_logical(o, "long", "time-micros")).unwrap(); + f["type"] = json!(["string", time_us]); + } + "union_ts_millis_utc_or_array" => { + let orig = f["type"].as_array().unwrap().clone(); + let ts_ms = + find_first(&orig, |o| is_logical(o, "long", "timestamp-millis")).unwrap(); + let arr = find_first(&orig, |o| { + o.get("type").and_then(|v| v.as_str()) == Some("array") + }) + .unwrap(); + f["type"] = json!([arr, ts_ms]); + } + "union_ts_micros_local_or_bytes" => { + let orig = f["type"].as_array().unwrap().clone(); + let lts_us = + find_first(&orig, |o| is_logical(o, "long", "local-timestamp-micros")) + .unwrap(); + f["type"] = json!(["bytes", lts_us]); + } + "union_uuid_or_fixed10" => { + let orig = f["type"].as_array().unwrap().clone(); + let uuid = find_first(&orig, |o| is_logical(o, "string", "uuid")).unwrap(); + let fx10 = find_first(&orig, |o| is_named_type(o, "fixed", "Fx10")).unwrap(); + f["type"] = json!([fx10, uuid]); + } + "union_dec_bytes_or_dec_fixed" => { + let orig = f["type"].as_array().unwrap().clone(); + let dec_bytes = find_first(&orig, |o| { + o.get("type").and_then(|v| v.as_str()) == Some("bytes") + && o.get("logicalType").and_then(|v| v.as_str()) == Some("decimal") + }) + .unwrap(); + let dec_fix = find_first(&orig, |o| { + is_named_type(o, "fixed", "DecFix20") + && o.get("logicalType").and_then(|v| v.as_str()) == Some("decimal") + }) + .unwrap(); + f["type"] = json!([dec_fix, dec_bytes]); + } + "union_null_bytes_string" => { + f["type"] = json!(["bytes", "string", "null"]); + } + "array_of_union" => { + let obj = f + .get_mut("type") + .expect("array type") + .as_object_mut() + .unwrap(); + obj.insert("items".to_string(), json!(["string", "long"])); + } + "map_of_union" => { + let obj = f + .get_mut("type") + .expect("map type") + .as_object_mut() + .unwrap(); + obj.insert("values".to_string(), json!(["double", "null"])); + } + "record_with_union_field" => { + let rec = f + .get_mut("type") + .expect("record type") + .as_object_mut() + .unwrap(); + let rec_fields = rec.get_mut("fields").unwrap().as_array_mut().unwrap(); + let mut found = false; + for rf in rec_fields.iter_mut() { + if rf.get("name").and_then(|v| v.as_str()) == Some("u") { + rf["type"] = json!(["string", "long"]); // rely on int→long promotion + found = true; + break; + } + } + assert!(found, "field 'u' expected in HasUnion"); + } + "union_ts_micros_utc_or_map" => { + let orig = f["type"].as_array().unwrap().clone(); + let ts_us = + find_first(&orig, |o| is_logical(o, "long", "timestamp-micros")).unwrap(); + let map = find_first(&orig, |o| { + o.get("type").and_then(|v| v.as_str()) == Some("map") + }) + .unwrap(); + f["type"] = json!([map, ts_us]); + } + "union_ts_millis_local_or_string" => { + let orig = f["type"].as_array().unwrap().clone(); + let lts_ms = + find_first(&orig, |o| is_logical(o, "long", "local-timestamp-millis")) + .unwrap(); + f["type"] = json!(["string", lts_ms]); + } + "union_bool_or_string" => { + f["type"] = json!(["string", "boolean"]); + } + _ => {} + } + } + let reader_schema = AvroSchema::new(root.to_string()); + let resolved = read_alltypes_with_reader_schema(path, reader_schema); + + fn branch_token(dt: &DataType) -> String { + match dt { + DataType::Null => "null".into(), + DataType::Boolean => "boolean".into(), + DataType::Int32 => "int".into(), + DataType::Int64 => "long".into(), + DataType::Float32 => "float".into(), + DataType::Float64 => "double".into(), + DataType::Binary => "bytes".into(), + DataType::Utf8 => "string".into(), + DataType::Date32 => "date".into(), + DataType::Time32(arrow_schema::TimeUnit::Millisecond) => "time-millis".into(), + DataType::Time64(arrow_schema::TimeUnit::Microsecond) => "time-micros".into(), + DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, tz) => if tz.is_some() { + "timestamp-millis" + } else { + "local-timestamp-millis" + } + .into(), + DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, tz) => if tz.is_some() { + "timestamp-micros" + } else { + "local-timestamp-micros" + } + .into(), + DataType::Interval(IntervalUnit::MonthDayNano) => "duration".into(), + DataType::FixedSizeBinary(n) => format!("fixed{n}"), + DataType::Dictionary(_, _) => "enum".into(), + DataType::Decimal128(p, s) => format!("decimal({p},{s})"), + DataType::Decimal256(p, s) => format!("decimal({p},{s})"), + #[cfg(feature = "small_decimals")] + DataType::Decimal64(p, s) => format!("decimal({p},{s})"), + DataType::Struct(fields) => { + if fields.len() == 2 && fields[0].name() == "a" && fields[1].name() == "b" { + "record:RecA".into() + } else if fields.len() == 2 + && fields[0].name() == "x" + && fields[1].name() == "y" + { + "record:RecB".into() + } else { + "record".into() + } + } + DataType::List(_) => "array".into(), + DataType::Map(_, _) => "map".into(), + other => format!("{other:?}"), + } + } + + fn union_tokens(u: &UnionArray) -> (Vec, HashMap) { + let fields = match u.data_type() { + DataType::Union(fields, _) => fields, + other => panic!("expected Union, got {other:?}"), + }; + let mut dict: HashMap = HashMap::with_capacity(fields.len()); + for (tid, f) in fields.iter() { + dict.insert(tid, branch_token(f.data_type())); + } + let ids: Vec = u.type_ids().iter().copied().collect(); + (ids, dict) + } + + fn expected_token(field_name: &str, writer_token: &str) -> String { + match field_name { + "union_prim" => match writer_token { + "int" => "long".into(), + "float" => "double".into(), + other => other.into(), + }, + "record_with_union_field.u" => match writer_token { + "int" => "long".into(), + other => other.into(), + }, + _ => writer_token.into(), + } + } + + fn get_union<'a>( + rb: &'a RecordBatch, + schema: arrow_schema::SchemaRef, + fname: &str, + ) -> &'a UnionArray { + let idx = schema.index_of(fname).unwrap(); + rb.column(idx) + .as_any() + .downcast_ref::() + .unwrap_or_else(|| panic!("{fname} should be a Union")) + } + + fn assert_union_equivalent(field_name: &str, u_writer: &UnionArray, u_reader: &UnionArray) { + let (ids_w, dict_w) = union_tokens(u_writer); + let (ids_r, dict_r) = union_tokens(u_reader); + assert_eq!( + ids_w.len(), + ids_r.len(), + "{field_name}: row count mismatch between baseline and resolved" + ); + for (i, (id_w, id_r)) in ids_w.iter().zip(ids_r.iter()).enumerate() { + let w_tok = dict_w.get(id_w).unwrap(); + let want = expected_token(field_name, w_tok); + let got = dict_r.get(id_r).unwrap(); + assert_eq!( + got, &want, + "{field_name}: row {i} resolved to wrong union branch (writer={w_tok}, expected={want}, got={got})" + ); + } + } + + for (fname, dt) in [ + ("nullable_int_nullfirst", DataType::Int32), + ("nullable_string_nullsecond", DataType::Utf8), + ] { + let idx_b = baseline_schema.index_of(fname).unwrap(); + let idx_r = resolved.schema().index_of(fname).unwrap(); + let col_b = baseline.column(idx_b); + let col_r = resolved.column(idx_r); + assert_eq!( + col_b.data_type(), + &dt, + "baseline {fname} should decode as non-union with nullability" + ); + assert_eq!( + col_b.as_ref(), + col_r.as_ref(), + "{fname}: values must be identical regardless of null-branch order" + ); + } + let union_fields = [ + "union_prim", + "union_bytes_vs_string", + "union_fixed_dur_decfix", + "union_enum_records_array_map", + "union_date_or_fixed4", + "union_time_millis_or_enum", + "union_time_micros_or_string", + "union_ts_millis_utc_or_array", + "union_ts_micros_local_or_bytes", + "union_uuid_or_fixed10", + "union_dec_bytes_or_dec_fixed", + "union_null_bytes_string", + "union_ts_micros_utc_or_map", + "union_ts_millis_local_or_string", + "union_bool_or_string", + ]; + for fname in union_fields { + let u_b = get_union(&baseline, baseline_schema.clone(), fname); + let u_r = get_union(&resolved, resolved.schema(), fname); + assert_union_equivalent(fname, u_b, u_r); + } + { + let fname = "array_of_union"; + let idx_b = baseline_schema.index_of(fname).unwrap(); + let idx_r = resolved.schema().index_of(fname).unwrap(); + let arr_b = baseline + .column(idx_b) + .as_any() + .downcast_ref::() + .expect("array_of_union should be a List"); + let arr_r = resolved + .column(idx_r) + .as_any() + .downcast_ref::() + .expect("array_of_union should be a List"); + assert_eq!( + arr_b.value_offsets(), + arr_r.value_offsets(), + "{fname}: list offsets changed after resolution" + ); + let u_b = arr_b + .values() + .as_any() + .downcast_ref::() + .expect("array items should be Union"); + let u_r = arr_r + .values() + .as_any() + .downcast_ref::() + .expect("array items should be Union"); + let (ids_b, dict_b) = union_tokens(u_b); + let (ids_r, dict_r) = union_tokens(u_r); + assert_eq!(ids_b.len(), ids_r.len(), "{fname}: values length mismatch"); + for (i, (id_b, id_r)) in ids_b.iter().zip(ids_r.iter()).enumerate() { + let w_tok = dict_b.get(id_b).unwrap(); + let got = dict_r.get(id_r).unwrap(); + assert_eq!( + got, w_tok, + "{fname}: value {i} resolved to wrong branch (writer={w_tok}, got={got})" + ); + } + } + { + let fname = "map_of_union"; + let idx_b = baseline_schema.index_of(fname).unwrap(); + let idx_r = resolved.schema().index_of(fname).unwrap(); + let map_b = baseline + .column(idx_b) + .as_any() + .downcast_ref::() + .expect("map_of_union should be a Map"); + let map_r = resolved + .column(idx_r) + .as_any() + .downcast_ref::() + .expect("map_of_union should be a Map"); + assert_eq!( + map_b.value_offsets(), + map_r.value_offsets(), + "{fname}: map value offsets changed after resolution" + ); + let ent_b = map_b.entries(); + let ent_r = map_r.entries(); + let val_b_any = ent_b.column(1).as_ref(); + let val_r_any = ent_r.column(1).as_ref(); + let b_union = val_b_any.as_any().downcast_ref::(); + let r_union = val_r_any.as_any().downcast_ref::(); + if let (Some(u_b), Some(u_r)) = (b_union, r_union) { + assert_union_equivalent(fname, u_b, u_r); + } else { + assert_eq!( + val_b_any.data_type(), + val_r_any.data_type(), + "{fname}: value data types differ after resolution" + ); + assert_eq!( + val_b_any, val_r_any, + "{fname}: value arrays differ after resolution (nullable value column case)" + ); + let value_nullable = |m: &MapArray| -> bool { + match m.data_type() { + DataType::Map(entries_field, _sorted) => match entries_field.data_type() { + DataType::Struct(fields) => { + assert_eq!(fields.len(), 2, "entries struct must have 2 fields"); + assert_eq!(fields[0].name(), "key"); + assert_eq!(fields[1].name(), "value"); + fields[1].is_nullable() + } + other => panic!("Map entries field must be Struct, got {other:?}"), + }, + other => panic!("expected Map data type, got {other:?}"), + } + }; + assert!( + value_nullable(map_b), + "{fname}: baseline Map value field should be nullable per Arrow spec" + ); + assert!( + value_nullable(map_r), + "{fname}: resolved Map value field should be nullable per Arrow spec" + ); + } + } + { + let fname = "record_with_union_field"; + let idx_b = baseline_schema.index_of(fname).unwrap(); + let idx_r = resolved.schema().index_of(fname).unwrap(); + let rec_b = baseline + .column(idx_b) + .as_any() + .downcast_ref::() + .expect("record_with_union_field should be a Struct"); + let rec_r = resolved + .column(idx_r) + .as_any() + .downcast_ref::() + .expect("record_with_union_field should be a Struct"); + let u_b = rec_b + .column_by_name("u") + .unwrap() + .as_any() + .downcast_ref::() + .expect("field 'u' should be Union (baseline)"); + let u_r = rec_r + .column_by_name("u") + .unwrap() + .as_any() + .downcast_ref::() + .expect("field 'u' should be Union (resolved)"); + assert_union_equivalent("record_with_union_field.u", u_b, u_r); + } + } + + #[test] + fn test_union_fields_end_to_end_expected_arrays() { + fn tid_by_name(fields: &UnionFields, want: &str) -> i8 { + for (tid, f) in fields.iter() { + if f.name() == want { + return tid; + } + } + panic!("union child '{want}' not found") + } + + fn tid_by_dt(fields: &UnionFields, pred: impl Fn(&DataType) -> bool) -> i8 { + for (tid, f) in fields.iter() { + if pred(f.data_type()) { + return tid; + } + } + panic!("no union child matches predicate"); + } + + fn uuid16_from_str(s: &str) -> [u8; 16] { + fn hex(b: u8) -> u8 { + match b { + b'0'..=b'9' => b - b'0', + b'a'..=b'f' => b - b'a' + 10, + b'A'..=b'F' => b - b'A' + 10, + _ => panic!("invalid hex"), + } + } + let mut out = [0u8; 16]; + let bytes = s.as_bytes(); + let (mut i, mut j) = (0, 0); + while i < bytes.len() { + if bytes[i] == b'-' { + i += 1; + continue; + } + let hi = hex(bytes[i]); + let lo = hex(bytes[i + 1]); + out[j] = (hi << 4) | lo; + j += 1; + i += 2; + } + assert_eq!(j, 16, "uuid must decode to 16 bytes"); + out + } + + fn empty_child_for(dt: &DataType) -> Arc { + match dt { + DataType::Null => Arc::new(NullArray::new(0)), + DataType::Boolean => Arc::new(BooleanArray::from(Vec::::new())), + DataType::Int32 => Arc::new(Int32Array::from(Vec::::new())), + DataType::Int64 => Arc::new(Int64Array::from(Vec::::new())), + DataType::Float32 => Arc::new(arrow_array::Float32Array::from(Vec::::new())), + DataType::Float64 => Arc::new(arrow_array::Float64Array::from(Vec::::new())), + DataType::Binary => Arc::new(BinaryArray::from(Vec::<&[u8]>::new())), + DataType::Utf8 => Arc::new(StringArray::from(Vec::<&str>::new())), + DataType::Date32 => Arc::new(arrow_array::Date32Array::from(Vec::::new())), + DataType::Time32(arrow_schema::TimeUnit::Millisecond) => { + Arc::new(Time32MillisecondArray::from(Vec::::new())) + } + DataType::Time64(arrow_schema::TimeUnit::Microsecond) => { + Arc::new(Time64MicrosecondArray::from(Vec::::new())) + } + DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, tz) => { + let a = TimestampMillisecondArray::from(Vec::::new()); + Arc::new(if let Some(tz) = tz { + a.with_timezone(tz.clone()) + } else { + a + }) + } + DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, tz) => { + let a = TimestampMicrosecondArray::from(Vec::::new()); + Arc::new(if let Some(tz) = tz { + a.with_timezone(tz.clone()) + } else { + a + }) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + Arc::new(arrow_array::IntervalMonthDayNanoArray::from(Vec::< + IntervalMonthDayNano, + >::new( + ))) + } + DataType::FixedSizeBinary(n) => Arc::new(FixedSizeBinaryArray::new_null(*n, 0)), + DataType::Dictionary(k, v) => { + assert_eq!(**k, DataType::Int32, "expect int32 keys for enums"); + let keys = Int32Array::from(Vec::::new()); + let values = match v.as_ref() { + DataType::Utf8 => { + Arc::new(StringArray::from(Vec::<&str>::new())) as ArrayRef + } + other => panic!("unexpected dictionary value type {other:?}"), + }; + Arc::new(DictionaryArray::::try_new(keys, values).unwrap()) + } + DataType::List(field) => { + let values: ArrayRef = match field.data_type() { + DataType::Int32 => { + Arc::new(Int32Array::from(Vec::::new())) as ArrayRef + } + DataType::Int64 => { + Arc::new(Int64Array::from(Vec::::new())) as ArrayRef + } + DataType::Utf8 => { + Arc::new(StringArray::from(Vec::<&str>::new())) as ArrayRef + } + DataType::Union(_, _) => { + let (uf, _) = if let DataType::Union(f, m) = field.data_type() { + (f.clone(), m) + } else { + unreachable!() + }; + let children: Vec = uf + .iter() + .map(|(_, f)| empty_child_for(f.data_type())) + .collect(); + Arc::new( + UnionArray::try_new( + uf.clone(), + ScalarBuffer::::from(Vec::::new()), + Some(ScalarBuffer::::from(Vec::::new())), + children, + ) + .unwrap(), + ) as ArrayRef + } + other => panic!("unsupported list item type: {other:?}"), + }; + let offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0])); + Arc::new(ListArray::try_new(field.clone(), offsets, values, None).unwrap()) + } + DataType::Map(entry_field, ordered) => { + let DataType::Struct(childs) = entry_field.data_type() else { + panic!("map entries must be struct") + }; + let key_field = &childs[0]; + let val_field = &childs[1]; + assert_eq!(key_field.data_type(), &DataType::Utf8); + let keys = StringArray::from(Vec::<&str>::new()); + let vals: ArrayRef = match val_field.data_type() { + DataType::Float64 => { + Arc::new(arrow_array::Float64Array::from(Vec::::new())) as ArrayRef + } + DataType::Int64 => { + Arc::new(Int64Array::from(Vec::::new())) as ArrayRef + } + DataType::Utf8 => { + Arc::new(StringArray::from(Vec::<&str>::new())) as ArrayRef + } + DataType::Union(uf, _) => { + let ch: Vec = uf + .iter() + .map(|(_, f)| empty_child_for(f.data_type())) + .collect(); + Arc::new( + UnionArray::try_new( + uf.clone(), + ScalarBuffer::::from(Vec::::new()), + Some(ScalarBuffer::::from(Vec::::new())), + ch, + ) + .unwrap(), + ) as ArrayRef + } + other => panic!("unsupported map value type: {other:?}"), + }; + let entries = StructArray::new( + Fields::from(vec![key_field.as_ref().clone(), val_field.as_ref().clone()]), + vec![Arc::new(keys) as ArrayRef, vals], + None, + ); + let offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0])); + Arc::new(MapArray::new( + entry_field.clone(), + offsets, + entries, + None, + *ordered, + )) + } + other => panic!("empty_child_for: unhandled type {other:?}"), + } + } + + fn mk_dense_union( + fields: &UnionFields, + type_ids: Vec, + offsets: Vec, + provide: impl Fn(&Field) -> Option, + ) -> ArrayRef { + let children: Vec = fields + .iter() + .map(|(_, f)| provide(f).unwrap_or_else(|| empty_child_for(f.data_type()))) + .collect(); + + Arc::new( + UnionArray::try_new( + fields.clone(), + ScalarBuffer::::from(type_ids), + Some(ScalarBuffer::::from(offsets)), + children, + ) + .unwrap(), + ) as ArrayRef + } + + // Dates / times / timestamps from the Avro content block: + let date_a: i32 = 19_000; + let time_ms_a: i32 = 13 * 3_600_000 + 45 * 60_000 + 30_000 + 123; + let time_us_b: i64 = 23 * 3_600_000_000 + 59 * 60_000_000 + 59 * 1_000_000 + 999_999; + let ts_ms_2024_01_01: i64 = 1_704_067_200_000; + let ts_us_2024_01_01: i64 = ts_ms_2024_01_01 * 1000; + // Fixed / bytes-like values: + let fx8_a: [u8; 8] = *b"ABCDEFGH"; + let fx4_abcd: [u8; 4] = *b"ABCD"; + let fx4_misc: [u8; 4] = [0x00, 0x11, 0x22, 0x33]; + let fx10_ascii: [u8; 10] = *b"0123456789"; + let fx10_aa: [u8; 10] = [0xAA; 10]; + // Duration logical values as MonthDayNano: + let dur_a = IntervalMonthDayNanoType::make_value(1, 2, 3_000_000_000); + let dur_b = IntervalMonthDayNanoType::make_value(12, 31, 999_000_000); + // UUID logical values (stored as 16-byte FixedSizeBinary in Arrow): + let uuid1 = uuid16_from_str("fe7bc30b-4ce8-4c5e-b67c-2234a2d38e66"); + let uuid2 = uuid16_from_str("0826cc06-d2e3-4599-b4ad-af5fa6905cdb"); + // Decimals from Avro content: + let dec_b_scale2_pos: i128 = 123_456; // "1234.56" bytes-decimal -> (precision=10, scale=2) + let dec_fix16_neg: i128 = -101; // "-1.01" fixed(16) decimal(10,2) + let dec_fix20_s4: i128 = 1_234_567_891_234; // "123456789.1234" fixed(20) decimal(20,4) + let dec_fix20_s4_neg: i128 = -123; // "-0.0123" fixed(20) decimal(20,4) + let path = "test/data/union_fields.avro"; + let actual = read_file(path, 1024, false); + let schema = actual.schema(); + // Helper to fetch union metadata for a column + let get_union = |name: &str| -> (UnionFields, UnionMode) { + let idx = schema.index_of(name).unwrap(); + match schema.field(idx).data_type() { + DataType::Union(f, m) => (f.clone(), *m), + other => panic!("{name} should be a Union, got {other:?}"), + } + }; + let mut expected_cols: Vec = Vec::with_capacity(schema.fields().len()); + // 1) ["null","int"]: Int32 (nullable) + expected_cols.push(Arc::new(Int32Array::from(vec![ + None, + Some(42), + None, + Some(0), + ]))); + // 2) ["string","null"]: Utf8 (nullable) + expected_cols.push(Arc::new(StringArray::from(vec![ + Some("s1"), + None, + Some("s3"), + Some(""), + ]))); + // 3) union_prim: ["boolean","int","long","float","double","bytes","string"] + { + let (uf, mode) = get_union("union_prim"); + assert!(matches!(mode, UnionMode::Dense)); + let generated_names: Vec<&str> = uf.iter().map(|(_, f)| f.name().as_str()).collect(); + let expected_names = vec![ + "boolean", "int", "long", "float", "double", "bytes", "string", + ]; + assert_eq!( + generated_names, expected_names, + "Field names for union_prim are incorrect" + ); + let tids = vec![ + tid_by_name(&uf, "long"), + tid_by_name(&uf, "int"), + tid_by_name(&uf, "float"), + tid_by_name(&uf, "double"), + ]; + let offs = vec![0, 0, 0, 0]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.name().as_str() { + "int" => Some(Arc::new(Int32Array::from(vec![-1])) as ArrayRef), + "long" => Some(Arc::new(Int64Array::from(vec![1_234_567_890_123i64])) as ArrayRef), + "float" => { + Some(Arc::new(arrow_array::Float32Array::from(vec![1.25f32])) as ArrayRef) + } + "double" => { + Some(Arc::new(arrow_array::Float64Array::from(vec![-2.5f64])) as ArrayRef) + } + _ => None, + }); + expected_cols.push(arr); + } + // 4) union_bytes_vs_string: ["bytes","string"] + { + let (uf, _) = get_union("union_bytes_vs_string"); + let tids = vec![ + tid_by_name(&uf, "bytes"), + tid_by_name(&uf, "string"), + tid_by_name(&uf, "string"), + tid_by_name(&uf, "bytes"), + ]; + let offs = vec![0, 0, 1, 1]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.name().as_str() { + "bytes" => Some( + Arc::new(BinaryArray::from(vec![&[0x00, 0xFF, 0x7F][..], &[][..]])) as ArrayRef, + ), + "string" => Some(Arc::new(StringArray::from(vec!["hello", "world"])) as ArrayRef), + _ => None, + }); + expected_cols.push(arr); + } + // 5) union_fixed_dur_decfix: [Fx8, Dur12, DecFix16(decimal(10,2))] + { + let (uf, _) = get_union("union_fixed_dur_decfix"); + let tid_fx8 = tid_by_dt(&uf, |dt| matches!(dt, DataType::FixedSizeBinary(8))); + let tid_dur = tid_by_dt(&uf, |dt| { + matches!( + dt, + DataType::Interval(arrow_schema::IntervalUnit::MonthDayNano) + ) + }); + let tid_dec = tid_by_dt(&uf, |dt| match dt { + #[cfg(feature = "small_decimals")] + DataType::Decimal64(10, 2) => true, + DataType::Decimal128(10, 2) | DataType::Decimal256(10, 2) => true, + _ => false, + }); + let tids = vec![tid_fx8, tid_dur, tid_dec, tid_dur]; + let offs = vec![0, 0, 0, 1]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + DataType::FixedSizeBinary(8) => { + let it = [Some(fx8_a)].into_iter(); + Some(Arc::new( + FixedSizeBinaryArray::try_from_sparse_iter_with_size(it, 8).unwrap(), + ) as ArrayRef) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + Some(Arc::new(arrow_array::IntervalMonthDayNanoArray::from(vec![ + dur_a, dur_b, + ])) as ArrayRef) + } + #[cfg(feature = "small_decimals")] + DataType::Decimal64(10, 2) => { + let a = arrow_array::Decimal64Array::from_iter_values([dec_fix16_neg as i64]); + Some(Arc::new(a.with_precision_and_scale(10, 2).unwrap()) as ArrayRef) + } + DataType::Decimal128(10, 2) => { + let a = arrow_array::Decimal128Array::from_iter_values([dec_fix16_neg]); + Some(Arc::new(a.with_precision_and_scale(10, 2).unwrap()) as ArrayRef) + } + DataType::Decimal256(10, 2) => { + let a = arrow_array::Decimal256Array::from_iter_values([i256::from_i128( + dec_fix16_neg, + )]); + Some(Arc::new(a.with_precision_and_scale(10, 2).unwrap()) as ArrayRef) + } + _ => None, + }); + let generated_names: Vec<&str> = uf.iter().map(|(_, f)| f.name().as_str()).collect(); + let expected_names = vec!["Fx8", "Dur12", "DecFix16"]; + assert_eq!( + generated_names, expected_names, + "Data type names were not generated correctly for union_fixed_dur_decfix" + ); + expected_cols.push(arr); + } + // 6) union_enum_records_array_map: [enum ColorU, record RecA, record RecB, array, map] + { + let (uf, _) = get_union("union_enum_records_array_map"); + let tid_enum = tid_by_dt(&uf, |dt| matches!(dt, DataType::Dictionary(_, _))); + let tid_reca = tid_by_dt(&uf, |dt| { + if let DataType::Struct(fs) = dt { + fs.len() == 2 && fs[0].name() == "a" && fs[1].name() == "b" + } else { + false + } + }); + let tid_recb = tid_by_dt(&uf, |dt| { + if let DataType::Struct(fs) = dt { + fs.len() == 2 && fs[0].name() == "x" && fs[1].name() == "y" + } else { + false + } + }); + let tid_arr = tid_by_dt(&uf, |dt| matches!(dt, DataType::List(_))); + let tids = vec![tid_enum, tid_reca, tid_recb, tid_arr]; + let offs = vec![0, 0, 0, 0]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + DataType::Dictionary(_, _) => { + let keys = Int32Array::from(vec![0i32]); // "RED" + let values = + Arc::new(StringArray::from(vec!["RED", "GREEN", "BLUE"])) as ArrayRef; + Some( + Arc::new(DictionaryArray::::try_new(keys, values).unwrap()) + as ArrayRef, + ) + } + DataType::Struct(fs) + if fs.len() == 2 && fs[0].name() == "a" && fs[1].name() == "b" => + { + let a = Int32Array::from(vec![7]); + let b = StringArray::from(vec!["x"]); + Some(Arc::new(StructArray::new( + fs.clone(), + vec![Arc::new(a), Arc::new(b)], + None, + )) as ArrayRef) + } + DataType::Struct(fs) + if fs.len() == 2 && fs[0].name() == "x" && fs[1].name() == "y" => + { + let x = Int64Array::from(vec![123_456_789i64]); + let y = BinaryArray::from(vec![&[0xFF, 0x00][..]]); + Some(Arc::new(StructArray::new( + fs.clone(), + vec![Arc::new(x), Arc::new(y)], + None, + )) as ArrayRef) + } + DataType::List(field) => { + let values = Int64Array::from(vec![1i64, 2, 3]); + let offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0, 3])); + Some(Arc::new( + ListArray::try_new(field.clone(), offsets, Arc::new(values), None).unwrap(), + ) as ArrayRef) + } + DataType::Map(_, _) => None, + other => panic!("unexpected child {other:?}"), + }); + expected_cols.push(arr); + } + // 7) union_date_or_fixed4: [date32, fixed(4)] + { + let (uf, _) = get_union("union_date_or_fixed4"); + let tid_date = tid_by_dt(&uf, |dt| matches!(dt, DataType::Date32)); + let tid_fx4 = tid_by_dt(&uf, |dt| matches!(dt, DataType::FixedSizeBinary(4))); + let tids = vec![tid_date, tid_fx4, tid_date, tid_fx4]; + let offs = vec![0, 0, 1, 1]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + DataType::Date32 => { + Some(Arc::new(arrow_array::Date32Array::from(vec![date_a, 0])) as ArrayRef) + } + DataType::FixedSizeBinary(4) => { + let it = [Some(fx4_abcd), Some(fx4_misc)].into_iter(); + Some(Arc::new( + FixedSizeBinaryArray::try_from_sparse_iter_with_size(it, 4).unwrap(), + ) as ArrayRef) + } + _ => None, + }); + expected_cols.push(arr); + } + // 8) union_time_millis_or_enum: [time-millis, enum OnOff] + { + let (uf, _) = get_union("union_time_millis_or_enum"); + let tid_ms = tid_by_dt(&uf, |dt| { + matches!(dt, DataType::Time32(arrow_schema::TimeUnit::Millisecond)) + }); + let tid_en = tid_by_dt(&uf, |dt| matches!(dt, DataType::Dictionary(_, _))); + let tids = vec![tid_ms, tid_en, tid_en, tid_ms]; + let offs = vec![0, 0, 1, 1]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + DataType::Time32(arrow_schema::TimeUnit::Millisecond) => { + Some(Arc::new(Time32MillisecondArray::from(vec![time_ms_a, 0])) as ArrayRef) + } + DataType::Dictionary(_, _) => { + let keys = Int32Array::from(vec![0i32, 1]); // "ON", "OFF" + let values = Arc::new(StringArray::from(vec!["ON", "OFF"])) as ArrayRef; + Some( + Arc::new(DictionaryArray::::try_new(keys, values).unwrap()) + as ArrayRef, + ) + } + _ => None, + }); + expected_cols.push(arr); + } + // 9) union_time_micros_or_string: [time-micros, string] + { + let (uf, _) = get_union("union_time_micros_or_string"); + let tid_us = tid_by_dt(&uf, |dt| { + matches!(dt, DataType::Time64(arrow_schema::TimeUnit::Microsecond)) + }); + let tid_s = tid_by_name(&uf, "string"); + let tids = vec![tid_s, tid_us, tid_s, tid_s]; + let offs = vec![0, 0, 1, 2]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + DataType::Time64(arrow_schema::TimeUnit::Microsecond) => { + Some(Arc::new(Time64MicrosecondArray::from(vec![time_us_b])) as ArrayRef) + } + DataType::Utf8 => { + Some(Arc::new(StringArray::from(vec!["evening", "night", ""])) as ArrayRef) + } + _ => None, + }); + expected_cols.push(arr); + } + // 10) union_ts_millis_utc_or_array: [timestamp-millis(TZ), array] + { + let (uf, _) = get_union("union_ts_millis_utc_or_array"); + let tid_ts = tid_by_dt(&uf, |dt| { + matches!( + dt, + DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, _) + ) + }); + let tid_arr = tid_by_dt(&uf, |dt| matches!(dt, DataType::List(_))); + let tids = vec![tid_ts, tid_arr, tid_arr, tid_ts]; + let offs = vec![0, 0, 1, 1]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, tz) => { + let a = TimestampMillisecondArray::from(vec![ + ts_ms_2024_01_01, + ts_ms_2024_01_01 + 86_400_000, + ]); + Some(Arc::new(if let Some(tz) = tz { + a.with_timezone(tz.clone()) + } else { + a + }) as ArrayRef) + } + DataType::List(field) => { + let values = Int32Array::from(vec![0, 1, 2, -1, 0, 1]); + let offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0, 3, 6])); + Some(Arc::new( + ListArray::try_new(field.clone(), offsets, Arc::new(values), None).unwrap(), + ) as ArrayRef) + } + _ => None, + }); + expected_cols.push(arr); + } + // 11) union_ts_micros_local_or_bytes: [local-timestamp-micros, bytes] + { + let (uf, _) = get_union("union_ts_micros_local_or_bytes"); + let tid_lts = tid_by_dt(&uf, |dt| { + matches!( + dt, + DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None) + ) + }); + let tid_b = tid_by_name(&uf, "bytes"); + let tids = vec![tid_b, tid_lts, tid_b, tid_b]; + let offs = vec![0, 0, 1, 2]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None) => Some(Arc::new( + TimestampMicrosecondArray::from(vec![ts_us_2024_01_01]), + ) + as ArrayRef), + DataType::Binary => Some(Arc::new(BinaryArray::from(vec![ + &b"\x11\x22\x33"[..], + &b"\x00"[..], + &b"\x10\x20\x30\x40"[..], + ])) as ArrayRef), + _ => None, + }); + expected_cols.push(arr); + } + // 12) union_uuid_or_fixed10: [uuid(string)->fixed(16), fixed(10)] + { + let (uf, _) = get_union("union_uuid_or_fixed10"); + let tid_fx16 = tid_by_dt(&uf, |dt| matches!(dt, DataType::FixedSizeBinary(16))); + let tid_fx10 = tid_by_dt(&uf, |dt| matches!(dt, DataType::FixedSizeBinary(10))); + let tids = vec![tid_fx16, tid_fx10, tid_fx16, tid_fx10]; + let offs = vec![0, 0, 1, 1]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + DataType::FixedSizeBinary(16) => { + let it = [Some(uuid1), Some(uuid2)].into_iter(); + Some(Arc::new( + FixedSizeBinaryArray::try_from_sparse_iter_with_size(it, 16).unwrap(), + ) as ArrayRef) + } + DataType::FixedSizeBinary(10) => { + let it = [Some(fx10_ascii), Some(fx10_aa)].into_iter(); + Some(Arc::new( + FixedSizeBinaryArray::try_from_sparse_iter_with_size(it, 10).unwrap(), + ) as ArrayRef) + } + _ => None, + }); + expected_cols.push(arr); + } + // 13) union_dec_bytes_or_dec_fixed: [bytes dec(10,2), fixed(20) dec(20,4)] + { + let (uf, _) = get_union("union_dec_bytes_or_dec_fixed"); + let tid_b10s2 = tid_by_dt(&uf, |dt| match dt { + #[cfg(feature = "small_decimals")] + DataType::Decimal64(10, 2) => true, + DataType::Decimal128(10, 2) | DataType::Decimal256(10, 2) => true, + _ => false, + }); + let tid_f20s4 = tid_by_dt(&uf, |dt| { + matches!( + dt, + DataType::Decimal128(20, 4) | DataType::Decimal256(20, 4) + ) + }); + let tids = vec![tid_b10s2, tid_f20s4, tid_b10s2, tid_f20s4]; + let offs = vec![0, 0, 1, 1]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + #[cfg(feature = "small_decimals")] + DataType::Decimal64(10, 2) => { + let a = Decimal64Array::from_iter_values([dec_b_scale2_pos as i64, 0i64]); + Some(Arc::new(a.with_precision_and_scale(10, 2).unwrap()) as ArrayRef) + } + DataType::Decimal128(10, 2) => { + let a = Decimal128Array::from_iter_values([dec_b_scale2_pos, 0]); + Some(Arc::new(a.with_precision_and_scale(10, 2).unwrap()) as ArrayRef) + } + DataType::Decimal256(10, 2) => { + let a = Decimal256Array::from_iter_values([ + i256::from_i128(dec_b_scale2_pos), + i256::from(0), + ]); + Some(Arc::new(a.with_precision_and_scale(10, 2).unwrap()) as ArrayRef) + } + DataType::Decimal128(20, 4) => { + let a = Decimal128Array::from_iter_values([dec_fix20_s4_neg, dec_fix20_s4]); + Some(Arc::new(a.with_precision_and_scale(20, 4).unwrap()) as ArrayRef) + } + DataType::Decimal256(20, 4) => { + let a = Decimal256Array::from_iter_values([ + i256::from_i128(dec_fix20_s4_neg), + i256::from_i128(dec_fix20_s4), + ]); + Some(Arc::new(a.with_precision_and_scale(20, 4).unwrap()) as ArrayRef) + } + _ => None, + }); + expected_cols.push(arr); + } + // 14) union_null_bytes_string: ["null","bytes","string"] + { + let (uf, _) = get_union("union_null_bytes_string"); + let tid_n = tid_by_name(&uf, "null"); + let tid_b = tid_by_name(&uf, "bytes"); + let tid_s = tid_by_name(&uf, "string"); + let tids = vec![tid_n, tid_b, tid_s, tid_s]; + let offs = vec![0, 0, 0, 1]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.name().as_str() { + "null" => Some(Arc::new(arrow_array::NullArray::new(1)) as ArrayRef), + "bytes" => Some(Arc::new(BinaryArray::from(vec![&b"\x01\x02"[..]])) as ArrayRef), + "string" => Some(Arc::new(StringArray::from(vec!["text", "u"])) as ArrayRef), + _ => None, + }); + expected_cols.push(arr); + } + // 15) array_of_union: array<[long,string]> + { + let idx = schema.index_of("array_of_union").unwrap(); + let dt = schema.field(idx).data_type().clone(); + let (item_field, _) = match &dt { + DataType::List(f) => (f.clone(), ()), + other => panic!("array_of_union must be List, got {other:?}"), + }; + let (uf, _) = match item_field.data_type() { + DataType::Union(f, m) => (f.clone(), m), + other => panic!("array_of_union items must be Union, got {other:?}"), + }; + let tid_l = tid_by_name(&uf, "long"); + let tid_s = tid_by_name(&uf, "string"); + let type_ids = vec![tid_l, tid_s, tid_l, tid_s, tid_l, tid_l, tid_s, tid_l]; + let offsets = vec![0, 0, 1, 1, 2, 3, 2, 4]; + let values_union = + mk_dense_union(&uf, type_ids, offsets, |f| match f.name().as_str() { + "long" => { + Some(Arc::new(Int64Array::from(vec![1i64, -5, 42, -1, 0])) as ArrayRef) + } + "string" => Some(Arc::new(StringArray::from(vec!["a", "", "z"])) as ArrayRef), + _ => None, + }); + let list_offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0, 3, 5, 6, 8])); + expected_cols.push(Arc::new( + ListArray::try_new(item_field.clone(), list_offsets, values_union, None).unwrap(), + )); + } + // 16) map_of_union: map<[null,double]> + { + let idx = schema.index_of("map_of_union").unwrap(); + let dt = schema.field(idx).data_type().clone(); + let (entry_field, ordered) = match &dt { + DataType::Map(f, ordered) => (f.clone(), *ordered), + other => panic!("map_of_union must be Map, got {other:?}"), + }; + let DataType::Struct(entry_fields) = entry_field.data_type() else { + panic!("map entries must be struct") + }; + let key_field = entry_fields[0].clone(); + let val_field = entry_fields[1].clone(); + let keys = StringArray::from(vec!["a", "b", "x", "pi"]); + let rounded_pi = (std::f64::consts::PI * 100_000.0).round() / 100_000.0; + let values: ArrayRef = match val_field.data_type() { + DataType::Union(uf, _) => { + let tid_n = tid_by_name(uf, "null"); + let tid_d = tid_by_name(uf, "double"); + let tids = vec![tid_n, tid_d, tid_d, tid_d]; + let offs = vec![0, 0, 1, 2]; + mk_dense_union(uf, tids, offs, |f| match f.name().as_str() { + "null" => Some(Arc::new(NullArray::new(1)) as ArrayRef), + "double" => Some(Arc::new(arrow_array::Float64Array::from(vec![ + 2.5f64, -0.5f64, rounded_pi, + ])) as ArrayRef), + _ => None, + }) + } + DataType::Float64 => Arc::new(arrow_array::Float64Array::from(vec![ + None, + Some(2.5), + Some(-0.5), + Some(rounded_pi), + ])), + other => panic!("unexpected map value type {other:?}"), + }; + let entries = StructArray::new( + Fields::from(vec![key_field.as_ref().clone(), val_field.as_ref().clone()]), + vec![Arc::new(keys) as ArrayRef, values], + None, + ); + let offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0, 2, 3, 3, 4])); + expected_cols.push(Arc::new(MapArray::new( + entry_field, + offsets, + entries, + None, + ordered, + ))); + } + // 17) record_with_union_field: struct { id:int, u:[int,string] } + { + let idx = schema.index_of("record_with_union_field").unwrap(); + let DataType::Struct(rec_fields) = schema.field(idx).data_type() else { + panic!("record_with_union_field should be Struct") + }; + let id = Int32Array::from(vec![1, 2, 3, 4]); + let u_field = rec_fields.iter().find(|f| f.name() == "u").unwrap(); + let DataType::Union(uf, _) = u_field.data_type() else { + panic!("u must be Union") + }; + let tid_i = tid_by_name(uf, "int"); + let tid_s = tid_by_name(uf, "string"); + let tids = vec![tid_s, tid_i, tid_i, tid_s]; + let offs = vec![0, 0, 1, 1]; + let u = mk_dense_union(uf, tids, offs, |f| match f.name().as_str() { + "int" => Some(Arc::new(Int32Array::from(vec![99, 0])) as ArrayRef), + "string" => Some(Arc::new(StringArray::from(vec!["one", "four"])) as ArrayRef), + _ => None, + }); + let rec = StructArray::new(rec_fields.clone(), vec![Arc::new(id) as ArrayRef, u], None); + expected_cols.push(Arc::new(rec)); + } + // 18) union_ts_micros_utc_or_map: [timestamp-micros(TZ), map] + { + let (uf, _) = get_union("union_ts_micros_utc_or_map"); + let tid_ts = tid_by_dt(&uf, |dt| { + matches!( + dt, + DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, Some(_)) + ) + }); + let tid_map = tid_by_dt(&uf, |dt| matches!(dt, DataType::Map(_, _))); + let tids = vec![tid_ts, tid_map, tid_ts, tid_map]; + let offs = vec![0, 0, 1, 1]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, tz) => { + let a = TimestampMicrosecondArray::from(vec![ts_us_2024_01_01, 0i64]); + Some(Arc::new(if let Some(tz) = tz { + a.with_timezone(tz.clone()) + } else { + a + }) as ArrayRef) + } + DataType::Map(entry_field, ordered) => { + let DataType::Struct(fs) = entry_field.data_type() else { + panic!("map entries must be struct") + }; + let key_field = fs[0].clone(); + let val_field = fs[1].clone(); + assert_eq!(key_field.data_type(), &DataType::Utf8); + assert_eq!(val_field.data_type(), &DataType::Int64); + let keys = StringArray::from(vec!["k1", "k2", "n"]); + let vals = Int64Array::from(vec![1i64, 2, 0]); + let entries = StructArray::new( + Fields::from(vec![key_field.as_ref().clone(), val_field.as_ref().clone()]), + vec![Arc::new(keys) as ArrayRef, Arc::new(vals) as ArrayRef], + None, + ); + let offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0, 2, 3])); + Some(Arc::new(MapArray::new( + entry_field.clone(), + offsets, + entries, + None, + *ordered, + )) as ArrayRef) + } + _ => None, + }); + expected_cols.push(arr); + } + // 19) union_ts_millis_local_or_string: [local-timestamp-millis, string] + { + let (uf, _) = get_union("union_ts_millis_local_or_string"); + let tid_ts = tid_by_dt(&uf, |dt| { + matches!( + dt, + DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None) + ) + }); + let tid_s = tid_by_name(&uf, "string"); + let tids = vec![tid_s, tid_ts, tid_s, tid_s]; + let offs = vec![0, 0, 1, 2]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None) => Some(Arc::new( + TimestampMillisecondArray::from(vec![ts_ms_2024_01_01]), + ) + as ArrayRef), + DataType::Utf8 => { + Some( + Arc::new(StringArray::from(vec!["local midnight", "done", ""])) as ArrayRef, + ) + } + _ => None, + }); + expected_cols.push(arr); + } + // 20) union_bool_or_string: ["boolean","string"] + { + let (uf, _) = get_union("union_bool_or_string"); + let tid_b = tid_by_name(&uf, "boolean"); + let tid_s = tid_by_name(&uf, "string"); + let tids = vec![tid_b, tid_s, tid_b, tid_s]; + let offs = vec![0, 0, 1, 1]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.name().as_str() { + "boolean" => Some(Arc::new(BooleanArray::from(vec![true, false])) as ArrayRef), + "string" => Some(Arc::new(StringArray::from(vec!["no", "yes"])) as ArrayRef), + _ => None, + }); + expected_cols.push(arr); + } + let expected = RecordBatch::try_new(schema.clone(), expected_cols).unwrap(); + assert_eq!( + actual, expected, + "full end-to-end equality for union_fields.avro" + ); + } + + #[test] + fn test_read_zero_byte_avro_file() { + let batch = read_file("test/data/zero_byte.avro", 3, false); + let schema = batch.schema(); + assert_eq!(schema.fields().len(), 1); + let field = schema.field(0); + assert_eq!(field.name(), "data"); + assert_eq!(field.data_type(), &DataType::Binary); + assert!(field.is_nullable()); + assert_eq!(batch.num_rows(), 3); + assert_eq!(batch.num_columns(), 1); + let binary_array = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(binary_array.is_null(0)); + assert!(binary_array.is_valid(1)); + assert_eq!(binary_array.value(1), b""); + assert!(binary_array.is_valid(2)); + assert_eq!(binary_array.value(2), b"some bytes"); + } + + #[test] + fn test_alltypes() { + let expected = RecordBatch::try_from_iter_with_nullable([ + ( + "id", + Arc::new(Int32Array::from(vec![4, 5, 6, 7, 2, 3, 0, 1])) as _, + true, + ), + ( + "bool_col", + Arc::new(BooleanArray::from_iter((0..8).map(|x| Some(x % 2 == 0)))) as _, + true, + ), + ( + "tinyint_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "smallint_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "int_col", + Arc::new(Int32Array::from_iter_values((0..8).map(|x| x % 2))) as _, + true, + ), + ( + "bigint_col", + Arc::new(Int64Array::from_iter_values((0..8).map(|x| (x % 2) * 10))) as _, + true, + ), + ( + "float_col", + Arc::new(Float32Array::from_iter_values( + (0..8).map(|x| (x % 2) as f32 * 1.1), + )) as _, + true, + ), + ( + "double_col", + Arc::new(Float64Array::from_iter_values( + (0..8).map(|x| (x % 2) as f64 * 10.1), + )) as _, + true, + ), + ( + "date_string_col", + Arc::new(BinaryArray::from_iter_values([ + [48, 51, 47, 48, 49, 47, 48, 57], + [48, 51, 47, 48, 49, 47, 48, 57], + [48, 52, 47, 48, 49, 47, 48, 57], + [48, 52, 47, 48, 49, 47, 48, 57], + [48, 50, 47, 48, 49, 47, 48, 57], + [48, 50, 47, 48, 49, 47, 48, 57], + [48, 49, 47, 48, 49, 47, 48, 57], + [48, 49, 47, 48, 49, 47, 48, 57], + ])) as _, + true, + ), + ( + "string_col", + Arc::new(BinaryArray::from_iter_values((0..8).map(|x| [48 + x % 2]))) as _, + true, + ), + ( + "timestamp_col", + Arc::new( + TimestampMicrosecondArray::from_iter_values([ + 1235865600000000, // 2009-03-01T00:00:00.000 + 1235865660000000, // 2009-03-01T00:01:00.000 + 1238544000000000, // 2009-04-01T00:00:00.000 + 1238544060000000, // 2009-04-01T00:01:00.000 + 1233446400000000, // 2009-02-01T00:00:00.000 + 1233446460000000, // 2009-02-01T00:01:00.000 + 1230768000000000, // 2009-01-01T00:00:00.000 + 1230768060000000, // 2009-01-01T00:01:00.000 + ]) + .with_timezone("+00:00"), + ) as _, + true, + ), + ]) + .unwrap(); + + for file in files() { + let file = arrow_test_data(file); + + assert_eq!(read_file(&file, 8, false), expected); + assert_eq!(read_file(&file, 3, false), expected); + } + } + + #[test] + // TODO: avoid requiring snappy for this file + #[cfg(feature = "snappy")] + fn test_alltypes_dictionary() { + let file = "avro/alltypes_dictionary.avro"; + let expected = RecordBatch::try_from_iter_with_nullable([ + ("id", Arc::new(Int32Array::from(vec![0, 1])) as _, true), + ( + "bool_col", + Arc::new(BooleanArray::from(vec![Some(true), Some(false)])) as _, + true, + ), + ( + "tinyint_col", + Arc::new(Int32Array::from(vec![0, 1])) as _, + true, + ), + ( + "smallint_col", + Arc::new(Int32Array::from(vec![0, 1])) as _, + true, + ), + ("int_col", Arc::new(Int32Array::from(vec![0, 1])) as _, true), + ( + "bigint_col", + Arc::new(Int64Array::from(vec![0, 10])) as _, + true, + ), + ( + "float_col", + Arc::new(Float32Array::from(vec![0.0, 1.1])) as _, + true, + ), + ( + "double_col", + Arc::new(Float64Array::from(vec![0.0, 10.1])) as _, + true, + ), + ( + "date_string_col", + Arc::new(BinaryArray::from_iter_values([b"01/01/09", b"01/01/09"])) as _, + true, + ), + ( + "string_col", + Arc::new(BinaryArray::from_iter_values([b"0", b"1"])) as _, + true, + ), + ( + "timestamp_col", + Arc::new( + TimestampMicrosecondArray::from_iter_values([ + 1230768000000000, // 2009-01-01T00:00:00.000 + 1230768060000000, // 2009-01-01T00:01:00.000 + ]) + .with_timezone("+00:00"), + ) as _, + true, + ), + ]) + .unwrap(); + let file_path = arrow_test_data(file); + let batch_large = read_file(&file_path, 8, false); + assert_eq!( + batch_large, expected, + "Decoded RecordBatch does not match for file {file}" + ); + let batch_small = read_file(&file_path, 3, false); + assert_eq!( + batch_small, expected, + "Decoded RecordBatch (batch size 3) does not match for file {file}" + ); + } + + #[test] + fn test_alltypes_nulls_plain() { + let file = "avro/alltypes_nulls_plain.avro"; + let expected = RecordBatch::try_from_iter_with_nullable([ + ( + "string_col", + Arc::new(StringArray::from(vec![None::<&str>])) as _, + true, + ), + ("int_col", Arc::new(Int32Array::from(vec![None])) as _, true), + ( + "bool_col", + Arc::new(BooleanArray::from(vec![None])) as _, + true, + ), + ( + "bigint_col", + Arc::new(Int64Array::from(vec![None])) as _, + true, + ), + ( + "float_col", + Arc::new(Float32Array::from(vec![None])) as _, + true, + ), + ( + "double_col", + Arc::new(Float64Array::from(vec![None])) as _, + true, + ), + ( + "bytes_col", + Arc::new(BinaryArray::from(vec![None::<&[u8]>])) as _, + true, + ), + ]) + .unwrap(); + let file_path = arrow_test_data(file); + let batch_large = read_file(&file_path, 8, false); + assert_eq!( + batch_large, expected, + "Decoded RecordBatch does not match for file {file}" + ); + let batch_small = read_file(&file_path, 3, false); + assert_eq!( + batch_small, expected, + "Decoded RecordBatch (batch size 3) does not match for file {file}" + ); + } + + #[test] + // TODO: avoid requiring snappy for this file + #[cfg(feature = "snappy")] + fn test_binary() { + let file = arrow_test_data("avro/binary.avro"); + let batch = read_file(&file, 8, false); + let expected = RecordBatch::try_from_iter_with_nullable([( + "foo", + Arc::new(BinaryArray::from_iter_values(vec![ + b"\x00" as &[u8], + b"\x01" as &[u8], + b"\x02" as &[u8], + b"\x03" as &[u8], + b"\x04" as &[u8], + b"\x05" as &[u8], + b"\x06" as &[u8], + b"\x07" as &[u8], + b"\x08" as &[u8], + b"\t" as &[u8], + b"\n" as &[u8], + b"\x0b" as &[u8], + ])) as Arc, + true, + )]) + .unwrap(); + assert_eq!(batch, expected); + } + + #[test] + // TODO: avoid requiring snappy for these files + #[cfg(feature = "snappy")] + fn test_decimal() { + // Choose expected Arrow types depending on the `small_decimals` feature flag. + // With `small_decimals` enabled, Decimal32/Decimal64 are used where their + // precision allows; otherwise, those cases resolve to Decimal128. + #[cfg(feature = "small_decimals")] + let files: [(&str, DataType, HashMap); 8] = [ + ( + "avro/fixed_length_decimal.avro", + DataType::Decimal128(25, 2), + HashMap::from([ + ( + "avro.namespace".to_string(), + "topLevelRecord.value".to_string(), + ), + ("avro.name".to_string(), "fixed".to_string()), + ]), + ), + ( + "avro/fixed_length_decimal_legacy.avro", + DataType::Decimal64(13, 2), + HashMap::from([ + ( + "avro.namespace".to_string(), + "topLevelRecord.value".to_string(), + ), + ("avro.name".to_string(), "fixed".to_string()), + ]), + ), + ( + "avro/int32_decimal.avro", + DataType::Decimal32(4, 2), + HashMap::from([ + ( + "avro.namespace".to_string(), + "topLevelRecord.value".to_string(), + ), + ("avro.name".to_string(), "fixed".to_string()), + ]), + ), + ( + "avro/int64_decimal.avro", + DataType::Decimal64(10, 2), + HashMap::from([ + ( + "avro.namespace".to_string(), + "topLevelRecord.value".to_string(), + ), + ("avro.name".to_string(), "fixed".to_string()), + ]), + ), + ( + "test/data/int256_decimal.avro", + DataType::Decimal256(76, 10), + HashMap::new(), + ), + ( + "test/data/fixed256_decimal.avro", + DataType::Decimal256(76, 10), + HashMap::from([("avro.name".to_string(), "Decimal256Fixed".to_string())]), + ), + ( + "test/data/fixed_length_decimal_legacy_32.avro", + DataType::Decimal32(9, 2), + HashMap::from([("avro.name".to_string(), "Decimal32FixedLegacy".to_string())]), + ), + ( + "test/data/int128_decimal.avro", + DataType::Decimal128(38, 2), + HashMap::new(), + ), + ]; + #[cfg(not(feature = "small_decimals"))] + let files: [(&str, DataType, HashMap); 8] = [ + ( + "avro/fixed_length_decimal.avro", + DataType::Decimal128(25, 2), + HashMap::from([ + ( + "avro.namespace".to_string(), + "topLevelRecord.value".to_string(), + ), + ("avro.name".to_string(), "fixed".to_string()), + ]), + ), + ( + "avro/fixed_length_decimal_legacy.avro", + DataType::Decimal128(13, 2), + HashMap::from([ + ( + "avro.namespace".to_string(), + "topLevelRecord.value".to_string(), + ), + ("avro.name".to_string(), "fixed".to_string()), + ]), + ), + ( + "avro/int32_decimal.avro", + DataType::Decimal128(4, 2), + HashMap::from([ + ( + "avro.namespace".to_string(), + "topLevelRecord.value".to_string(), + ), + ("avro.name".to_string(), "fixed".to_string()), + ]), + ), + ( + "avro/int64_decimal.avro", + DataType::Decimal128(10, 2), + HashMap::from([ + ( + "avro.namespace".to_string(), + "topLevelRecord.value".to_string(), + ), + ("avro.name".to_string(), "fixed".to_string()), + ]), + ), + ( + "test/data/int256_decimal.avro", + DataType::Decimal256(76, 10), + HashMap::new(), + ), + ( + "test/data/fixed256_decimal.avro", + DataType::Decimal256(76, 10), + HashMap::from([("avro.name".to_string(), "Decimal256Fixed".to_string())]), + ), + ( + "test/data/fixed_length_decimal_legacy_32.avro", + DataType::Decimal128(9, 2), + HashMap::from([("avro.name".to_string(), "Decimal32FixedLegacy".to_string())]), + ), + ( + "test/data/int128_decimal.avro", + DataType::Decimal128(38, 2), + HashMap::new(), + ), + ]; + for (file, expected_dt, mut metadata) in files { + let (precision, scale) = match expected_dt { + DataType::Decimal32(p, s) + | DataType::Decimal64(p, s) + | DataType::Decimal128(p, s) + | DataType::Decimal256(p, s) => (p, s), + _ => unreachable!("Unexpected decimal type in test inputs"), + }; + assert!(scale >= 0, "test data uses non-negative scales only"); + let scale_u32 = scale as u32; + let file_path: String = if file.starts_with("avro/") { + arrow_test_data(file) + } else { + std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join(file) + .to_string_lossy() + .into_owned() + }; + let pow10: i128 = 10i128.pow(scale_u32); + let values_i128: Vec = (1..=24).map(|n| (n as i128) * pow10).collect(); + let build_expected = |dt: &DataType, values: &[i128]| -> ArrayRef { + match *dt { + #[cfg(feature = "small_decimals")] + DataType::Decimal32(p, s) => { + let it = values.iter().map(|&v| v as i32); + Arc::new( + Decimal32Array::from_iter_values(it) + .with_precision_and_scale(p, s) + .unwrap(), + ) + } + #[cfg(feature = "small_decimals")] + DataType::Decimal64(p, s) => { + let it = values.iter().map(|&v| v as i64); + Arc::new( + Decimal64Array::from_iter_values(it) + .with_precision_and_scale(p, s) + .unwrap(), + ) + } + DataType::Decimal128(p, s) => { + let it = values.iter().copied(); + Arc::new( + Decimal128Array::from_iter_values(it) + .with_precision_and_scale(p, s) + .unwrap(), + ) + } + DataType::Decimal256(p, s) => { + let it = values.iter().map(|&v| i256::from_i128(v)); + Arc::new( + Decimal256Array::from_iter_values(it) + .with_precision_and_scale(p, s) + .unwrap(), + ) + } + _ => unreachable!("Unexpected decimal type in test"), + } + }; + let actual_batch = read_file(&file_path, 8, false); + let actual_nullable = actual_batch.schema().field(0).is_nullable(); + let expected_array = build_expected(&expected_dt, &values_i128); + metadata.insert("precision".to_string(), precision.to_string()); + metadata.insert("scale".to_string(), scale.to_string()); + let field = + Field::new("value", expected_dt.clone(), actual_nullable).with_metadata(metadata); + let expected_schema = Arc::new(Schema::new(vec![field])); + let expected_batch = + RecordBatch::try_new(expected_schema.clone(), vec![expected_array]).unwrap(); + assert_eq!( + actual_batch, expected_batch, + "Decoded RecordBatch does not match for {file}" + ); + let actual_batch_small = read_file(&file_path, 3, false); + assert_eq!( + actual_batch_small, expected_batch, + "Decoded RecordBatch does not match for {file} with batch size 3" + ); + } + } + + #[test] + fn test_read_duration_logical_types_feature_toggle() -> Result<(), ArrowError> { + let file_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("test/data/duration_logical_types.avro") + .to_string_lossy() + .into_owned(); + + let actual_batch = read_file(&file_path, 4, false); + + let expected_batch = { + #[cfg(feature = "avro_custom_types")] + { + let schema = Arc::new(Schema::new(vec![ + Field::new( + "duration_time_nanos", + DataType::Duration(TimeUnit::Nanosecond), + false, + ), + Field::new( + "duration_time_micros", + DataType::Duration(TimeUnit::Microsecond), + false, + ), + Field::new( + "duration_time_millis", + DataType::Duration(TimeUnit::Millisecond), + false, + ), + Field::new( + "duration_time_seconds", + DataType::Duration(TimeUnit::Second), + false, + ), + ])); + + let nanos = Arc::new(PrimitiveArray::::from(vec![ + 10, 20, 30, 40, + ])) as ArrayRef; + let micros = Arc::new(PrimitiveArray::::from(vec![ + 100, 200, 300, 400, + ])) as ArrayRef; + let millis = Arc::new(PrimitiveArray::::from(vec![ + 1000, 2000, 3000, 4000, + ])) as ArrayRef; + let seconds = Arc::new(PrimitiveArray::::from(vec![1, 2, 3, 4])) + as ArrayRef; + + RecordBatch::try_new(schema, vec![nanos, micros, millis, seconds])? + } + #[cfg(not(feature = "avro_custom_types"))] + { + let schema = Arc::new(Schema::new(vec![ + Field::new("duration_time_nanos", DataType::Int64, false).with_metadata( + [( + "logicalType".to_string(), + "arrow.duration-nanos".to_string(), + )] + .into(), + ), + Field::new("duration_time_micros", DataType::Int64, false).with_metadata( + [( + "logicalType".to_string(), + "arrow.duration-micros".to_string(), + )] + .into(), + ), + Field::new("duration_time_millis", DataType::Int64, false).with_metadata( + [( + "logicalType".to_string(), + "arrow.duration-millis".to_string(), + )] + .into(), + ), + Field::new("duration_time_seconds", DataType::Int64, false).with_metadata( + [( + "logicalType".to_string(), + "arrow.duration-seconds".to_string(), + )] + .into(), + ), + ])); + + let nanos = + Arc::new(PrimitiveArray::::from(vec![10, 20, 30, 40])) as ArrayRef; + let micros = Arc::new(PrimitiveArray::::from(vec![100, 200, 300, 400])) + as ArrayRef; + let millis = Arc::new(PrimitiveArray::::from(vec![ + 1000, 2000, 3000, 4000, + ])) as ArrayRef; + let seconds = + Arc::new(PrimitiveArray::::from(vec![1, 2, 3, 4])) as ArrayRef; + + RecordBatch::try_new(schema, vec![nanos, micros, millis, seconds])? + } + }; + + assert_eq!(actual_batch, expected_batch); + + Ok(()) + } + + #[test] + // TODO: avoid requiring snappy for this file + #[cfg(feature = "snappy")] + fn test_dict_pages_offset_zero() { + let file = arrow_test_data("avro/dict-page-offset-zero.avro"); + let batch = read_file(&file, 32, false); + let num_rows = batch.num_rows(); + let expected_field = Int32Array::from(vec![Some(1552); num_rows]); + let expected = RecordBatch::try_from_iter_with_nullable([( + "l_partkey", + Arc::new(expected_field) as Arc, + true, + )]) + .unwrap(); + assert_eq!(batch, expected); + } + + #[test] + // TODO: avoid requiring snappy for this file + #[cfg(feature = "snappy")] + fn test_list_columns() { + let file = arrow_test_data("avro/list_columns.avro"); + let mut int64_list_builder = ListBuilder::new(Int64Builder::new()); + { + { + let values = int64_list_builder.values(); + values.append_value(1); + values.append_value(2); + values.append_value(3); + } + int64_list_builder.append(true); + } + { + { + let values = int64_list_builder.values(); + values.append_null(); + values.append_value(1); + } + int64_list_builder.append(true); + } + { + { + let values = int64_list_builder.values(); + values.append_value(4); + } + int64_list_builder.append(true); + } + let int64_list = int64_list_builder.finish(); + let mut utf8_list_builder = ListBuilder::new(StringBuilder::new()); + { + { + let values = utf8_list_builder.values(); + values.append_value("abc"); + values.append_value("efg"); + values.append_value("hij"); + } + utf8_list_builder.append(true); + } + { + utf8_list_builder.append(false); + } + { + { + let values = utf8_list_builder.values(); + values.append_value("efg"); + values.append_null(); + values.append_value("hij"); + values.append_value("xyz"); + } + utf8_list_builder.append(true); + } + let utf8_list = utf8_list_builder.finish(); + let expected = RecordBatch::try_from_iter_with_nullable([ + ("int64_list", Arc::new(int64_list) as Arc, true), + ("utf8_list", Arc::new(utf8_list) as Arc, true), + ]) + .unwrap(); + let batch = read_file(&file, 8, false); + assert_eq!(batch, expected); + } + + #[test] + #[cfg(feature = "snappy")] + fn test_nested_lists() { + use arrow_data::ArrayDataBuilder; + let file = arrow_test_data("avro/nested_lists.snappy.avro"); + let inner_values = StringArray::from(vec![ + Some("a"), + Some("b"), + Some("c"), + Some("d"), + Some("a"), + Some("b"), + Some("c"), + Some("d"), + Some("e"), + Some("a"), + Some("b"), + Some("c"), + Some("d"), + Some("e"), + Some("f"), + ]); + let inner_offsets = Buffer::from_slice_ref([0, 2, 3, 3, 4, 6, 8, 8, 9, 11, 13, 14, 14, 15]); + let inner_validity = [ + true, true, false, true, true, true, false, true, true, true, true, false, true, + ]; + let inner_null_buffer = Buffer::from_iter(inner_validity.iter().copied()); + let inner_field = Field::new("item", DataType::Utf8, true); + let inner_list_data = ArrayDataBuilder::new(DataType::List(Arc::new(inner_field))) + .len(13) + .add_buffer(inner_offsets) + .add_child_data(inner_values.to_data()) + .null_bit_buffer(Some(inner_null_buffer)) + .build() + .unwrap(); + let inner_list_array = ListArray::from(inner_list_data); + let middle_offsets = Buffer::from_slice_ref([0, 2, 4, 6, 8, 11, 13]); + let middle_validity = [true; 6]; + let middle_null_buffer = Buffer::from_iter(middle_validity.iter().copied()); + let middle_field = Field::new("item", inner_list_array.data_type().clone(), true); + let middle_list_data = ArrayDataBuilder::new(DataType::List(Arc::new(middle_field))) + .len(6) + .add_buffer(middle_offsets) + .add_child_data(inner_list_array.to_data()) + .null_bit_buffer(Some(middle_null_buffer)) + .build() + .unwrap(); + let middle_list_array = ListArray::from(middle_list_data); + let outer_offsets = Buffer::from_slice_ref([0, 2, 4, 6]); + let outer_null_buffer = Buffer::from_slice_ref([0b111]); // all 3 rows valid + let outer_field = Field::new("item", middle_list_array.data_type().clone(), true); + let outer_list_data = ArrayDataBuilder::new(DataType::List(Arc::new(outer_field))) + .len(3) + .add_buffer(outer_offsets) + .add_child_data(middle_list_array.to_data()) + .null_bit_buffer(Some(outer_null_buffer)) + .build() + .unwrap(); + let a_expected = ListArray::from(outer_list_data); + let b_expected = Int32Array::from(vec![1, 1, 1]); + let expected = RecordBatch::try_from_iter_with_nullable([ + ("a", Arc::new(a_expected) as Arc, true), + ("b", Arc::new(b_expected) as Arc, true), + ]) + .unwrap(); + let left = read_file(&file, 8, false); + assert_eq!(left, expected, "Mismatch for batch size=8"); + let left_small = read_file(&file, 3, false); + assert_eq!(left_small, expected, "Mismatch for batch size=3"); + } + + #[test] + fn test_simple() { + let tests = [ + ("avro/simple_enum.avro", 4, build_expected_enum(), 2), + ("avro/simple_fixed.avro", 2, build_expected_fixed(), 1), + ]; + + fn build_expected_enum() -> RecordBatch { + // Build the DictionaryArrays for f1, f2, f3 + let keys_f1 = Int32Array::from(vec![0, 1, 2, 3]); + let vals_f1 = StringArray::from(vec!["a", "b", "c", "d"]); + let f1_dict = + DictionaryArray::::try_new(keys_f1, Arc::new(vals_f1)).unwrap(); + let keys_f2 = Int32Array::from(vec![2, 3, 0, 1]); + let vals_f2 = StringArray::from(vec!["e", "f", "g", "h"]); + let f2_dict = + DictionaryArray::::try_new(keys_f2, Arc::new(vals_f2)).unwrap(); + let keys_f3 = Int32Array::from(vec![Some(1), Some(2), None, Some(0)]); + let vals_f3 = StringArray::from(vec!["i", "j", "k"]); + let f3_dict = + DictionaryArray::::try_new(keys_f3, Arc::new(vals_f3)).unwrap(); + let dict_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let mut md_f1 = HashMap::new(); + md_f1.insert( + AVRO_ENUM_SYMBOLS_METADATA_KEY.to_string(), + r#"["a","b","c","d"]"#.to_string(), + ); + md_f1.insert(AVRO_NAME_METADATA_KEY.to_string(), "enum1".to_string()); + md_f1.insert(AVRO_NAMESPACE_METADATA_KEY.to_string(), "ns1".to_string()); + let f1_field = Field::new("f1", dict_type.clone(), false).with_metadata(md_f1); + let mut md_f2 = HashMap::new(); + md_f2.insert( + AVRO_ENUM_SYMBOLS_METADATA_KEY.to_string(), + r#"["e","f","g","h"]"#.to_string(), + ); + md_f2.insert(AVRO_NAME_METADATA_KEY.to_string(), "enum2".to_string()); + md_f2.insert(AVRO_NAMESPACE_METADATA_KEY.to_string(), "ns2".to_string()); + let f2_field = Field::new("f2", dict_type.clone(), false).with_metadata(md_f2); + let mut md_f3 = HashMap::new(); + md_f3.insert( + AVRO_ENUM_SYMBOLS_METADATA_KEY.to_string(), + r#"["i","j","k"]"#.to_string(), + ); + md_f3.insert(AVRO_NAME_METADATA_KEY.to_string(), "enum3".to_string()); + md_f3.insert(AVRO_NAMESPACE_METADATA_KEY.to_string(), "ns1".to_string()); + let f3_field = Field::new("f3", dict_type.clone(), true).with_metadata(md_f3); + let expected_schema = Arc::new(Schema::new(vec![f1_field, f2_field, f3_field])); + RecordBatch::try_new( + expected_schema, + vec![ + Arc::new(f1_dict) as Arc, + Arc::new(f2_dict) as Arc, + Arc::new(f3_dict) as Arc, + ], + ) + .unwrap() + } + + fn build_expected_fixed() -> RecordBatch { + let f1 = + FixedSizeBinaryArray::try_from_iter(vec![b"abcde", b"12345"].into_iter()).unwrap(); + let f2 = + FixedSizeBinaryArray::try_from_iter(vec![b"fghijklmno", b"1234567890"].into_iter()) + .unwrap(); + let f3 = FixedSizeBinaryArray::try_from_sparse_iter_with_size( + vec![Some(b"ABCDEF" as &[u8]), None].into_iter(), + 6, + ) + .unwrap(); + + // Add Avro named-type metadata for fixed fields + let mut md_f1 = HashMap::new(); + md_f1.insert( + crate::schema::AVRO_NAME_METADATA_KEY.to_string(), + "fixed1".to_string(), + ); + md_f1.insert( + crate::schema::AVRO_NAMESPACE_METADATA_KEY.to_string(), + "ns1".to_string(), + ); + + let mut md_f2 = HashMap::new(); + md_f2.insert( + crate::schema::AVRO_NAME_METADATA_KEY.to_string(), + "fixed2".to_string(), + ); + md_f2.insert( + crate::schema::AVRO_NAMESPACE_METADATA_KEY.to_string(), + "ns2".to_string(), + ); + + let mut md_f3 = HashMap::new(); + md_f3.insert( + crate::schema::AVRO_NAME_METADATA_KEY.to_string(), + "fixed3".to_string(), + ); + md_f3.insert( + crate::schema::AVRO_NAMESPACE_METADATA_KEY.to_string(), + "ns1".to_string(), + ); + + let expected_schema = Arc::new(Schema::new(vec![ + Field::new("f1", DataType::FixedSizeBinary(5), false).with_metadata(md_f1), + Field::new("f2", DataType::FixedSizeBinary(10), false).with_metadata(md_f2), + Field::new("f3", DataType::FixedSizeBinary(6), true).with_metadata(md_f3), + ])); + + RecordBatch::try_new( + expected_schema, + vec![ + Arc::new(f1) as Arc, + Arc::new(f2) as Arc, + Arc::new(f3) as Arc, + ], + ) + .unwrap() + } + for (file_name, batch_size, expected, alt_batch_size) in tests { + let file = arrow_test_data(file_name); + let actual = read_file(&file, batch_size, false); + assert_eq!(actual, expected); + let actual2 = read_file(&file, alt_batch_size, false); + assert_eq!(actual2, expected); + } + } + + #[test] + #[cfg(feature = "snappy")] + fn test_single_nan() { + let file = arrow_test_data("avro/single_nan.avro"); + let actual = read_file(&file, 1, false); + use arrow_array::Float64Array; + let schema = Arc::new(Schema::new(vec![Field::new( + "mycol", + DataType::Float64, + true, + )])); + let col = Float64Array::from(vec![None]); + let expected = RecordBatch::try_new(schema, vec![Arc::new(col)]).unwrap(); + assert_eq!(actual, expected); + let actual2 = read_file(&file, 2, false); + assert_eq!(actual2, expected); + } + + #[test] + fn test_duration_uuid() { + let batch = read_file("test/data/duration_uuid.avro", 4, false); + let schema = batch.schema(); + let fields = schema.fields(); + assert_eq!(fields.len(), 2); + assert_eq!(fields[0].name(), "duration_field"); + assert_eq!( + fields[0].data_type(), + &DataType::Interval(IntervalUnit::MonthDayNano) + ); + assert_eq!(fields[1].name(), "uuid_field"); + assert_eq!(fields[1].data_type(), &DataType::FixedSizeBinary(16)); + assert_eq!(batch.num_rows(), 4); + assert_eq!(batch.num_columns(), 2); + let duration_array = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let expected_duration_array: IntervalMonthDayNanoArray = [ + Some(IntervalMonthDayNanoType::make_value(1, 15, 500_000_000)), + Some(IntervalMonthDayNanoType::make_value(0, 5, 2_500_000_000)), + Some(IntervalMonthDayNanoType::make_value(2, 0, 0)), + Some(IntervalMonthDayNanoType::make_value(12, 31, 999_000_000)), + ] + .iter() + .copied() + .collect(); + assert_eq!(&expected_duration_array, duration_array); + let uuid_array = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + let expected_uuid_array = FixedSizeBinaryArray::try_from_sparse_iter_with_size( + [ + Some([ + 0xfe, 0x7b, 0xc3, 0x0b, 0x4c, 0xe8, 0x4c, 0x5e, 0xb6, 0x7c, 0x22, 0x34, 0xa2, + 0xd3, 0x8e, 0x66, + ]), + Some([ + 0xb3, 0x3f, 0x2a, 0xd7, 0x97, 0xb4, 0x4d, 0xe1, 0x8b, 0xfe, 0x94, 0x94, 0x1d, + 0x60, 0x15, 0x6e, + ]), + Some([ + 0x5f, 0x74, 0x92, 0x64, 0x07, 0x4b, 0x40, 0x05, 0x84, 0xbf, 0x11, 0x5e, 0xa8, + 0x4e, 0xd2, 0x0a, + ]), + Some([ + 0x08, 0x26, 0xcc, 0x06, 0xd2, 0xe3, 0x45, 0x99, 0xb4, 0xad, 0xaf, 0x5f, 0xa6, + 0x90, 0x5c, 0xdb, + ]), + ] + .into_iter(), + 16, + ) + .unwrap(); + assert_eq!(&expected_uuid_array, uuid_array); + } + + #[test] + #[cfg(feature = "snappy")] + fn test_datapage_v2() { + let file = arrow_test_data("avro/datapage_v2.snappy.avro"); + let batch = read_file(&file, 8, false); + let a = StringArray::from(vec![ + Some("abc"), + Some("abc"), + Some("abc"), + None, + Some("abc"), + ]); + let b = Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)]); + let c = Float64Array::from(vec![Some(2.0), Some(3.0), Some(4.0), Some(5.0), Some(2.0)]); + let d = BooleanArray::from(vec![ + Some(true), + Some(true), + Some(true), + Some(false), + Some(true), + ]); + let e_values = Int32Array::from(vec![ + Some(1), + Some(2), + Some(3), + Some(1), + Some(2), + Some(3), + Some(1), + Some(2), + ]); + let e_offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0i32, 3, 3, 3, 6, 8])); + let e_validity = Some(NullBuffer::from(vec![true, false, false, true, true])); + let field_e = Arc::new(Field::new("item", DataType::Int32, true)); + let e = ListArray::new(field_e, e_offsets, Arc::new(e_values), e_validity); + let expected = RecordBatch::try_from_iter_with_nullable([ + ("a", Arc::new(a) as Arc, true), + ("b", Arc::new(b) as Arc, true), + ("c", Arc::new(c) as Arc, true), + ("d", Arc::new(d) as Arc, true), + ("e", Arc::new(e) as Arc, true), + ]) + .unwrap(); + assert_eq!(batch, expected); + } + + #[test] + fn test_nested_records() { + let f1_f1_1 = StringArray::from(vec!["aaa", "bbb"]); + let f1_f1_2 = Int32Array::from(vec![10, 20]); + let rounded_pi = (std::f64::consts::PI * 100.0).round() / 100.0; + let f1_f1_3_1 = Float64Array::from(vec![rounded_pi, rounded_pi]); + let f1_f1_3 = StructArray::from(vec![( + Arc::new(Field::new("f1_3_1", DataType::Float64, false)), + Arc::new(f1_f1_3_1) as Arc, + )]); + // Add Avro named-type metadata to nested field f1_3 (ns3.record3) + let mut f1_3_md: HashMap = HashMap::new(); + f1_3_md.insert(AVRO_NAMESPACE_METADATA_KEY.to_string(), "ns3".to_string()); + f1_3_md.insert(AVRO_NAME_METADATA_KEY.to_string(), "record3".to_string()); + let f1_expected = StructArray::from(vec![ + ( + Arc::new(Field::new("f1_1", DataType::Utf8, false)), + Arc::new(f1_f1_1) as Arc, + ), + ( + Arc::new(Field::new("f1_2", DataType::Int32, false)), + Arc::new(f1_f1_2) as Arc, + ), + ( + Arc::new( + Field::new( + "f1_3", + DataType::Struct(Fields::from(vec![Field::new( + "f1_3_1", + DataType::Float64, + false, + )])), + false, + ) + .with_metadata(f1_3_md), + ), + Arc::new(f1_f1_3) as Arc, + ), + ]); + let f2_fields = [ + Field::new("f2_1", DataType::Boolean, false), + Field::new("f2_2", DataType::Float32, false), + ]; + let f2_struct_builder = StructBuilder::new( + f2_fields + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>>(), + vec![ + Box::new(BooleanBuilder::new()) as Box, + Box::new(Float32Builder::new()) as Box, + ], + ); + let mut f2_list_builder = ListBuilder::new(f2_struct_builder); + { + let struct_builder = f2_list_builder.values(); + struct_builder.append(true); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_value(true); + } + { + let b = struct_builder.field_builder::(1).unwrap(); + b.append_value(1.2_f32); + } + struct_builder.append(true); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_value(true); + } + { + let b = struct_builder.field_builder::(1).unwrap(); + b.append_value(2.2_f32); + } + f2_list_builder.append(true); + } + { + let struct_builder = f2_list_builder.values(); + struct_builder.append(true); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_value(false); + } + { + let b = struct_builder.field_builder::(1).unwrap(); + b.append_value(10.2_f32); + } + f2_list_builder.append(true); + } + + let list_array_with_nullable_items = f2_list_builder.finish(); + // Add Avro named-type metadata to f2's list item (ns4.record4) + let mut f2_item_md: HashMap = HashMap::new(); + f2_item_md.insert(AVRO_NAME_METADATA_KEY.to_string(), "record4".to_string()); + f2_item_md.insert(AVRO_NAMESPACE_METADATA_KEY.to_string(), "ns4".to_string()); + let item_field = Arc::new( + Field::new( + "item", + list_array_with_nullable_items.values().data_type().clone(), + false, // items are non-nullable for f2 + ) + .with_metadata(f2_item_md), + ); + let list_data_type = DataType::List(item_field); + let f2_array_data = list_array_with_nullable_items + .to_data() + .into_builder() + .data_type(list_data_type) + .build() + .unwrap(); + let f2_expected = ListArray::from(f2_array_data); + let mut f3_struct_builder = StructBuilder::new( + vec![Arc::new(Field::new("f3_1", DataType::Utf8, false))], + vec![Box::new(StringBuilder::new()) as Box], + ); + f3_struct_builder.append(true); + { + let b = f3_struct_builder.field_builder::(0).unwrap(); + b.append_value("xyz"); + } + f3_struct_builder.append(false); + { + let b = f3_struct_builder.field_builder::(0).unwrap(); + b.append_null(); + } + let f3_expected = f3_struct_builder.finish(); + let f4_fields = [Field::new("f4_1", DataType::Int64, false)]; + let f4_struct_builder = StructBuilder::new( + f4_fields + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>>(), + vec![Box::new(Int64Builder::new()) as Box], + ); + let mut f4_list_builder = ListBuilder::new(f4_struct_builder); + { + let struct_builder = f4_list_builder.values(); + struct_builder.append(true); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_value(200); + } + struct_builder.append(false); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_null(); + } + f4_list_builder.append(true); + } + { + let struct_builder = f4_list_builder.values(); + struct_builder.append(false); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_null(); + } + struct_builder.append(true); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_value(300); + } + f4_list_builder.append(true); + } + let f4_expected = f4_list_builder.finish(); + // Add Avro named-type metadata to f4's list item (ns6.record6), item is nullable + let mut f4_item_md: HashMap = HashMap::new(); + f4_item_md.insert(AVRO_NAMESPACE_METADATA_KEY.to_string(), "ns6".to_string()); + f4_item_md.insert(AVRO_NAME_METADATA_KEY.to_string(), "record6".to_string()); + let f4_item_field = Arc::new( + Field::new("item", f4_expected.values().data_type().clone(), true) + .with_metadata(f4_item_md), + ); + let f4_list_data_type = DataType::List(f4_item_field); + let f4_array_data = f4_expected + .to_data() + .into_builder() + .data_type(f4_list_data_type) + .build() + .unwrap(); + let f4_expected = ListArray::from(f4_array_data); + // Build Schema with Avro named-type metadata on the top-level f1 and f3 fields + let mut f1_md: HashMap = HashMap::new(); + f1_md.insert(AVRO_NAME_METADATA_KEY.to_string(), "record2".to_string()); + f1_md.insert(AVRO_NAMESPACE_METADATA_KEY.to_string(), "ns2".to_string()); + let mut f3_md: HashMap = HashMap::new(); + f3_md.insert(AVRO_NAMESPACE_METADATA_KEY.to_string(), "ns5".to_string()); + f3_md.insert(AVRO_NAME_METADATA_KEY.to_string(), "record5".to_string()); + let expected_schema = Schema::new(vec![ + Field::new("f1", f1_expected.data_type().clone(), false).with_metadata(f1_md), + Field::new("f2", f2_expected.data_type().clone(), false), + Field::new("f3", f3_expected.data_type().clone(), true).with_metadata(f3_md), + Field::new("f4", f4_expected.data_type().clone(), false), + ]); + let expected = RecordBatch::try_new( + Arc::new(expected_schema), + vec![ + Arc::new(f1_expected) as Arc, + Arc::new(f2_expected) as Arc, + Arc::new(f3_expected) as Arc, + Arc::new(f4_expected) as Arc, + ], + ) + .unwrap(); + let file = arrow_test_data("avro/nested_records.avro"); + let batch_large = read_file(&file, 8, false); + assert_eq!( + batch_large, expected, + "Decoded RecordBatch does not match expected data for nested records (batch size 8)" + ); + let batch_small = read_file(&file, 3, false); + assert_eq!( + batch_small, expected, + "Decoded RecordBatch does not match expected data for nested records (batch size 3)" + ); + } + + #[test] + // TODO: avoid requiring snappy for this file + #[cfg(feature = "snappy")] + fn test_repeated_no_annotation() { + use arrow_data::ArrayDataBuilder; + let file = arrow_test_data("avro/repeated_no_annotation.avro"); + let batch_large = read_file(&file, 8, false); + // id column + let id_array = Int32Array::from(vec![1, 2, 3, 4, 5, 6]); + // Build the inner Struct + let number_array = Int64Array::from(vec![ + Some(5555555555), + Some(1111111111), + Some(1111111111), + Some(2222222222), + Some(3333333333), + ]); + let kind_array = + StringArray::from(vec![None, Some("home"), Some("home"), None, Some("mobile")]); + let phone_fields = Fields::from(vec![ + Field::new("number", DataType::Int64, true), + Field::new("kind", DataType::Utf8, true), + ]); + let phone_struct_data = ArrayDataBuilder::new(DataType::Struct(phone_fields)) + .len(5) + .child_data(vec![number_array.into_data(), kind_array.into_data()]) + .build() + .unwrap(); + let phone_struct_array = StructArray::from(phone_struct_data); + // Build List> with Avro named-type metadata on the *element* field + let phone_list_offsets = Buffer::from_slice_ref([0i32, 0, 0, 0, 1, 2, 5]); + let phone_list_validity = Buffer::from_iter([false, false, true, true, true, true]); + // The Avro schema names this inner record "phone" in namespace "topLevelRecord.phoneNumbers" + let mut phone_item_md = HashMap::new(); + phone_item_md.insert(AVRO_NAME_METADATA_KEY.to_string(), "phone".to_string()); + phone_item_md.insert( + AVRO_NAMESPACE_METADATA_KEY.to_string(), + "topLevelRecord.phoneNumbers".to_string(), + ); + let phone_item_field = Field::new("item", phone_struct_array.data_type().clone(), true) + .with_metadata(phone_item_md); + let phone_list_data = ArrayDataBuilder::new(DataType::List(Arc::new(phone_item_field))) + .len(6) + .add_buffer(phone_list_offsets) + .null_bit_buffer(Some(phone_list_validity)) + .child_data(vec![phone_struct_array.into_data()]) + .build() + .unwrap(); + let phone_list_array = ListArray::from(phone_list_data); + // Wrap in Struct { phone: List<...> } + let phone_numbers_validity = Buffer::from_iter([false, false, true, true, true, true]); + let phone_numbers_field = Field::new("phone", phone_list_array.data_type().clone(), true); + let phone_numbers_struct_data = + ArrayDataBuilder::new(DataType::Struct(Fields::from(vec![phone_numbers_field]))) + .len(6) + .null_bit_buffer(Some(phone_numbers_validity)) + .child_data(vec![phone_list_array.into_data()]) + .build() + .unwrap(); + let phone_numbers_struct_array = StructArray::from(phone_numbers_struct_data); + // Build the expected Schema, annotating the top-level "phoneNumbers" field with Avro name/namespace + let mut phone_numbers_md = HashMap::new(); + phone_numbers_md.insert( + AVRO_NAME_METADATA_KEY.to_string(), + "phoneNumbers".to_string(), + ); + phone_numbers_md.insert( + AVRO_NAMESPACE_METADATA_KEY.to_string(), + "topLevelRecord".to_string(), + ); + let id_field = Field::new("id", DataType::Int32, true); + let phone_numbers_schema_field = Field::new( + "phoneNumbers", + phone_numbers_struct_array.data_type().clone(), + true, + ) + .with_metadata(phone_numbers_md); + let expected_schema = Schema::new(vec![id_field, phone_numbers_schema_field]); + // Final expected RecordBatch (arrays already carry matching list-element metadata) + let expected = RecordBatch::try_new( + Arc::new(expected_schema), + vec![ + Arc::new(id_array) as _, + Arc::new(phone_numbers_struct_array) as _, + ], + ) + .unwrap(); + assert_eq!(batch_large, expected, "Mismatch for batch_size=8"); + let batch_small = read_file(&file, 3, false); + assert_eq!(batch_small, expected, "Mismatch for batch_size=3"); + } + + #[test] + // TODO: avoid requiring snappy for this file + #[cfg(feature = "snappy")] + fn test_nonnullable_impala() { + let file = arrow_test_data("avro/nonnullable.impala.avro"); + let id = Int64Array::from(vec![Some(8)]); + let mut int_array_builder = ListBuilder::new(Int32Builder::new()); + { + let vb = int_array_builder.values(); + vb.append_value(-1); + } + int_array_builder.append(true); // finalize one sub-list + let int_array = int_array_builder.finish(); + let mut iaa_builder = ListBuilder::new(ListBuilder::new(Int32Builder::new())); + { + let inner_list_builder = iaa_builder.values(); + { + let vb = inner_list_builder.values(); + vb.append_value(-1); + vb.append_value(-2); + } + inner_list_builder.append(true); + inner_list_builder.append(true); + } + iaa_builder.append(true); + let int_array_array = iaa_builder.finish(); + let field_names = MapFieldNames { + entry: "entries".to_string(), + key: "key".to_string(), + value: "value".to_string(), + }; + let mut int_map_builder = + MapBuilder::new(Some(field_names), StringBuilder::new(), Int32Builder::new()); + { + let (keys, vals) = int_map_builder.entries(); + keys.append_value("k1"); + vals.append_value(-1); + } + int_map_builder.append(true).unwrap(); // finalize map for row 0 + let int_map = int_map_builder.finish(); + let field_names2 = MapFieldNames { + entry: "entries".to_string(), + key: "key".to_string(), + value: "value".to_string(), + }; + let mut ima_builder = ListBuilder::new(MapBuilder::new( + Some(field_names2), + StringBuilder::new(), + Int32Builder::new(), + )); + { + let map_builder = ima_builder.values(); + map_builder.append(true).unwrap(); + { + let (keys, vals) = map_builder.entries(); + keys.append_value("k1"); + vals.append_value(1); + } + map_builder.append(true).unwrap(); + map_builder.append(true).unwrap(); + map_builder.append(true).unwrap(); + } + ima_builder.append(true); + let int_map_array_ = ima_builder.finish(); + // Helper metadata maps + let meta_nested_struct: HashMap = [ + ("avro.name", "nested_Struct"), + ("avro.namespace", "topLevelRecord"), + ] + .into_iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + let meta_c: HashMap = [ + ("avro.name", "c"), + ("avro.namespace", "topLevelRecord.nested_Struct"), + ] + .into_iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + let meta_d_item_struct: HashMap = [ + ("avro.name", "D"), + ("avro.namespace", "topLevelRecord.nested_Struct.c"), + ] + .into_iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + let meta_g_value: HashMap = [ + ("avro.name", "G"), + ("avro.namespace", "topLevelRecord.nested_Struct"), + ] + .into_iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + let meta_h: HashMap = [ + ("avro.name", "h"), + ("avro.namespace", "topLevelRecord.nested_Struct.G"), + ] + .into_iter() + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + // Types used multiple times below + let ef_struct_field = Arc::new( + Field::new( + "item", + DataType::Struct( + vec![ + Field::new("e", DataType::Int32, true), + Field::new("f", DataType::Utf8, true), + ] + .into(), + ), + true, + ) + .with_metadata(meta_d_item_struct.clone()), + ); + let d_inner_list_field = Arc::new(Field::new( + "item", + DataType::List(ef_struct_field.clone()), + true, + )); + let d_field = Field::new("D", DataType::List(d_inner_list_field.clone()), true); + // G.value.h.i : List + let i_list_field = Arc::new(Field::new("item", DataType::Float64, true)); + let i_field = Field::new("i", DataType::List(i_list_field.clone()), true); + // G.value.h : Struct<{ i: List }> with metadata (h) + let h_field = Field::new("h", DataType::Struct(vec![i_field.clone()].into()), true) + .with_metadata(meta_h.clone()); + // G.value : Struct<{ h: ... }> with metadata (G) + let g_value_struct_field = Field::new( + "value", + DataType::Struct(vec![h_field.clone()].into()), + true, + ) + .with_metadata(meta_g_value.clone()); + // entries struct for Map G + let entries_struct_field = Field::new( + "entries", + DataType::Struct( + vec![ + Field::new("key", DataType::Utf8, false), + g_value_struct_field.clone(), + ] + .into(), + ), + false, + ); + // Top-level nested_Struct fields (include metadata on "c") + let a_field = Arc::new(Field::new("a", DataType::Int32, true)); + let b_field = Arc::new(Field::new( + "B", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + )); + let c_field = Arc::new( + Field::new("c", DataType::Struct(vec![d_field.clone()].into()), true) + .with_metadata(meta_c.clone()), + ); + let g_field = Arc::new(Field::new( + "G", + DataType::Map(Arc::new(entries_struct_field.clone()), false), + true, + )); + // Now create builders that match these exact field types (so nested types carry metadata) + let mut nested_sb = StructBuilder::new( + vec![ + a_field.clone(), + b_field.clone(), + c_field.clone(), + g_field.clone(), + ], + vec![ + Box::new(Int32Builder::new()), + Box::new(ListBuilder::new(Int32Builder::new())), + { + // builder for "c" with correctly typed "D" including metadata on inner list item + Box::new(StructBuilder::new( + vec![Arc::new(d_field.clone())], + vec![Box::new({ + let ef_struct_builder = StructBuilder::new( + vec![ + Arc::new(Field::new("e", DataType::Int32, true)), + Arc::new(Field::new("f", DataType::Utf8, true)), + ], + vec![ + Box::new(Int32Builder::new()), + Box::new(StringBuilder::new()), + ], + ); + // Inner list that holds Struct with Avro named-type metadata ("D") + let list_of_ef = ListBuilder::new(ef_struct_builder) + .with_field(ef_struct_field.clone()); + // Outer list for "D" + ListBuilder::new(list_of_ef) + })], + )) + }, + { + let map_field_names = MapFieldNames { + entry: "entries".to_string(), + key: "key".to_string(), + value: "value".to_string(), + }; + let i_list_builder = ListBuilder::new(Float64Builder::new()); + let h_struct_builder = StructBuilder::new( + vec![Arc::new(Field::new( + "i", + DataType::List(i_list_field.clone()), + true, + ))], + vec![Box::new(i_list_builder)], + ); + let g_value_builder = StructBuilder::new( + vec![Arc::new( + Field::new("h", DataType::Struct(vec![i_field.clone()].into()), true) + .with_metadata(meta_h.clone()), + )], + vec![Box::new(h_struct_builder)], + ); + // Use with_values_field to attach metadata to "value" field in the map's entries + let map_builder = MapBuilder::new( + Some(map_field_names), + StringBuilder::new(), + g_value_builder, + ) + .with_values_field(Arc::new( + Field::new( + "value", + DataType::Struct(vec![h_field.clone()].into()), + true, + ) + .with_metadata(meta_g_value.clone()), + )); + + Box::new(map_builder) + }, + ], + ); + nested_sb.append(true); + { + let a_builder = nested_sb.field_builder::(0).unwrap(); + a_builder.append_value(-1); + } + { + let b_builder = nested_sb + .field_builder::>(1) + .unwrap(); + { + let vb = b_builder.values(); + vb.append_value(-1); + } + b_builder.append(true); + } + { + let c_struct_builder = nested_sb.field_builder::(2).unwrap(); + c_struct_builder.append(true); + let d_list_builder = c_struct_builder + .field_builder::>>(0) + .unwrap(); + { + let sub_list_builder = d_list_builder.values(); + { + let ef_struct = sub_list_builder.values(); + ef_struct.append(true); + { + let e_b = ef_struct.field_builder::(0).unwrap(); + e_b.append_value(-1); + let f_b = ef_struct.field_builder::(1).unwrap(); + f_b.append_value("nonnullable"); + } + sub_list_builder.append(true); + } + d_list_builder.append(true); + } + } + { + let g_map_builder = nested_sb + .field_builder::>(3) + .unwrap(); + g_map_builder.append(true).unwrap(); + } + let nested_struct = nested_sb.finish(); + let schema = Arc::new(arrow_schema::Schema::new(vec![ + Field::new("ID", id.data_type().clone(), true), + Field::new("Int_Array", int_array.data_type().clone(), true), + Field::new("int_array_array", int_array_array.data_type().clone(), true), + Field::new("Int_Map", int_map.data_type().clone(), true), + Field::new("int_map_array", int_map_array_.data_type().clone(), true), + Field::new("nested_Struct", nested_struct.data_type().clone(), true) + .with_metadata(meta_nested_struct.clone()), + ])); + let expected = RecordBatch::try_new( + schema, + vec![ + Arc::new(id) as Arc, + Arc::new(int_array), + Arc::new(int_array_array), + Arc::new(int_map), + Arc::new(int_map_array_), + Arc::new(nested_struct), + ], + ) + .unwrap(); + let batch_large = read_file(&file, 8, false); + assert_eq!(batch_large, expected, "Mismatch for batch_size=8"); + let batch_small = read_file(&file, 3, false); + assert_eq!(batch_small, expected, "Mismatch for batch_size=3"); + } + + #[test] + fn test_nonnullable_impala_strict() { + let file = arrow_test_data("avro/nonnullable.impala.avro"); + let err = read_file_strict(&file, 8, false).unwrap_err(); + assert!(err.to_string().contains( + "Found Avro union of the form ['T','null'], which is disallowed in strict_mode" + )); + } + + #[test] + // TODO: avoid requiring snappy for this file + #[cfg(feature = "snappy")] + fn test_nullable_impala() { + let file = arrow_test_data("avro/nullable.impala.avro"); + let batch1 = read_file(&file, 3, false); + let batch2 = read_file(&file, 8, false); + assert_eq!(batch1, batch2); + let batch = batch1; + assert_eq!(batch.num_rows(), 7); + let id_array = batch + .column(0) + .as_any() + .downcast_ref::() + .expect("id column should be an Int64Array"); + let expected_ids = [1, 2, 3, 4, 5, 6, 7]; + for (i, &expected_id) in expected_ids.iter().enumerate() { + assert_eq!(id_array.value(i), expected_id, "Mismatch in id at row {i}",); + } + let int_array = batch + .column(1) + .as_any() + .downcast_ref::() + .expect("int_array column should be a ListArray"); + { + let offsets = int_array.value_offsets(); + let start = offsets[0] as usize; + let end = offsets[1] as usize; + let values = int_array + .values() + .as_any() + .downcast_ref::() + .expect("Values of int_array should be an Int32Array"); + let row0: Vec> = (start..end).map(|i| Some(values.value(i))).collect(); + assert_eq!( + row0, + vec![Some(1), Some(2), Some(3)], + "Mismatch in int_array row 0" + ); + } + let nested_struct = batch + .column(5) + .as_any() + .downcast_ref::() + .expect("nested_struct column should be a StructArray"); + let a_array = nested_struct + .column_by_name("A") + .expect("Field A should exist in nested_struct") + .as_any() + .downcast_ref::() + .expect("Field A should be an Int32Array"); + assert_eq!(a_array.value(0), 1, "Mismatch in nested_struct.A at row 0"); + assert!( + !a_array.is_valid(1), + "Expected null in nested_struct.A at row 1" + ); + assert!( + !a_array.is_valid(3), + "Expected null in nested_struct.A at row 3" + ); + assert_eq!(a_array.value(6), 7, "Mismatch in nested_struct.A at row 6"); + } + + #[test] + fn test_nullable_impala_strict() { + let file = arrow_test_data("avro/nullable.impala.avro"); + let err = read_file_strict(&file, 8, false).unwrap_err(); + assert!(err.to_string().contains( + "Found Avro union of the form ['T','null'], which is disallowed in strict_mode" + )); + } + + #[test] + fn test_nested_record_type_reuse() { + // The .avro file has the following schema: + // { + // "type" : "record", + // "name" : "Record", + // "fields" : [ { + // "name" : "nested", + // "type" : { + // "type" : "record", + // "name" : "Nested", + // "fields" : [ { + // "name" : "nested_int", + // "type" : "int" + // } ] + // } + // }, { + // "name" : "nestedRecord", + // "type" : "Nested" + // }, { + // "name" : "nestedArray", + // "type" : { + // "type" : "array", + // "items" : "Nested" + // } + // } ] + // } + let batch = read_file("test/data/nested_record_reuse.avro", 8, false); + let schema = batch.schema(); + + // Verify schema structure + assert_eq!(schema.fields().len(), 3); + let fields = schema.fields(); + assert_eq!(fields[0].name(), "nested"); + assert_eq!(fields[1].name(), "nestedRecord"); + assert_eq!(fields[2].name(), "nestedArray"); + assert!(matches!(fields[0].data_type(), DataType::Struct(_))); + assert!(matches!(fields[1].data_type(), DataType::Struct(_))); + assert!(matches!(fields[2].data_type(), DataType::List(_))); + + // Validate that the nested record type + if let DataType::Struct(nested_fields) = fields[0].data_type() { + assert_eq!(nested_fields.len(), 1); + assert_eq!(nested_fields[0].name(), "nested_int"); + assert_eq!(nested_fields[0].data_type(), &DataType::Int32); + } + + // Validate that the nested record type is reused + assert_eq!(fields[0].data_type(), fields[1].data_type()); + if let DataType::List(array_field) = fields[2].data_type() { + assert_eq!(array_field.data_type(), fields[0].data_type()); + } + + // Validate data + assert_eq!(batch.num_rows(), 2); + assert_eq!(batch.num_columns(), 3); + + // Validate the first column (nested) + let nested_col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let nested_int_array = nested_col + .column_by_name("nested_int") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(nested_int_array.value(0), 42); + assert_eq!(nested_int_array.value(1), 99); + + // Validate the second column (nestedRecord) + let nested_record_col = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + let nested_record_int_array = nested_record_col + .column_by_name("nested_int") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(nested_record_int_array.value(0), 100); + assert_eq!(nested_record_int_array.value(1), 200); + + // Validate the third column (nestedArray) + let nested_array_col = batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(nested_array_col.len(), 2); + let first_array_struct = nested_array_col.value(0); + let first_array_struct_array = first_array_struct + .as_any() + .downcast_ref::() + .unwrap(); + let first_array_int_values = first_array_struct_array + .column_by_name("nested_int") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(first_array_int_values.len(), 3); + assert_eq!(first_array_int_values.value(0), 1); + assert_eq!(first_array_int_values.value(1), 2); + assert_eq!(first_array_int_values.value(2), 3); + } + + #[test] + fn test_enum_type_reuse() { + // The .avro file has the following schema: + // { + // "type" : "record", + // "name" : "Record", + // "fields" : [ { + // "name" : "status", + // "type" : { + // "type" : "enum", + // "name" : "Status", + // "symbols" : [ "ACTIVE", "INACTIVE", "PENDING" ] + // } + // }, { + // "name" : "backupStatus", + // "type" : "Status" + // }, { + // "name" : "statusHistory", + // "type" : { + // "type" : "array", + // "items" : "Status" + // } + // } ] + // } + let batch = read_file("test/data/enum_reuse.avro", 8, false); + let schema = batch.schema(); + + // Verify schema structure + assert_eq!(schema.fields().len(), 3); + let fields = schema.fields(); + assert_eq!(fields[0].name(), "status"); + assert_eq!(fields[1].name(), "backupStatus"); + assert_eq!(fields[2].name(), "statusHistory"); + assert!(matches!(fields[0].data_type(), DataType::Dictionary(_, _))); + assert!(matches!(fields[1].data_type(), DataType::Dictionary(_, _))); + assert!(matches!(fields[2].data_type(), DataType::List(_))); + + if let DataType::Dictionary(key_type, value_type) = fields[0].data_type() { + assert_eq!(key_type.as_ref(), &DataType::Int32); + assert_eq!(value_type.as_ref(), &DataType::Utf8); + } + + // Validate that the enum types are reused + assert_eq!(fields[0].data_type(), fields[1].data_type()); + if let DataType::List(array_field) = fields[2].data_type() { + assert_eq!(array_field.data_type(), fields[0].data_type()); + } + + // Validate data - should have 2 rows + assert_eq!(batch.num_rows(), 2); + assert_eq!(batch.num_columns(), 3); + + // Get status enum values + let status_col = batch + .column(0) + .as_any() + .downcast_ref::>() + .unwrap(); + let status_values = status_col + .values() + .as_any() + .downcast_ref::() + .unwrap(); + + // First row should be "ACTIVE", second row should be "PENDING" + assert_eq!( + status_values.value(status_col.key(0).unwrap() as usize), + "ACTIVE" + ); + assert_eq!( + status_values.value(status_col.key(1).unwrap() as usize), + "PENDING" + ); + + // Get backupStatus enum values (same as status) + let backup_status_col = batch + .column(1) + .as_any() + .downcast_ref::>() + .unwrap(); + let backup_status_values = backup_status_col + .values() + .as_any() + .downcast_ref::() + .unwrap(); + + // First row should be "INACTIVE", second row should be "ACTIVE" + assert_eq!( + backup_status_values.value(backup_status_col.key(0).unwrap() as usize), + "INACTIVE" + ); + assert_eq!( + backup_status_values.value(backup_status_col.key(1).unwrap() as usize), + "ACTIVE" + ); + + // Get statusHistory array + let status_history_col = batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(status_history_col.len(), 2); + + // Validate first row's array data + let first_array_dict = status_history_col.value(0); + let first_array_dict_array = first_array_dict + .as_any() + .downcast_ref::>() + .unwrap(); + let first_array_values = first_array_dict_array + .values() + .as_any() + .downcast_ref::() + .unwrap(); + + // First row: ["PENDING", "ACTIVE", "INACTIVE"] + assert_eq!(first_array_dict_array.len(), 3); + assert_eq!( + first_array_values.value(first_array_dict_array.key(0).unwrap() as usize), + "PENDING" + ); + assert_eq!( + first_array_values.value(first_array_dict_array.key(1).unwrap() as usize), + "ACTIVE" + ); + assert_eq!( + first_array_values.value(first_array_dict_array.key(2).unwrap() as usize), + "INACTIVE" + ); + } + + #[test] + fn comprehensive_e2e_test() { + let path = "test/data/comprehensive_e2e.avro"; + let batch = read_file(path, 1024, false); + let schema = batch.schema(); + + #[inline] + fn tid_by_name(fields: &UnionFields, want: &str) -> i8 { + for (tid, f) in fields.iter() { + if f.name() == want { + return tid; + } + } + panic!("union child '{want}' not found"); + } + + #[inline] + fn tid_by_dt(fields: &UnionFields, pred: impl Fn(&DataType) -> bool) -> i8 { + for (tid, f) in fields.iter() { + if pred(f.data_type()) { + return tid; + } + } + panic!("no union child matches predicate"); + } + + fn mk_dense_union( + fields: &UnionFields, + type_ids: Vec, + offsets: Vec, + provide: impl Fn(&Field) -> Option, + ) -> ArrayRef { + fn empty_child_for(dt: &DataType) -> Arc { + match dt { + DataType::Null => Arc::new(NullArray::new(0)), + DataType::Boolean => Arc::new(BooleanArray::from(Vec::::new())), + DataType::Int32 => Arc::new(Int32Array::from(Vec::::new())), + DataType::Int64 => Arc::new(Int64Array::from(Vec::::new())), + DataType::Float32 => Arc::new(Float32Array::from(Vec::::new())), + DataType::Float64 => Arc::new(Float64Array::from(Vec::::new())), + DataType::Binary => Arc::new(BinaryArray::from(Vec::<&[u8]>::new())), + DataType::Utf8 => Arc::new(StringArray::from(Vec::<&str>::new())), + DataType::Date32 => Arc::new(Date32Array::from(Vec::::new())), + DataType::Time32(arrow_schema::TimeUnit::Millisecond) => { + Arc::new(Time32MillisecondArray::from(Vec::::new())) + } + DataType::Time64(arrow_schema::TimeUnit::Microsecond) => { + Arc::new(Time64MicrosecondArray::from(Vec::::new())) + } + DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, tz) => { + let a = TimestampMillisecondArray::from(Vec::::new()); + Arc::new(if let Some(tz) = tz { + a.with_timezone(tz.clone()) + } else { + a + }) + } + DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, tz) => { + let a = TimestampMicrosecondArray::from(Vec::::new()); + Arc::new(if let Some(tz) = tz { + a.with_timezone(tz.clone()) + } else { + a + }) + } + DataType::Interval(IntervalUnit::MonthDayNano) => Arc::new( + IntervalMonthDayNanoArray::from(Vec::::new()), + ), + DataType::FixedSizeBinary(sz) => Arc::new( + FixedSizeBinaryArray::try_from_sparse_iter_with_size( + std::iter::empty::>>(), + *sz, + ) + .unwrap(), + ), + DataType::Dictionary(_, _) => { + let keys = Int32Array::from(Vec::::new()); + let values = Arc::new(StringArray::from(Vec::<&str>::new())); + Arc::new(DictionaryArray::::try_new(keys, values).unwrap()) + } + DataType::Struct(fields) => { + let children: Vec = fields + .iter() + .map(|f| empty_child_for(f.data_type()) as ArrayRef) + .collect(); + Arc::new(StructArray::new(fields.clone(), children, None)) + } + DataType::List(field) => { + let offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0])); + Arc::new( + ListArray::try_new( + field.clone(), + offsets, + empty_child_for(field.data_type()), + None, + ) + .unwrap(), + ) + } + DataType::Map(entry_field, is_sorted) => { + let (key_field, val_field) = match entry_field.data_type() { + DataType::Struct(fs) => (fs[0].clone(), fs[1].clone()), + other => panic!("unexpected map entries type: {other:?}"), + }; + let keys = StringArray::from(Vec::<&str>::new()); + let vals: ArrayRef = match val_field.data_type() { + DataType::Null => Arc::new(NullArray::new(0)) as ArrayRef, + DataType::Boolean => { + Arc::new(BooleanArray::from(Vec::::new())) as ArrayRef + } + DataType::Int32 => { + Arc::new(Int32Array::from(Vec::::new())) as ArrayRef + } + DataType::Int64 => { + Arc::new(Int64Array::from(Vec::::new())) as ArrayRef + } + DataType::Float32 => { + Arc::new(Float32Array::from(Vec::::new())) as ArrayRef + } + DataType::Float64 => { + Arc::new(Float64Array::from(Vec::::new())) as ArrayRef + } + DataType::Utf8 => { + Arc::new(StringArray::from(Vec::<&str>::new())) as ArrayRef + } + DataType::Binary => { + Arc::new(BinaryArray::from(Vec::<&[u8]>::new())) as ArrayRef + } + DataType::Union(uf, _) => { + let children: Vec = uf + .iter() + .map(|(_, f)| empty_child_for(f.data_type())) + .collect(); + Arc::new( + UnionArray::try_new( + uf.clone(), + ScalarBuffer::::from(Vec::::new()), + Some(ScalarBuffer::::from(Vec::::new())), + children, + ) + .unwrap(), + ) as ArrayRef + } + other => panic!("unsupported map value type: {other:?}"), + }; + let entries = StructArray::new( + Fields::from(vec![ + key_field.as_ref().clone(), + val_field.as_ref().clone(), + ]), + vec![Arc::new(keys) as ArrayRef, vals], + None, + ); + let offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0])); + Arc::new(MapArray::new( + entry_field.clone(), + offsets, + entries, + None, + *is_sorted, + )) + } + other => panic!("empty_child_for: unhandled type {other:?}"), + } + } + let children: Vec = fields + .iter() + .map(|(_, f)| provide(f).unwrap_or_else(|| empty_child_for(f.data_type()))) + .collect(); + Arc::new( + UnionArray::try_new( + fields.clone(), + ScalarBuffer::::from(type_ids), + Some(ScalarBuffer::::from(offsets)), + children, + ) + .unwrap(), + ) as ArrayRef + } + + #[inline] + fn uuid16_from_str(s: &str) -> [u8; 16] { + let mut out = [0u8; 16]; + let mut idx = 0usize; + let mut hi: Option = None; + for ch in s.chars() { + if ch == '-' { + continue; + } + let v = ch.to_digit(16).expect("invalid hex digit in UUID") as u8; + if let Some(h) = hi { + out[idx] = (h << 4) | v; + idx += 1; + hi = None; + } else { + hi = Some(v); + } + } + assert_eq!(idx, 16, "UUID must decode to 16 bytes"); + out + } + let date_a: i32 = 19_000; // 2022-01-08 + let time_ms_a: i32 = 12 * 3_600_000 + 34 * 60_000 + 56_000 + 789; + let time_us_eod: i64 = 86_400_000_000 - 1; + let ts_ms_2024_01_01: i64 = 1_704_067_200_000; // 2024-01-01T00:00:00Z + let ts_us_2024_01_01: i64 = ts_ms_2024_01_01 * 1_000; + let dur_small = IntervalMonthDayNanoType::make_value(1, 2, 3_000_000_000); + let dur_zero = IntervalMonthDayNanoType::make_value(0, 0, 0); + let dur_large = + IntervalMonthDayNanoType::make_value(12, 31, ((86_400_000 - 1) as i64) * 1_000_000); + let dur_2years = IntervalMonthDayNanoType::make_value(24, 0, 0); + let uuid1 = uuid16_from_str("fe7bc30b-4ce8-4c5e-b67c-2234a2d38e66"); + let uuid2 = uuid16_from_str("0826cc06-d2e3-4599-b4ad-af5fa6905cdb"); + + #[inline] + fn push_like( + reader_schema: &arrow_schema::Schema, + name: &str, + arr: ArrayRef, + fields: &mut Vec, + cols: &mut Vec, + ) { + let src = reader_schema + .field_with_name(name) + .unwrap_or_else(|_| panic!("source schema missing field '{name}'")); + let mut f = Field::new(name, arr.data_type().clone(), src.is_nullable()); + let md = src.metadata(); + if !md.is_empty() { + f = f.with_metadata(md.clone()); + } + fields.push(Arc::new(f)); + cols.push(arr); + } + + let mut fields: Vec = Vec::new(); + let mut columns: Vec = Vec::new(); + push_like( + schema.as_ref(), + "id", + Arc::new(Int64Array::from(vec![1, 2, 3, 4])) as ArrayRef, + &mut fields, + &mut columns, + ); + push_like( + schema.as_ref(), + "flag", + Arc::new(BooleanArray::from(vec![true, false, true, false])) as ArrayRef, + &mut fields, + &mut columns, + ); + push_like( + schema.as_ref(), + "ratio_f32", + Arc::new(Float32Array::from(vec![1.25f32, -0.0, 3.5, 9.75])) as ArrayRef, + &mut fields, + &mut columns, + ); + push_like( + schema.as_ref(), + "ratio_f64", + Arc::new(Float64Array::from(vec![2.5f64, -1.0, 7.0, -2.25])) as ArrayRef, + &mut fields, + &mut columns, + ); + push_like( + schema.as_ref(), + "count_i32", + Arc::new(Int32Array::from(vec![7, -1, 0, 123])) as ArrayRef, + &mut fields, + &mut columns, + ); + push_like( + schema.as_ref(), + "count_i64", + Arc::new(Int64Array::from(vec![ + 7_000_000_000i64, + -2, + 0, + -9_876_543_210i64, + ])) as ArrayRef, + &mut fields, + &mut columns, + ); + push_like( + schema.as_ref(), + "opt_i32_nullfirst", + Arc::new(Int32Array::from(vec![None, Some(42), None, Some(0)])) as ArrayRef, + &mut fields, + &mut columns, + ); + push_like( + schema.as_ref(), + "opt_str_nullsecond", + Arc::new(StringArray::from(vec![ + Some("alpha"), + None, + Some("s3"), + Some(""), + ])) as ArrayRef, + &mut fields, + &mut columns, + ); + { + let uf = match schema + .field_with_name("tri_union_prim") + .unwrap() + .data_type() + { + DataType::Union(f, UnionMode::Dense) => f.clone(), + other => panic!("tri_union_prim should be dense union, got {other:?}"), + }; + let tid_i = tid_by_name(&uf, "int"); + let tid_s = tid_by_name(&uf, "string"); + let tid_b = tid_by_name(&uf, "boolean"); + let tids = vec![tid_i, tid_s, tid_b, tid_s]; + let offs = vec![0, 0, 0, 1]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + DataType::Int32 => Some(Arc::new(Int32Array::from(vec![0])) as ArrayRef), + DataType::Utf8 => Some(Arc::new(StringArray::from(vec!["hi", ""])) as ArrayRef), + DataType::Boolean => Some(Arc::new(BooleanArray::from(vec![true])) as ArrayRef), + _ => None, + }); + push_like( + schema.as_ref(), + "tri_union_prim", + arr, + &mut fields, + &mut columns, + ); + } + + push_like( + schema.as_ref(), + "str_utf8", + Arc::new(StringArray::from(vec!["hello", "", "world", "✓ unicode"])) as ArrayRef, + &mut fields, + &mut columns, + ); + push_like( + schema.as_ref(), + "raw_bytes", + Arc::new(BinaryArray::from(vec![ + b"\x00\x01".as_ref(), + b"".as_ref(), + b"\xFF\x00".as_ref(), + b"\x10\x20\x30\x40".as_ref(), + ])) as ArrayRef, + &mut fields, + &mut columns, + ); + { + let it = [ + Some(*b"0123456789ABCDEF"), + Some([0u8; 16]), + Some(*b"ABCDEFGHIJKLMNOP"), + Some([0xAA; 16]), + ] + .into_iter(); + let arr = + Arc::new(FixedSizeBinaryArray::try_from_sparse_iter_with_size(it, 16).unwrap()) + as ArrayRef; + push_like( + schema.as_ref(), + "fx16_plain", + arr, + &mut fields, + &mut columns, + ); + } + { + #[cfg(feature = "small_decimals")] + let dec10_2 = Arc::new( + Decimal64Array::from_iter_values([123456i64, -1, 0, 9_999_999_999i64]) + .with_precision_and_scale(10, 2) + .unwrap(), + ) as ArrayRef; + #[cfg(not(feature = "small_decimals"))] + let dec10_2 = Arc::new( + Decimal128Array::from_iter_values([123456i128, -1, 0, 9_999_999_999i128]) + .with_precision_and_scale(10, 2) + .unwrap(), + ) as ArrayRef; + push_like( + schema.as_ref(), + "dec_bytes_s10_2", + dec10_2, + &mut fields, + &mut columns, + ); + } + { + #[cfg(feature = "small_decimals")] + let dec20_4 = Arc::new( + Decimal128Array::from_iter_values([1_234_567_891_234i128, -420_000i128, 0, -1i128]) + .with_precision_and_scale(20, 4) + .unwrap(), + ) as ArrayRef; + #[cfg(not(feature = "small_decimals"))] + let dec20_4 = Arc::new( + Decimal128Array::from_iter_values([1_234_567_891_234i128, -420_000i128, 0, -1i128]) + .with_precision_and_scale(20, 4) + .unwrap(), + ) as ArrayRef; + push_like( + schema.as_ref(), + "dec_fix_s20_4", + dec20_4, + &mut fields, + &mut columns, + ); + } + { + let it = [Some(uuid1), Some(uuid2), Some(uuid1), Some(uuid2)].into_iter(); + let arr = + Arc::new(FixedSizeBinaryArray::try_from_sparse_iter_with_size(it, 16).unwrap()) + as ArrayRef; + push_like(schema.as_ref(), "uuid_str", arr, &mut fields, &mut columns); + } + push_like( + schema.as_ref(), + "d_date", + Arc::new(Date32Array::from(vec![date_a, 0, 1, 365])) as ArrayRef, + &mut fields, + &mut columns, + ); + push_like( + schema.as_ref(), + "t_millis", + Arc::new(Time32MillisecondArray::from(vec![ + time_ms_a, + 0, + 1, + 86_400_000 - 1, + ])) as ArrayRef, + &mut fields, + &mut columns, + ); + push_like( + schema.as_ref(), + "t_micros", + Arc::new(Time64MicrosecondArray::from(vec![ + time_us_eod, + 0, + 1, + 1_000_000, + ])) as ArrayRef, + &mut fields, + &mut columns, + ); + { + let a = TimestampMillisecondArray::from(vec![ + ts_ms_2024_01_01, + -1, + ts_ms_2024_01_01 + 123, + 0, + ]) + .with_timezone("+00:00"); + push_like( + schema.as_ref(), + "ts_millis_utc", + Arc::new(a) as ArrayRef, + &mut fields, + &mut columns, + ); + } + { + let a = TimestampMicrosecondArray::from(vec![ + ts_us_2024_01_01, + 1, + ts_us_2024_01_01 + 456, + 0, + ]) + .with_timezone("+00:00"); + push_like( + schema.as_ref(), + "ts_micros_utc", + Arc::new(a) as ArrayRef, + &mut fields, + &mut columns, + ); + } + push_like( + schema.as_ref(), + "ts_millis_local", + Arc::new(TimestampMillisecondArray::from(vec![ + ts_ms_2024_01_01 + 86_400_000, + 0, + ts_ms_2024_01_01 + 789, + 123_456_789, + ])) as ArrayRef, + &mut fields, + &mut columns, + ); + push_like( + schema.as_ref(), + "ts_micros_local", + Arc::new(TimestampMicrosecondArray::from(vec![ + ts_us_2024_01_01 + 123_456, + 0, + ts_us_2024_01_01 + 101_112, + 987_654_321, + ])) as ArrayRef, + &mut fields, + &mut columns, + ); + { + let v = vec![dur_small, dur_zero, dur_large, dur_2years]; + push_like( + schema.as_ref(), + "interval_mdn", + Arc::new(IntervalMonthDayNanoArray::from(v)) as ArrayRef, + &mut fields, + &mut columns, + ); + } + { + let keys = Int32Array::from(vec![1, 2, 3, 0]); // NEW, PROCESSING, DONE, UNKNOWN + let values = Arc::new(StringArray::from(vec![ + "UNKNOWN", + "NEW", + "PROCESSING", + "DONE", + ])) as ArrayRef; + let dict = DictionaryArray::::try_new(keys, values).unwrap(); + push_like( + schema.as_ref(), + "status", + Arc::new(dict) as ArrayRef, + &mut fields, + &mut columns, + ); + } + { + let list_field = match schema.field_with_name("arr_union").unwrap().data_type() { + DataType::List(f) => f.clone(), + other => panic!("arr_union should be List, got {other:?}"), + }; + let uf = match list_field.data_type() { + DataType::Union(f, UnionMode::Dense) => f.clone(), + other => panic!("arr_union item should be union, got {other:?}"), + }; + let tid_l = tid_by_name(&uf, "long"); + let tid_s = tid_by_name(&uf, "string"); + let tid_n = tid_by_name(&uf, "null"); + let type_ids = vec![ + tid_l, tid_s, tid_n, tid_l, tid_n, tid_s, tid_l, tid_l, tid_s, tid_n, tid_l, + ]; + let offsets = vec![0, 0, 0, 1, 1, 1, 2, 3, 2, 2, 4]; + let values = mk_dense_union(&uf, type_ids, offsets, |f| match f.data_type() { + DataType::Int64 => { + Some(Arc::new(Int64Array::from(vec![1i64, -3, 0, -1, 0])) as ArrayRef) + } + DataType::Utf8 => { + Some(Arc::new(StringArray::from(vec!["x", "z", "end"])) as ArrayRef) + } + DataType::Null => Some(Arc::new(NullArray::new(3)) as ArrayRef), + _ => None, + }); + let list_offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0, 4, 7, 8, 11])); + let arr = Arc::new(ListArray::try_new(list_field, list_offsets, values, None).unwrap()) + as ArrayRef; + push_like(schema.as_ref(), "arr_union", arr, &mut fields, &mut columns); + } + { + let (entry_field, entries_fields, uf, is_sorted) = + match schema.field_with_name("map_union").unwrap().data_type() { + DataType::Map(entry_field, is_sorted) => { + let fs = match entry_field.data_type() { + DataType::Struct(fs) => fs.clone(), + other => panic!("map entries must be struct, got {other:?}"), + }; + let val_f = fs[1].clone(); + let uf = match val_f.data_type() { + DataType::Union(f, UnionMode::Dense) => f.clone(), + other => panic!("map value must be union, got {other:?}"), + }; + (entry_field.clone(), fs, uf, *is_sorted) + } + other => panic!("map_union should be Map, got {other:?}"), + }; + let keys = StringArray::from(vec!["a", "b", "c", "neg", "pi", "ok"]); + let moff = OffsetBuffer::new(ScalarBuffer::::from(vec![0, 3, 4, 4, 6])); + let tid_null = tid_by_name(&uf, "null"); + let tid_d = tid_by_name(&uf, "double"); + let tid_s = tid_by_name(&uf, "string"); + let type_ids = vec![tid_d, tid_null, tid_s, tid_d, tid_d, tid_s]; + let offsets = vec![0, 0, 0, 1, 2, 1]; + let pi_5dp = (std::f64::consts::PI * 100_000.0).trunc() / 100_000.0; + let vals = mk_dense_union(&uf, type_ids, offsets, |f| match f.data_type() { + DataType::Float64 => { + Some(Arc::new(Float64Array::from(vec![1.5f64, -0.5, pi_5dp])) as ArrayRef) + } + DataType::Utf8 => { + Some(Arc::new(StringArray::from(vec!["yes", "true"])) as ArrayRef) + } + DataType::Null => Some(Arc::new(NullArray::new(2)) as ArrayRef), + _ => None, + }); + let entries = StructArray::new( + entries_fields.clone(), + vec![Arc::new(keys) as ArrayRef, vals], + None, + ); + let map = + Arc::new(MapArray::new(entry_field, moff, entries, None, is_sorted)) as ArrayRef; + push_like(schema.as_ref(), "map_union", map, &mut fields, &mut columns); + } + { + let fs = match schema.field_with_name("address").unwrap().data_type() { + DataType::Struct(fs) => fs.clone(), + other => panic!("address should be Struct, got {other:?}"), + }; + let street = Arc::new(StringArray::from(vec![ + "100 Main", + "", + "42 Galaxy Way", + "End Ave", + ])) as ArrayRef; + let zip = Arc::new(Int32Array::from(vec![12345, 0, 42424, 1])) as ArrayRef; + let country = Arc::new(StringArray::from(vec!["US", "CA", "US", "GB"])) as ArrayRef; + let arr = Arc::new(StructArray::new(fs, vec![street, zip, country], None)) as ArrayRef; + push_like(schema.as_ref(), "address", arr, &mut fields, &mut columns); + } + { + let fs = match schema.field_with_name("maybe_auth").unwrap().data_type() { + DataType::Struct(fs) => fs.clone(), + other => panic!("maybe_auth should be Struct, got {other:?}"), + }; + let user = + Arc::new(StringArray::from(vec!["alice", "bob", "carol", "dave"])) as ArrayRef; + let token_values: Vec> = vec![ + None, // row 1: null + Some(b"\x01\x02\x03".as_ref()), // row 2: bytes + None, // row 3: null + Some(b"".as_ref()), // row 4: empty bytes + ]; + let token = Arc::new(BinaryArray::from(token_values)) as ArrayRef; + let arr = Arc::new(StructArray::new(fs, vec![user, token], None)) as ArrayRef; + push_like( + schema.as_ref(), + "maybe_auth", + arr, + &mut fields, + &mut columns, + ); + } + { + let uf = match schema + .field_with_name("union_enum_record_array_map") + .unwrap() + .data_type() + { + DataType::Union(f, UnionMode::Dense) => f.clone(), + other => panic!("union_enum_record_array_map should be union, got {other:?}"), + }; + let mut tid_enum: Option = None; + let mut tid_rec_a: Option = None; + let mut tid_array: Option = None; + let mut tid_map: Option = None; + let mut map_entry_field: Option = None; + let mut map_sorted: bool = false; + for (tid, f) in uf.iter() { + match f.data_type() { + DataType::Dictionary(_, _) => tid_enum = Some(tid), + DataType::Struct(childs) + if childs.len() == 2 + && childs[0].name() == "a" + && childs[1].name() == "b" => + { + tid_rec_a = Some(tid) + } + DataType::List(item) if matches!(item.data_type(), DataType::Int64) => { + tid_array = Some(tid) + } + DataType::Map(ef, is_sorted) => { + tid_map = Some(tid); + map_entry_field = Some(ef.clone()); + map_sorted = *is_sorted; + } + _ => {} + } + } + let (tid_enum, tid_rec_a, tid_array, tid_map) = ( + tid_enum.unwrap(), + tid_rec_a.unwrap(), + tid_array.unwrap(), + tid_map.unwrap(), + ); + let tids = vec![tid_enum, tid_rec_a, tid_array, tid_map]; + let offs = vec![0, 0, 0, 0]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + DataType::Dictionary(_, _) => { + let keys = Int32Array::from(vec![0i32]); + let values = + Arc::new(StringArray::from(vec!["RED", "GREEN", "BLUE"])) as ArrayRef; + Some( + Arc::new(DictionaryArray::::try_new(keys, values).unwrap()) + as ArrayRef, + ) + } + DataType::Struct(fs) + if fs.len() == 2 && fs[0].name() == "a" && fs[1].name() == "b" => + { + let a = Int32Array::from(vec![7]); + let b = StringArray::from(vec!["rec"]); + Some(Arc::new(StructArray::new( + fs.clone(), + vec![Arc::new(a), Arc::new(b)], + None, + )) as ArrayRef) + } + DataType::List(field) => { + let values = Int64Array::from(vec![1i64, 2, 3]); + let offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0, 3])); + Some(Arc::new( + ListArray::try_new(field.clone(), offsets, Arc::new(values), None).unwrap(), + ) as ArrayRef) + } + DataType::Map(_, _) => { + let entry_field = map_entry_field.clone().unwrap(); + let (key_field, val_field) = match entry_field.data_type() { + DataType::Struct(fs) => (fs[0].clone(), fs[1].clone()), + _ => unreachable!(), + }; + let keys = StringArray::from(vec!["k"]); + let vals = StringArray::from(vec!["v"]); + let entries = StructArray::new( + Fields::from(vec![key_field.as_ref().clone(), val_field.as_ref().clone()]), + vec![Arc::new(keys) as ArrayRef, Arc::new(vals) as ArrayRef], + None, + ); + let offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0, 1])); + Some(Arc::new(MapArray::new( + entry_field.clone(), + offsets, + entries, + None, + map_sorted, + )) as ArrayRef) + } + _ => None, + }); + push_like( + schema.as_ref(), + "union_enum_record_array_map", + arr, + &mut fields, + &mut columns, + ); + } + { + let uf = match schema + .field_with_name("union_date_or_fixed4") + .unwrap() + .data_type() + { + DataType::Union(f, UnionMode::Dense) => f.clone(), + other => panic!("union_date_or_fixed4 should be union, got {other:?}"), + }; + let tid_date = tid_by_dt(&uf, |dt| matches!(dt, DataType::Date32)); + let tid_fx4 = tid_by_dt(&uf, |dt| matches!(dt, DataType::FixedSizeBinary(4))); + let tids = vec![tid_date, tid_fx4, tid_date, tid_fx4]; + let offs = vec![0, 0, 1, 1]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + DataType::Date32 => Some(Arc::new(Date32Array::from(vec![date_a, 0])) as ArrayRef), + DataType::FixedSizeBinary(4) => { + let it = [Some(*b"\x00\x11\x22\x33"), Some(*b"ABCD")].into_iter(); + Some(Arc::new( + FixedSizeBinaryArray::try_from_sparse_iter_with_size(it, 4).unwrap(), + ) as ArrayRef) + } + _ => None, + }); + push_like( + schema.as_ref(), + "union_date_or_fixed4", + arr, + &mut fields, + &mut columns, + ); + } + { + let uf = match schema + .field_with_name("union_interval_or_string") + .unwrap() + .data_type() + { + DataType::Union(f, UnionMode::Dense) => f.clone(), + other => panic!("union_interval_or_string should be union, got {other:?}"), + }; + let tid_dur = tid_by_dt(&uf, |dt| { + matches!(dt, DataType::Interval(IntervalUnit::MonthDayNano)) + }); + let tid_str = tid_by_dt(&uf, |dt| matches!(dt, DataType::Utf8)); + let tids = vec![tid_dur, tid_str, tid_dur, tid_str]; + let offs = vec![0, 0, 1, 1]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + DataType::Interval(IntervalUnit::MonthDayNano) => Some(Arc::new( + IntervalMonthDayNanoArray::from(vec![dur_small, dur_large]), + ) + as ArrayRef), + DataType::Utf8 => Some(Arc::new(StringArray::from(vec![ + "duration-as-text", + "iso-8601-period-P1Y", + ])) as ArrayRef), + _ => None, + }); + push_like( + schema.as_ref(), + "union_interval_or_string", + arr, + &mut fields, + &mut columns, + ); + } + { + let uf = match schema + .field_with_name("union_uuid_or_fixed10") + .unwrap() + .data_type() + { + DataType::Union(f, UnionMode::Dense) => f.clone(), + other => panic!("union_uuid_or_fixed10 should be union, got {other:?}"), + }; + let tid_uuid = tid_by_dt(&uf, |dt| matches!(dt, DataType::FixedSizeBinary(16))); + let tid_fx10 = tid_by_dt(&uf, |dt| matches!(dt, DataType::FixedSizeBinary(10))); + let tids = vec![tid_uuid, tid_fx10, tid_uuid, tid_fx10]; + let offs = vec![0, 0, 1, 1]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + DataType::FixedSizeBinary(16) => { + let it = [Some(uuid1), Some(uuid2)].into_iter(); + Some(Arc::new( + FixedSizeBinaryArray::try_from_sparse_iter_with_size(it, 16).unwrap(), + ) as ArrayRef) + } + DataType::FixedSizeBinary(10) => { + let fx10_a = [0xAAu8; 10]; + let fx10_b = [0x00u8, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99]; + let it = [Some(fx10_a), Some(fx10_b)].into_iter(); + Some(Arc::new( + FixedSizeBinaryArray::try_from_sparse_iter_with_size(it, 10).unwrap(), + ) as ArrayRef) + } + _ => None, + }); + push_like( + schema.as_ref(), + "union_uuid_or_fixed10", + arr, + &mut fields, + &mut columns, + ); + } + { + let list_field = match schema + .field_with_name("array_records_with_union") + .unwrap() + .data_type() + { + DataType::List(f) => f.clone(), + other => panic!("array_records_with_union should be List, got {other:?}"), + }; + let kv_fields = match list_field.data_type() { + DataType::Struct(fs) => fs.clone(), + other => panic!("array_records_with_union items must be Struct, got {other:?}"), + }; + let val_field = kv_fields + .iter() + .find(|f| f.name() == "val") + .unwrap() + .clone(); + let uf = match val_field.data_type() { + DataType::Union(f, UnionMode::Dense) => f.clone(), + other => panic!("KV.val should be union, got {other:?}"), + }; + let keys = Arc::new(StringArray::from(vec!["k1", "k2", "k", "k3", "x"])) as ArrayRef; + let tid_null = tid_by_name(&uf, "null"); + let tid_i = tid_by_name(&uf, "int"); + let tid_l = tid_by_name(&uf, "long"); + let type_ids = vec![tid_i, tid_null, tid_l, tid_null, tid_i]; + let offsets = vec![0, 0, 0, 1, 1]; + let vals = mk_dense_union(&uf, type_ids, offsets, |f| match f.data_type() { + DataType::Int32 => Some(Arc::new(Int32Array::from(vec![5, -5])) as ArrayRef), + DataType::Int64 => Some(Arc::new(Int64Array::from(vec![99i64])) as ArrayRef), + DataType::Null => Some(Arc::new(NullArray::new(2)) as ArrayRef), + _ => None, + }); + let values_struct = + Arc::new(StructArray::new(kv_fields.clone(), vec![keys, vals], None)) as ArrayRef; + let list_offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0, 2, 3, 4, 5])); + let arr = Arc::new( + ListArray::try_new(list_field, list_offsets, values_struct, None).unwrap(), + ) as ArrayRef; + push_like( + schema.as_ref(), + "array_records_with_union", + arr, + &mut fields, + &mut columns, + ); + } + { + let uf = match schema + .field_with_name("union_map_or_array_int") + .unwrap() + .data_type() + { + DataType::Union(f, UnionMode::Dense) => f.clone(), + other => panic!("union_map_or_array_int should be union, got {other:?}"), + }; + let tid_map = tid_by_dt(&uf, |dt| matches!(dt, DataType::Map(_, _))); + let tid_list = tid_by_dt(&uf, |dt| matches!(dt, DataType::List(_))); + let map_child: ArrayRef = { + let (entry_field, is_sorted) = match uf + .iter() + .find(|(tid, _)| *tid == tid_map) + .unwrap() + .1 + .data_type() + { + DataType::Map(ef, is_sorted) => (ef.clone(), *is_sorted), + _ => unreachable!(), + }; + let (key_field, val_field) = match entry_field.data_type() { + DataType::Struct(fs) => (fs[0].clone(), fs[1].clone()), + _ => unreachable!(), + }; + let keys = StringArray::from(vec!["x", "y", "only"]); + let vals = Int32Array::from(vec![1, 2, 10]); + let entries = StructArray::new( + Fields::from(vec![key_field.as_ref().clone(), val_field.as_ref().clone()]), + vec![Arc::new(keys) as ArrayRef, Arc::new(vals) as ArrayRef], + None, + ); + let moff = OffsetBuffer::new(ScalarBuffer::::from(vec![0, 2, 3])); + Arc::new(MapArray::new(entry_field, moff, entries, None, is_sorted)) as ArrayRef + }; + let list_child: ArrayRef = { + let list_field = match uf + .iter() + .find(|(tid, _)| *tid == tid_list) + .unwrap() + .1 + .data_type() + { + DataType::List(f) => f.clone(), + _ => unreachable!(), + }; + let values = Int32Array::from(vec![1, 2, 3, 0]); + let offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0, 3, 4])); + Arc::new(ListArray::try_new(list_field, offsets, Arc::new(values), None).unwrap()) + as ArrayRef + }; + let tids = vec![tid_map, tid_list, tid_map, tid_list]; + let offs = vec![0, 0, 1, 1]; + let arr = mk_dense_union(&uf, tids, offs, |f| match f.data_type() { + DataType::Map(_, _) => Some(map_child.clone()), + DataType::List(_) => Some(list_child.clone()), + _ => None, + }); + push_like( + schema.as_ref(), + "union_map_or_array_int", + arr, + &mut fields, + &mut columns, + ); + } + push_like( + schema.as_ref(), + "renamed_with_default", + Arc::new(Int32Array::from(vec![100, 42, 7, 42])) as ArrayRef, + &mut fields, + &mut columns, + ); + { + let fs = match schema.field_with_name("person").unwrap().data_type() { + DataType::Struct(fs) => fs.clone(), + other => panic!("person should be Struct, got {other:?}"), + }; + let name = + Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol", "Dave"])) as ArrayRef; + let age = Arc::new(Int32Array::from(vec![30, 0, 25, 41])) as ArrayRef; + let arr = Arc::new(StructArray::new(fs, vec![name, age], None)) as ArrayRef; + push_like(schema.as_ref(), "person", arr, &mut fields, &mut columns); + } + let expected = + RecordBatch::try_new(Arc::new(Schema::new(Fields::from(fields))), columns).unwrap(); + assert_eq!( + expected, batch, + "entire RecordBatch mismatch (schema, all columns, all rows)" + ); + } + #[test] + fn comprehensive_e2e_resolution_test() { + use serde_json::Value; + use std::collections::HashMap; + + // Build a reader schema that stresses Avro schema‑resolution + // + // Changes relative to writer schema: + // * Rename fields using writer aliases: id -> identifier, renamed_with_default -> old_count + // * Promote numeric types: count_i32 (int) -> long, ratio_f32 (float) -> double + // * Reorder many union branches (reverse), incl. nested unions + // * Reorder array/map union item/value branches + // * Rename nested Address field: street -> street_name (uses alias in writer) + // * Change Person type name/namespace: com.example.Person (matches writer alias) + // * Reverse top‑level field order + // + // Reader‑side aliases are added wherever names change (per Avro spec). + fn make_comprehensive_reader_schema(path: &str) -> AvroSchema { + fn set_type_string(f: &mut Value, new_ty: &str) { + if let Some(ty) = f.get_mut("type") { + match ty { + Value::String(_) | Value::Object(_) => { + *ty = Value::String(new_ty.to_string()); + } + Value::Array(arr) => { + for b in arr.iter_mut() { + match b { + Value::String(s) if s != "null" => { + *b = Value::String(new_ty.to_string()); + break; + } + Value::Object(_) => { + *b = Value::String(new_ty.to_string()); + break; + } + _ => {} + } + } + } + _ => {} + } + } + } + fn reverse_union_array(f: &mut Value) { + if let Some(arr) = f.get_mut("type").and_then(|t| t.as_array_mut()) { + arr.reverse(); + } + } + fn reverse_items_union(f: &mut Value) { + if let Some(obj) = f.get_mut("type").and_then(|t| t.as_object_mut()) { + if let Some(items) = obj.get_mut("items").and_then(|v| v.as_array_mut()) { + items.reverse(); + } + } + } + fn reverse_map_values_union(f: &mut Value) { + if let Some(obj) = f.get_mut("type").and_then(|t| t.as_object_mut()) { + if let Some(values) = obj.get_mut("values").and_then(|v| v.as_array_mut()) { + values.reverse(); + } + } + } + fn reverse_nested_union_in_record(f: &mut Value, field_name: &str) { + if let Some(obj) = f.get_mut("type").and_then(|t| t.as_object_mut()) { + if let Some(fields) = obj.get_mut("fields").and_then(|v| v.as_array_mut()) { + for ff in fields.iter_mut() { + if ff.get("name").and_then(|n| n.as_str()) == Some(field_name) { + if let Some(ty) = ff.get_mut("type") { + if let Some(arr) = ty.as_array_mut() { + arr.reverse(); + } + } + } + } + } + } + } + fn rename_nested_field_with_alias(f: &mut Value, old: &str, new: &str) { + if let Some(obj) = f.get_mut("type").and_then(|t| t.as_object_mut()) { + if let Some(fields) = obj.get_mut("fields").and_then(|v| v.as_array_mut()) { + for ff in fields.iter_mut() { + if ff.get("name").and_then(|n| n.as_str()) == Some(old) { + ff["name"] = Value::String(new.to_string()); + ff["aliases"] = Value::Array(vec![Value::String(old.to_string())]); + } + } + } + } + } + let mut root = load_writer_schema_json(path); + assert_eq!(root["type"], "record", "writer schema must be a record"); + let fields = root + .get_mut("fields") + .and_then(|f| f.as_array_mut()) + .expect("record has fields"); + for f in fields.iter_mut() { + let Some(name) = f.get("name").and_then(|n| n.as_str()) else { + continue; + }; + match name { + // Field aliasing (reader‑side aliases added) + "id" => { + f["name"] = Value::String("identifier".into()); + f["aliases"] = Value::Array(vec![Value::String("id".into())]); + } + "renamed_with_default" => { + f["name"] = Value::String("old_count".into()); + f["aliases"] = + Value::Array(vec![Value::String("renamed_with_default".into())]); + } + // Promotions + "count_i32" => set_type_string(f, "long"), + "ratio_f32" => set_type_string(f, "double"), + // Union reorder (exercise resolution) + "opt_str_nullsecond" => reverse_union_array(f), + "union_enum_record_array_map" => reverse_union_array(f), + "union_date_or_fixed4" => reverse_union_array(f), + "union_interval_or_string" => reverse_union_array(f), + "union_uuid_or_fixed10" => reverse_union_array(f), + "union_map_or_array_int" => reverse_union_array(f), + "maybe_auth" => reverse_nested_union_in_record(f, "token"), + // Array/Map unions + "arr_union" => reverse_items_union(f), + "map_union" => reverse_map_values_union(f), + // Nested rename using reader‑side alias + "address" => rename_nested_field_with_alias(f, "street", "street_name"), + // Type‑name alias for nested record + "person" => { + if let Some(tobj) = f.get_mut("type").and_then(|t| t.as_object_mut()) { + tobj.insert("name".to_string(), Value::String("Person".into())); + tobj.insert( + "namespace".to_string(), + Value::String("com.example".into()), + ); + tobj.insert( + "aliases".into(), + Value::Array(vec![ + Value::String("PersonV2".into()), + Value::String("com.example.v2.PersonV2".into()), + ]), + ); + } + } + _ => {} + } + } + fields.reverse(); + AvroSchema::new(root.to_string()) + } + + let path = "test/data/comprehensive_e2e.avro"; + let reader_schema = make_comprehensive_reader_schema(path); + let batch = read_alltypes_with_reader_schema(path, reader_schema.clone()); + + const UUID_EXT_KEY: &str = "ARROW:extension:name"; + const UUID_LOGICAL_KEY: &str = "logicalType"; + + let uuid_md_top: Option> = batch + .schema() + .field_with_name("uuid_str") + .ok() + .and_then(|f| { + let md = f.metadata(); + let has_ext = md.get(UUID_EXT_KEY).is_some(); + let is_uuid_logical = md + .get(UUID_LOGICAL_KEY) + .map(|v| v.trim_matches('"') == "uuid") + .unwrap_or(false); + if has_ext || is_uuid_logical { + Some(md.clone()) + } else { + None + } + }); + + let uuid_md_union: Option> = batch + .schema() + .field_with_name("union_uuid_or_fixed10") + .ok() + .and_then(|f| match f.data_type() { + DataType::Union(uf, _) => uf + .iter() + .find(|(_, child)| child.name() == "uuid") + .and_then(|(_, child)| { + let md = child.metadata(); + let has_ext = md.get(UUID_EXT_KEY).is_some(); + let is_uuid_logical = md + .get(UUID_LOGICAL_KEY) + .map(|v| v.trim_matches('"') == "uuid") + .unwrap_or(false); + if has_ext || is_uuid_logical { + Some(md.clone()) + } else { + None + } + }), + _ => None, + }); + + let add_uuid_ext_top = |f: Field| -> Field { + if let Some(md) = &uuid_md_top { + f.with_metadata(md.clone()) + } else { + f + } + }; + let add_uuid_ext_union = |f: Field| -> Field { + if let Some(md) = &uuid_md_union { + f.with_metadata(md.clone()) + } else { + f + } + }; + + #[inline] + fn uuid16_from_str(s: &str) -> [u8; 16] { + let mut out = [0u8; 16]; + let mut idx = 0usize; + let mut hi: Option = None; + for ch in s.chars() { + if ch == '-' { + continue; + } + let v = ch.to_digit(16).expect("invalid hex digit in UUID") as u8; + if let Some(h) = hi { + out[idx] = (h << 4) | v; + idx += 1; + hi = None; + } else { + hi = Some(v); + } + } + assert_eq!(idx, 16, "UUID must decode to 16 bytes"); + out + } + + fn mk_dense_union( + fields: &UnionFields, + type_ids: Vec, + offsets: Vec, + provide: impl Fn(&Field) -> Option, + ) -> ArrayRef { + fn empty_child_for(dt: &DataType) -> Arc { + match dt { + DataType::Null => Arc::new(NullArray::new(0)), + DataType::Boolean => Arc::new(BooleanArray::from(Vec::::new())), + DataType::Int32 => Arc::new(Int32Array::from(Vec::::new())), + DataType::Int64 => Arc::new(Int64Array::from(Vec::::new())), + DataType::Float32 => Arc::new(Float32Array::from(Vec::::new())), + DataType::Float64 => Arc::new(Float64Array::from(Vec::::new())), + DataType::Binary => Arc::new(BinaryArray::from(Vec::<&[u8]>::new())), + DataType::Utf8 => Arc::new(StringArray::from(Vec::<&str>::new())), + DataType::Date32 => Arc::new(Date32Array::from(Vec::::new())), + DataType::Time32(arrow_schema::TimeUnit::Millisecond) => { + Arc::new(Time32MillisecondArray::from(Vec::::new())) + } + DataType::Time64(arrow_schema::TimeUnit::Microsecond) => { + Arc::new(Time64MicrosecondArray::from(Vec::::new())) + } + DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, tz) => { + let a = TimestampMillisecondArray::from(Vec::::new()); + Arc::new(if let Some(tz) = tz { + a.with_timezone(tz.clone()) + } else { + a + }) + } + DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, tz) => { + let a = TimestampMicrosecondArray::from(Vec::::new()); + Arc::new(if let Some(tz) = tz { + a.with_timezone(tz.clone()) + } else { + a + }) + } + DataType::Interval(IntervalUnit::MonthDayNano) => Arc::new( + IntervalMonthDayNanoArray::from(Vec::::new()), + ), + DataType::FixedSizeBinary(sz) => Arc::new( + FixedSizeBinaryArray::try_from_sparse_iter_with_size( + std::iter::empty::>>(), + *sz, + ) + .unwrap(), + ), + DataType::Dictionary(_, _) => { + let keys = Int32Array::from(Vec::::new()); + let values = Arc::new(StringArray::from(Vec::<&str>::new())); + Arc::new(DictionaryArray::::try_new(keys, values).unwrap()) + } + DataType::Struct(fields) => { + let children: Vec = fields + .iter() + .map(|f| empty_child_for(f.data_type()) as ArrayRef) + .collect(); + Arc::new(StructArray::new(fields.clone(), children, None)) + } + DataType::List(field) => { + let offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0])); + Arc::new( + ListArray::try_new( + field.clone(), + offsets, + empty_child_for(field.data_type()), + None, + ) + .unwrap(), + ) + } + DataType::Map(entry_field, is_sorted) => { + let (key_field, val_field) = match entry_field.data_type() { + DataType::Struct(fs) => (fs[0].clone(), fs[1].clone()), + other => panic!("unexpected map entries type: {other:?}"), + }; + let keys = StringArray::from(Vec::<&str>::new()); + let vals: ArrayRef = match val_field.data_type() { + DataType::Null => Arc::new(NullArray::new(0)) as ArrayRef, + DataType::Boolean => { + Arc::new(BooleanArray::from(Vec::::new())) as ArrayRef + } + DataType::Int32 => { + Arc::new(Int32Array::from(Vec::::new())) as ArrayRef + } + DataType::Int64 => { + Arc::new(Int64Array::from(Vec::::new())) as ArrayRef + } + DataType::Float32 => { + Arc::new(Float32Array::from(Vec::::new())) as ArrayRef + } + DataType::Float64 => { + Arc::new(Float64Array::from(Vec::::new())) as ArrayRef + } + DataType::Utf8 => { + Arc::new(StringArray::from(Vec::<&str>::new())) as ArrayRef + } + DataType::Binary => { + Arc::new(BinaryArray::from(Vec::<&[u8]>::new())) as ArrayRef + } + DataType::Union(uf, _) => { + let children: Vec = uf + .iter() + .map(|(_, f)| empty_child_for(f.data_type())) + .collect(); + Arc::new( + UnionArray::try_new( + uf.clone(), + ScalarBuffer::::from(Vec::::new()), + Some(ScalarBuffer::::from(Vec::::new())), + children, + ) + .unwrap(), + ) as ArrayRef + } + other => panic!("unsupported map value type: {other:?}"), + }; + let entries = StructArray::new( + Fields::from(vec![ + key_field.as_ref().clone(), + val_field.as_ref().clone(), + ]), + vec![Arc::new(keys) as ArrayRef, vals], + None, + ); + let offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0])); + Arc::new(MapArray::new( + entry_field.clone(), + offsets, + entries, + None, + *is_sorted, + )) + } + other => panic!("empty_child_for: unhandled type {other:?}"), + } + } + let children: Vec = fields + .iter() + .map(|(_, f)| provide(f).unwrap_or_else(|| empty_child_for(f.data_type()))) + .collect(); + Arc::new( + UnionArray::try_new( + fields.clone(), + ScalarBuffer::::from(type_ids), + Some(ScalarBuffer::::from(offsets)), + children, + ) + .unwrap(), + ) as ArrayRef + } + let date_a: i32 = 19_000; // 2022-01-08 + let time_ms_a: i32 = 12 * 3_600_000 + 34 * 60_000 + 56_000 + 789; + let time_us_eod: i64 = 86_400_000_000 - 1; + let ts_ms_2024_01_01: i64 = 1_704_067_200_000; // 2024-01-01T00:00:00Z + let ts_us_2024_01_01: i64 = ts_ms_2024_01_01 * 1_000; + let dur_small = IntervalMonthDayNanoType::make_value(1, 2, 3_000_000_000); + let dur_zero = IntervalMonthDayNanoType::make_value(0, 0, 0); + let dur_large = + IntervalMonthDayNanoType::make_value(12, 31, ((86_400_000 - 1) as i64) * 1_000_000); + let dur_2years = IntervalMonthDayNanoType::make_value(24, 0, 0); + let uuid1 = uuid16_from_str("fe7bc30b-4ce8-4c5e-b67c-2234a2d38e66"); + let uuid2 = uuid16_from_str("0826cc06-d2e3-4599-b4ad-af5fa6905cdb"); + let item_name = Field::LIST_FIELD_DEFAULT_NAME; + let uf_tri = UnionFields::try_new( + vec![0, 1, 2], + vec![ + Field::new("int", DataType::Int32, false), + Field::new("string", DataType::Utf8, false), + Field::new("boolean", DataType::Boolean, false), + ], + ) + .unwrap(); + let uf_arr_items = UnionFields::try_new( + vec![0, 1, 2], + vec![ + Field::new("null", DataType::Null, false), + Field::new("string", DataType::Utf8, false), + Field::new("long", DataType::Int64, false), + ], + ) + .unwrap(); + let arr_items_field = Arc::new(Field::new( + item_name, + DataType::Union(uf_arr_items.clone(), UnionMode::Dense), + true, + )); + let uf_map_vals = UnionFields::try_new( + vec![0, 1, 2], + vec![ + Field::new("string", DataType::Utf8, false), + Field::new("double", DataType::Float64, false), + Field::new("null", DataType::Null, false), + ], + ) + .unwrap(); + let map_entries_field = Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new( + "value", + DataType::Union(uf_map_vals.clone(), UnionMode::Dense), + true, + ), + ])), + false, + )); + // Enum metadata for Color (now includes name/namespace) + let mut enum_md_color = { + let mut m = HashMap::::new(); + m.insert( + crate::schema::AVRO_ENUM_SYMBOLS_METADATA_KEY.to_string(), + serde_json::to_string(&vec!["RED", "GREEN", "BLUE"]).unwrap(), + ); + m + }; + enum_md_color.insert(AVRO_NAME_METADATA_KEY.to_string(), "Color".to_string()); + enum_md_color.insert( + AVRO_NAMESPACE_METADATA_KEY.to_string(), + "org.apache.arrow.avrotests.v1.types".to_string(), + ); + let union_rec_a_fields = Fields::from(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + ]); + let union_rec_b_fields = Fields::from(vec![ + Field::new("x", DataType::Int64, false), + Field::new("y", DataType::Binary, false), + ]); + let union_map_entries = Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Utf8, false), + ])), + false, + )); + let rec_a_md = { + let mut m = HashMap::::new(); + m.insert(AVRO_NAME_METADATA_KEY.to_string(), "RecA".to_string()); + m.insert( + AVRO_NAMESPACE_METADATA_KEY.to_string(), + "org.apache.arrow.avrotests.v1.types".to_string(), + ); + m + }; + let rec_b_md = { + let mut m = HashMap::::new(); + m.insert(AVRO_NAME_METADATA_KEY.to_string(), "RecB".to_string()); + m.insert( + AVRO_NAMESPACE_METADATA_KEY.to_string(), + "org.apache.arrow.avrotests.v1.types".to_string(), + ); + m + }; + let uf_union_big = UnionFields::try_new( + vec![0, 1, 2, 3, 4], + vec![ + Field::new( + "map", + DataType::Map(union_map_entries.clone(), false), + false, + ), + Field::new( + "array", + DataType::List(Arc::new(Field::new(item_name, DataType::Int64, false))), + false, + ), + Field::new( + "org.apache.arrow.avrotests.v1.types.RecB", + DataType::Struct(union_rec_b_fields.clone()), + false, + ) + .with_metadata(rec_b_md.clone()), + Field::new( + "org.apache.arrow.avrotests.v1.types.RecA", + DataType::Struct(union_rec_a_fields.clone()), + false, + ) + .with_metadata(rec_a_md.clone()), + Field::new( + "org.apache.arrow.avrotests.v1.types.Color", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + ) + .with_metadata(enum_md_color.clone()), + ], + ) + .unwrap(); + let fx4_md = { + let mut m = HashMap::::new(); + m.insert(AVRO_NAME_METADATA_KEY.to_string(), "Fx4".to_string()); + m.insert( + AVRO_NAMESPACE_METADATA_KEY.to_string(), + "org.apache.arrow.avrotests.v1".to_string(), + ); + m + }; + let uf_date_fixed4 = UnionFields::try_new( + vec![0, 1], + vec![ + Field::new( + "org.apache.arrow.avrotests.v1.Fx4", + DataType::FixedSizeBinary(4), + false, + ) + .with_metadata(fx4_md.clone()), + Field::new("date", DataType::Date32, false), + ], + ) + .unwrap(); + let dur12u_md = { + let mut m = HashMap::::new(); + m.insert(AVRO_NAME_METADATA_KEY.to_string(), "Dur12U".to_string()); + m.insert( + AVRO_NAMESPACE_METADATA_KEY.to_string(), + "org.apache.arrow.avrotests.v1".to_string(), + ); + m + }; + let uf_dur_or_str = UnionFields::try_new( + vec![0, 1], + vec![ + Field::new("string", DataType::Utf8, false), + Field::new( + "org.apache.arrow.avrotests.v1.Dur12U", + DataType::Interval(arrow_schema::IntervalUnit::MonthDayNano), + false, + ) + .with_metadata(dur12u_md.clone()), + ], + ) + .unwrap(); + let fx10_md = { + let mut m = HashMap::::new(); + m.insert(AVRO_NAME_METADATA_KEY.to_string(), "Fx10".to_string()); + m.insert( + AVRO_NAMESPACE_METADATA_KEY.to_string(), + "org.apache.arrow.avrotests.v1".to_string(), + ); + m + }; + let uf_uuid_or_fx10 = UnionFields::try_new( + vec![0, 1], + vec![ + Field::new( + "org.apache.arrow.avrotests.v1.Fx10", + DataType::FixedSizeBinary(10), + false, + ) + .with_metadata(fx10_md.clone()), + add_uuid_ext_union(Field::new("uuid", DataType::FixedSizeBinary(16), false)), + ], + ) + .unwrap(); + let uf_kv_val = UnionFields::try_new( + vec![0, 1, 2], + vec![ + Field::new("null", DataType::Null, false), + Field::new("int", DataType::Int32, false), + Field::new("long", DataType::Int64, false), + ], + ) + .unwrap(); + let kv_fields = Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new( + "val", + DataType::Union(uf_kv_val.clone(), UnionMode::Dense), + true, + ), + ]); + let kv_item_field = Arc::new(Field::new( + item_name, + DataType::Struct(kv_fields.clone()), + false, + )); + let map_int_entries = Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, false), + ])), + false, + )); + let uf_map_or_array = UnionFields::try_new( + vec![0, 1], + vec![ + Field::new( + "array", + DataType::List(Arc::new(Field::new(item_name, DataType::Int32, false))), + false, + ), + Field::new("map", DataType::Map(map_int_entries.clone(), false), false), + ], + ) + .unwrap(); + let mut enum_md_status = { + let mut m = HashMap::::new(); + m.insert( + crate::schema::AVRO_ENUM_SYMBOLS_METADATA_KEY.to_string(), + serde_json::to_string(&vec!["UNKNOWN", "NEW", "PROCESSING", "DONE"]).unwrap(), + ); + m + }; + enum_md_status.insert(AVRO_NAME_METADATA_KEY.to_string(), "Status".to_string()); + enum_md_status.insert( + AVRO_NAMESPACE_METADATA_KEY.to_string(), + "org.apache.arrow.avrotests.v1.types".to_string(), + ); + let mut dec20_md = HashMap::::new(); + dec20_md.insert("precision".to_string(), "20".to_string()); + dec20_md.insert("scale".to_string(), "4".to_string()); + dec20_md.insert(AVRO_NAME_METADATA_KEY.to_string(), "DecFix20".to_string()); + dec20_md.insert( + AVRO_NAMESPACE_METADATA_KEY.to_string(), + "org.apache.arrow.avrotests.v1.types".to_string(), + ); + let mut dec10_md = HashMap::::new(); + dec10_md.insert("precision".to_string(), "10".to_string()); + dec10_md.insert("scale".to_string(), "2".to_string()); + let fx16_top_md = { + let mut m = HashMap::::new(); + m.insert(AVRO_NAME_METADATA_KEY.to_string(), "Fx16".to_string()); + m.insert( + AVRO_NAMESPACE_METADATA_KEY.to_string(), + "org.apache.arrow.avrotests.v1.types".to_string(), + ); + m + }; + let dur12_top_md = { + let mut m = HashMap::::new(); + m.insert(AVRO_NAME_METADATA_KEY.to_string(), "Dur12".to_string()); + m.insert( + AVRO_NAMESPACE_METADATA_KEY.to_string(), + "org.apache.arrow.avrotests.v1.types".to_string(), + ); + m + }; + #[cfg(feature = "small_decimals")] + let dec20_dt = DataType::Decimal128(20, 4); + #[cfg(not(feature = "small_decimals"))] + let dec20_dt = DataType::Decimal128(20, 4); + #[cfg(feature = "small_decimals")] + let dec10_dt = DataType::Decimal64(10, 2); + #[cfg(not(feature = "small_decimals"))] + let dec10_dt = DataType::Decimal128(10, 2); + let fields: Vec = vec![ + Arc::new(Field::new( + "person", + DataType::Struct(Fields::from(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("age", DataType::Int32, false), + ])), + false, + )), + Arc::new(Field::new("old_count", DataType::Int32, false)), + Arc::new(Field::new( + "union_map_or_array_int", + DataType::Union(uf_map_or_array.clone(), UnionMode::Dense), + false, + )), + Arc::new(Field::new( + "array_records_with_union", + DataType::List(kv_item_field.clone()), + false, + )), + Arc::new(Field::new( + "union_uuid_or_fixed10", + DataType::Union(uf_uuid_or_fx10.clone(), UnionMode::Dense), + false, + )), + Arc::new(Field::new( + "union_interval_or_string", + DataType::Union(uf_dur_or_str.clone(), UnionMode::Dense), + false, + )), + Arc::new(Field::new( + "union_date_or_fixed4", + DataType::Union(uf_date_fixed4.clone(), UnionMode::Dense), + false, + )), + Arc::new(Field::new( + "union_enum_record_array_map", + DataType::Union(uf_union_big.clone(), UnionMode::Dense), + false, + )), + Arc::new(Field::new( + "maybe_auth", + DataType::Struct(Fields::from(vec![ + Field::new("user", DataType::Utf8, false), + Field::new("token", DataType::Binary, true), // [bytes,null] -> nullable bytes + ])), + false, + )), + Arc::new(Field::new( + "address", + DataType::Struct(Fields::from(vec![ + Field::new("street_name", DataType::Utf8, false), + Field::new("zip", DataType::Int32, false), + Field::new("country", DataType::Utf8, false), + ])), + false, + )), + Arc::new(Field::new( + "map_union", + DataType::Map(map_entries_field.clone(), false), + false, + )), + Arc::new(Field::new( + "arr_union", + DataType::List(arr_items_field.clone()), + false, + )), + Arc::new( + Field::new( + "status", + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)), + false, + ) + .with_metadata(enum_md_status.clone()), + ), + Arc::new( + Field::new( + "interval_mdn", + DataType::Interval(IntervalUnit::MonthDayNano), + false, + ) + .with_metadata(dur12_top_md.clone()), + ), + Arc::new(Field::new( + "ts_micros_local", + DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, None), + false, + )), + Arc::new(Field::new( + "ts_millis_local", + DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, None), + false, + )), + Arc::new(Field::new( + "ts_micros_utc", + DataType::Timestamp(arrow_schema::TimeUnit::Microsecond, Some("+00:00".into())), + false, + )), + Arc::new(Field::new( + "ts_millis_utc", + DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, Some("+00:00".into())), + false, + )), + Arc::new(Field::new( + "t_micros", + DataType::Time64(arrow_schema::TimeUnit::Microsecond), + false, + )), + Arc::new(Field::new( + "t_millis", + DataType::Time32(arrow_schema::TimeUnit::Millisecond), + false, + )), + Arc::new(Field::new("d_date", DataType::Date32, false)), + Arc::new(add_uuid_ext_top(Field::new( + "uuid_str", + DataType::FixedSizeBinary(16), + false, + ))), + Arc::new(Field::new("dec_fix_s20_4", dec20_dt, false).with_metadata(dec20_md.clone())), + Arc::new( + Field::new("dec_bytes_s10_2", dec10_dt, false).with_metadata(dec10_md.clone()), + ), + Arc::new( + Field::new("fx16_plain", DataType::FixedSizeBinary(16), false) + .with_metadata(fx16_top_md.clone()), + ), + Arc::new(Field::new("raw_bytes", DataType::Binary, false)), + Arc::new(Field::new("str_utf8", DataType::Utf8, false)), + Arc::new(Field::new( + "tri_union_prim", + DataType::Union(uf_tri.clone(), UnionMode::Dense), + false, + )), + Arc::new(Field::new("opt_str_nullsecond", DataType::Utf8, true)), + Arc::new(Field::new("opt_i32_nullfirst", DataType::Int32, true)), + Arc::new(Field::new("count_i64", DataType::Int64, false)), + Arc::new(Field::new("count_i32", DataType::Int64, false)), + Arc::new(Field::new("ratio_f64", DataType::Float64, false)), + Arc::new(Field::new("ratio_f32", DataType::Float64, false)), + Arc::new(Field::new("flag", DataType::Boolean, false)), + Arc::new(Field::new("identifier", DataType::Int64, false)), + ]; + let expected_schema = Arc::new(arrow_schema::Schema::new(Fields::from(fields))); + let mut cols: Vec = vec![ + Arc::new(StructArray::new( + match expected_schema + .field_with_name("person") + .unwrap() + .data_type() + { + DataType::Struct(fs) => fs.clone(), + _ => unreachable!(), + }, + vec![ + Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol", "Dave"])) as ArrayRef, + Arc::new(Int32Array::from(vec![30, 0, 25, 41])) as ArrayRef, + ], + None, + )) as ArrayRef, + Arc::new(Int32Array::from(vec![100, 42, 7, 42])) as ArrayRef, + ]; + { + let map_child: ArrayRef = { + let keys = StringArray::from(vec!["x", "y", "only"]); + let vals = Int32Array::from(vec![1, 2, 10]); + let entries = StructArray::new( + Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, false), + ]), + vec![Arc::new(keys) as ArrayRef, Arc::new(vals) as ArrayRef], + None, + ); + let moff = OffsetBuffer::new(ScalarBuffer::::from(vec![0, 2, 3])); + Arc::new(MapArray::new( + map_int_entries.clone(), + moff, + entries, + None, + false, + )) as ArrayRef + }; + let list_child: ArrayRef = { + let values = Int32Array::from(vec![1, 2, 3, 0]); + let offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0, 3, 4])); + Arc::new( + ListArray::try_new( + Arc::new(Field::new(item_name, DataType::Int32, false)), + offsets, + Arc::new(values), + None, + ) + .unwrap(), + ) as ArrayRef + }; + let tids = vec![1, 0, 1, 0]; + let offs = vec![0, 0, 1, 1]; + let arr = mk_dense_union(&uf_map_or_array, tids, offs, |f| match f.name().as_str() { + "array" => Some(list_child.clone()), + "map" => Some(map_child.clone()), + _ => None, + }); + cols.push(arr); + } + { + let keys = Arc::new(StringArray::from(vec!["k1", "k2", "k", "k3", "x"])) as ArrayRef; + let type_ids = vec![1, 0, 2, 0, 1]; + let offsets = vec![0, 0, 0, 1, 1]; + let vals = mk_dense_union(&uf_kv_val, type_ids, offsets, |f| match f.data_type() { + DataType::Int32 => Some(Arc::new(Int32Array::from(vec![5, -5])) as ArrayRef), + DataType::Int64 => Some(Arc::new(Int64Array::from(vec![99i64])) as ArrayRef), + DataType::Null => Some(Arc::new(NullArray::new(2)) as ArrayRef), + _ => None, + }); + let values_struct = + Arc::new(StructArray::new(kv_fields.clone(), vec![keys, vals], None)); + let list_offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0, 2, 3, 4, 5])); + let arr = Arc::new( + ListArray::try_new(kv_item_field.clone(), list_offsets, values_struct, None) + .unwrap(), + ) as ArrayRef; + cols.push(arr); + } + { + let type_ids = vec![1, 0, 1, 0]; // [uuid, fixed10, uuid, fixed10] but uf order = [fixed10, uuid] + let offs = vec![0, 0, 1, 1]; + let arr = mk_dense_union(&uf_uuid_or_fx10, type_ids, offs, |f| match f.data_type() { + DataType::FixedSizeBinary(16) => { + let it = [Some(uuid1), Some(uuid2)].into_iter(); + Some(Arc::new( + FixedSizeBinaryArray::try_from_sparse_iter_with_size(it, 16).unwrap(), + ) as ArrayRef) + } + DataType::FixedSizeBinary(10) => { + let fx10_a = [0xAAu8; 10]; + let fx10_b = [0x00u8, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99]; + let it = [Some(fx10_a), Some(fx10_b)].into_iter(); + Some(Arc::new( + FixedSizeBinaryArray::try_from_sparse_iter_with_size(it, 10).unwrap(), + ) as ArrayRef) + } + _ => None, + }); + cols.push(arr); + } + { + let type_ids = vec![1, 0, 1, 0]; // [duration, string, duration, string] but uf order = [string, duration] + let offs = vec![0, 0, 1, 1]; + let arr = mk_dense_union(&uf_dur_or_str, type_ids, offs, |f| match f.data_type() { + DataType::Interval(arrow_schema::IntervalUnit::MonthDayNano) => Some(Arc::new( + IntervalMonthDayNanoArray::from(vec![dur_small, dur_large]), + ) + as ArrayRef), + DataType::Utf8 => Some(Arc::new(StringArray::from(vec![ + "duration-as-text", + "iso-8601-period-P1Y", + ])) as ArrayRef), + _ => None, + }); + cols.push(arr); + } + { + let type_ids = vec![1, 0, 1, 0]; // [date, fixed, date, fixed] but uf order = [fixed, date] + let offs = vec![0, 0, 1, 1]; + let arr = mk_dense_union(&uf_date_fixed4, type_ids, offs, |f| match f.data_type() { + DataType::Date32 => Some(Arc::new(Date32Array::from(vec![date_a, 0])) as ArrayRef), + DataType::FixedSizeBinary(4) => { + let it = [Some(*b"\x00\x11\x22\x33"), Some(*b"ABCD")].into_iter(); + Some(Arc::new( + FixedSizeBinaryArray::try_from_sparse_iter_with_size(it, 4).unwrap(), + ) as ArrayRef) + } + _ => None, + }); + cols.push(arr); + } + { + let tids = vec![4, 3, 1, 0]; // uf order = [map(0), array(1), RecB(2), RecA(3), enum(4)] + let offs = vec![0, 0, 0, 0]; + let arr = mk_dense_union(&uf_union_big, tids, offs, |f| match f.data_type() { + DataType::Dictionary(_, _) => { + let keys = Int32Array::from(vec![0i32]); + let values = + Arc::new(StringArray::from(vec!["RED", "GREEN", "BLUE"])) as ArrayRef; + Some( + Arc::new(DictionaryArray::::try_new(keys, values).unwrap()) + as ArrayRef, + ) + } + DataType::Struct(fs) if fs == &union_rec_a_fields => { + let a = Int32Array::from(vec![7]); + let b = StringArray::from(vec!["rec"]); + Some(Arc::new(StructArray::new( + fs.clone(), + vec![Arc::new(a) as ArrayRef, Arc::new(b) as ArrayRef], + None, + )) as ArrayRef) + } + DataType::List(_) => { + let values = Int64Array::from(vec![1i64, 2, 3]); + let offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0, 3])); + Some(Arc::new( + ListArray::try_new( + Arc::new(Field::new(item_name, DataType::Int64, false)), + offsets, + Arc::new(values), + None, + ) + .unwrap(), + ) as ArrayRef) + } + DataType::Map(_, _) => { + let keys = StringArray::from(vec!["k"]); + let vals = StringArray::from(vec!["v"]); + let entries = StructArray::new( + Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Utf8, false), + ]), + vec![Arc::new(keys) as ArrayRef, Arc::new(vals) as ArrayRef], + None, + ); + let moff = OffsetBuffer::new(ScalarBuffer::::from(vec![0, 1])); + Some(Arc::new(MapArray::new( + union_map_entries.clone(), + moff, + entries, + None, + false, + )) as ArrayRef) + } + _ => None, + }); + cols.push(arr); + } + { + let fs = match expected_schema + .field_with_name("maybe_auth") + .unwrap() + .data_type() + { + DataType::Struct(fs) => fs.clone(), + _ => unreachable!(), + }; + let user = + Arc::new(StringArray::from(vec!["alice", "bob", "carol", "dave"])) as ArrayRef; + let token_values: Vec> = vec![ + None, + Some(b"\x01\x02\x03".as_ref()), + None, + Some(b"".as_ref()), + ]; + let token = Arc::new(BinaryArray::from(token_values)) as ArrayRef; + cols.push(Arc::new(StructArray::new(fs, vec![user, token], None)) as ArrayRef); + } + { + let fs = match expected_schema + .field_with_name("address") + .unwrap() + .data_type() + { + DataType::Struct(fs) => fs.clone(), + _ => unreachable!(), + }; + let street = Arc::new(StringArray::from(vec![ + "100 Main", + "", + "42 Galaxy Way", + "End Ave", + ])) as ArrayRef; + let zip = Arc::new(Int32Array::from(vec![12345, 0, 42424, 1])) as ArrayRef; + let country = Arc::new(StringArray::from(vec!["US", "CA", "US", "GB"])) as ArrayRef; + cols.push(Arc::new(StructArray::new(fs, vec![street, zip, country], None)) as ArrayRef); + } + { + let keys = StringArray::from(vec!["a", "b", "c", "neg", "pi", "ok"]); + let moff = OffsetBuffer::new(ScalarBuffer::::from(vec![0, 3, 4, 4, 6])); + let tid_s = 0; // string + let tid_d = 1; // double + let tid_n = 2; // null + let type_ids = vec![tid_d, tid_n, tid_s, tid_d, tid_d, tid_s]; + let offsets = vec![0, 0, 0, 1, 2, 1]; + let pi_5dp = (std::f64::consts::PI * 100_000.0).trunc() / 100_000.0; + let vals = mk_dense_union(&uf_map_vals, type_ids, offsets, |f| match f.data_type() { + DataType::Float64 => { + Some(Arc::new(Float64Array::from(vec![1.5f64, -0.5, pi_5dp])) as ArrayRef) + } + DataType::Utf8 => { + Some(Arc::new(StringArray::from(vec!["yes", "true"])) as ArrayRef) + } + DataType::Null => Some(Arc::new(NullArray::new(1)) as ArrayRef), + _ => None, + }); + let entries = StructArray::new( + Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new( + "value", + DataType::Union(uf_map_vals.clone(), UnionMode::Dense), + true, + ), + ]), + vec![Arc::new(keys) as ArrayRef, vals], + None, + ); + let map = Arc::new(MapArray::new( + map_entries_field.clone(), + moff, + entries, + None, + false, + )) as ArrayRef; + cols.push(map); + } + { + let type_ids = vec![ + 2, 1, 0, 2, 0, 1, 2, 2, 1, 0, + 2, // long,string,null,long,null,string,long,long,string,null,long + ]; + let offsets = vec![0, 0, 0, 1, 1, 1, 2, 3, 2, 2, 4]; + let values = + mk_dense_union(&uf_arr_items, type_ids, offsets, |f| match f.data_type() { + DataType::Int64 => { + Some(Arc::new(Int64Array::from(vec![1i64, -3, 0, -1, 0])) as ArrayRef) + } + DataType::Utf8 => { + Some(Arc::new(StringArray::from(vec!["x", "z", "end"])) as ArrayRef) + } + DataType::Null => Some(Arc::new(NullArray::new(3)) as ArrayRef), + _ => None, + }); + let list_offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0, 4, 7, 8, 11])); + let arr = Arc::new( + ListArray::try_new(arr_items_field.clone(), list_offsets, values, None).unwrap(), + ) as ArrayRef; + cols.push(arr); + } + { + let keys = Int32Array::from(vec![1, 2, 3, 0]); // NEW, PROCESSING, DONE, UNKNOWN + let values = Arc::new(StringArray::from(vec![ + "UNKNOWN", + "NEW", + "PROCESSING", + "DONE", + ])) as ArrayRef; + let dict = DictionaryArray::::try_new(keys, values).unwrap(); + cols.push(Arc::new(dict) as ArrayRef); + } + cols.push(Arc::new(IntervalMonthDayNanoArray::from(vec![ + dur_small, dur_zero, dur_large, dur_2years, + ])) as ArrayRef); + cols.push(Arc::new(TimestampMicrosecondArray::from(vec![ + ts_us_2024_01_01 + 123_456, + 0, + ts_us_2024_01_01 + 101_112, + 987_654_321, + ])) as ArrayRef); + cols.push(Arc::new(TimestampMillisecondArray::from(vec![ + ts_ms_2024_01_01 + 86_400_000, + 0, + ts_ms_2024_01_01 + 789, + 123_456_789, + ])) as ArrayRef); + { + let a = TimestampMicrosecondArray::from(vec![ + ts_us_2024_01_01, + 1, + ts_us_2024_01_01 + 456, + 0, + ]) + .with_timezone("+00:00"); + cols.push(Arc::new(a) as ArrayRef); + } + { + let a = TimestampMillisecondArray::from(vec![ + ts_ms_2024_01_01, + -1, + ts_ms_2024_01_01 + 123, + 0, + ]) + .with_timezone("+00:00"); + cols.push(Arc::new(a) as ArrayRef); + } + cols.push(Arc::new(Time64MicrosecondArray::from(vec![ + time_us_eod, + 0, + 1, + 1_000_000, + ])) as ArrayRef); + cols.push(Arc::new(Time32MillisecondArray::from(vec![ + time_ms_a, + 0, + 1, + 86_400_000 - 1, + ])) as ArrayRef); + cols.push(Arc::new(Date32Array::from(vec![date_a, 0, 1, 365])) as ArrayRef); + { + let it = [Some(uuid1), Some(uuid2), Some(uuid1), Some(uuid2)].into_iter(); + cols.push(Arc::new( + FixedSizeBinaryArray::try_from_sparse_iter_with_size(it, 16).unwrap(), + ) as ArrayRef); + } + { + #[cfg(feature = "small_decimals")] + let arr = Arc::new( + Decimal128Array::from_iter_values([1_234_567_891_234i128, -420_000i128, 0, -1i128]) + .with_precision_and_scale(20, 4) + .unwrap(), + ) as ArrayRef; + #[cfg(not(feature = "small_decimals"))] + let arr = Arc::new( + Decimal128Array::from_iter_values([1_234_567_891_234i128, -420_000i128, 0, -1i128]) + .with_precision_and_scale(20, 4) + .unwrap(), + ) as ArrayRef; + cols.push(arr); + } + { + #[cfg(feature = "small_decimals")] + let arr = Arc::new( + Decimal64Array::from_iter_values([123456i64, -1, 0, 9_999_999_999i64]) + .with_precision_and_scale(10, 2) + .unwrap(), + ) as ArrayRef; + #[cfg(not(feature = "small_decimals"))] + let arr = Arc::new( + Decimal128Array::from_iter_values([123456i128, -1, 0, 9_999_999_999i128]) + .with_precision_and_scale(10, 2) + .unwrap(), + ) as ArrayRef; + cols.push(arr); + } + { + let it = [ + Some(*b"0123456789ABCDEF"), + Some([0u8; 16]), + Some(*b"ABCDEFGHIJKLMNOP"), + Some([0xAA; 16]), + ] + .into_iter(); + cols.push(Arc::new( + FixedSizeBinaryArray::try_from_sparse_iter_with_size(it, 16).unwrap(), + ) as ArrayRef); + } + cols.push(Arc::new(BinaryArray::from(vec![ + b"\x00\x01".as_ref(), + b"".as_ref(), + b"\xFF\x00".as_ref(), + b"\x10\x20\x30\x40".as_ref(), + ])) as ArrayRef); + cols.push(Arc::new(StringArray::from(vec!["hello", "", "world", "✓ unicode"])) as ArrayRef); + { + let tids = vec![0, 1, 2, 1]; + let offs = vec![0, 0, 0, 1]; + let arr = mk_dense_union(&uf_tri, tids, offs, |f| match f.data_type() { + DataType::Int32 => Some(Arc::new(Int32Array::from(vec![0])) as ArrayRef), + DataType::Utf8 => Some(Arc::new(StringArray::from(vec!["hi", ""])) as ArrayRef), + DataType::Boolean => Some(Arc::new(BooleanArray::from(vec![true])) as ArrayRef), + _ => None, + }); + cols.push(arr); } + cols.push(Arc::new(StringArray::from(vec![ + Some("alpha"), + None, + Some("s3"), + Some(""), + ])) as ArrayRef); + cols.push(Arc::new(Int32Array::from(vec![None, Some(42), None, Some(0)])) as ArrayRef); + cols.push(Arc::new(Int64Array::from(vec![ + 7_000_000_000i64, + -2, + 0, + -9_876_543_210i64, + ])) as ArrayRef); + cols.push(Arc::new(Int64Array::from(vec![7i64, -1, 0, 123])) as ArrayRef); + cols.push(Arc::new(Float64Array::from(vec![2.5f64, -1.0, 7.0, -2.25])) as ArrayRef); + cols.push(Arc::new(Float64Array::from(vec![1.25f64, -0.0, 3.5, 9.75])) as ArrayRef); + cols.push(Arc::new(BooleanArray::from(vec![true, false, true, false])) as ArrayRef); + cols.push(Arc::new(Int64Array::from(vec![1, 2, 3, 4])) as ArrayRef); + let expected = RecordBatch::try_new(expected_schema, cols).unwrap(); + assert_eq!( + expected, batch, + "entire RecordBatch mismatch (schema, all columns, all rows)" + ); } } diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 3466b064455f..648baa60c723 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -15,92 +15,190 @@ // specific language governing permissions and limitations // under the License. -use crate::codec::{AvroDataType, Codec, Nullability}; -use crate::reader::block::{Block, BlockDecoder}; +//! Avro Decoder for Arrow types. + +use crate::codec::{ + AvroDataType, AvroField, AvroLiteral, Codec, Promotion, ResolutionInfo, ResolvedRecord, + ResolvedUnion, +}; use crate::reader::cursor::AvroCursor; -use crate::reader::header::Header; -use crate::reader::ReadOptions; -use crate::schema::*; +use crate::schema::Nullability; +#[cfg(feature = "small_decimals")] +use arrow_array::builder::{Decimal32Builder, Decimal64Builder}; +use arrow_array::builder::{Decimal128Builder, Decimal256Builder, IntervalMonthDayNanoBuilder}; use arrow_array::types::*; use arrow_array::*; use arrow_buffer::*; use arrow_schema::{ - ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef, + ArrowError, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DataType, Field as ArrowField, + FieldRef, Fields, Schema as ArrowSchema, SchemaRef, UnionFields, UnionMode, }; +#[cfg(feature = "small_decimals")] +use arrow_schema::{DECIMAL32_MAX_PRECISION, DECIMAL64_MAX_PRECISION}; +#[cfg(feature = "avro_custom_types")] +use arrow_select::take::{TakeOptions, take}; use std::cmp::Ordering; -use std::collections::HashMap; -use std::io::Read; use std::sync::Arc; +use strum_macros::AsRefStr; +use uuid::Uuid; + +const DEFAULT_CAPACITY: usize = 1024; + +/// Runtime plan for decoding reader-side `["null", T]` types. +#[derive(Clone, Copy, Debug)] +enum NullablePlan { + /// Writer actually wrote a union (branch tag present). + ReadTag, + /// Writer wrote a single (non-union) value resolved to the non-null branch + /// of the reader union; do NOT read a branch tag, but apply any promotion. + FromSingle { promotion: Promotion }, +} + +/// Macro to decode a decimal payload for a given width and integer type. +macro_rules! decode_decimal { + ($size:expr, $buf:expr, $builder:expr, $N:expr, $Int:ty) => {{ + let bytes = read_decimal_bytes_be::<{ $N }>($buf, $size)?; + $builder.append_value(<$Int>::from_be_bytes(bytes)); + }}; +} + +/// Macro to finish a decimal builder into an array with precision/scale and nulls. +macro_rules! flush_decimal { + ($builder:expr, $precision:expr, $scale:expr, $nulls:expr, $ArrayTy:ty) => {{ + let (_, vals, _) = $builder.finish().into_parts(); + let dec = <$ArrayTy>::try_new(vals, $nulls)? + .with_precision_and_scale(*$precision as u8, $scale.unwrap_or(0) as i8) + .map_err(|e| ArrowError::ParseError(e.to_string()))?; + Arc::new(dec) as ArrayRef + }}; +} + +/// Macro to append a default decimal value from two's-complement big-endian bytes +/// into the corresponding decimal builder, with compile-time constructed error text. +macro_rules! append_decimal_default { + ($lit:expr, $builder:expr, $N:literal, $Int:ty, $name:literal) => {{ + match $lit { + AvroLiteral::Bytes(b) => { + let ext = sign_cast_to::<$N>(b)?; + let val = <$Int>::from_be_bytes(ext); + $builder.append_value(val); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + concat!( + "Default for ", + $name, + " must be bytes (two's-complement big-endian)" + ) + .to_string(), + )), + } + }}; +} /// Decodes avro encoded data into [`RecordBatch`] -pub struct RecordDecoder { +#[derive(Debug)] +pub(crate) struct RecordDecoder { schema: SchemaRef, fields: Vec, - use_utf8view: bool, + projector: Option, } impl RecordDecoder { - /// Create a new [`RecordDecoder`] from the provided [`AvroDataType`] with default options - pub fn try_new(data_type: &AvroDataType) -> Result { - Self::try_new_with_options(data_type, ReadOptions::default()) - } - - /// Create a new [`RecordDecoder`] from the provided [`AvroDataType`] with additional options + /// Creates a new [`RecordDecoder`] from the provided [`AvroDataType`] with additional options. /// /// This method allows you to customize how the Avro data is decoded into Arrow arrays. /// - /// # Parameters - /// * `data_type` - The Avro data type to decode - /// * `options` - Configuration options for decoding - pub fn try_new_with_options( - data_type: &AvroDataType, - options: ReadOptions, - ) -> Result { - match Decoder::try_new(data_type)? { - Decoder::Record(fields, encodings) => Ok(Self { - schema: Arc::new(ArrowSchema::new(fields)), - fields: encodings, - use_utf8view: options.use_utf8view(), - }), - encoding => Err(ArrowError::ParseError(format!( - "Expected record got {encoding:?}" + /// # Arguments + /// * `data_type` - The Avro data type to decode. + /// * `use_utf8view` - A flag indicating whether to use `Utf8View` for string types. + /// + /// # Errors + /// This function will return an error if the provided `data_type` is not a `Record`. + pub(crate) fn try_new_with_options(data_type: &AvroDataType) -> Result { + match data_type.codec() { + Codec::Struct(reader_fields) => { + // Build Arrow schema fields and per-child decoders + let mut arrow_fields = Vec::with_capacity(reader_fields.len()); + let mut encodings = Vec::with_capacity(reader_fields.len()); + for avro_field in reader_fields.iter() { + arrow_fields.push(avro_field.field()); + encodings.push(Decoder::try_new(avro_field.data_type())?); + } + let projector = match data_type.resolution.as_ref() { + Some(ResolutionInfo::Record(rec)) => { + Some(ProjectorBuilder::try_new(rec, reader_fields).build()?) + } + _ => None, + }; + Ok(Self { + schema: Arc::new(ArrowSchema::new(arrow_fields)), + fields: encodings, + projector, + }) + } + other => Err(ArrowError::ParseError(format!( + "Expected record got {other:?}" ))), } } - pub fn schema(&self) -> &SchemaRef { + /// Returns the decoder's `SchemaRef` + pub(crate) fn schema(&self) -> &SchemaRef { &self.schema } /// Decode `count` records from `buf` - pub fn decode(&mut self, buf: &[u8], count: usize) -> Result { + pub(crate) fn decode(&mut self, buf: &[u8], count: usize) -> Result { let mut cursor = AvroCursor::new(buf); - for _ in 0..count { - for field in &mut self.fields { - field.decode(&mut cursor)?; + match self.projector.as_mut() { + Some(proj) => { + for _ in 0..count { + proj.project_record(&mut cursor, &mut self.fields)?; + } + } + None => { + for _ in 0..count { + for field in &mut self.fields { + field.decode(&mut cursor)?; + } + } } } Ok(cursor.position()) } /// Flush the decoded records into a [`RecordBatch`] - pub fn flush(&mut self) -> Result { + pub(crate) fn flush(&mut self) -> Result { let arrays = self .fields .iter_mut() .map(|x| x.flush(None)) .collect::, _>>()?; - RecordBatch::try_new(self.schema.clone(), arrays) } } #[derive(Debug)] +struct EnumResolution { + mapping: Arc<[i32]>, + default_index: i32, +} + +#[derive(Debug, AsRefStr)] enum Decoder { Null(usize), Boolean(BooleanBufferBuilder), Int32(Vec), Int64(Vec), + #[cfg(feature = "avro_custom_types")] + DurationSecond(Vec), + #[cfg(feature = "avro_custom_types")] + DurationMillisecond(Vec), + #[cfg(feature = "avro_custom_types")] + DurationMicrosecond(Vec), + #[cfg(feature = "avro_custom_types")] + DurationNanosecond(Vec), Float32(Vec), Float64(Vec), Date32(Vec), @@ -108,13 +206,22 @@ enum Decoder { TimeMicros(Vec), TimestampMillis(bool, Vec), TimestampMicros(bool, Vec), + TimestampNanos(bool, Vec), + Int32ToInt64(Vec), + Int32ToFloat32(Vec), + Int32ToFloat64(Vec), + Int64ToFloat32(Vec), + Int64ToFloat64(Vec), + Float32ToFloat64(Vec), + BytesToString(OffsetBufferBuilder, Vec), + StringToBytes(OffsetBufferBuilder, Vec), Binary(OffsetBufferBuilder, Vec), /// String data encoded as UTF-8 bytes, mapped to Arrow's StringArray String(OffsetBufferBuilder, Vec), /// String data encoded as UTF-8 bytes, but mapped to Arrow's StringViewArray StringView(OffsetBufferBuilder, Vec), Array(FieldRef, OffsetBufferBuilder, Box), - Record(Fields, Vec), + Record(Fields, Vec, Option), Map( FieldRef, OffsetBufferBuilder, @@ -122,44 +229,170 @@ enum Decoder { Vec, Box, ), - Nullable(Nullability, NullBufferBuilder, Box), + Fixed(i32, Vec), + Enum(Vec, Arc<[String]>, Option), + Duration(IntervalMonthDayNanoBuilder), + Uuid(Vec), + #[cfg(feature = "small_decimals")] + Decimal32(usize, Option, Option, Decimal32Builder), + #[cfg(feature = "small_decimals")] + Decimal64(usize, Option, Option, Decimal64Builder), + Decimal128(usize, Option, Option, Decimal128Builder), + Decimal256(usize, Option, Option, Decimal256Builder), + #[cfg(feature = "avro_custom_types")] + RunEndEncoded(u8, usize, Box), + Union(UnionDecoder), + Nullable(Nullability, NullBufferBuilder, Box, NullablePlan), } impl Decoder { fn try_new(data_type: &AvroDataType) -> Result { - let nyi = |s: &str| Err(ArrowError::NotYetImplemented(s.to_string())); - - let decoder = match data_type.codec() { - Codec::Null => Self::Null(0), - Codec::Boolean => Self::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)), - Codec::Int32 => Self::Int32(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Int64 => Self::Int64(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Float32 => Self::Float32(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Float64 => Self::Float64(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::Binary => Self::Binary( + if let Some(ResolutionInfo::Union(info)) = data_type.resolution.as_ref() { + if info.writer_is_union && !info.reader_is_union { + let mut clone = data_type.clone(); + clone.resolution = None; // Build target base decoder without Union resolution + let target = Box::new(Self::try_new_internal(&clone)?); + let decoder = Self::Union( + UnionDecoderBuilder::new() + .with_resolved_union(info.clone()) + .with_target(target) + .build()?, + ); + return Ok(decoder); + } + } + Self::try_new_internal(data_type) + } + + fn try_new_internal(data_type: &AvroDataType) -> Result { + // Extract just the Promotion (if any) to simplify pattern matching + let promotion = match data_type.resolution.as_ref() { + Some(ResolutionInfo::Promotion(p)) => Some(p), + _ => None, + }; + let decoder = match (data_type.codec(), promotion) { + (Codec::Int64, Some(Promotion::IntToLong)) => { + Self::Int32ToInt64(Vec::with_capacity(DEFAULT_CAPACITY)) + } + (Codec::Float32, Some(Promotion::IntToFloat)) => { + Self::Int32ToFloat32(Vec::with_capacity(DEFAULT_CAPACITY)) + } + (Codec::Float64, Some(Promotion::IntToDouble)) => { + Self::Int32ToFloat64(Vec::with_capacity(DEFAULT_CAPACITY)) + } + (Codec::Float32, Some(Promotion::LongToFloat)) => { + Self::Int64ToFloat32(Vec::with_capacity(DEFAULT_CAPACITY)) + } + (Codec::Float64, Some(Promotion::LongToDouble)) => { + Self::Int64ToFloat64(Vec::with_capacity(DEFAULT_CAPACITY)) + } + (Codec::Float64, Some(Promotion::FloatToDouble)) => { + Self::Float32ToFloat64(Vec::with_capacity(DEFAULT_CAPACITY)) + } + (Codec::Utf8, Some(Promotion::BytesToString)) + | (Codec::Utf8View, Some(Promotion::BytesToString)) => Self::BytesToString( + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Vec::with_capacity(DEFAULT_CAPACITY), + ), + (Codec::Binary, Some(Promotion::StringToBytes)) => Self::StringToBytes( + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Vec::with_capacity(DEFAULT_CAPACITY), + ), + (Codec::Null, _) => Self::Null(0), + (Codec::Boolean, _) => Self::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)), + (Codec::Int32, _) => Self::Int32(Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::Int64, _) => Self::Int64(Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::Float32, _) => Self::Float32(Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::Float64, _) => Self::Float64(Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::Binary, _) => Self::Binary( OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::with_capacity(DEFAULT_CAPACITY), ), - Codec::Utf8 => Self::String( + (Codec::Utf8, _) => Self::String( OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::with_capacity(DEFAULT_CAPACITY), ), - Codec::Utf8View => Self::StringView( + (Codec::Utf8View, _) => Self::StringView( OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::with_capacity(DEFAULT_CAPACITY), ), - Codec::Date32 => Self::Date32(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::TimeMillis => Self::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::TimeMicros => Self::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::TimestampMillis(is_utc) => { + (Codec::Date32, _) => Self::Date32(Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::TimeMillis, _) => Self::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::TimeMicros, _) => Self::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::TimestampMillis(is_utc), _) => { Self::TimestampMillis(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) } - Codec::TimestampMicros(is_utc) => { + (Codec::TimestampMicros(is_utc), _) => { Self::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) } - Codec::Fixed(_) => return nyi("decoding fixed"), - Codec::Interval => return nyi("decoding interval"), - Codec::List(item) => { + (Codec::TimestampNanos(is_utc), _) => { + Self::TimestampNanos(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) + } + #[cfg(feature = "avro_custom_types")] + (Codec::DurationNanos, _) => { + Self::DurationNanosecond(Vec::with_capacity(DEFAULT_CAPACITY)) + } + #[cfg(feature = "avro_custom_types")] + (Codec::DurationMicros, _) => { + Self::DurationMicrosecond(Vec::with_capacity(DEFAULT_CAPACITY)) + } + #[cfg(feature = "avro_custom_types")] + (Codec::DurationMillis, _) => { + Self::DurationMillisecond(Vec::with_capacity(DEFAULT_CAPACITY)) + } + #[cfg(feature = "avro_custom_types")] + (Codec::DurationSeconds, _) => { + Self::DurationSecond(Vec::with_capacity(DEFAULT_CAPACITY)) + } + (Codec::Fixed(sz), _) => Self::Fixed(*sz, Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::Decimal(precision, scale, size), _) => { + let p = *precision; + let s = *scale; + let prec = p as u8; + let scl = s.unwrap_or(0) as i8; + #[cfg(feature = "small_decimals")] + { + if p <= DECIMAL32_MAX_PRECISION as usize { + let builder = Decimal32Builder::with_capacity(DEFAULT_CAPACITY) + .with_precision_and_scale(prec, scl)?; + Self::Decimal32(p, s, *size, builder) + } else if p <= DECIMAL64_MAX_PRECISION as usize { + let builder = Decimal64Builder::with_capacity(DEFAULT_CAPACITY) + .with_precision_and_scale(prec, scl)?; + Self::Decimal64(p, s, *size, builder) + } else if p <= DECIMAL128_MAX_PRECISION as usize { + let builder = Decimal128Builder::with_capacity(DEFAULT_CAPACITY) + .with_precision_and_scale(prec, scl)?; + Self::Decimal128(p, s, *size, builder) + } else if p <= DECIMAL256_MAX_PRECISION as usize { + let builder = Decimal256Builder::with_capacity(DEFAULT_CAPACITY) + .with_precision_and_scale(prec, scl)?; + Self::Decimal256(p, s, *size, builder) + } else { + return Err(ArrowError::ParseError(format!( + "Decimal precision {p} exceeds maximum supported" + ))); + } + } + #[cfg(not(feature = "small_decimals"))] + { + if p <= DECIMAL128_MAX_PRECISION as usize { + let builder = Decimal128Builder::with_capacity(DEFAULT_CAPACITY) + .with_precision_and_scale(prec, scl)?; + Self::Decimal128(p, s, *size, builder) + } else if p <= DECIMAL256_MAX_PRECISION as usize { + let builder = Decimal256Builder::with_capacity(DEFAULT_CAPACITY) + .with_precision_and_scale(prec, scl)?; + Self::Decimal256(p, s, *size, builder) + } else { + return Err(ArrowError::ParseError(format!( + "Decimal precision {p} exceeds maximum supported" + ))); + } + } + } + (Codec::Interval, _) => Self::Duration(IntervalMonthDayNanoBuilder::new()), + (Codec::List(item), _) => { let decoder = Self::try_new(item)?; Self::Array( Arc::new(item.field_with_name("item")), @@ -167,7 +400,17 @@ impl Decoder { Box::new(decoder), ) } - Codec::Struct(fields) => { + (Codec::Enum(symbols), _) => { + let res = match data_type.resolution.as_ref() { + Some(ResolutionInfo::EnumMapping(mapping)) => Some(EnumResolution { + mapping: mapping.mapping.clone(), + default_index: mapping.default_index, + }), + _ => None, + }; + Self::Enum(Vec::with_capacity(DEFAULT_CAPACITY), symbols.clone(), res) + } + (Codec::Struct(fields), _) => { let mut arrow_fields = Vec::with_capacity(fields.len()); let mut encodings = Vec::with_capacity(fields.len()); for avro_field in fields.iter() { @@ -175,10 +418,16 @@ impl Decoder { arrow_fields.push(avro_field.field()); encodings.push(encoding); } - Self::Record(arrow_fields.into(), encodings) + let projector = + if let Some(ResolutionInfo::Record(rec)) = data_type.resolution.as_ref() { + Some(ProjectorBuilder::try_new(rec, fields).build()?) + } else { + None + }; + Self::Record(arrow_fields.into(), encodings, projector) } - Codec::Map(child) => { - let val_field = child.field_with_name("value").with_nullable(true); + (Codec::Map(child), _) => { + let val_field = child.field_with_name("value"); let map_field = Arc::new(ArrowField::new( "entries", DataType::Struct(Fields::from(vec![ @@ -196,42 +445,407 @@ impl Decoder { Box::new(val_dec), ) } + (Codec::Uuid, _) => Self::Uuid(Vec::with_capacity(DEFAULT_CAPACITY)), + (Codec::Union(encodings, fields, UnionMode::Dense), _) => { + let decoders = encodings + .iter() + .map(Self::try_new_internal) + .collect::, _>>()?; + if fields.len() != decoders.len() { + return Err(ArrowError::SchemaError(format!( + "Union has {} fields but {} decoders", + fields.len(), + decoders.len() + ))); + } + // Proactive guard: if a user provides a union with more branches than + // a 32-bit Avro index can address, fail fast with a clear message. + let branch_count = decoders.len(); + let max_addr = (i32::MAX as usize) + 1; + if branch_count > max_addr { + return Err(ArrowError::SchemaError(format!( + "Union has {branch_count} branches, which exceeds the maximum addressable \ + branches by an Avro int tag ({} + 1).", + i32::MAX + ))); + } + let mut builder = UnionDecoderBuilder::new() + .with_fields(fields.clone()) + .with_branches(decoders); + if let Some(ResolutionInfo::Union(info)) = data_type.resolution.as_ref() { + if info.reader_is_union { + builder = builder.with_resolved_union(info.clone()); + } + } + Self::Union(builder.build()?) + } + (Codec::Union(_, _, _), _) => { + return Err(ArrowError::NotYetImplemented( + "Sparse Arrow unions are not yet supported".to_string(), + )); + } + #[cfg(feature = "avro_custom_types")] + (Codec::RunEndEncoded(values_dt, width_bits_or_bytes), _) => { + let inner = Self::try_new(values_dt)?; + let byte_width: u8 = match *width_bits_or_bytes { + 2 | 4 | 8 => *width_bits_or_bytes, + 16 => 2, + 32 => 4, + 64 => 8, + other => { + return Err(ArrowError::InvalidArgumentError(format!( + "Unsupported run-end width {other} for RunEndEncoded; \ + expected 16/32/64 bits or 2/4/8 bytes" + ))); + } + }; + Self::RunEndEncoded(byte_width, 0, Box::new(inner)) + } }; - Ok(match data_type.nullability() { - Some(nullability) => Self::Nullable( - nullability, - NullBufferBuilder::new(DEFAULT_CAPACITY), - Box::new(decoder), - ), + Some(nullability) => { + // Default to reading a union branch tag unless the resolution proves otherwise. + let mut plan = NullablePlan::ReadTag; + if let Some(ResolutionInfo::Union(info)) = data_type.resolution.as_ref() { + if !info.writer_is_union && info.reader_is_union { + if let Some(Some((_reader_idx, promo))) = info.writer_to_reader.first() { + plan = NullablePlan::FromSingle { promotion: *promo }; + } + } + } + Self::Nullable( + nullability, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(decoder), + plan, + ) + } None => decoder, }) } /// Append a null record - fn append_null(&mut self) { + fn append_null(&mut self) -> Result<(), ArrowError> { match self { Self::Null(count) => *count += 1, Self::Boolean(b) => b.append(false), Self::Int32(v) | Self::Date32(v) | Self::TimeMillis(v) => v.push(0), Self::Int64(v) + | Self::Int32ToInt64(v) | Self::TimeMicros(v) | Self::TimestampMillis(_, v) - | Self::TimestampMicros(_, v) => v.push(0), - Self::Float32(v) => v.push(0.), - Self::Float64(v) => v.push(0.), - Self::Binary(offsets, _) | Self::String(offsets, _) | Self::StringView(offsets, _) => { + | Self::TimestampMicros(_, v) + | Self::TimestampNanos(_, v) => v.push(0), + #[cfg(feature = "avro_custom_types")] + Self::DurationSecond(v) + | Self::DurationMillisecond(v) + | Self::DurationMicrosecond(v) + | Self::DurationNanosecond(v) => v.push(0), + Self::Float32(v) | Self::Int32ToFloat32(v) | Self::Int64ToFloat32(v) => v.push(0.), + Self::Float64(v) + | Self::Int32ToFloat64(v) + | Self::Int64ToFloat64(v) + | Self::Float32ToFloat64(v) => v.push(0.), + Self::Binary(offsets, _) + | Self::String(offsets, _) + | Self::StringView(offsets, _) + | Self::BytesToString(offsets, _) + | Self::StringToBytes(offsets, _) => { offsets.push_length(0); } - Self::Array(_, offsets, e) => { + Self::Uuid(v) => { + v.extend([0; 16]); + } + Self::Array(_, offsets, _) => { offsets.push_length(0); - e.append_null(); } - Self::Record(_, e) => e.iter_mut().for_each(|e| e.append_null()), + Self::Record(_, e, _) => { + for encoding in e.iter_mut() { + encoding.append_null()?; + } + } Self::Map(_, _koff, moff, _, _) => { moff.push_length(0); } - Self::Nullable(_, _, _) => unreachable!("Nulls cannot be nested"), + Self::Fixed(sz, accum) => { + accum.extend(std::iter::repeat_n(0u8, *sz as usize)); + } + #[cfg(feature = "small_decimals")] + Self::Decimal32(_, _, _, builder) => builder.append_value(0), + #[cfg(feature = "small_decimals")] + Self::Decimal64(_, _, _, builder) => builder.append_value(0), + Self::Decimal128(_, _, _, builder) => builder.append_value(0), + Self::Decimal256(_, _, _, builder) => builder.append_value(i256::ZERO), + Self::Enum(indices, _, _) => indices.push(0), + Self::Duration(builder) => builder.append_null(), + #[cfg(feature = "avro_custom_types")] + Self::RunEndEncoded(_, len, inner) => { + *len += 1; + inner.append_null()?; + } + Self::Union(u) => u.append_null()?, + Self::Nullable(_, null_buffer, inner, _) => { + null_buffer.append(false); + inner.append_null()?; + } + } + Ok(()) + } + + /// Append a single default literal into the decoder's buffers + fn append_default(&mut self, lit: &AvroLiteral) -> Result<(), ArrowError> { + match self { + Self::Nullable(_, nb, inner, _) => { + if matches!(lit, AvroLiteral::Null) { + nb.append(false); + inner.append_null() + } else { + nb.append(true); + inner.append_default(lit) + } + } + Self::Null(count) => match lit { + AvroLiteral::Null => { + *count += 1; + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Non-null default for null type".to_string(), + )), + }, + Self::Boolean(b) => match lit { + AvroLiteral::Boolean(v) => { + b.append(*v); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for boolean must be boolean".to_string(), + )), + }, + Self::Int32(v) | Self::Date32(v) | Self::TimeMillis(v) => match lit { + AvroLiteral::Int(i) => { + v.push(*i); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for int32/date32/time-millis must be int".to_string(), + )), + }, + #[cfg(feature = "avro_custom_types")] + Self::DurationSecond(v) + | Self::DurationMillisecond(v) + | Self::DurationMicrosecond(v) + | Self::DurationNanosecond(v) => match lit { + AvroLiteral::Long(i) => { + v.push(*i); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for duration long must be long".to_string(), + )), + }, + Self::Int64(v) + | Self::Int32ToInt64(v) + | Self::TimeMicros(v) + | Self::TimestampMillis(_, v) + | Self::TimestampMicros(_, v) + | Self::TimestampNanos(_, v) => match lit { + AvroLiteral::Long(i) => { + v.push(*i); + Ok(()) + } + AvroLiteral::Int(i) => { + v.push(*i as i64); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for long/time-micros/timestamp must be long or int".to_string(), + )), + }, + Self::Float32(v) | Self::Int32ToFloat32(v) | Self::Int64ToFloat32(v) => match lit { + AvroLiteral::Float(f) => { + v.push(*f); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for float must be float".to_string(), + )), + }, + Self::Float64(v) + | Self::Int32ToFloat64(v) + | Self::Int64ToFloat64(v) + | Self::Float32ToFloat64(v) => match lit { + AvroLiteral::Double(f) => { + v.push(*f); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for double must be double".to_string(), + )), + }, + Self::Binary(offsets, values) | Self::StringToBytes(offsets, values) => match lit { + AvroLiteral::Bytes(b) => { + offsets.push_length(b.len()); + values.extend_from_slice(b); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for bytes must be bytes".to_string(), + )), + }, + Self::BytesToString(offsets, values) + | Self::String(offsets, values) + | Self::StringView(offsets, values) => match lit { + AvroLiteral::String(s) => { + let b = s.as_bytes(); + offsets.push_length(b.len()); + values.extend_from_slice(b); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for string must be string".to_string(), + )), + }, + Self::Uuid(values) => match lit { + AvroLiteral::String(s) => { + let uuid = Uuid::try_parse(s).map_err(|e| { + ArrowError::InvalidArgumentError(format!("Invalid UUID default: {s} ({e})")) + })?; + values.extend_from_slice(uuid.as_bytes()); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for uuid must be string".to_string(), + )), + }, + Self::Fixed(sz, accum) => match lit { + AvroLiteral::Bytes(b) => { + if b.len() != *sz as usize { + return Err(ArrowError::InvalidArgumentError(format!( + "Fixed default length {} does not match size {sz}", + b.len(), + ))); + } + accum.extend_from_slice(b); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for fixed must be bytes".to_string(), + )), + }, + #[cfg(feature = "small_decimals")] + Self::Decimal32(_, _, _, builder) => { + append_decimal_default!(lit, builder, 4, i32, "decimal32") + } + #[cfg(feature = "small_decimals")] + Self::Decimal64(_, _, _, builder) => { + append_decimal_default!(lit, builder, 8, i64, "decimal64") + } + Self::Decimal128(_, _, _, builder) => { + append_decimal_default!(lit, builder, 16, i128, "decimal128") + } + Self::Decimal256(_, _, _, builder) => { + append_decimal_default!(lit, builder, 32, i256, "decimal256") + } + Self::Duration(builder) => match lit { + AvroLiteral::Bytes(b) => { + if b.len() != 12 { + return Err(ArrowError::InvalidArgumentError(format!( + "Duration default must be exactly 12 bytes, got {}", + b.len() + ))); + } + let months = u32::from_le_bytes([b[0], b[1], b[2], b[3]]); + let days = u32::from_le_bytes([b[4], b[5], b[6], b[7]]); + let millis = u32::from_le_bytes([b[8], b[9], b[10], b[11]]); + let nanos = (millis as i64) * 1_000_000; + builder.append_value(IntervalMonthDayNano::new( + months as i32, + days as i32, + nanos, + )); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for duration must be 12-byte little-endian months/days/millis" + .to_string(), + )), + }, + Self::Array(_, offsets, inner) => match lit { + AvroLiteral::Array(items) => { + offsets.push_length(items.len()); + for item in items { + inner.append_default(item)?; + } + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for array must be an array literal".to_string(), + )), + }, + Self::Map(_, koff, moff, kdata, valdec) => match lit { + AvroLiteral::Map(entries) => { + moff.push_length(entries.len()); + for (k, v) in entries { + let kb = k.as_bytes(); + koff.push_length(kb.len()); + kdata.extend_from_slice(kb); + valdec.append_default(v)?; + } + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for map must be a map/object literal".to_string(), + )), + }, + Self::Enum(indices, symbols, _) => match lit { + AvroLiteral::Enum(sym) => { + let pos = symbols.iter().position(|s| s == sym).ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Enum default symbol {sym:?} not in reader symbols" + )) + })?; + indices.push(pos as i32); + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for enum must be a symbol".to_string(), + )), + }, + #[cfg(feature = "avro_custom_types")] + Self::RunEndEncoded(_, len, inner) => { + *len += 1; + inner.append_default(lit) + } + Self::Union(u) => u.append_default(lit), + Self::Record(field_meta, decoders, projector) => match lit { + AvroLiteral::Map(entries) => { + for (i, dec) in decoders.iter_mut().enumerate() { + let name = field_meta[i].name(); + if let Some(sub) = entries.get(name) { + dec.append_default(sub)?; + } else if let Some(proj) = projector.as_ref() { + proj.project_default(dec, i)?; + } else { + dec.append_null()?; + } + } + Ok(()) + } + AvroLiteral::Null => { + for (i, dec) in decoders.iter_mut().enumerate() { + if let Some(proj) = projector.as_ref() { + proj.project_default(dec, i)?; + } else { + dec.append_null()?; + } + } + Ok(()) + } + _ => Err(ArrowError::InvalidArgumentError( + "Default for record must be a map/object or null".to_string(), + )), + }, } } @@ -246,25 +860,51 @@ impl Decoder { Self::Int64(values) | Self::TimeMicros(values) | Self::TimestampMillis(_, values) - | Self::TimestampMicros(_, values) => values.push(buf.get_long()?), + | Self::TimestampMicros(_, values) + | Self::TimestampNanos(_, values) => values.push(buf.get_long()?), + #[cfg(feature = "avro_custom_types")] + Self::DurationSecond(values) + | Self::DurationMillisecond(values) + | Self::DurationMicrosecond(values) + | Self::DurationNanosecond(values) => values.push(buf.get_long()?), Self::Float32(values) => values.push(buf.get_float()?), Self::Float64(values) => values.push(buf.get_double()?), - Self::Binary(offsets, values) + Self::Int32ToInt64(values) => values.push(buf.get_int()? as i64), + Self::Int32ToFloat32(values) => values.push(buf.get_int()? as f32), + Self::Int32ToFloat64(values) => values.push(buf.get_int()? as f64), + Self::Int64ToFloat32(values) => values.push(buf.get_long()? as f32), + Self::Int64ToFloat64(values) => values.push(buf.get_long()? as f64), + Self::Float32ToFloat64(values) => values.push(buf.get_float()? as f64), + Self::StringToBytes(offsets, values) + | Self::BytesToString(offsets, values) + | Self::Binary(offsets, values) | Self::String(offsets, values) | Self::StringView(offsets, values) => { let data = buf.get_bytes()?; offsets.push_length(data.len()); values.extend_from_slice(data); } + Self::Uuid(values) => { + let s_bytes = buf.get_bytes()?; + let s = std::str::from_utf8(s_bytes).map_err(|e| { + ArrowError::ParseError(format!("UUID bytes are not valid UTF-8: {e}")) + })?; + let uuid = Uuid::try_parse(s) + .map_err(|e| ArrowError::ParseError(format!("Failed to parse uuid: {e}")))?; + values.extend_from_slice(uuid.as_bytes()); + } Self::Array(_, off, encoding) => { let total_items = read_blocks(buf, |cursor| encoding.decode(cursor))?; off.push_length(total_items); } - Self::Record(_, encodings) => { + Self::Record(_, encodings, None) => { for encoding in encodings { encoding.decode(buf)?; } } + Self::Record(_, encodings, Some(proj)) => { + proj.project_record(buf, encodings)?; + } Self::Map(_, koff, moff, kdata, valdec) => { let newly_added = read_blocks(buf, |cur| { let kb = cur.get_bytes()?; @@ -274,22 +914,150 @@ impl Decoder { })?; moff.push_length(newly_added); } - Self::Nullable(nullability, nulls, e) => { - let is_valid = buf.get_bool()? == matches!(nullability, Nullability::NullFirst); - nulls.append(is_valid); - match is_valid { - true => e.decode(buf)?, - false => e.append_null(), + Self::Fixed(sz, accum) => { + let fx = buf.get_fixed(*sz as usize)?; + accum.extend_from_slice(fx); + } + #[cfg(feature = "small_decimals")] + Self::Decimal32(_, _, size, builder) => { + decode_decimal!(size, buf, builder, 4, i32); + } + #[cfg(feature = "small_decimals")] + Self::Decimal64(_, _, size, builder) => { + decode_decimal!(size, buf, builder, 8, i64); + } + Self::Decimal128(_, _, size, builder) => { + decode_decimal!(size, buf, builder, 16, i128); + } + Self::Decimal256(_, _, size, builder) => { + decode_decimal!(size, buf, builder, 32, i256); + } + Self::Enum(indices, _, None) => { + indices.push(buf.get_int()?); + } + Self::Enum(indices, _, Some(res)) => { + let raw = buf.get_int()?; + let resolved = usize::try_from(raw) + .ok() + .and_then(|idx| res.mapping.get(idx).copied()) + .filter(|&idx| idx >= 0) + .unwrap_or(res.default_index); + if resolved >= 0 { + indices.push(resolved); + } else { + return Err(ArrowError::ParseError(format!( + "Enum symbol index {raw} not resolvable and no default provided", + ))); + } + } + Self::Duration(builder) => { + let b = buf.get_fixed(12)?; + let months = u32::from_le_bytes(b[0..4].try_into().unwrap()); + let days = u32::from_le_bytes(b[4..8].try_into().unwrap()); + let millis = u32::from_le_bytes(b[8..12].try_into().unwrap()); + let nanos = (millis as i64) * 1_000_000; + builder.append_value(IntervalMonthDayNano::new(months as i32, days as i32, nanos)); + } + #[cfg(feature = "avro_custom_types")] + Self::RunEndEncoded(_, len, inner) => { + *len += 1; + inner.decode(buf)?; + } + Self::Union(u) => u.decode(buf)?, + Self::Nullable(order, nb, encoding, plan) => { + match *plan { + NullablePlan::FromSingle { promotion } => { + encoding.decode_with_promotion(buf, promotion)?; + nb.append(true); + } + NullablePlan::ReadTag => { + let branch = buf.read_vlq()?; + let is_not_null = match *order { + Nullability::NullFirst => branch != 0, + Nullability::NullSecond => branch == 0, + }; + if is_not_null { + // It is important to decode before appending to null buffer in case of decode error + encoding.decode(buf)?; + } else { + encoding.append_null()?; + } + nb.append(is_not_null); + } } } } Ok(()) } + fn decode_with_promotion( + &mut self, + buf: &mut AvroCursor<'_>, + promotion: Promotion, + ) -> Result<(), ArrowError> { + #[cfg(feature = "avro_custom_types")] + if let Self::RunEndEncoded(_, len, inner) = self { + *len += 1; + return inner.decode_with_promotion(buf, promotion); + } + + macro_rules! promote_numeric_to { + ($variant:ident, $getter:ident, $to:ty) => {{ + match self { + Self::$variant(v) => { + let x = buf.$getter()?; + v.push(x as $to); + Ok(()) + } + other => Err(ArrowError::ParseError(format!( + "Promotion {promotion} target mismatch: expected {}, got {}", + stringify!($variant), + >::as_ref(other) + ))), + } + }}; + } + match promotion { + Promotion::Direct => self.decode(buf), + Promotion::IntToLong => promote_numeric_to!(Int64, get_int, i64), + Promotion::IntToFloat => promote_numeric_to!(Float32, get_int, f32), + Promotion::IntToDouble => promote_numeric_to!(Float64, get_int, f64), + Promotion::LongToFloat => promote_numeric_to!(Float32, get_long, f32), + Promotion::LongToDouble => promote_numeric_to!(Float64, get_long, f64), + Promotion::FloatToDouble => promote_numeric_to!(Float64, get_float, f64), + Promotion::StringToBytes => match self { + Self::Binary(offsets, values) | Self::StringToBytes(offsets, values) => { + let data = buf.get_bytes()?; + offsets.push_length(data.len()); + values.extend_from_slice(data); + Ok(()) + } + other => Err(ArrowError::ParseError(format!( + "Promotion {promotion} target mismatch: expected bytes (Binary/StringToBytes), got {}", + >::as_ref(other) + ))), + }, + Promotion::BytesToString => match self { + Self::String(offsets, values) + | Self::StringView(offsets, values) + | Self::BytesToString(offsets, values) => { + let data = buf.get_bytes()?; + offsets.push_length(data.len()); + values.extend_from_slice(data); + Ok(()) + } + other => Err(ArrowError::ParseError(format!( + "Promotion {promotion} target mismatch: expected string (String/StringView/BytesToString), got {}", + >::as_ref(other) + ))), + }, + } + } + /// Flush decoded records to an [`ArrayRef`] fn flush(&mut self, nulls: Option) -> Result { Ok(match self { - Self::Nullable(_, n, e) => e.flush(n.finish())?, + Self::Nullable(_, n, e, _) => e.flush(n.finish())?, Self::Null(size) => Arc::new(NullArray::new(std::mem::replace(size, 0))), Self::Boolean(b) => Arc::new(BooleanArray::new(b.finish(), nulls)), Self::Int32(values) => Arc::new(flush_primitive::(values, nulls)), @@ -309,23 +1077,51 @@ impl Decoder { flush_primitive::(values, nulls) .with_timezone_opt(is_utc.then(|| "+00:00")), ), + Self::TimestampNanos(is_utc, values) => Arc::new( + flush_primitive::(values, nulls) + .with_timezone_opt(is_utc.then(|| "+00:00")), + ), + #[cfg(feature = "avro_custom_types")] + Self::DurationSecond(values) => { + Arc::new(flush_primitive::(values, nulls)) + } + #[cfg(feature = "avro_custom_types")] + Self::DurationMillisecond(values) => { + Arc::new(flush_primitive::(values, nulls)) + } + #[cfg(feature = "avro_custom_types")] + Self::DurationMicrosecond(values) => { + Arc::new(flush_primitive::(values, nulls)) + } + #[cfg(feature = "avro_custom_types")] + Self::DurationNanosecond(values) => { + Arc::new(flush_primitive::(values, nulls)) + } Self::Float32(values) => Arc::new(flush_primitive::(values, nulls)), Self::Float64(values) => Arc::new(flush_primitive::(values, nulls)), - Self::Binary(offsets, values) => { + Self::Int32ToInt64(values) => Arc::new(flush_primitive::(values, nulls)), + Self::Int32ToFloat32(values) | Self::Int64ToFloat32(values) => { + Arc::new(flush_primitive::(values, nulls)) + } + Self::Int32ToFloat64(values) + | Self::Int64ToFloat64(values) + | Self::Float32ToFloat64(values) => { + Arc::new(flush_primitive::(values, nulls)) + } + Self::StringToBytes(offsets, values) | Self::Binary(offsets, values) => { let offsets = flush_offsets(offsets); let values = flush_values(values).into(); - Arc::new(BinaryArray::new(offsets, values, nulls)) + Arc::new(BinaryArray::try_new(offsets, values, nulls)?) } - Self::String(offsets, values) => { + Self::BytesToString(offsets, values) | Self::String(offsets, values) => { let offsets = flush_offsets(offsets); let values = flush_values(values).into(); - Arc::new(StringArray::new(offsets, values, nulls)) + Arc::new(StringArray::try_new(offsets, values, nulls)?) } Self::StringView(offsets, values) => { let offsets = flush_offsets(offsets); let values = flush_values(values); - let array = StringArray::new(offsets, values.into(), nulls.clone()); - + let array = StringArray::try_new(offsets, values.into(), nulls.clone())?; let values: Vec<&str> = (0..array.len()) .map(|i| { if array.is_valid(i) { @@ -335,27 +1131,26 @@ impl Decoder { } }) .collect(); - Arc::new(StringViewArray::from(values)) } Self::Array(field, offsets, values) => { let values = values.flush(None)?; let offsets = flush_offsets(offsets); - Arc::new(ListArray::new(field.clone(), offsets, values, nulls)) + Arc::new(ListArray::try_new(field.clone(), offsets, values, nulls)?) } - Self::Record(fields, encodings) => { + Self::Record(fields, encodings, _) => { let arrays = encodings .iter_mut() .map(|x| x.flush(None)) .collect::, _>>()?; - Arc::new(StructArray::new(fields.clone(), arrays, nulls)) + Arc::new(StructArray::try_new(fields.clone(), arrays, nulls)?) } Self::Map(map_field, k_off, m_off, kdata, valdec) => { let moff = flush_offsets(m_off); let koff = flush_offsets(k_off); let kd = flush_values(kdata).into(); let val_arr = valdec.flush(None)?; - let key_arr = StringArray::new(koff, kd, None); + let key_arr = StringArray::try_new(koff, kd, None)?; if key_arr.len() != val_arr.len() { return Err(ArrowError::InvalidArgumentError(format!( "Map keys length ({}) != map values length ({})", @@ -372,94 +1167,1000 @@ impl Decoder { ))); } } - let entries_struct = StructArray::new( - Fields::from(vec![ - Arc::new(ArrowField::new("key", DataType::Utf8, false)), - Arc::new(ArrowField::new("value", val_arr.data_type().clone(), true)), - ]), - vec![Arc::new(key_arr), val_arr], - None, - ); - let map_arr = MapArray::new(map_field.clone(), moff, entries_struct, nulls, false); + let entries_fields = match map_field.data_type() { + DataType::Struct(fields) => fields.clone(), + other => { + return Err(ArrowError::InvalidArgumentError(format!( + "Map entries field must be a Struct, got {other:?}" + ))); + } + }; + let entries_struct = + StructArray::try_new(entries_fields, vec![Arc::new(key_arr), val_arr], None)?; + let map_arr = + MapArray::try_new(map_field.clone(), moff, entries_struct, nulls, false)?; Arc::new(map_arr) } + Self::Fixed(sz, accum) => { + let b: Buffer = flush_values(accum).into(); + let arr = FixedSizeBinaryArray::try_new(*sz, b, nulls) + .map_err(|e| ArrowError::ParseError(e.to_string()))?; + Arc::new(arr) + } + Self::Uuid(values) => { + let arr = FixedSizeBinaryArray::try_new(16, std::mem::take(values).into(), nulls) + .map_err(|e| ArrowError::ParseError(e.to_string()))?; + Arc::new(arr) + } + #[cfg(feature = "small_decimals")] + Self::Decimal32(precision, scale, _, builder) => { + flush_decimal!(builder, precision, scale, nulls, Decimal32Array) + } + #[cfg(feature = "small_decimals")] + Self::Decimal64(precision, scale, _, builder) => { + flush_decimal!(builder, precision, scale, nulls, Decimal64Array) + } + Self::Decimal128(precision, scale, _, builder) => { + flush_decimal!(builder, precision, scale, nulls, Decimal128Array) + } + Self::Decimal256(precision, scale, _, builder) => { + flush_decimal!(builder, precision, scale, nulls, Decimal256Array) + } + Self::Enum(indices, symbols, _) => flush_dict(indices, symbols, nulls)?, + Self::Duration(builder) => { + let (_, vals, _) = builder.finish().into_parts(); + let vals = IntervalMonthDayNanoArray::try_new(vals, nulls) + .map_err(|e| ArrowError::ParseError(e.to_string()))?; + Arc::new(vals) + } + #[cfg(feature = "avro_custom_types")] + Self::RunEndEncoded(width, len, inner) => { + let values = inner.flush(nulls)?; + let n = *len; + let arr = values.as_ref(); + let mut run_starts: Vec = Vec::with_capacity(n); + if n > 0 { + run_starts.push(0); + for i in 1..n { + if !values_equal_at(arr, i - 1, i) { + run_starts.push(i); + } + } + } + if n > (u32::MAX as usize) { + return Err(ArrowError::InvalidArgumentError(format!( + "RunEndEncoded length {n} exceeds maximum supported by UInt32 indices for take", + ))); + } + let run_count = run_starts.len(); + let take_idx: PrimitiveArray = + run_starts.iter().map(|&s| s as u32).collect(); + let per_run_values = if run_count == 0 { + values.slice(0, 0) + } else { + take(arr, &take_idx, Option::from(TakeOptions::default())).map_err(|e| { + ArrowError::ParseError(format!("take() for REE values failed: {e}")) + })? + }; + + macro_rules! build_run_array { + ($Native:ty, $ArrowTy:ty) => {{ + let mut ends: Vec<$Native> = Vec::with_capacity(run_count); + for (idx, &_start) in run_starts.iter().enumerate() { + let end = if idx + 1 < run_count { + run_starts[idx + 1] + } else { + n + }; + ends.push(end as $Native); + } + let ends: PrimitiveArray<$ArrowTy> = ends.into_iter().collect(); + let run_arr = RunArray::<$ArrowTy>::try_new(&ends, per_run_values.as_ref()) + .map_err(|e| ArrowError::ParseError(e.to_string()))?; + Arc::new(run_arr) as ArrayRef + }}; + } + match *width { + 2 => { + if n > i16::MAX as usize { + return Err(ArrowError::InvalidArgumentError(format!( + "RunEndEncoded length {n} exceeds i16::MAX for run end width 2" + ))); + } + build_run_array!(i16, Int16Type) + } + 4 => build_run_array!(i32, Int32Type), + 8 => build_run_array!(i64, Int64Type), + other => { + return Err(ArrowError::InvalidArgumentError(format!( + "Unsupported run-end width {other} for RunEndEncoded" + ))); + } + } + } + Self::Union(u) => u.flush(nulls)?, }) } } -fn read_blocks( - buf: &mut AvroCursor, - decode_entry: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, -) -> Result { - read_blockwise_items(buf, true, decode_entry) +// A lookup table for resolving fields between writer and reader schemas during record projection. +#[derive(Debug)] +struct DispatchLookupTable { + // Maps each reader field index `r` to the corresponding writer field index. + // + // Semantics: + // - `to_reader[r] >= 0`: The value is an index into the writer's fields. The value from + // the writer field is decoded, and `promotion[r]` is applied. + // - `to_reader[r] == NO_SOURCE` (-1): No matching writer field exists. The reader field's + // default value is used. + // + // Representation (`i8`): + // `i8` is used for a dense, cache-friendly dispatch table, consistent with Arrow's use of + // `i8` for union type IDs. This requires that writer field indices do not exceed `i8::MAX`. + // + // Invariants: + // - `to_reader.len() == promotion.len()` and matches the reader field count. + // - If `to_reader[r] == NO_SOURCE`, `promotion[r]` is ignored. + to_reader: Box<[i8]>, + // For each reader field `r`, specifies the `Promotion` to apply to the writer's value. + // + // This is used when a writer field's type can be promoted to a reader field's type + // (e.g., `Int` to `Long`). It is ignored if `to_reader[r] == NO_SOURCE`. + promotion: Box<[Promotion]>, } -fn read_blockwise_items( - buf: &mut AvroCursor, - read_size_after_negative: bool, - mut decode_fn: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, -) -> Result { - let mut total = 0usize; - loop { - // Read the block count - // positive = that many items - // negative = that many items + read block size - // See: https://avro.apache.org/docs/1.11.1/specification/#maps - let block_count = buf.get_long()?; - match block_count.cmp(&0) { - Ordering::Equal => break, - Ordering::Less => { - // If block_count is negative, read the absolute value of count, - // then read the block size as a long and discard - let count = (-block_count) as usize; - if read_size_after_negative { - let _size_in_bytes = buf.get_long()?; - } - for _ in 0..count { - decode_fn(buf)?; +// Sentinel used in `DispatchLookupTable::to_reader` to mark +// "no matching writer field". +const NO_SOURCE: i8 = -1; + +impl DispatchLookupTable { + fn from_writer_to_reader( + promotion_map: &[Option<(usize, Promotion)>], + ) -> Result { + let mut to_reader = Vec::with_capacity(promotion_map.len()); + let mut promotion = Vec::with_capacity(promotion_map.len()); + for map in promotion_map { + match *map { + Some((idx, promo)) => { + let idx_i8 = i8::try_from(idx).map_err(|_| { + ArrowError::SchemaError(format!( + "Reader branch index {idx} exceeds i8 range (max {})", + i8::MAX + )) + })?; + to_reader.push(idx_i8); + promotion.push(promo); } - total += count; - } - Ordering::Greater => { - // If block_count is positive, decode that many items - let count = block_count as usize; - for _i in 0..count { - decode_fn(buf)?; + None => { + to_reader.push(NO_SOURCE); + promotion.push(Promotion::Direct); } - total += count; } } + Ok(Self { + to_reader: to_reader.into_boxed_slice(), + promotion: promotion.into_boxed_slice(), + }) } - Ok(total) -} -#[inline] -fn flush_values(values: &mut Vec) -> Vec { - std::mem::replace(values, Vec::with_capacity(DEFAULT_CAPACITY)) + // Resolve a writer branch index to (reader_idx, promotion) + #[inline] + fn resolve(&self, writer_index: usize) -> Option<(usize, Promotion)> { + let reader_index = *self.to_reader.get(writer_index)?; + (reader_index >= 0).then(|| (reader_index as usize, self.promotion[writer_index])) + } } -#[inline] -fn flush_offsets(offsets: &mut OffsetBufferBuilder) -> OffsetBuffer { - std::mem::replace(offsets, OffsetBufferBuilder::new(DEFAULT_CAPACITY)).finish() +#[derive(Debug)] +struct UnionDecoder { + fields: UnionFields, + type_ids: Vec, + offsets: Vec, + branches: Vec, + counts: Vec, + reader_type_codes: Vec, + default_emit_idx: usize, + null_emit_idx: usize, + plan: UnionReadPlan, } -#[inline] -fn flush_primitive( - values: &mut Vec, - nulls: Option, -) -> PrimitiveArray { - PrimitiveArray::new(flush_values(values).into(), nulls) +impl Default for UnionDecoder { + fn default() -> Self { + Self { + fields: UnionFields::empty(), + type_ids: Vec::new(), + offsets: Vec::new(), + branches: Vec::new(), + counts: Vec::new(), + reader_type_codes: Vec::new(), + default_emit_idx: 0, + null_emit_idx: 0, + plan: UnionReadPlan::Passthrough, + } + } } -const DEFAULT_CAPACITY: usize = 1024; +#[derive(Debug)] +enum UnionReadPlan { + ReaderUnion { + lookup_table: DispatchLookupTable, + }, + FromSingle { + reader_idx: usize, + promotion: Promotion, + }, + ToSingle { + target: Box, + lookup_table: DispatchLookupTable, + }, + Passthrough, +} -#[cfg(test)] +impl UnionDecoder { + fn try_new( + fields: UnionFields, + branches: Vec, + resolved: Option, + ) -> Result { + let reader_type_codes = fields.iter().map(|(tid, _)| tid).collect::>(); + let null_branch = branches.iter().position(|b| matches!(b, Decoder::Null(_))); + let default_emit_idx = 0; + let null_emit_idx = null_branch.unwrap_or(default_emit_idx); + let branch_len = branches.len().max(reader_type_codes.len()); + // Guard against impractically large unions that cannot be indexed by an Avro int + let max_addr = (i32::MAX as usize) + 1; + if branches.len() > max_addr { + return Err(ArrowError::SchemaError(format!( + "Reader union has {} branches, which exceeds the maximum addressable \ + branches by an Avro int tag ({} + 1).", + branches.len(), + i32::MAX + ))); + } + Ok(Self { + fields, + type_ids: Vec::with_capacity(DEFAULT_CAPACITY), + offsets: Vec::with_capacity(DEFAULT_CAPACITY), + branches, + counts: vec![0; branch_len], + reader_type_codes, + default_emit_idx, + null_emit_idx, + plan: Self::plan_from_resolved(resolved)?, + }) + } + + fn try_new_from_writer_union( + info: ResolvedUnion, + target: Box, + ) -> Result { + // This constructor is only for writer-union to single-type resolution + debug_assert!(info.writer_is_union && !info.reader_is_union); + let lookup_table = DispatchLookupTable::from_writer_to_reader(&info.writer_to_reader)?; + Ok(Self { + plan: UnionReadPlan::ToSingle { + target, + lookup_table, + }, + ..Self::default() + }) + } + + fn plan_from_resolved(resolved: Option) -> Result { + let Some(info) = resolved else { + return Ok(UnionReadPlan::Passthrough); + }; + match (info.writer_is_union, info.reader_is_union) { + (true, true) => { + let lookup_table = + DispatchLookupTable::from_writer_to_reader(&info.writer_to_reader)?; + Ok(UnionReadPlan::ReaderUnion { lookup_table }) + } + (false, true) => { + let Some(&(reader_idx, promotion)) = + info.writer_to_reader.first().and_then(Option::as_ref) + else { + return Err(ArrowError::SchemaError( + "Writer type does not match any reader union branch".to_string(), + )); + }; + Ok(UnionReadPlan::FromSingle { + reader_idx, + promotion, + }) + } + (true, false) => Err(ArrowError::InvalidArgumentError( + "UnionDecoder::try_new cannot build writer-union to single; use UnionDecoderBuilder with a target" + .to_string(), + )), + // (false, false) is invalid and should never be constructed by the resolver. + _ => Err(ArrowError::SchemaError( + "ResolvedUnion constructed for non-union sides; resolver should return None" + .to_string(), + )), + } + } + + #[inline] + fn read_tag(buf: &mut AvroCursor<'_>) -> Result { + // Avro unions are encoded by first writing the zero-based branch index. + // In Avro 1.11.1 this is specified as an *int*; older specs said *long*, + // but both use zig-zag varint encoding, so decoding as long is compatible + // with either form and widely used in practice. + let raw = buf.get_long()?; + if raw < 0 { + return Err(ArrowError::ParseError(format!( + "Negative union branch index {raw}" + ))); + } + usize::try_from(raw).map_err(|_| { + ArrowError::ParseError(format!( + "Union branch index {raw} does not fit into usize on this platform ({}-bit)", + (usize::BITS as usize) + )) + }) + } + + #[inline] + fn emit_to(&mut self, reader_idx: usize) -> Result<&mut Decoder, ArrowError> { + let branches_len = self.branches.len(); + let Some(reader_branch) = self.branches.get_mut(reader_idx) else { + return Err(ArrowError::ParseError(format!( + "Union branch index {reader_idx} out of range ({branches_len} branches)" + ))); + }; + self.type_ids.push(self.reader_type_codes[reader_idx]); + self.offsets.push(self.counts[reader_idx]); + self.counts[reader_idx] += 1; + Ok(reader_branch) + } + + #[inline] + fn on_decoder(&mut self, fallback_idx: usize, action: F) -> Result<(), ArrowError> + where + F: FnOnce(&mut Decoder) -> Result<(), ArrowError>, + { + if let UnionReadPlan::ToSingle { target, .. } = &mut self.plan { + return action(target); + } + let reader_idx = match &self.plan { + UnionReadPlan::FromSingle { reader_idx, .. } => *reader_idx, + _ => fallback_idx, + }; + self.emit_to(reader_idx).and_then(action) + } + + fn append_null(&mut self) -> Result<(), ArrowError> { + self.on_decoder(self.null_emit_idx, |decoder| decoder.append_null()) + } + + fn append_default(&mut self, lit: &AvroLiteral) -> Result<(), ArrowError> { + self.on_decoder(self.default_emit_idx, |decoder| decoder.append_default(lit)) + } + + fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> { + let (reader_idx, promotion) = match &mut self.plan { + UnionReadPlan::Passthrough => (Self::read_tag(buf)?, Promotion::Direct), + UnionReadPlan::ReaderUnion { lookup_table } => { + let idx = Self::read_tag(buf)?; + lookup_table.resolve(idx).ok_or_else(|| { + ArrowError::ParseError(format!( + "Union branch index {idx} not resolvable by reader schema" + )) + })? + } + UnionReadPlan::FromSingle { + reader_idx, + promotion, + } => (*reader_idx, *promotion), + UnionReadPlan::ToSingle { + target, + lookup_table, + } => { + let idx = Self::read_tag(buf)?; + return match lookup_table.resolve(idx) { + Some((_, promotion)) => target.decode_with_promotion(buf, promotion), + None => Err(ArrowError::ParseError(format!( + "Writer union branch {idx} does not resolve to reader type" + ))), + }; + } + }; + let decoder = self.emit_to(reader_idx)?; + decoder.decode_with_promotion(buf, promotion) + } + + fn flush(&mut self, nulls: Option) -> Result { + if let UnionReadPlan::ToSingle { target, .. } = &mut self.plan { + return target.flush(nulls); + } + debug_assert!( + nulls.is_none(), + "UnionArray does not accept a validity bitmap; \ + nulls should have been materialized as a Null child during decode" + ); + let children = self + .branches + .iter_mut() + .map(|d| d.flush(None)) + .collect::, _>>()?; + let arr = UnionArray::try_new( + self.fields.clone(), + flush_values(&mut self.type_ids).into_iter().collect(), + Some(flush_values(&mut self.offsets).into_iter().collect()), + children, + ) + .map_err(|e| ArrowError::ParseError(e.to_string()))?; + Ok(Arc::new(arr)) + } +} + +#[derive(Debug, Default)] +struct UnionDecoderBuilder { + fields: Option, + branches: Option>, + resolved: Option, + target: Option>, +} + +impl UnionDecoderBuilder { + fn new() -> Self { + Self::default() + } + + fn with_fields(mut self, fields: UnionFields) -> Self { + self.fields = Some(fields); + self + } + + fn with_branches(mut self, branches: Vec) -> Self { + self.branches = Some(branches); + self + } + + fn with_resolved_union(mut self, resolved_union: ResolvedUnion) -> Self { + self.resolved = Some(resolved_union); + self + } + + fn with_target(mut self, target: Box) -> Self { + self.target = Some(target); + self + } + + fn build(self) -> Result { + match (self.resolved, self.fields, self.branches, self.target) { + (resolved, Some(fields), Some(branches), None) => { + UnionDecoder::try_new(fields, branches, resolved) + } + (Some(info), None, None, Some(target)) + if info.writer_is_union && !info.reader_is_union => + { + UnionDecoder::try_new_from_writer_union(info, target) + } + _ => Err(ArrowError::InvalidArgumentError( + "Invalid UnionDecoderBuilder configuration: expected either \ + (fields + branches + resolved) with no target for reader-unions, or \ + (resolved + target) with no fields/branches for writer-union to single." + .to_string(), + )), + } + } +} + +#[derive(Debug, Copy, Clone)] +enum NegativeBlockBehavior { + ProcessItems, + SkipBySize, +} + +#[inline] +fn skip_blocks( + buf: &mut AvroCursor, + mut skip_item: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, +) -> Result { + process_blockwise( + buf, + move |c| skip_item(c), + NegativeBlockBehavior::SkipBySize, + ) +} + +#[inline] +fn flush_dict( + indices: &mut Vec, + symbols: &[String], + nulls: Option, +) -> Result { + let keys = flush_primitive::(indices, nulls); + let values = Arc::new(StringArray::from_iter_values( + symbols.iter().map(|s| s.as_str()), + )); + DictionaryArray::try_new(keys, values) + .map_err(|e| ArrowError::ParseError(e.to_string())) + .map(|arr| Arc::new(arr) as ArrayRef) +} + +#[inline] +fn read_blocks( + buf: &mut AvroCursor, + decode_entry: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, +) -> Result { + process_blockwise(buf, decode_entry, NegativeBlockBehavior::ProcessItems) +} + +#[inline] +fn process_blockwise( + buf: &mut AvroCursor, + mut on_item: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, + negative_behavior: NegativeBlockBehavior, +) -> Result { + let mut total = 0usize; + loop { + // Read the block count + // positive = that many items + // negative = that many items + read block size + // See: https://avro.apache.org/docs/1.11.1/specification/#maps + let block_count = buf.get_long()?; + match block_count.cmp(&0) { + Ordering::Equal => break, + Ordering::Less => { + let count = (-block_count) as usize; + // A negative count is followed by a long of the size in bytes + let size_in_bytes = buf.get_long()? as usize; + match negative_behavior { + NegativeBlockBehavior::ProcessItems => { + // Process items one-by-one after reading size + for _ in 0..count { + on_item(buf)?; + } + } + NegativeBlockBehavior::SkipBySize => { + // Skip the entire block payload at once + let _ = buf.get_fixed(size_in_bytes)?; + } + } + total += count; + } + Ordering::Greater => { + let count = block_count as usize; + for _ in 0..count { + on_item(buf)?; + } + total += count; + } + } + } + Ok(total) +} + +#[inline] +fn flush_values(values: &mut Vec) -> Vec { + std::mem::replace(values, Vec::with_capacity(DEFAULT_CAPACITY)) +} + +#[inline] +fn flush_offsets(offsets: &mut OffsetBufferBuilder) -> OffsetBuffer { + std::mem::replace(offsets, OffsetBufferBuilder::new(DEFAULT_CAPACITY)).finish() +} + +#[inline] +fn flush_primitive( + values: &mut Vec, + nulls: Option, +) -> PrimitiveArray { + PrimitiveArray::new(flush_values(values).into(), nulls) +} + +#[inline] +fn read_decimal_bytes_be( + buf: &mut AvroCursor<'_>, + size: &Option, +) -> Result<[u8; N], ArrowError> { + match size { + Some(n) if *n == N => { + let raw = buf.get_fixed(N)?; + let mut arr = [0u8; N]; + arr.copy_from_slice(raw); + Ok(arr) + } + Some(n) => { + let raw = buf.get_fixed(*n)?; + sign_cast_to::(raw) + } + None => { + let raw = buf.get_bytes()?; + sign_cast_to::(raw) + } + } +} + +/// Sign-extend or (when larger) validate-and-truncate a big-endian two's-complement +/// integer into exactly `N` bytes. This matches Avro's decimal binary encoding: +/// the payload is a big-endian two's-complement integer, and when narrowing it must +/// be representable without changing sign or value. +/// +/// If `raw.len() < N`, the value is sign-extended. +/// If `raw.len() > N`, all truncated leading bytes must match the sign-extension byte +/// and the MSB of the first kept byte must match the sign (to avoid silent overflow). +#[inline] +fn sign_cast_to(raw: &[u8]) -> Result<[u8; N], ArrowError> { + let len = raw.len(); + // Fast path: exact width, just copy + if len == N { + let mut out = [0u8; N]; + out.copy_from_slice(raw); + return Ok(out); + } + // Determine sign byte from MSB of first byte (empty => positive) + let first = raw.first().copied().unwrap_or(0u8); + let sign_byte = if (first & 0x80) == 0 { 0x00 } else { 0xFF }; + // Pre-fill with sign byte to support sign extension + let mut out = [sign_byte; N]; + if len > N { + // Validate truncation: all dropped leading bytes must equal sign_byte, + // and the MSB of the first kept byte must match the sign. + let extra = len - N; + // Any non-sign byte in the truncated prefix indicates overflow + if raw[..extra].iter().any(|&b| b != sign_byte) { + return Err(ArrowError::ParseError(format!( + "Decimal value with {} bytes cannot be represented in {} bytes without overflow", + len, N + ))); + } + if N > 0 { + let first_kept = raw[extra]; + let sign_bit_mismatch = ((first_kept ^ sign_byte) & 0x80) != 0; + if sign_bit_mismatch { + return Err(ArrowError::ParseError(format!( + "Decimal value with {} bytes cannot be represented in {} bytes without overflow", + len, N + ))); + } + } + out.copy_from_slice(&raw[extra..]); + return Ok(out); + } + out[N - len..].copy_from_slice(raw); + Ok(out) +} + +#[cfg(feature = "avro_custom_types")] +#[inline] +fn values_equal_at(arr: &dyn Array, i: usize, j: usize) -> bool { + match (arr.is_null(i), arr.is_null(j)) { + (true, true) => true, + (true, false) | (false, true) => false, + (false, false) => { + let a = arr.slice(i, 1); + let b = arr.slice(j, 1); + a == b + } + } +} + +#[derive(Debug)] +struct Projector { + writer_to_reader: Arc<[Option]>, + skip_decoders: Vec>, + field_defaults: Vec>, + default_injections: Arc<[(usize, AvroLiteral)]>, +} + +#[derive(Debug)] +struct ProjectorBuilder<'a> { + rec: &'a ResolvedRecord, + reader_fields: Arc<[AvroField]>, +} + +impl<'a> ProjectorBuilder<'a> { + #[inline] + fn try_new(rec: &'a ResolvedRecord, reader_fields: &Arc<[AvroField]>) -> Self { + Self { + rec, + reader_fields: reader_fields.clone(), + } + } + + #[inline] + fn build(self) -> Result { + let reader_fields = self.reader_fields; + let mut field_defaults: Vec> = Vec::with_capacity(reader_fields.len()); + for avro_field in reader_fields.as_ref() { + if let Some(ResolutionInfo::DefaultValue(lit)) = + avro_field.data_type().resolution.as_ref() + { + field_defaults.push(Some(lit.clone())); + } else { + field_defaults.push(None); + } + } + let mut default_injections: Vec<(usize, AvroLiteral)> = + Vec::with_capacity(self.rec.default_fields.len()); + for &idx in self.rec.default_fields.as_ref() { + let lit = field_defaults + .get(idx) + .and_then(|lit| lit.clone()) + .unwrap_or(AvroLiteral::Null); + default_injections.push((idx, lit)); + } + let mut skip_decoders: Vec> = + Vec::with_capacity(self.rec.skip_fields.len()); + for datatype in self.rec.skip_fields.as_ref() { + let skipper = match datatype { + Some(datatype) => Some(Skipper::from_avro(datatype)?), + None => None, + }; + skip_decoders.push(skipper); + } + Ok(Projector { + writer_to_reader: self.rec.writer_to_reader.clone(), + skip_decoders, + field_defaults, + default_injections: default_injections.into(), + }) + } +} + +impl Projector { + #[inline] + fn project_default(&self, decoder: &mut Decoder, index: usize) -> Result<(), ArrowError> { + // SAFETY: `index` is obtained by listing the reader's record fields (i.e., from + // `decoders.iter_mut().enumerate()`), and `field_defaults` was built in + // `ProjectorBuilder::build` to have exactly one element per reader field. + // Therefore, `index < self.field_defaults.len()` always holds here, so + // `self.field_defaults[index]` cannot panic. We only take an immutable reference + // via `.as_ref()`, and `self` is borrowed immutably. + if let Some(default_literal) = self.field_defaults[index].as_ref() { + decoder.append_default(default_literal) + } else { + decoder.append_null() + } + } + + #[inline] + fn project_record( + &mut self, + buf: &mut AvroCursor<'_>, + encodings: &mut [Decoder], + ) -> Result<(), ArrowError> { + debug_assert_eq!( + self.writer_to_reader.len(), + self.skip_decoders.len(), + "internal invariant: mapping and skipper lists must have equal length" + ); + for (i, (mapping, skipper_opt)) in self + .writer_to_reader + .iter() + .zip(self.skip_decoders.iter_mut()) + .enumerate() + { + match (mapping, skipper_opt.as_mut()) { + (Some(reader_index), _) => encodings[*reader_index].decode(buf)?, + (None, Some(skipper)) => skipper.skip(buf)?, + (None, None) => { + return Err(ArrowError::SchemaError(format!( + "No skipper available for writer-only field at index {i}", + ))); + } + } + } + for (reader_index, lit) in self.default_injections.as_ref() { + encodings[*reader_index].append_default(lit)?; + } + Ok(()) + } +} + +/// Lightweight skipper for non‑projected writer fields +/// (fields present in the writer schema but omitted by the reader/projection); +/// per Avro 1.11.1 schema resolution these fields are ignored. +/// +/// +#[derive(Debug)] +enum Skipper { + Null, + Boolean, + Int32, + Int64, + Float32, + Float64, + Bytes, + String, + TimeMicros, + TimestampMillis, + TimestampMicros, + TimestampNanos, + Fixed(usize), + Decimal(Option), + UuidString, + Enum, + DurationFixed12, + List(Box), + Map(Box), + Struct(Vec), + Union(Vec), + Nullable(Nullability, Box), + #[cfg(feature = "avro_custom_types")] + RunEndEncoded(Box), +} + +impl Skipper { + fn from_avro(dt: &AvroDataType) -> Result { + let mut base = match dt.codec() { + Codec::Null => Self::Null, + Codec::Boolean => Self::Boolean, + Codec::Int32 | Codec::Date32 | Codec::TimeMillis => Self::Int32, + Codec::Int64 => Self::Int64, + Codec::TimeMicros => Self::TimeMicros, + Codec::TimestampMillis(_) => Self::TimestampMillis, + Codec::TimestampMicros(_) => Self::TimestampMicros, + Codec::TimestampNanos(_) => Self::TimestampNanos, + #[cfg(feature = "avro_custom_types")] + Codec::DurationNanos + | Codec::DurationMicros + | Codec::DurationMillis + | Codec::DurationSeconds => Self::Int64, + Codec::Float32 => Self::Float32, + Codec::Float64 => Self::Float64, + Codec::Binary => Self::Bytes, + Codec::Utf8 | Codec::Utf8View => Self::String, + Codec::Fixed(sz) => Self::Fixed(*sz as usize), + Codec::Decimal(_, _, size) => Self::Decimal(*size), + Codec::Uuid => Self::UuidString, // encoded as string + Codec::Enum(_) => Self::Enum, + Codec::List(item) => Self::List(Box::new(Skipper::from_avro(item)?)), + Codec::Struct(fields) => Self::Struct( + fields + .iter() + .map(|f| Skipper::from_avro(f.data_type())) + .collect::>()?, + ), + Codec::Map(values) => Self::Map(Box::new(Skipper::from_avro(values)?)), + Codec::Interval => Self::DurationFixed12, + Codec::Union(encodings, _, _) => { + let max_addr = (i32::MAX as usize) + 1; + if encodings.len() > max_addr { + return Err(ArrowError::SchemaError(format!( + "Writer union has {} branches, which exceeds the maximum addressable \ + branches by an Avro int tag ({} + 1).", + encodings.len(), + i32::MAX + ))); + } + Self::Union( + encodings + .iter() + .map(Skipper::from_avro) + .collect::>()?, + ) + } + #[cfg(feature = "avro_custom_types")] + Codec::RunEndEncoded(inner, _w) => { + Self::RunEndEncoded(Box::new(Skipper::from_avro(inner)?)) + } + }; + if let Some(n) = dt.nullability() { + base = Self::Nullable(n, Box::new(base)); + } + Ok(base) + } + + fn skip(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> { + match self { + Self::Null => Ok(()), + Self::Boolean => { + buf.get_bool()?; + Ok(()) + } + Self::Int32 => { + buf.get_int()?; + Ok(()) + } + Self::Int64 + | Self::TimeMicros + | Self::TimestampMillis + | Self::TimestampMicros + | Self::TimestampNanos => { + buf.get_long()?; + Ok(()) + } + Self::Float32 => { + buf.get_float()?; + Ok(()) + } + Self::Float64 => { + buf.get_double()?; + Ok(()) + } + Self::Bytes | Self::String | Self::UuidString => { + buf.get_bytes()?; + Ok(()) + } + Self::Fixed(sz) => { + buf.get_fixed(*sz)?; + Ok(()) + } + Self::Decimal(size) => { + if let Some(s) = size { + buf.get_fixed(*s) + } else { + buf.get_bytes() + }?; + Ok(()) + } + Self::Enum => { + buf.get_int()?; + Ok(()) + } + Self::DurationFixed12 => { + buf.get_fixed(12)?; + Ok(()) + } + Self::List(item) => { + skip_blocks(buf, |c| item.skip(c))?; + Ok(()) + } + Self::Map(value) => { + skip_blocks(buf, |c| { + c.get_bytes()?; // key + value.skip(c) + })?; + Ok(()) + } + Self::Struct(fields) => { + for f in fields.iter_mut() { + f.skip(buf)? + } + Ok(()) + } + Self::Union(encodings) => { + // Union tag must be ZigZag-decoded + let raw = buf.get_long()?; + if raw < 0 { + return Err(ArrowError::ParseError(format!( + "Negative union branch index {raw}" + ))); + } + let idx: usize = usize::try_from(raw).map_err(|_| { + ArrowError::ParseError(format!( + "Union branch index {raw} does not fit into usize on this platform ({}-bit)", + (usize::BITS as usize) + )) + })?; + let Some(encoding) = encodings.get_mut(idx) else { + return Err(ArrowError::ParseError(format!( + "Union branch index {idx} out of range for skipper ({} branches)", + encodings.len() + ))); + }; + encoding.skip(buf) + } + Self::Nullable(order, inner) => { + let branch = buf.read_vlq()?; + let is_not_null = match *order { + Nullability::NullFirst => branch != 0, + Nullability::NullSecond => branch == 0, + }; + if is_not_null { + inner.skip(buf)?; + } + Ok(()) + } + #[cfg(feature = "avro_custom_types")] + Self::RunEndEncoded(inner) => inner.skip(buf), + } + } +} + +#[cfg(test)] mod tests { use super::*; - use arrow_array::{ - cast::AsArray, Array, Decimal128Array, DictionaryArray, FixedSizeBinaryArray, - IntervalMonthDayNanoArray, ListArray, MapArray, StringArray, StructArray, - }; + use crate::codec::AvroFieldBuilder; + use crate::schema::{Attributes, ComplexType, Field, PrimitiveType, Record, Schema, TypeName}; + use arrow_array::cast::AsArray; + use indexmap::IndexMap; + use std::collections::HashMap; fn encode_avro_int(value: i32) -> Vec { let mut buf = Vec::new(); @@ -493,40 +2194,512 @@ mod tests { AvroDataType::new(codec, Default::default(), None) } - #[test] - fn test_map_decoding_one_entry() { - let value_type = avro_from_codec(Codec::Utf8); - let map_type = avro_from_codec(Codec::Map(Arc::new(value_type))); - let mut decoder = Decoder::try_new(&map_type).unwrap(); - // Encode a single map with one entry: {"hello": "world"} - let mut data = Vec::new(); - data.extend_from_slice(&encode_avro_long(1)); - data.extend_from_slice(&encode_avro_bytes(b"hello")); // key - data.extend_from_slice(&encode_avro_bytes(b"world")); // value - data.extend_from_slice(&encode_avro_long(0)); - let mut cursor = AvroCursor::new(&data); - decoder.decode(&mut cursor).unwrap(); - let array = decoder.flush(None).unwrap(); - let map_arr = array.as_any().downcast_ref::().unwrap(); - assert_eq!(map_arr.len(), 1); // one map - assert_eq!(map_arr.value_length(0), 1); - let entries = map_arr.value(0); - let struct_entries = entries.as_any().downcast_ref::().unwrap(); - assert_eq!(struct_entries.len(), 1); - let key_arr = struct_entries - .column_by_name("key") - .unwrap() - .as_any() - .downcast_ref::() - .unwrap(); - let val_arr = struct_entries - .column_by_name("value") - .unwrap() - .as_any() - .downcast_ref::() - .unwrap(); - assert_eq!(key_arr.value(0), "hello"); - assert_eq!(val_arr.value(0), "world"); + fn resolved_root_datatype( + writer: Schema<'static>, + reader: Schema<'static>, + use_utf8view: bool, + strict_mode: bool, + ) -> AvroDataType { + // Wrap writer schema in a single-field record + let writer_record = Schema::Complex(ComplexType::Record(Record { + name: "Root", + namespace: None, + doc: None, + aliases: vec![], + fields: vec![Field { + name: "v", + r#type: writer, + default: None, + doc: None, + aliases: vec![], + }], + attributes: Attributes::default(), + })); + + // Wrap reader schema in a single-field record + let reader_record = Schema::Complex(ComplexType::Record(Record { + name: "Root", + namespace: None, + doc: None, + aliases: vec![], + fields: vec![Field { + name: "v", + r#type: reader, + default: None, + doc: None, + aliases: vec![], + }], + attributes: Attributes::default(), + })); + + // Build resolved record, then extract the inner field's resolved AvroDataType + let field = AvroFieldBuilder::new(&writer_record) + .with_reader_schema(&reader_record) + .with_utf8view(use_utf8view) + .with_strict_mode(strict_mode) + .build() + .expect("schema resolution should succeed"); + + match field.data_type().codec() { + Codec::Struct(fields) => fields[0].data_type().clone(), + other => panic!("expected wrapper struct, got {other:?}"), + } + } + + fn decoder_for_promotion( + writer: PrimitiveType, + reader: PrimitiveType, + use_utf8view: bool, + ) -> Decoder { + let ws = Schema::TypeName(TypeName::Primitive(writer)); + let rs = Schema::TypeName(TypeName::Primitive(reader)); + let dt = resolved_root_datatype(ws, rs, use_utf8view, false); + Decoder::try_new(&dt).unwrap() + } + + fn make_avro_dt(codec: Codec, nullability: Option) -> AvroDataType { + AvroDataType::new(codec, HashMap::new(), nullability) + } + + #[cfg(feature = "avro_custom_types")] + fn encode_vlq_u64(mut x: u64) -> Vec { + let mut out = Vec::with_capacity(10); + while x >= 0x80 { + out.push((x as u8) | 0x80); + x >>= 7; + } + out.push(x as u8); + out + } + + #[test] + fn test_union_resolution_writer_union_reader_union_reorder_and_promotion_dense() { + let ws = Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), + ]); + let rs = Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)), + ]); + + let dt = resolved_root_datatype(ws, rs, false, false); + let mut dec = Decoder::try_new(&dt).unwrap(); + + let mut rec1 = encode_avro_long(0); + rec1.extend(encode_avro_int(7)); + let mut cur1 = AvroCursor::new(&rec1); + dec.decode(&mut cur1).unwrap(); + + let mut rec2 = encode_avro_long(1); + rec2.extend(encode_avro_bytes("abc".as_bytes())); + let mut cur2 = AvroCursor::new(&rec2); + dec.decode(&mut cur2).unwrap(); + + let arr = dec.flush(None).unwrap(); + let ua = arr + .as_any() + .downcast_ref::() + .expect("dense union output"); + + assert_eq!( + ua.type_id(0), + 1, + "first value must select reader 'long' branch" + ); + assert_eq!(ua.value_offset(0), 0); + + assert_eq!( + ua.type_id(1), + 0, + "second value must select reader 'string' branch" + ); + assert_eq!(ua.value_offset(1), 0); + + let long_child = ua.child(1).as_any().downcast_ref::().unwrap(); + assert_eq!(long_child.len(), 1); + assert_eq!(long_child.value(0), 7); + + let str_child = ua.child(0).as_any().downcast_ref::().unwrap(); + assert_eq!(str_child.len(), 1); + assert_eq!(str_child.value(0), "abc"); + } + + #[test] + fn test_union_resolution_writer_union_reader_nonunion_promotion_int_to_long() { + let ws = Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), + ]); + let rs = Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)); + + let dt = resolved_root_datatype(ws, rs, false, false); + let mut dec = Decoder::try_new(&dt).unwrap(); + + let mut data = encode_avro_long(0); + data.extend(encode_avro_int(5)); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + + let arr = dec.flush(None).unwrap(); + let out = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(out.len(), 1); + assert_eq!(out.value(0), 5); + } + + #[test] + fn test_union_resolution_writer_union_reader_nonunion_mismatch_errors() { + let ws = Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), + ]); + let rs = Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)); + + let dt = resolved_root_datatype(ws, rs, false, false); + let mut dec = Decoder::try_new(&dt).unwrap(); + + let mut data = encode_avro_long(1); + data.extend(encode_avro_bytes("z".as_bytes())); + let mut cur = AvroCursor::new(&data); + let res = dec.decode(&mut cur); + assert!( + res.is_err(), + "expected error when writer union branch does not resolve to reader non-union type" + ); + } + + #[test] + fn test_union_resolution_writer_nonunion_reader_union_selects_matching_branch() { + let ws = Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)); + let rs = Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)), + ]); + + let dt = resolved_root_datatype(ws, rs, false, false); + let mut dec = Decoder::try_new(&dt).unwrap(); + + let data = encode_avro_int(6); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + + let arr = dec.flush(None).unwrap(); + let ua = arr + .as_any() + .downcast_ref::() + .expect("dense union output"); + assert_eq!(ua.len(), 1); + assert_eq!( + ua.type_id(0), + 1, + "must resolve to reader 'long' branch (type_id 1)" + ); + assert_eq!(ua.value_offset(0), 0); + + let long_child = ua.child(1).as_any().downcast_ref::().unwrap(); + assert_eq!(long_child.len(), 1); + assert_eq!(long_child.value(0), 6); + + let str_child = ua.child(0).as_any().downcast_ref::().unwrap(); + assert_eq!(str_child.len(), 0, "string branch must be empty"); + } + + #[test] + fn test_union_resolution_writer_union_reader_union_unmapped_branch_errors() { + let ws = Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Boolean)), + ]); + let rs = Schema::Union(vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), + Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)), + ]); + + let dt = resolved_root_datatype(ws, rs, false, false); + let mut dec = Decoder::try_new(&dt).unwrap(); + + let mut data = encode_avro_long(1); + data.push(1); + let mut cur = AvroCursor::new(&data); + let res = dec.decode(&mut cur); + assert!( + res.is_err(), + "expected error for unmapped writer 'boolean' branch" + ); + } + + #[test] + fn test_schema_resolution_promotion_int_to_long() { + let mut dec = decoder_for_promotion(PrimitiveType::Int, PrimitiveType::Long, false); + assert!(matches!(dec, Decoder::Int32ToInt64(_))); + for v in [0, 1, -2, 123456] { + let data = encode_avro_int(v); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), 0); + assert_eq!(a.value(1), 1); + assert_eq!(a.value(2), -2); + assert_eq!(a.value(3), 123456); + } + + #[test] + fn test_schema_resolution_promotion_int_to_float() { + let mut dec = decoder_for_promotion(PrimitiveType::Int, PrimitiveType::Float, false); + assert!(matches!(dec, Decoder::Int32ToFloat32(_))); + for v in [0, 42, -7] { + let data = encode_avro_int(v); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), 0.0); + assert_eq!(a.value(1), 42.0); + assert_eq!(a.value(2), -7.0); + } + + #[test] + fn test_schema_resolution_promotion_int_to_double() { + let mut dec = decoder_for_promotion(PrimitiveType::Int, PrimitiveType::Double, false); + assert!(matches!(dec, Decoder::Int32ToFloat64(_))); + for v in [1, -1, 10_000] { + let data = encode_avro_int(v); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), 1.0); + assert_eq!(a.value(1), -1.0); + assert_eq!(a.value(2), 10_000.0); + } + + #[test] + fn test_schema_resolution_promotion_long_to_float() { + let mut dec = decoder_for_promotion(PrimitiveType::Long, PrimitiveType::Float, false); + assert!(matches!(dec, Decoder::Int64ToFloat32(_))); + for v in [0_i64, 1_000_000_i64, -123_i64] { + let data = encode_avro_long(v); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), 0.0); + assert_eq!(a.value(1), 1_000_000.0); + assert_eq!(a.value(2), -123.0); + } + + #[test] + fn test_schema_resolution_promotion_long_to_double() { + let mut dec = decoder_for_promotion(PrimitiveType::Long, PrimitiveType::Double, false); + assert!(matches!(dec, Decoder::Int64ToFloat64(_))); + for v in [2_i64, -2_i64, 9_223_372_i64] { + let data = encode_avro_long(v); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), 2.0); + assert_eq!(a.value(1), -2.0); + assert_eq!(a.value(2), 9_223_372.0); + } + + #[test] + fn test_schema_resolution_promotion_float_to_double() { + let mut dec = decoder_for_promotion(PrimitiveType::Float, PrimitiveType::Double, false); + assert!(matches!(dec, Decoder::Float32ToFloat64(_))); + for v in [0.5_f32, -3.25_f32, 1.0e6_f32] { + let data = v.to_le_bytes().to_vec(); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), 0.5_f64); + assert_eq!(a.value(1), -3.25_f64); + assert_eq!(a.value(2), 1.0e6_f64); + } + + #[test] + fn test_schema_resolution_promotion_bytes_to_string_utf8() { + let mut dec = decoder_for_promotion(PrimitiveType::Bytes, PrimitiveType::String, false); + assert!(matches!(dec, Decoder::BytesToString(_, _))); + for s in ["hello", "world", "héllo"] { + let data = encode_avro_bytes(s.as_bytes()); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), "hello"); + assert_eq!(a.value(1), "world"); + assert_eq!(a.value(2), "héllo"); + } + + #[test] + fn test_schema_resolution_promotion_bytes_to_string_utf8view_enabled() { + let mut dec = decoder_for_promotion(PrimitiveType::Bytes, PrimitiveType::String, true); + assert!(matches!(dec, Decoder::BytesToString(_, _))); + let data = encode_avro_bytes("abc".as_bytes()); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), "abc"); + } + + #[test] + fn test_schema_resolution_promotion_string_to_bytes() { + let mut dec = decoder_for_promotion(PrimitiveType::String, PrimitiveType::Bytes, false); + assert!(matches!(dec, Decoder::StringToBytes(_, _))); + for s in ["", "abc", "data"] { + let data = encode_avro_bytes(s.as_bytes()); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), b""); + assert_eq!(a.value(1), b"abc"); + assert_eq!(a.value(2), "data".as_bytes()); + } + + #[test] + fn test_schema_resolution_no_promotion_passthrough_int() { + let ws = Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)); + let rs = Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)); + // Wrap both in a synthetic single-field record and resolve with AvroFieldBuilder + let writer_record = Schema::Complex(ComplexType::Record(Record { + name: "Root", + namespace: None, + doc: None, + aliases: vec![], + fields: vec![Field { + name: "v", + r#type: ws, + default: None, + doc: None, + aliases: vec![], + }], + attributes: Attributes::default(), + })); + let reader_record = Schema::Complex(ComplexType::Record(Record { + name: "Root", + namespace: None, + doc: None, + aliases: vec![], + fields: vec![Field { + name: "v", + r#type: rs, + default: None, + doc: None, + aliases: vec![], + }], + attributes: Attributes::default(), + })); + let field = AvroFieldBuilder::new(&writer_record) + .with_reader_schema(&reader_record) + .with_utf8view(false) + .with_strict_mode(false) + .build() + .unwrap(); + // Extract the resolved inner field's AvroDataType + let dt = match field.data_type().codec() { + Codec::Struct(fields) => fields[0].data_type().clone(), + other => panic!("expected wrapper struct, got {other:?}"), + }; + let mut dec = Decoder::try_new(&dt).unwrap(); + assert!(matches!(dec, Decoder::Int32(_))); + for v in [7, -9] { + let data = encode_avro_int(v); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + } + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), 7); + assert_eq!(a.value(1), -9); + } + + #[test] + fn test_schema_resolution_illegal_promotion_int_to_boolean_errors() { + let ws = Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)); + let rs = Schema::TypeName(TypeName::Primitive(PrimitiveType::Boolean)); + let writer_record = Schema::Complex(ComplexType::Record(Record { + name: "Root", + namespace: None, + doc: None, + aliases: vec![], + fields: vec![Field { + name: "v", + r#type: ws, + default: None, + doc: None, + aliases: vec![], + }], + attributes: Attributes::default(), + })); + let reader_record = Schema::Complex(ComplexType::Record(Record { + name: "Root", + namespace: None, + doc: None, + aliases: vec![], + fields: vec![Field { + name: "v", + r#type: rs, + default: None, + doc: None, + aliases: vec![], + }], + attributes: Attributes::default(), + })); + let res = AvroFieldBuilder::new(&writer_record) + .with_reader_schema(&reader_record) + .with_utf8view(false) + .with_strict_mode(false) + .build(); + assert!(res.is_err(), "expected error for illegal promotion"); + } + + #[test] + fn test_map_decoding_one_entry() { + let value_type = avro_from_codec(Codec::Utf8); + let map_type = avro_from_codec(Codec::Map(Arc::new(value_type))); + let mut decoder = Decoder::try_new(&map_type).unwrap(); + // Encode a single map with one entry: {"hello": "world"} + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_long(1)); + data.extend_from_slice(&encode_avro_bytes(b"hello")); // key + data.extend_from_slice(&encode_avro_bytes(b"world")); // value + data.extend_from_slice(&encode_avro_long(0)); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + let array = decoder.flush(None).unwrap(); + let map_arr = array.as_any().downcast_ref::().unwrap(); + assert_eq!(map_arr.len(), 1); // one map + assert_eq!(map_arr.value_length(0), 1); + let entries = map_arr.value(0); + let struct_entries = entries.as_any().downcast_ref::().unwrap(); + assert_eq!(struct_entries.len(), 1); + let key_arr = struct_entries + .column_by_name("key") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let val_arr = struct_entries + .column_by_name("value") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(key_arr.value(0), "hello"); + assert_eq!(val_arr.value(0), "world"); } #[test] @@ -542,6 +2715,95 @@ mod tests { assert_eq!(map_arr.value_length(0), 0); } + #[test] + fn test_fixed_decoding() { + let avro_type = avro_from_codec(Codec::Fixed(3)); + let mut decoder = Decoder::try_new(&avro_type).expect("Failed to create decoder"); + + let data1 = [1u8, 2, 3]; + let mut cursor1 = AvroCursor::new(&data1); + decoder + .decode(&mut cursor1) + .expect("Failed to decode data1"); + assert_eq!(cursor1.position(), 3, "Cursor should advance by fixed size"); + let data2 = [4u8, 5, 6]; + let mut cursor2 = AvroCursor::new(&data2); + decoder + .decode(&mut cursor2) + .expect("Failed to decode data2"); + assert_eq!(cursor2.position(), 3, "Cursor should advance by fixed size"); + let array = decoder.flush(None).expect("Failed to flush decoder"); + assert_eq!(array.len(), 2, "Array should contain two items"); + let fixed_size_binary_array = array + .as_any() + .downcast_ref::() + .expect("Failed to downcast to FixedSizeBinaryArray"); + assert_eq!( + fixed_size_binary_array.value_length(), + 3, + "Fixed size of binary values should be 3" + ); + assert_eq!( + fixed_size_binary_array.value(0), + &[1, 2, 3], + "First item mismatch" + ); + assert_eq!( + fixed_size_binary_array.value(1), + &[4, 5, 6], + "Second item mismatch" + ); + } + + #[test] + fn test_fixed_decoding_empty() { + let avro_type = avro_from_codec(Codec::Fixed(5)); + let mut decoder = Decoder::try_new(&avro_type).expect("Failed to create decoder"); + + let array = decoder + .flush(None) + .expect("Failed to flush decoder for empty input"); + + assert_eq!(array.len(), 0, "Array should be empty"); + let fixed_size_binary_array = array + .as_any() + .downcast_ref::() + .expect("Failed to downcast to FixedSizeBinaryArray for empty array"); + + assert_eq!( + fixed_size_binary_array.value_length(), + 5, + "Fixed size of binary values should be 5 as per type" + ); + } + + #[test] + fn test_uuid_decoding() { + let avro_type = avro_from_codec(Codec::Uuid); + let mut decoder = Decoder::try_new(&avro_type).expect("Failed to create decoder"); + let uuid_str = "f81d4fae-7dec-11d0-a765-00a0c91e6bf6"; + let data = encode_avro_bytes(uuid_str.as_bytes()); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).expect("Failed to decode data"); + assert_eq!( + cursor.position(), + data.len(), + "Cursor should advance by varint size + data size" + ); + let array = decoder.flush(None).expect("Failed to flush decoder"); + let fixed_size_binary_array = array + .as_any() + .downcast_ref::() + .expect("Array should be a FixedSizeBinaryArray"); + assert_eq!(fixed_size_binary_array.len(), 1); + assert_eq!(fixed_size_binary_array.value_length(), 16); + let expected_bytes = [ + 0xf8, 0x1d, 0x4f, 0xae, 0x7d, 0xec, 0x11, 0xd0, 0xa7, 0x65, 0x00, 0xa0, 0xc9, 0x1e, + 0x6b, 0xf6, + ]; + assert_eq!(fixed_size_binary_array.value(0), &expected_bytes); + } + #[test] fn test_array_decoding() { let item_dt = avro_from_codec(Codec::Int32); @@ -634,4 +2896,1862 @@ mod tests { assert_eq!(list_arr.len(), 1); assert_eq!(list_arr.value_length(0), 0); } + + #[test] + fn test_decimal_decoding_fixed256() { + let dt = avro_from_codec(Codec::Decimal(50, Some(2), Some(32))); + let mut decoder = Decoder::try_new(&dt).unwrap(); + let row1 = [ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x30, 0x39, + ]; + let row2 = [ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0x85, + ]; + let mut data = Vec::new(); + data.extend_from_slice(&row1); + data.extend_from_slice(&row2); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + let arr = decoder.flush(None).unwrap(); + let dec = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec.len(), 2); + assert_eq!(dec.value_as_string(0), "123.45"); + assert_eq!(dec.value_as_string(1), "-1.23"); + } + + #[test] + fn test_decimal_decoding_fixed128() { + let dt = avro_from_codec(Codec::Decimal(28, Some(2), Some(16))); + let mut decoder = Decoder::try_new(&dt).unwrap(); + let row1 = [ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x30, 0x39, + ]; + let row2 = [ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0x85, + ]; + let mut data = Vec::new(); + data.extend_from_slice(&row1); + data.extend_from_slice(&row2); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + let arr = decoder.flush(None).unwrap(); + let dec = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec.len(), 2); + assert_eq!(dec.value_as_string(0), "123.45"); + assert_eq!(dec.value_as_string(1), "-1.23"); + } + + #[test] + fn test_decimal_decoding_fixed32_from_32byte_fixed_storage() { + let dt = avro_from_codec(Codec::Decimal(5, Some(2), Some(32))); + let mut decoder = Decoder::try_new(&dt).unwrap(); + let row1 = [ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x30, 0x39, + ]; + let row2 = [ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0x85, + ]; + let mut data = Vec::new(); + data.extend_from_slice(&row1); + data.extend_from_slice(&row2); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + let arr = decoder.flush(None).unwrap(); + #[cfg(feature = "small_decimals")] + { + let dec = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec.len(), 2); + assert_eq!(dec.value_as_string(0), "123.45"); + assert_eq!(dec.value_as_string(1), "-1.23"); + } + #[cfg(not(feature = "small_decimals"))] + { + let dec = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec.len(), 2); + assert_eq!(dec.value_as_string(0), "123.45"); + assert_eq!(dec.value_as_string(1), "-1.23"); + } + } + + #[test] + fn test_decimal_decoding_fixed32_from_16byte_fixed_storage() { + let dt = avro_from_codec(Codec::Decimal(5, Some(2), Some(16))); + let mut decoder = Decoder::try_new(&dt).unwrap(); + let row1 = [ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x30, 0x39, + ]; + let row2 = [ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0x85, + ]; + let mut data = Vec::new(); + data.extend_from_slice(&row1); + data.extend_from_slice(&row2); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + + let arr = decoder.flush(None).unwrap(); + #[cfg(feature = "small_decimals")] + { + let dec = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec.len(), 2); + assert_eq!(dec.value_as_string(0), "123.45"); + assert_eq!(dec.value_as_string(1), "-1.23"); + } + #[cfg(not(feature = "small_decimals"))] + { + let dec = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec.len(), 2); + assert_eq!(dec.value_as_string(0), "123.45"); + assert_eq!(dec.value_as_string(1), "-1.23"); + } + } + + #[test] + fn test_decimal_decoding_bytes_with_nulls() { + let dt = avro_from_codec(Codec::Decimal(4, Some(1), None)); + let inner = Decoder::try_new(&dt).unwrap(); + let mut decoder = Decoder::Nullable( + Nullability::NullSecond, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(inner), + NullablePlan::ReadTag, + ); + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&encode_avro_bytes(&[0x04, 0xD2])); + data.extend_from_slice(&encode_avro_int(1)); + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&encode_avro_bytes(&[0xFB, 0x2E])); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + let arr = decoder.flush(None).unwrap(); + #[cfg(feature = "small_decimals")] + { + let dec_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec_arr.len(), 3); + assert!(dec_arr.is_valid(0)); + assert!(!dec_arr.is_valid(1)); + assert!(dec_arr.is_valid(2)); + assert_eq!(dec_arr.value_as_string(0), "123.4"); + assert_eq!(dec_arr.value_as_string(2), "-123.4"); + } + #[cfg(not(feature = "small_decimals"))] + { + let dec_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec_arr.len(), 3); + assert!(dec_arr.is_valid(0)); + assert!(!dec_arr.is_valid(1)); + assert!(dec_arr.is_valid(2)); + assert_eq!(dec_arr.value_as_string(0), "123.4"); + assert_eq!(dec_arr.value_as_string(2), "-123.4"); + } + } + + #[test] + fn test_decimal_decoding_bytes_with_nulls_fixed_size_narrow_result() { + let dt = avro_from_codec(Codec::Decimal(6, Some(2), Some(16))); + let inner = Decoder::try_new(&dt).unwrap(); + let mut decoder = Decoder::Nullable( + Nullability::NullSecond, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(inner), + NullablePlan::ReadTag, + ); + let row1 = [ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0xE2, 0x40, + ]; + let row3 = [ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFE, + 0x1D, 0xC0, + ]; + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&row1); + data.extend_from_slice(&encode_avro_int(1)); + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&row3); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + let arr = decoder.flush(None).unwrap(); + #[cfg(feature = "small_decimals")] + { + let dec_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec_arr.len(), 3); + assert!(dec_arr.is_valid(0)); + assert!(!dec_arr.is_valid(1)); + assert!(dec_arr.is_valid(2)); + assert_eq!(dec_arr.value_as_string(0), "1234.56"); + assert_eq!(dec_arr.value_as_string(2), "-1234.56"); + } + #[cfg(not(feature = "small_decimals"))] + { + let dec_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec_arr.len(), 3); + assert!(dec_arr.is_valid(0)); + assert!(!dec_arr.is_valid(1)); + assert!(dec_arr.is_valid(2)); + assert_eq!(dec_arr.value_as_string(0), "1234.56"); + assert_eq!(dec_arr.value_as_string(2), "-1234.56"); + } + } + + #[test] + fn test_enum_decoding() { + let symbols: Arc<[String]> = vec!["A", "B", "C"].into_iter().map(String::from).collect(); + let avro_type = avro_from_codec(Codec::Enum(symbols.clone())); + let mut decoder = Decoder::try_new(&avro_type).unwrap(); + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(2)); + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&encode_avro_int(1)); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + let array = decoder.flush(None).unwrap(); + let dict_array = array + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(dict_array.len(), 3); + let values = dict_array + .values() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(values.value(0), "A"); + assert_eq!(values.value(1), "B"); + assert_eq!(values.value(2), "C"); + assert_eq!(dict_array.keys().values(), &[2, 0, 1]); + } + + #[test] + fn test_enum_decoding_with_nulls() { + let symbols: Arc<[String]> = vec!["X", "Y"].into_iter().map(String::from).collect(); + let enum_codec = Codec::Enum(symbols.clone()); + let avro_type = + AvroDataType::new(enum_codec, Default::default(), Some(Nullability::NullFirst)); + let mut decoder = Decoder::try_new(&avro_type).unwrap(); + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_long(1)); + data.extend_from_slice(&encode_avro_int(1)); + data.extend_from_slice(&encode_avro_long(0)); + data.extend_from_slice(&encode_avro_long(1)); + data.extend_from_slice(&encode_avro_int(0)); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + let array = decoder.flush(None).unwrap(); + let dict_array = array + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(dict_array.len(), 3); + assert!(dict_array.is_valid(0)); + assert!(dict_array.is_null(1)); + assert!(dict_array.is_valid(2)); + let expected_keys = Int32Array::from(vec![Some(1), None, Some(0)]); + assert_eq!(dict_array.keys(), &expected_keys); + let values = dict_array + .values() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(values.value(0), "X"); + assert_eq!(values.value(1), "Y"); + } + + #[test] + fn test_duration_decoding_with_nulls() { + let duration_codec = Codec::Interval; + let avro_type = AvroDataType::new( + duration_codec, + Default::default(), + Some(Nullability::NullFirst), + ); + let mut decoder = Decoder::try_new(&avro_type).unwrap(); + let mut data = Vec::new(); + // First value: 1 month, 2 days, 3 millis + data.extend_from_slice(&encode_avro_long(1)); // not null + let mut duration1 = Vec::new(); + duration1.extend_from_slice(&1u32.to_le_bytes()); + duration1.extend_from_slice(&2u32.to_le_bytes()); + duration1.extend_from_slice(&3u32.to_le_bytes()); + data.extend_from_slice(&duration1); + // Second value: null + data.extend_from_slice(&encode_avro_long(0)); // null + data.extend_from_slice(&encode_avro_long(1)); // not null + let mut duration2 = Vec::new(); + duration2.extend_from_slice(&4u32.to_le_bytes()); + duration2.extend_from_slice(&5u32.to_le_bytes()); + duration2.extend_from_slice(&6u32.to_le_bytes()); + data.extend_from_slice(&duration2); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + let array = decoder.flush(None).unwrap(); + let interval_array = array + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(interval_array.len(), 3); + assert!(interval_array.is_valid(0)); + assert!(interval_array.is_null(1)); + assert!(interval_array.is_valid(2)); + let expected = IntervalMonthDayNanoArray::from(vec![ + Some(IntervalMonthDayNano { + months: 1, + days: 2, + nanoseconds: 3_000_000, + }), + None, + Some(IntervalMonthDayNano { + months: 4, + days: 5, + nanoseconds: 6_000_000, + }), + ]); + assert_eq!(interval_array, &expected); + } + + #[test] + fn test_duration_decoding_empty() { + let duration_codec = Codec::Interval; + let avro_type = AvroDataType::new(duration_codec, Default::default(), None); + let mut decoder = Decoder::try_new(&avro_type).unwrap(); + let array = decoder.flush(None).unwrap(); + assert_eq!(array.len(), 0); + } + + #[test] + #[cfg(feature = "avro_custom_types")] + fn test_duration_seconds_decoding() { + let avro_type = AvroDataType::new(Codec::DurationSeconds, Default::default(), None); + let mut decoder = Decoder::try_new(&avro_type).unwrap(); + let mut data = Vec::new(); + // Three values: 0, -1, 2 + data.extend_from_slice(&encode_avro_long(0)); + data.extend_from_slice(&encode_avro_long(-1)); + data.extend_from_slice(&encode_avro_long(2)); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + let array = decoder.flush(None).unwrap(); + let dur = array + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(dur.values(), &[0, -1, 2]); + } + + #[test] + #[cfg(feature = "avro_custom_types")] + fn test_duration_milliseconds_decoding() { + let avro_type = AvroDataType::new(Codec::DurationMillis, Default::default(), None); + let mut decoder = Decoder::try_new(&avro_type).unwrap(); + let mut data = Vec::new(); + for v in [1i64, 0, -2] { + data.extend_from_slice(&encode_avro_long(v)); + } + let mut cursor = AvroCursor::new(&data); + for _ in 0..3 { + decoder.decode(&mut cursor).unwrap(); + } + let array = decoder.flush(None).unwrap(); + let dur = array + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(dur.values(), &[1, 0, -2]); + } + + #[test] + #[cfg(feature = "avro_custom_types")] + fn test_duration_microseconds_decoding() { + let avro_type = AvroDataType::new(Codec::DurationMicros, Default::default(), None); + let mut decoder = Decoder::try_new(&avro_type).unwrap(); + let mut data = Vec::new(); + for v in [5i64, -6, 7] { + data.extend_from_slice(&encode_avro_long(v)); + } + let mut cursor = AvroCursor::new(&data); + for _ in 0..3 { + decoder.decode(&mut cursor).unwrap(); + } + let array = decoder.flush(None).unwrap(); + let dur = array + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(dur.values(), &[5, -6, 7]); + } + + #[test] + #[cfg(feature = "avro_custom_types")] + fn test_duration_nanoseconds_decoding() { + let avro_type = AvroDataType::new(Codec::DurationNanos, Default::default(), None); + let mut decoder = Decoder::try_new(&avro_type).unwrap(); + let mut data = Vec::new(); + for v in [8i64, 9, -10] { + data.extend_from_slice(&encode_avro_long(v)); + } + let mut cursor = AvroCursor::new(&data); + for _ in 0..3 { + decoder.decode(&mut cursor).unwrap(); + } + let array = decoder.flush(None).unwrap(); + let dur = array + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(dur.values(), &[8, 9, -10]); + } + + #[test] + fn test_nullable_decode_error_bitmap_corruption() { + // Nullable Int32 with ['T','null'] encoding (NullSecond) + let avro_type = AvroDataType::new( + Codec::Int32, + Default::default(), + Some(Nullability::NullSecond), + ); + let mut decoder = Decoder::try_new(&avro_type).unwrap(); + + // Row 1: union branch 1 (null) + let mut row1 = Vec::new(); + row1.extend_from_slice(&encode_avro_int(1)); + + // Row 2: union branch 0 (non-null) but missing the int payload -> decode error + let mut row2 = Vec::new(); + row2.extend_from_slice(&encode_avro_int(0)); // branch = 0 => non-null + + // Row 3: union branch 0 (non-null) with correct int payload -> should succeed + let mut row3 = Vec::new(); + row3.extend_from_slice(&encode_avro_int(0)); // branch + row3.extend_from_slice(&encode_avro_int(42)); // actual value + + decoder.decode(&mut AvroCursor::new(&row1)).unwrap(); + assert!(decoder.decode(&mut AvroCursor::new(&row2)).is_err()); // decode error + decoder.decode(&mut AvroCursor::new(&row3)).unwrap(); + + let array = decoder.flush(None).unwrap(); + + // Should contain 2 elements: row1 (null) and row3 (42) + assert_eq!(array.len(), 2); + let int_array = array.as_any().downcast_ref::().unwrap(); + assert!(int_array.is_null(0)); // row1 is null + assert_eq!(int_array.value(1), 42); // row3 value is 42 + } + + #[test] + fn test_enum_mapping_reordered_symbols() { + let reader_symbols: Arc<[String]> = + vec!["B".to_string(), "C".to_string(), "A".to_string()].into(); + let mapping: Arc<[i32]> = Arc::from(vec![2, 0, 1]); + let default_index: i32 = -1; + let mut dec = Decoder::Enum( + Vec::with_capacity(DEFAULT_CAPACITY), + reader_symbols.clone(), + Some(EnumResolution { + mapping, + default_index, + }), + ); + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&encode_avro_int(1)); + data.extend_from_slice(&encode_avro_int(2)); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + dec.decode(&mut cur).unwrap(); + dec.decode(&mut cur).unwrap(); + let arr = dec.flush(None).unwrap(); + let dict = arr + .as_any() + .downcast_ref::>() + .unwrap(); + let expected_keys = Int32Array::from(vec![2, 0, 1]); + assert_eq!(dict.keys(), &expected_keys); + let values = dict + .values() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(values.value(0), "B"); + assert_eq!(values.value(1), "C"); + assert_eq!(values.value(2), "A"); + } + + #[test] + fn test_enum_mapping_unknown_symbol_and_out_of_range_fall_back_to_default() { + let reader_symbols: Arc<[String]> = vec!["A".to_string(), "B".to_string()].into(); + let default_index: i32 = 1; + let mapping: Arc<[i32]> = Arc::from(vec![0, 1]); + let mut dec = Decoder::Enum( + Vec::with_capacity(DEFAULT_CAPACITY), + reader_symbols.clone(), + Some(EnumResolution { + mapping, + default_index, + }), + ); + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&encode_avro_int(1)); + data.extend_from_slice(&encode_avro_int(99)); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + dec.decode(&mut cur).unwrap(); + dec.decode(&mut cur).unwrap(); + let arr = dec.flush(None).unwrap(); + let dict = arr + .as_any() + .downcast_ref::>() + .unwrap(); + let expected_keys = Int32Array::from(vec![0, 1, 1]); + assert_eq!(dict.keys(), &expected_keys); + let values = dict + .values() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(values.value(0), "A"); + assert_eq!(values.value(1), "B"); + } + + #[test] + fn test_enum_mapping_unknown_symbol_without_default_errors() { + let reader_symbols: Arc<[String]> = vec!["A".to_string()].into(); + let default_index: i32 = -1; // indicates no default at type-level + let mapping: Arc<[i32]> = Arc::from(vec![-1]); + let mut dec = Decoder::Enum( + Vec::with_capacity(DEFAULT_CAPACITY), + reader_symbols, + Some(EnumResolution { + mapping, + default_index, + }), + ); + let data = encode_avro_int(0); + let mut cur = AvroCursor::new(&data); + let err = dec + .decode(&mut cur) + .expect_err("expected decode error for unresolved enum without default"); + let msg = err.to_string(); + assert!( + msg.contains("not resolvable") && msg.contains("no default"), + "unexpected error message: {msg}" + ); + } + + fn make_record_resolved_decoder( + reader_fields: &[(&str, DataType, bool)], + writer_to_reader: Vec>, + skip_decoders: Vec>, + ) -> Decoder { + let mut field_refs: Vec = Vec::with_capacity(reader_fields.len()); + let mut encodings: Vec = Vec::with_capacity(reader_fields.len()); + for (name, dt, nullable) in reader_fields { + field_refs.push(Arc::new(ArrowField::new(*name, dt.clone(), *nullable))); + let enc = match dt { + DataType::Int32 => Decoder::Int32(Vec::new()), + DataType::Int64 => Decoder::Int64(Vec::new()), + DataType::Utf8 => { + Decoder::String(OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::new()) + } + other => panic!("Unsupported test reader field type: {other:?}"), + }; + encodings.push(enc); + } + let fields: Fields = field_refs.into(); + Decoder::Record( + fields, + encodings, + Some(Projector { + writer_to_reader: Arc::from(writer_to_reader), + skip_decoders, + field_defaults: vec![None; reader_fields.len()], + default_injections: Arc::from(Vec::<(usize, AvroLiteral)>::new()), + }), + ) + } + + #[test] + fn test_skip_writer_trailing_field_int32() { + let mut dec = make_record_resolved_decoder( + &[("id", arrow_schema::DataType::Int32, false)], + vec![Some(0), None], + vec![None, Some(super::Skipper::Int32)], + ); + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(7)); + data.extend_from_slice(&encode_avro_int(999)); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + assert_eq!(cur.position(), data.len()); + let arr = dec.flush(None).unwrap(); + let struct_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(struct_arr.len(), 1); + let id = struct_arr + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(id.value(0), 7); + } + + #[test] + fn test_skip_writer_middle_field_string() { + let mut dec = make_record_resolved_decoder( + &[ + ("id", DataType::Int32, false), + ("score", DataType::Int64, false), + ], + vec![Some(0), None, Some(1)], + vec![None, Some(Skipper::String), None], + ); + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(42)); + data.extend_from_slice(&encode_avro_bytes(b"abcdef")); + data.extend_from_slice(&encode_avro_long(1000)); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + assert_eq!(cur.position(), data.len()); + let arr = dec.flush(None).unwrap(); + let s = arr.as_any().downcast_ref::().unwrap(); + let id = s + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let score = s + .column_by_name("score") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(id.value(0), 42); + assert_eq!(score.value(0), 1000); + } + + #[test] + fn test_skip_writer_array_with_negative_block_count_fast() { + let mut dec = make_record_resolved_decoder( + &[("id", DataType::Int32, false)], + vec![None, Some(0)], + vec![Some(super::Skipper::List(Box::new(Skipper::Int32))), None], + ); + let mut array_payload = Vec::new(); + array_payload.extend_from_slice(&encode_avro_int(1)); + array_payload.extend_from_slice(&encode_avro_int(2)); + array_payload.extend_from_slice(&encode_avro_int(3)); + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_long(-3)); + data.extend_from_slice(&encode_avro_long(array_payload.len() as i64)); + data.extend_from_slice(&array_payload); + data.extend_from_slice(&encode_avro_long(0)); + data.extend_from_slice(&encode_avro_int(5)); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + assert_eq!(cur.position(), data.len()); + let arr = dec.flush(None).unwrap(); + let s = arr.as_any().downcast_ref::().unwrap(); + let id = s + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(id.len(), 1); + assert_eq!(id.value(0), 5); + } + + #[test] + fn test_skip_writer_map_with_negative_block_count_fast() { + let mut dec = make_record_resolved_decoder( + &[("id", DataType::Int32, false)], + vec![None, Some(0)], + vec![Some(Skipper::Map(Box::new(Skipper::Int32))), None], + ); + let mut entries = Vec::new(); + entries.extend_from_slice(&encode_avro_bytes(b"k1")); + entries.extend_from_slice(&encode_avro_int(10)); + entries.extend_from_slice(&encode_avro_bytes(b"k2")); + entries.extend_from_slice(&encode_avro_int(20)); + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_long(-2)); + data.extend_from_slice(&encode_avro_long(entries.len() as i64)); + data.extend_from_slice(&entries); + data.extend_from_slice(&encode_avro_long(0)); + data.extend_from_slice(&encode_avro_int(123)); + let mut cur = AvroCursor::new(&data); + dec.decode(&mut cur).unwrap(); + assert_eq!(cur.position(), data.len()); + let arr = dec.flush(None).unwrap(); + let s = arr.as_any().downcast_ref::().unwrap(); + let id = s + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(id.len(), 1); + assert_eq!(id.value(0), 123); + } + + #[test] + fn test_skip_writer_nullable_field_union_nullfirst() { + let mut dec = make_record_resolved_decoder( + &[("id", DataType::Int32, false)], + vec![None, Some(0)], + vec![ + Some(super::Skipper::Nullable( + Nullability::NullFirst, + Box::new(super::Skipper::Int32), + )), + None, + ], + ); + let mut row1 = Vec::new(); + row1.extend_from_slice(&encode_avro_long(0)); + row1.extend_from_slice(&encode_avro_int(5)); + let mut row2 = Vec::new(); + row2.extend_from_slice(&encode_avro_long(1)); + row2.extend_from_slice(&encode_avro_int(123)); + row2.extend_from_slice(&encode_avro_int(7)); + let mut cur1 = AvroCursor::new(&row1); + let mut cur2 = AvroCursor::new(&row2); + dec.decode(&mut cur1).unwrap(); + dec.decode(&mut cur2).unwrap(); + assert_eq!(cur1.position(), row1.len()); + assert_eq!(cur2.position(), row2.len()); + let arr = dec.flush(None).unwrap(); + let s = arr.as_any().downcast_ref::().unwrap(); + let id = s + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(id.len(), 2); + assert_eq!(id.value(0), 5); + assert_eq!(id.value(1), 7); + } + + fn make_dense_union_avro( + children: Vec<(Codec, &'_ str, DataType)>, + type_ids: Vec, + ) -> AvroDataType { + let mut avro_children: Vec = Vec::with_capacity(children.len()); + let mut fields: Vec = Vec::with_capacity(children.len()); + for (codec, name, dt) in children.into_iter() { + avro_children.push(AvroDataType::new(codec, Default::default(), None)); + fields.push(arrow_schema::Field::new(name, dt, true)); + } + let union_fields = UnionFields::try_new(type_ids, fields).unwrap(); + let union_codec = Codec::Union(avro_children.into(), union_fields, UnionMode::Dense); + AvroDataType::new(union_codec, Default::default(), None) + } + + #[test] + fn test_union_dense_two_children_custom_type_ids() { + let union_dt = make_dense_union_avro( + vec![ + (Codec::Int32, "i", DataType::Int32), + (Codec::Utf8, "s", DataType::Utf8), + ], + vec![2, 5], + ); + let mut dec = Decoder::try_new(&union_dt).unwrap(); + let mut r1 = Vec::new(); + r1.extend_from_slice(&encode_avro_long(0)); + r1.extend_from_slice(&encode_avro_int(7)); + let mut r2 = Vec::new(); + r2.extend_from_slice(&encode_avro_long(1)); + r2.extend_from_slice(&encode_avro_bytes(b"x")); + let mut r3 = Vec::new(); + r3.extend_from_slice(&encode_avro_long(0)); + r3.extend_from_slice(&encode_avro_int(-1)); + dec.decode(&mut AvroCursor::new(&r1)).unwrap(); + dec.decode(&mut AvroCursor::new(&r2)).unwrap(); + dec.decode(&mut AvroCursor::new(&r3)).unwrap(); + let array = dec.flush(None).unwrap(); + let ua = array + .as_any() + .downcast_ref::() + .expect("expected UnionArray"); + assert_eq!(ua.len(), 3); + assert_eq!(ua.type_id(0), 2); + assert_eq!(ua.type_id(1), 5); + assert_eq!(ua.type_id(2), 2); + assert_eq!(ua.value_offset(0), 0); + assert_eq!(ua.value_offset(1), 0); + assert_eq!(ua.value_offset(2), 1); + let int_child = ua + .child(2) + .as_any() + .downcast_ref::() + .expect("int child"); + assert_eq!(int_child.len(), 2); + assert_eq!(int_child.value(0), 7); + assert_eq!(int_child.value(1), -1); + let str_child = ua + .child(5) + .as_any() + .downcast_ref::() + .expect("string child"); + assert_eq!(str_child.len(), 1); + assert_eq!(str_child.value(0), "x"); + } + + #[test] + fn test_union_dense_with_null_and_string_children() { + let union_dt = make_dense_union_avro( + vec![ + (Codec::Null, "n", DataType::Null), + (Codec::Utf8, "s", DataType::Utf8), + ], + vec![42, 7], + ); + let mut dec = Decoder::try_new(&union_dt).unwrap(); + let r1 = encode_avro_long(0); + let mut r2 = Vec::new(); + r2.extend_from_slice(&encode_avro_long(1)); + r2.extend_from_slice(&encode_avro_bytes(b"abc")); + let r3 = encode_avro_long(0); + dec.decode(&mut AvroCursor::new(&r1)).unwrap(); + dec.decode(&mut AvroCursor::new(&r2)).unwrap(); + dec.decode(&mut AvroCursor::new(&r3)).unwrap(); + let array = dec.flush(None).unwrap(); + let ua = array + .as_any() + .downcast_ref::() + .expect("expected UnionArray"); + assert_eq!(ua.len(), 3); + assert_eq!(ua.type_id(0), 42); + assert_eq!(ua.type_id(1), 7); + assert_eq!(ua.type_id(2), 42); + assert_eq!(ua.value_offset(0), 0); + assert_eq!(ua.value_offset(1), 0); + assert_eq!(ua.value_offset(2), 1); + let null_child = ua + .child(42) + .as_any() + .downcast_ref::() + .expect("null child"); + assert_eq!(null_child.len(), 2); + let str_child = ua + .child(7) + .as_any() + .downcast_ref::() + .expect("string child"); + assert_eq!(str_child.len(), 1); + assert_eq!(str_child.value(0), "abc"); + } + + #[test] + fn test_union_decode_negative_branch_index_errors() { + let union_dt = make_dense_union_avro( + vec![ + (Codec::Int32, "i", DataType::Int32), + (Codec::Utf8, "s", DataType::Utf8), + ], + vec![0, 1], + ); + let mut dec = Decoder::try_new(&union_dt).unwrap(); + let row = encode_avro_long(-1); // decodes back to -1 + let err = dec + .decode(&mut AvroCursor::new(&row)) + .expect_err("expected error for negative branch index"); + let msg = err.to_string(); + assert!( + msg.contains("Negative union branch index"), + "unexpected error message: {msg}" + ); + } + + #[test] + fn test_union_decode_out_of_range_branch_index_errors() { + let union_dt = make_dense_union_avro( + vec![ + (Codec::Int32, "i", DataType::Int32), + (Codec::Utf8, "s", DataType::Utf8), + ], + vec![10, 11], + ); + let mut dec = Decoder::try_new(&union_dt).unwrap(); + let row = encode_avro_long(2); + let err = dec + .decode(&mut AvroCursor::new(&row)) + .expect_err("expected error for out-of-range branch index"); + let msg = err.to_string(); + assert!( + msg.contains("out of range"), + "unexpected error message: {msg}" + ); + } + + #[test] + fn test_union_sparse_mode_not_supported() { + let children: Vec = vec![ + AvroDataType::new(Codec::Int32, Default::default(), None), + AvroDataType::new(Codec::Utf8, Default::default(), None), + ]; + let uf = UnionFields::try_new( + vec![1, 3], + vec![ + arrow_schema::Field::new("i", DataType::Int32, true), + arrow_schema::Field::new("s", DataType::Utf8, true), + ], + ) + .unwrap(); + let codec = Codec::Union(children.into(), uf, UnionMode::Sparse); + let dt = AvroDataType::new(codec, Default::default(), None); + let err = Decoder::try_new(&dt).expect_err("sparse union should not be supported"); + let msg = err.to_string(); + assert!( + msg.contains("Sparse Arrow unions are not yet supported"), + "unexpected error message: {msg}" + ); + } + + fn make_record_decoder_with_projector_defaults( + reader_fields: &[(&str, DataType, bool)], + field_defaults: Vec>, + default_injections: Vec<(usize, AvroLiteral)>, + writer_to_reader_len: usize, + ) -> Decoder { + assert_eq!( + field_defaults.len(), + reader_fields.len(), + "field_defaults must have one entry per reader field" + ); + let mut field_refs: Vec = Vec::with_capacity(reader_fields.len()); + let mut encodings: Vec = Vec::with_capacity(reader_fields.len()); + for (name, dt, nullable) in reader_fields { + field_refs.push(Arc::new(ArrowField::new(*name, dt.clone(), *nullable))); + let enc = match dt { + DataType::Int32 => Decoder::Int32(Vec::with_capacity(DEFAULT_CAPACITY)), + DataType::Int64 => Decoder::Int64(Vec::with_capacity(DEFAULT_CAPACITY)), + DataType::Utf8 => Decoder::String( + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Vec::with_capacity(DEFAULT_CAPACITY), + ), + other => panic!("Unsupported test field type in helper: {other:?}"), + }; + encodings.push(enc); + } + let fields: Fields = field_refs.into(); + let skip_decoders: Vec> = + (0..writer_to_reader_len).map(|_| None::).collect(); + let projector = Projector { + writer_to_reader: Arc::from(vec![None; writer_to_reader_len]), + skip_decoders, + field_defaults, + default_injections: Arc::from(default_injections), + }; + Decoder::Record(fields, encodings, Some(projector)) + } + + #[test] + fn test_default_append_int32_and_int64_from_int_and_long() { + let mut d_i32 = Decoder::Int32(Vec::with_capacity(DEFAULT_CAPACITY)); + d_i32.append_default(&AvroLiteral::Int(42)).unwrap(); + let arr = d_i32.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.len(), 1); + assert_eq!(a.value(0), 42); + let mut d_i64 = Decoder::Int64(Vec::with_capacity(DEFAULT_CAPACITY)); + d_i64.append_default(&AvroLiteral::Int(5)).unwrap(); + d_i64.append_default(&AvroLiteral::Long(7)).unwrap(); + let arr64 = d_i64.flush(None).unwrap(); + let a64 = arr64.as_any().downcast_ref::().unwrap(); + assert_eq!(a64.len(), 2); + assert_eq!(a64.value(0), 5); + assert_eq!(a64.value(1), 7); + } + + #[test] + fn test_default_append_floats_and_doubles() { + let mut d_f32 = Decoder::Float32(Vec::with_capacity(DEFAULT_CAPACITY)); + d_f32.append_default(&AvroLiteral::Float(1.5)).unwrap(); + let arr32 = d_f32.flush(None).unwrap(); + let a = arr32.as_any().downcast_ref::().unwrap(); + assert_eq!(a.value(0), 1.5); + let mut d_f64 = Decoder::Float64(Vec::with_capacity(DEFAULT_CAPACITY)); + d_f64.append_default(&AvroLiteral::Double(2.25)).unwrap(); + let arr64 = d_f64.flush(None).unwrap(); + let b = arr64.as_any().downcast_ref::().unwrap(); + assert_eq!(b.value(0), 2.25); + } + + #[test] + fn test_default_append_string_and_bytes() { + let mut d_str = Decoder::String( + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Vec::with_capacity(DEFAULT_CAPACITY), + ); + d_str + .append_default(&AvroLiteral::String("hi".into())) + .unwrap(); + let s_arr = d_str.flush(None).unwrap(); + let arr = s_arr.as_any().downcast_ref::().unwrap(); + assert_eq!(arr.value(0), "hi"); + let mut d_bytes = Decoder::Binary( + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Vec::with_capacity(DEFAULT_CAPACITY), + ); + d_bytes + .append_default(&AvroLiteral::Bytes(vec![1, 2, 3])) + .unwrap(); + let b_arr = d_bytes.flush(None).unwrap(); + let barr = b_arr.as_any().downcast_ref::().unwrap(); + assert_eq!(barr.value(0), &[1, 2, 3]); + let mut d_str_err = Decoder::String( + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Vec::with_capacity(DEFAULT_CAPACITY), + ); + let err = d_str_err + .append_default(&AvroLiteral::Bytes(vec![0x61, 0x62])) + .unwrap_err(); + assert!( + err.to_string() + .contains("Default for string must be string"), + "unexpected error: {err:?}" + ); + } + + #[test] + fn test_default_append_nullable_int32_null_and_value() { + let inner = Decoder::Int32(Vec::with_capacity(DEFAULT_CAPACITY)); + let mut dec = Decoder::Nullable( + Nullability::NullFirst, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(inner), + NullablePlan::ReadTag, + ); + dec.append_default(&AvroLiteral::Null).unwrap(); + dec.append_default(&AvroLiteral::Int(11)).unwrap(); + let arr = dec.flush(None).unwrap(); + let a = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(a.len(), 2); + assert!(a.is_null(0)); + assert_eq!(a.value(1), 11); + } + + #[test] + fn test_default_append_array_of_ints() { + let list_dt = avro_from_codec(Codec::List(Arc::new(avro_from_codec(Codec::Int32)))); + let mut d = Decoder::try_new(&list_dt).unwrap(); + let items = vec![ + AvroLiteral::Int(1), + AvroLiteral::Int(2), + AvroLiteral::Int(3), + ]; + d.append_default(&AvroLiteral::Array(items)).unwrap(); + let arr = d.flush(None).unwrap(); + let list = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(list.len(), 1); + assert_eq!(list.value_length(0), 3); + let vals = list.values().as_any().downcast_ref::().unwrap(); + assert_eq!(vals.values(), &[1, 2, 3]); + } + + #[test] + fn test_default_append_map_string_to_int() { + let map_dt = avro_from_codec(Codec::Map(Arc::new(avro_from_codec(Codec::Int32)))); + let mut d = Decoder::try_new(&map_dt).unwrap(); + let mut m: IndexMap = IndexMap::new(); + m.insert("k1".to_string(), AvroLiteral::Int(10)); + m.insert("k2".to_string(), AvroLiteral::Int(20)); + d.append_default(&AvroLiteral::Map(m)).unwrap(); + let arr = d.flush(None).unwrap(); + let map = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(map.len(), 1); + assert_eq!(map.value_length(0), 2); + let binding = map.value(0); + let entries = binding.as_any().downcast_ref::().unwrap(); + let k = entries + .column_by_name("key") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let v = entries + .column_by_name("value") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let keys: std::collections::HashSet<&str> = (0..k.len()).map(|i| k.value(i)).collect(); + assert_eq!(keys, ["k1", "k2"].into_iter().collect()); + let vals: std::collections::HashSet = (0..v.len()).map(|i| v.value(i)).collect(); + assert_eq!(vals, [10, 20].into_iter().collect()); + } + + #[test] + fn test_default_append_enum_by_symbol() { + let symbols: Arc<[String]> = vec!["A".into(), "B".into(), "C".into()].into(); + let mut d = Decoder::Enum(Vec::with_capacity(DEFAULT_CAPACITY), symbols.clone(), None); + d.append_default(&AvroLiteral::Enum("B".into())).unwrap(); + let arr = d.flush(None).unwrap(); + let dict = arr + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(dict.len(), 1); + let expected = Int32Array::from(vec![1]); + assert_eq!(dict.keys(), &expected); + let values = dict + .values() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(values.value(1), "B"); + } + + #[test] + fn test_default_append_uuid_and_type_error() { + let mut d = Decoder::Uuid(Vec::with_capacity(DEFAULT_CAPACITY)); + let uuid_str = "123e4567-e89b-12d3-a456-426614174000"; + d.append_default(&AvroLiteral::String(uuid_str.into())) + .unwrap(); + let arr_ref = d.flush(None).unwrap(); + let arr = arr_ref + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(arr.value_length(), 16); + assert_eq!(arr.len(), 1); + let mut d2 = Decoder::Uuid(Vec::with_capacity(DEFAULT_CAPACITY)); + let err = d2 + .append_default(&AvroLiteral::Bytes(vec![0u8; 16])) + .unwrap_err(); + assert!( + err.to_string().contains("Default for uuid must be string"), + "unexpected error: {err:?}" + ); + } + + #[test] + fn test_default_append_fixed_and_length_mismatch() { + let mut d = Decoder::Fixed(4, Vec::with_capacity(DEFAULT_CAPACITY)); + d.append_default(&AvroLiteral::Bytes(vec![1, 2, 3, 4])) + .unwrap(); + let arr_ref = d.flush(None).unwrap(); + let arr = arr_ref + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(arr.value_length(), 4); + assert_eq!(arr.value(0), &[1, 2, 3, 4]); + let mut d_err = Decoder::Fixed(4, Vec::with_capacity(DEFAULT_CAPACITY)); + let err = d_err + .append_default(&AvroLiteral::Bytes(vec![1, 2, 3])) + .unwrap_err(); + assert!( + err.to_string().contains("Fixed default length"), + "unexpected error: {err:?}" + ); + } + + #[test] + fn test_default_append_duration_and_length_validation() { + let dt = avro_from_codec(Codec::Interval); + let mut d = Decoder::try_new(&dt).unwrap(); + let mut bytes = Vec::with_capacity(12); + bytes.extend_from_slice(&1u32.to_le_bytes()); + bytes.extend_from_slice(&2u32.to_le_bytes()); + bytes.extend_from_slice(&3u32.to_le_bytes()); + d.append_default(&AvroLiteral::Bytes(bytes)).unwrap(); + let arr_ref = d.flush(None).unwrap(); + let arr = arr_ref + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(arr.len(), 1); + let v = arr.value(0); + assert_eq!(v.months, 1); + assert_eq!(v.days, 2); + assert_eq!(v.nanoseconds, 3_000_000); + let mut d_err = Decoder::try_new(&avro_from_codec(Codec::Interval)).unwrap(); + let err = d_err + .append_default(&AvroLiteral::Bytes(vec![0u8; 11])) + .unwrap_err(); + assert!( + err.to_string() + .contains("Duration default must be exactly 12 bytes"), + "unexpected error: {err:?}" + ); + } + + #[test] + fn test_default_append_decimal256_from_bytes() { + let dt = avro_from_codec(Codec::Decimal(50, Some(2), Some(32))); + let mut d = Decoder::try_new(&dt).unwrap(); + let pos: [u8; 32] = [ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x30, 0x39, + ]; + d.append_default(&AvroLiteral::Bytes(pos.to_vec())).unwrap(); + let neg: [u8; 32] = [ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0xFF, 0xFF, 0x85, + ]; + d.append_default(&AvroLiteral::Bytes(neg.to_vec())).unwrap(); + let arr = d.flush(None).unwrap(); + let dec = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec.len(), 2); + assert_eq!(dec.value_as_string(0), "123.45"); + assert_eq!(dec.value_as_string(1), "-1.23"); + } + + #[test] + fn test_record_append_default_map_missing_fields_uses_projector_field_defaults() { + let field_defaults = vec![None, Some(AvroLiteral::String("hi".into()))]; + let mut rec = make_record_decoder_with_projector_defaults( + &[("a", DataType::Int32, false), ("b", DataType::Utf8, false)], + field_defaults, + vec![], + 0, + ); + let mut map: IndexMap = IndexMap::new(); + map.insert("a".to_string(), AvroLiteral::Int(7)); + rec.append_default(&AvroLiteral::Map(map)).unwrap(); + let arr = rec.flush(None).unwrap(); + let s = arr.as_any().downcast_ref::().unwrap(); + let a = s + .column_by_name("a") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let b = s + .column_by_name("b") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(a.value(0), 7); + assert_eq!(b.value(0), "hi"); + } + + #[test] + fn test_record_append_default_null_uses_projector_field_defaults() { + let field_defaults = vec![ + Some(AvroLiteral::Int(5)), + Some(AvroLiteral::String("x".into())), + ]; + let mut rec = make_record_decoder_with_projector_defaults( + &[("a", DataType::Int32, false), ("b", DataType::Utf8, false)], + field_defaults, + vec![], + 0, + ); + rec.append_default(&AvroLiteral::Null).unwrap(); + let arr = rec.flush(None).unwrap(); + let s = arr.as_any().downcast_ref::().unwrap(); + let a = s + .column_by_name("a") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let b = s + .column_by_name("b") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(a.value(0), 5); + assert_eq!(b.value(0), "x"); + } + + #[test] + fn test_record_append_default_missing_fields_without_projector_defaults_yields_type_nulls_or_empties() + { + let fields = vec![("a", DataType::Int32, true), ("b", DataType::Utf8, true)]; + let mut field_refs: Vec = Vec::new(); + let mut encoders: Vec = Vec::new(); + for (name, dt, nullable) in &fields { + field_refs.push(Arc::new(ArrowField::new(*name, dt.clone(), *nullable))); + } + let enc_a = Decoder::Nullable( + Nullability::NullSecond, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(Decoder::Int32(Vec::with_capacity(DEFAULT_CAPACITY))), + NullablePlan::ReadTag, + ); + let enc_b = Decoder::Nullable( + Nullability::NullSecond, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(Decoder::String( + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Vec::with_capacity(DEFAULT_CAPACITY), + )), + NullablePlan::ReadTag, + ); + encoders.push(enc_a); + encoders.push(enc_b); + let projector = Projector { + writer_to_reader: Arc::from(vec![]), + skip_decoders: vec![], + field_defaults: vec![None, None], // no defaults -> append_null + default_injections: Arc::from(Vec::<(usize, AvroLiteral)>::new()), + }; + let mut rec = Decoder::Record(field_refs.into(), encoders, Some(projector)); + let mut map: IndexMap = IndexMap::new(); + map.insert("a".to_string(), AvroLiteral::Int(9)); + rec.append_default(&AvroLiteral::Map(map)).unwrap(); + let arr = rec.flush(None).unwrap(); + let s = arr.as_any().downcast_ref::().unwrap(); + let a = s + .column_by_name("a") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let b = s + .column_by_name("b") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert!(a.is_valid(0)); + assert_eq!(a.value(0), 9); + assert!(b.is_null(0)); + } + + #[test] + fn test_projector_default_injection_when_writer_lacks_fields() { + let defaults = vec![None, None]; + let injections = vec![ + (0, AvroLiteral::Int(99)), + (1, AvroLiteral::String("alice".into())), + ]; + let mut rec = make_record_decoder_with_projector_defaults( + &[ + ("id", DataType::Int32, false), + ("name", DataType::Utf8, false), + ], + defaults, + injections, + 0, + ); + rec.decode(&mut AvroCursor::new(&[])).unwrap(); + let arr = rec.flush(None).unwrap(); + let s = arr.as_any().downcast_ref::().unwrap(); + let id = s + .column_by_name("id") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let name = s + .column_by_name("name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(id.value(0), 99); + assert_eq!(name.value(0), "alice"); + } + + #[test] + fn union_type_ids_are_not_child_indexes() { + let encodings: Vec = + vec![avro_from_codec(Codec::Int32), avro_from_codec(Codec::Utf8)]; + let fields: UnionFields = [ + (42_i8, Arc::new(ArrowField::new("a", DataType::Int32, true))), + (7_i8, Arc::new(ArrowField::new("b", DataType::Utf8, true))), + ] + .into_iter() + .collect(); + let dt = avro_from_codec(Codec::Union( + encodings.into(), + fields.clone(), + UnionMode::Dense, + )); + let mut dec = Decoder::try_new(&dt).expect("decoder"); + let mut b1 = encode_avro_long(1); + b1.extend(encode_avro_bytes("hi".as_bytes())); + dec.decode(&mut AvroCursor::new(&b1)).expect("decode b1"); + let mut b0 = encode_avro_long(0); + b0.extend(encode_avro_int(5)); + dec.decode(&mut AvroCursor::new(&b0)).expect("decode b0"); + let arr = dec.flush(None).expect("flush"); + let ua = arr.as_any().downcast_ref::().expect("union"); + assert_eq!(ua.len(), 2); + assert_eq!(ua.type_id(0), 7, "type id must come from UnionFields"); + assert_eq!(ua.type_id(1), 42, "type id must come from UnionFields"); + assert_eq!(ua.value_offset(0), 0); + assert_eq!(ua.value_offset(1), 0); + let utf8_child = ua.child(7).as_any().downcast_ref::().unwrap(); + assert_eq!(utf8_child.len(), 1); + assert_eq!(utf8_child.value(0), "hi"); + let int_child = ua.child(42).as_any().downcast_ref::().unwrap(); + assert_eq!(int_child.len(), 1); + assert_eq!(int_child.value(0), 5); + let type_ids: Vec = fields.iter().map(|(tid, _)| tid).collect(); + assert_eq!(type_ids, vec![42_i8, 7_i8]); + } + + #[cfg(feature = "avro_custom_types")] + #[test] + fn skipper_from_avro_maps_custom_duration_variants_to_int64() -> Result<(), ArrowError> { + for codec in [ + Codec::DurationNanos, + Codec::DurationMicros, + Codec::DurationMillis, + Codec::DurationSeconds, + ] { + let dt = make_avro_dt(codec.clone(), None); + let s = Skipper::from_avro(&dt)?; + match s { + Skipper::Int64 => {} + other => panic!("expected Int64 skipper for {:?}, got {:?}", codec, other), + } + } + Ok(()) + } + + #[cfg(feature = "avro_custom_types")] + #[test] + fn skipper_skip_consumes_one_long_for_custom_durations() -> Result<(), ArrowError> { + let values: [i64; 7] = [0, 1, -1, 150, -150, i64::MAX / 3, i64::MIN / 3]; + for codec in [ + Codec::DurationNanos, + Codec::DurationMicros, + Codec::DurationMillis, + Codec::DurationSeconds, + ] { + let dt = make_avro_dt(codec.clone(), None); + let mut s = Skipper::from_avro(&dt)?; + for &v in &values { + let bytes = encode_avro_long(v); + let mut cursor = AvroCursor::new(&bytes); + s.skip(&mut cursor)?; + assert_eq!( + cursor.position(), + bytes.len(), + "did not consume all bytes for {:?} value {}", + codec, + v + ); + } + } + Ok(()) + } + + #[cfg(feature = "avro_custom_types")] + #[test] + fn skipper_nullable_custom_duration_respects_null_first() -> Result<(), ArrowError> { + let dt = make_avro_dt(Codec::DurationNanos, Some(Nullability::NullFirst)); + let mut s = Skipper::from_avro(&dt)?; + match &s { + Skipper::Nullable(Nullability::NullFirst, inner) => match **inner { + Skipper::Int64 => {} + ref other => panic!("expected inner Int64, got {:?}", other), + }, + other => panic!("expected Nullable(NullFirst, Int64), got {:?}", other), + } + { + let buf = encode_vlq_u64(0); + let mut cursor = AvroCursor::new(&buf); + s.skip(&mut cursor)?; + assert_eq!(cursor.position(), 1, "expected to consume only tag=0"); + } + { + let mut buf = encode_vlq_u64(1); + buf.extend(encode_avro_long(0)); + let mut cursor = AvroCursor::new(&buf); + s.skip(&mut cursor)?; + assert_eq!(cursor.position(), 2, "expected to consume tag=1 + long(0)"); + } + + Ok(()) + } + + #[cfg(feature = "avro_custom_types")] + #[test] + fn skipper_nullable_custom_duration_respects_null_second() -> Result<(), ArrowError> { + let dt = make_avro_dt(Codec::DurationMicros, Some(Nullability::NullSecond)); + let mut s = Skipper::from_avro(&dt)?; + match &s { + Skipper::Nullable(Nullability::NullSecond, inner) => match **inner { + Skipper::Int64 => {} + ref other => panic!("expected inner Int64, got {:?}", other), + }, + other => panic!("expected Nullable(NullSecond, Int64), got {:?}", other), + } + { + let buf = encode_vlq_u64(1); + let mut cursor = AvroCursor::new(&buf); + s.skip(&mut cursor)?; + assert_eq!(cursor.position(), 1, "expected to consume only tag=1"); + } + { + let mut buf = encode_vlq_u64(0); + buf.extend(encode_avro_long(-1)); + let mut cursor = AvroCursor::new(&buf); + s.skip(&mut cursor)?; + assert_eq!( + cursor.position(), + 1 + encode_avro_long(-1).len(), + "expected to consume tag=0 + long(-1)" + ); + } + Ok(()) + } + + #[test] + fn skipper_interval_is_fixed12_and_skips_12_bytes() -> Result<(), ArrowError> { + let dt = make_avro_dt(Codec::Interval, None); + let mut s = Skipper::from_avro(&dt)?; + match s { + Skipper::DurationFixed12 => {} + other => panic!("expected DurationFixed12, got {:?}", other), + } + let payload = vec![0u8; 12]; + let mut cursor = AvroCursor::new(&payload); + s.skip(&mut cursor)?; + assert_eq!(cursor.position(), 12, "expected to consume 12 fixed bytes"); + Ok(()) + } + + #[cfg(feature = "avro_custom_types")] + #[test] + fn test_run_end_encoded_width16_int32_basic_grouping() { + use arrow_array::RunArray; + use std::sync::Arc; + let inner = avro_from_codec(Codec::Int32); + let ree = AvroDataType::new( + Codec::RunEndEncoded(Arc::new(inner), 16), + Default::default(), + None, + ); + let mut dec = Decoder::try_new(&ree).expect("create REE decoder"); + for v in [1, 1, 1, 2, 2, 3, 3, 3, 3] { + let bytes = encode_avro_int(v); + dec.decode(&mut AvroCursor::new(&bytes)).expect("decode"); + } + let arr = dec.flush(None).expect("flush"); + let ra = arr + .as_any() + .downcast_ref::>() + .expect("RunArray"); + assert_eq!(ra.len(), 9); + assert_eq!(ra.run_ends().values(), &[3, 5, 9]); + let vals = ra + .values() + .as_ref() + .as_any() + .downcast_ref::() + .expect("values Int32"); + assert_eq!(vals.values(), &[1, 2, 3]); + } + + #[cfg(feature = "avro_custom_types")] + #[test] + fn test_run_end_encoded_width32_nullable_values_group_nulls() { + use arrow_array::RunArray; + use std::sync::Arc; + let inner = AvroDataType::new( + Codec::Int32, + Default::default(), + Some(Nullability::NullSecond), + ); + let ree = AvroDataType::new( + Codec::RunEndEncoded(Arc::new(inner), 32), + Default::default(), + None, + ); + let mut dec = Decoder::try_new(&ree).expect("create REE decoder"); + let seq: [Option; 8] = [ + None, + None, + Some(7), + Some(7), + Some(7), + None, + Some(5), + Some(5), + ]; + for item in seq { + let mut bytes = Vec::new(); + match item { + None => bytes.extend_from_slice(&encode_vlq_u64(1)), + Some(v) => { + bytes.extend_from_slice(&encode_vlq_u64(0)); + bytes.extend_from_slice(&encode_avro_int(v)); + } + } + dec.decode(&mut AvroCursor::new(&bytes)).expect("decode"); + } + let arr = dec.flush(None).expect("flush"); + let ra = arr + .as_any() + .downcast_ref::>() + .expect("RunArray"); + assert_eq!(ra.len(), 8); + assert_eq!(ra.run_ends().values(), &[2, 5, 6, 8]); + let vals = ra + .values() + .as_ref() + .as_any() + .downcast_ref::() + .expect("values Int32 (nullable)"); + assert_eq!(vals.len(), 4); + assert!(vals.is_null(0)); + assert_eq!(vals.value(1), 7); + assert!(vals.is_null(2)); + assert_eq!(vals.value(3), 5); + } + + #[cfg(feature = "avro_custom_types")] + #[test] + fn test_run_end_encoded_decode_with_promotion_int_to_double_via_nullable_from_single() { + use arrow_array::RunArray; + let inner_values = Decoder::Float64(Vec::with_capacity(DEFAULT_CAPACITY)); + let ree = Decoder::RunEndEncoded( + 8, /* bytes => Int64 run-ends */ + 0, + Box::new(inner_values), + ); + let mut dec = Decoder::Nullable( + Nullability::NullSecond, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(ree), + NullablePlan::FromSingle { + promotion: Promotion::IntToDouble, + }, + ); + for v in [1, 1, 2, 2, 2] { + let bytes = encode_avro_int(v); + dec.decode(&mut AvroCursor::new(&bytes)).expect("decode"); + } + let arr = dec.flush(None).expect("flush"); + let ra = arr + .as_any() + .downcast_ref::>() + .expect("RunArray"); + assert_eq!(ra.len(), 5); + assert_eq!(ra.run_ends().values(), &[2, 5]); + let vals = ra + .values() + .as_ref() + .as_any() + .downcast_ref::() + .expect("values Float64"); + assert_eq!(vals.values(), &[1.0, 2.0]); + } + + #[cfg(feature = "avro_custom_types")] + #[test] + fn test_run_end_encoded_unsupported_run_end_width_errors() { + use std::sync::Arc; + let inner = avro_from_codec(Codec::Int32); + let dt = AvroDataType::new( + Codec::RunEndEncoded(Arc::new(inner), 3), + Default::default(), + None, + ); + let err = Decoder::try_new(&dt).expect_err("must reject unsupported width"); + let msg = err.to_string(); + assert!( + msg.contains("Unsupported run-end width") + && msg.contains("16/32/64 bits or 2/4/8 bytes"), + "unexpected error message: {msg}" + ); + } + + #[cfg(feature = "avro_custom_types")] + #[test] + fn test_run_end_encoded_empty_input_is_empty_runarray() { + use arrow_array::RunArray; + use std::sync::Arc; + let inner = avro_from_codec(Codec::Utf8); + let dt = AvroDataType::new( + Codec::RunEndEncoded(Arc::new(inner), 4), + Default::default(), + None, + ); + let mut dec = Decoder::try_new(&dt).expect("create REE decoder"); + let arr = dec.flush(None).expect("flush"); + let ra = arr + .as_any() + .downcast_ref::>() + .expect("RunArray"); + assert_eq!(ra.len(), 0); + assert_eq!(ra.run_ends().len(), 0); + assert_eq!(ra.values().len(), 0); + } + + #[cfg(feature = "avro_custom_types")] + #[test] + fn test_run_end_encoded_strings_grouping_width32_bits() { + use arrow_array::RunArray; + use std::sync::Arc; + let inner = avro_from_codec(Codec::Utf8); + let dt = AvroDataType::new( + Codec::RunEndEncoded(Arc::new(inner), 32), + Default::default(), + None, + ); + let mut dec = Decoder::try_new(&dt).expect("create REE decoder"); + for s in ["a", "a", "bb", "bb", "bb", "a"] { + let bytes = encode_avro_bytes(s.as_bytes()); + dec.decode(&mut AvroCursor::new(&bytes)).expect("decode"); + } + let arr = dec.flush(None).expect("flush"); + let ra = arr + .as_any() + .downcast_ref::>() + .expect("RunArray"); + assert_eq!(ra.run_ends().values(), &[2, 5, 6]); + let vals = ra + .values() + .as_ref() + .as_any() + .downcast_ref::() + .expect("values String"); + assert_eq!(vals.len(), 3); + assert_eq!(vals.value(0), "a"); + assert_eq!(vals.value(1), "bb"); + assert_eq!(vals.value(2), "a"); + } + + #[cfg(not(feature = "avro_custom_types"))] + #[test] + fn test_no_custom_types_feature_smoke_decodes_plain_int32() { + let dt = avro_from_codec(Codec::Int32); + let mut dec = Decoder::try_new(&dt).expect("create Int32 decoder"); + for v in [1, 2, 3] { + let bytes = encode_avro_int(v); + dec.decode(&mut AvroCursor::new(&bytes)).expect("decode"); + } + let arr = dec.flush(None).expect("flush"); + let a = arr + .as_any() + .downcast_ref::() + .expect("Int32Array"); + assert_eq!(a.values(), &[1, 2, 3]); + } + + #[test] + fn test_timestamp_nanos_decoding_utc() { + let avro_type = avro_from_codec(Codec::TimestampNanos(true)); + let mut decoder = Decoder::try_new(&avro_type).expect("create TimestampNanos decoder"); + let mut data = Vec::new(); + for v in [0_i64, 1_i64, -1_i64, 1_234_567_890_i64] { + data.extend_from_slice(&encode_avro_long(v)); + } + let mut cur = AvroCursor::new(&data); + for _ in 0..4 { + decoder.decode(&mut cur).expect("decode nanos ts"); + } + let array = decoder.flush(None).expect("flush nanos ts"); + let ts = array + .as_any() + .downcast_ref::() + .expect("TimestampNanosecondArray"); + assert_eq!(ts.values(), &[0, 1, -1, 1_234_567_890]); + match ts.data_type() { + DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, tz) => { + assert_eq!(tz.as_deref(), Some("+00:00")); + } + other => panic!("expected Timestamp(Nanosecond, Some(\"+00:00\")), got {other:?}"), + } + } + + #[test] + fn test_timestamp_nanos_decoding_local() { + let avro_type = avro_from_codec(Codec::TimestampNanos(false)); + let mut decoder = Decoder::try_new(&avro_type).expect("create TimestampNanos decoder"); + let mut data = Vec::new(); + for v in [10_i64, 20_i64, -30_i64] { + data.extend_from_slice(&encode_avro_long(v)); + } + let mut cur = AvroCursor::new(&data); + for _ in 0..3 { + decoder.decode(&mut cur).expect("decode nanos ts"); + } + let array = decoder.flush(None).expect("flush nanos ts"); + let ts = array + .as_any() + .downcast_ref::() + .expect("TimestampNanosecondArray"); + assert_eq!(ts.values(), &[10, 20, -30]); + match ts.data_type() { + DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, tz) => { + assert_eq!(tz.as_deref(), None); + } + other => panic!("expected Timestamp(Nanosecond, None), got {other:?}"), + } + } + + #[test] + fn test_timestamp_nanos_decoding_with_nulls() { + let avro_type = AvroDataType::new( + Codec::TimestampNanos(false), + Default::default(), + Some(Nullability::NullFirst), + ); + let mut decoder = Decoder::try_new(&avro_type).expect("create nullable TimestampNanos"); + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_long(1)); + data.extend_from_slice(&encode_avro_long(42)); + data.extend_from_slice(&encode_avro_long(0)); + data.extend_from_slice(&encode_avro_long(1)); + data.extend_from_slice(&encode_avro_long(-7)); + let mut cur = AvroCursor::new(&data); + for _ in 0..3 { + decoder.decode(&mut cur).expect("decode nullable nanos ts"); + } + let array = decoder.flush(None).expect("flush nullable nanos ts"); + let ts = array + .as_any() + .downcast_ref::() + .expect("TimestampNanosecondArray"); + assert_eq!(ts.len(), 3); + assert!(ts.is_valid(0)); + assert!(ts.is_null(1)); + assert!(ts.is_valid(2)); + assert_eq!(ts.value(0), 42); + assert_eq!(ts.value(2), -7); + match ts.data_type() { + DataType::Timestamp(arrow_schema::TimeUnit::Nanosecond, tz) => { + assert_eq!(tz.as_deref(), None); + } + other => panic!("expected Timestamp(Nanosecond, None), got {other:?}"), + } + } } diff --git a/arrow-avro/src/reader/vlq.rs b/arrow-avro/src/reader/vlq.rs index b198a0d66f24..c0b471b466ea 100644 --- a/arrow-avro/src/reader/vlq.rs +++ b/arrow-avro/src/reader/vlq.rs @@ -84,7 +84,7 @@ fn read_varint_array(buf: [u8; 10]) -> Option<(u64, usize)> { #[cold] fn read_varint_slow(buf: &[u8]) -> Option<(u64, usize)> { let mut value = 0; - for (count, byte) in buf.iter().take(10).enumerate() { + for (count, _byte) in buf.iter().take(10).enumerate() { let byte = buf[count]; value |= u64::from(byte & 0x7F) << (count * 7); if byte <= 0x7F { diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs index c3e4549c8c38..819ea1f16e9b 100644 --- a/arrow-avro/src/schema.rs +++ b/arrow-avro/src/schema.rs @@ -15,12 +15,69 @@ // specific language governing permissions and limitations // under the License. +//! Avro Schema representations for Arrow. + +#[cfg(feature = "canonical_extension_types")] +use arrow_schema::extension::ExtensionType; +use arrow_schema::{ + ArrowError, DataType, Field as ArrowField, IntervalUnit, Schema as ArrowSchema, TimeUnit, + UnionMode, +}; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; +use serde_json::{Map as JsonMap, Value, json}; +#[cfg(feature = "sha256")] +use sha2::{Digest, Sha256}; +use std::borrow::Cow; +use std::cmp::PartialEq; +use std::collections::hash_map::Entry; +use std::collections::{HashMap, HashSet}; +use strum_macros::AsRefStr; + +/// The Avro single‑object encoding “magic” bytes (`0xC3 0x01`) +pub const SINGLE_OBJECT_MAGIC: [u8; 2] = [0xC3, 0x01]; + +/// The Confluent "magic" byte (`0x00`) +pub const CONFLUENT_MAGIC: [u8; 1] = [0x00]; + +/// The maximum possible length of a prefix. +/// SHA256 (32) + single-object magic (2) +pub const MAX_PREFIX_LEN: usize = 34; -/// The metadata key used for storing the JSON encoded [`Schema`] +/// The metadata key used for storing the JSON encoded `Schema` pub const SCHEMA_METADATA_KEY: &str = "avro.schema"; +/// Metadata key used to represent Avro enum symbols in an Arrow schema. +pub const AVRO_ENUM_SYMBOLS_METADATA_KEY: &str = "avro.enum.symbols"; + +/// Metadata key used to store the default value of a field in an Avro schema. +pub const AVRO_FIELD_DEFAULT_METADATA_KEY: &str = "avro.field.default"; + +/// Metadata key used to store the name of a type in an Avro schema. +pub const AVRO_NAME_METADATA_KEY: &str = "avro.name"; + +/// Metadata key used to store the name of a type in an Avro schema. +pub const AVRO_NAMESPACE_METADATA_KEY: &str = "avro.namespace"; + +/// Metadata key used to store the documentation for a type in an Avro schema. +pub const AVRO_DOC_METADATA_KEY: &str = "avro.doc"; + +/// Default name for the root record in an Avro schema. +pub const AVRO_ROOT_RECORD_DEFAULT_NAME: &str = "topLevelRecord"; + +/// Avro types are not nullable, with nullability instead encoded as a union +/// where one of the variants is the null type. +/// +/// To accommodate this, we specially case two-variant unions where one of the +/// variants is the null type, and use this to derive arrow's notion of nullability +#[derive(Debug, Copy, Clone, PartialEq, Default)] +pub(crate) enum Nullability { + /// The nulls are encoded as the first union variant + #[default] + NullFirst, + /// The nulls are encoded as the second union variant + NullSecond, +} + /// Either a [`PrimitiveType`] or a reference to a previously defined named type /// /// @@ -29,7 +86,7 @@ pub const SCHEMA_METADATA_KEY: &str = "avro.schema"; /// A type name in an Avro schema /// /// This represents the different ways a type can be referenced in an Avro schema. -pub enum TypeName<'a> { +pub(crate) enum TypeName<'a> { /// A primitive type like null, boolean, int, etc. Primitive(PrimitiveType), /// A reference to another named type @@ -39,9 +96,10 @@ pub enum TypeName<'a> { /// A primitive type /// /// -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, AsRefStr)] #[serde(rename_all = "camelCase")] -pub enum PrimitiveType { +#[strum(serialize_all = "lowercase")] +pub(crate) enum PrimitiveType { /// null: no value Null, /// boolean: a binary value @@ -60,21 +118,21 @@ pub enum PrimitiveType { String, } -/// Additional attributes within a [`Schema`] +/// Additional attributes within a `Schema` /// /// #[derive(Debug, Clone, PartialEq, Eq, Default, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] -pub struct Attributes<'a> { +pub(crate) struct Attributes<'a> { /// A logical type name /// /// #[serde(default)] - pub logical_type: Option<&'a str>, + pub(crate) logical_type: Option<&'a str>, /// Additional JSON attributes #[serde(flatten)] - pub additional: HashMap<&'a str, serde_json::Value>, + pub(crate) additional: HashMap<&'a str, Value>, } impl Attributes<'_> { @@ -90,13 +148,13 @@ impl Attributes<'_> { /// A type definition that is not a variant of [`ComplexType`] #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] -pub struct Type<'a> { +pub(crate) struct Type<'a> { /// The type of this Avro data structure #[serde(borrow)] - pub r#type: TypeName<'a>, + pub(crate) r#type: TypeName<'a>, /// Additional attributes associated with this type #[serde(flatten)] - pub attributes: Attributes<'a>, + pub(crate) attributes: Attributes<'a>, } /// An Avro schema @@ -105,7 +163,7 @@ pub struct Type<'a> { /// See for more details. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(untagged)] -pub enum Schema<'a> { +pub(crate) enum Schema<'a> { /// A direct type name (primitive or reference) #[serde(borrow)] TypeName(TypeName<'a>), @@ -125,7 +183,7 @@ pub enum Schema<'a> { /// #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "camelCase")] -pub enum ComplexType<'a> { +pub(crate) enum ComplexType<'a> { /// Record type: a sequence of fields with names and types #[serde(borrow)] Record(Record<'a>), @@ -147,125 +205,1580 @@ pub enum ComplexType<'a> { /// /// #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct Record<'a> { +pub(crate) struct Record<'a> { /// Name of the record #[serde(borrow)] - pub name: &'a str, + pub(crate) name: &'a str, /// Optional namespace for the record, provides a way to organize names #[serde(borrow, default)] - pub namespace: Option<&'a str>, + pub(crate) namespace: Option<&'a str>, /// Optional documentation string for the record #[serde(borrow, default)] - pub doc: Option<&'a str>, + pub(crate) doc: Option>, /// Alternative names for this record #[serde(borrow, default)] - pub aliases: Vec<&'a str>, + pub(crate) aliases: Vec<&'a str>, /// The fields contained in this record #[serde(borrow)] - pub fields: Vec>, + pub(crate) fields: Vec>, /// Additional attributes for this record #[serde(flatten)] - pub attributes: Attributes<'a>, + pub(crate) attributes: Attributes<'a>, +} + +fn deserialize_default<'de, D>(deserializer: D) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + Value::deserialize(deserializer).map(Some) } /// A field within a [`Record`] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct Field<'a> { +pub(crate) struct Field<'a> { /// Name of the field within the record #[serde(borrow)] - pub name: &'a str, + pub(crate) name: &'a str, /// Optional documentation for this field #[serde(borrow, default)] - pub doc: Option<&'a str>, + pub(crate) doc: Option>, /// The field's type definition #[serde(borrow)] - pub r#type: Schema<'a>, + pub(crate) r#type: Schema<'a>, /// Optional default value for this field + #[serde(deserialize_with = "deserialize_default", default)] + pub(crate) default: Option, + /// Alternative names (aliases) for this field (Avro spec: field-level aliases). + /// Borrowed from input JSON where possible. #[serde(borrow, default)] - pub default: Option<&'a str>, + pub(crate) aliases: Vec<&'a str>, } /// An enumeration /// /// #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct Enum<'a> { +pub(crate) struct Enum<'a> { /// Name of the enum #[serde(borrow)] - pub name: &'a str, + pub(crate) name: &'a str, /// Optional namespace for the enum, provides organizational structure #[serde(borrow, default)] - pub namespace: Option<&'a str>, + pub(crate) namespace: Option<&'a str>, /// Optional documentation string describing the enum #[serde(borrow, default)] - pub doc: Option<&'a str>, + pub(crate) doc: Option>, /// Alternative names for this enum #[serde(borrow, default)] - pub aliases: Vec<&'a str>, + pub(crate) aliases: Vec<&'a str>, /// The symbols (values) that this enum can have #[serde(borrow)] - pub symbols: Vec<&'a str>, + pub(crate) symbols: Vec<&'a str>, /// Optional default value for this enum #[serde(borrow, default)] - pub default: Option<&'a str>, + pub(crate) default: Option<&'a str>, /// Additional attributes for this enum #[serde(flatten)] - pub attributes: Attributes<'a>, + pub(crate) attributes: Attributes<'a>, } /// An array /// /// #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct Array<'a> { +pub(crate) struct Array<'a> { /// The schema for items in this array #[serde(borrow)] - pub items: Box>, + pub(crate) items: Box>, /// Additional attributes for this array #[serde(flatten)] - pub attributes: Attributes<'a>, + pub(crate) attributes: Attributes<'a>, } /// A map /// /// #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct Map<'a> { +pub(crate) struct Map<'a> { /// The schema for values in this map #[serde(borrow)] - pub values: Box>, + pub(crate) values: Box>, /// Additional attributes for this map #[serde(flatten)] - pub attributes: Attributes<'a>, + pub(crate) attributes: Attributes<'a>, } /// A fixed length binary array /// /// #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct Fixed<'a> { +pub(crate) struct Fixed<'a> { /// Name of the fixed type #[serde(borrow)] - pub name: &'a str, + pub(crate) name: &'a str, /// Optional namespace for the fixed type #[serde(borrow, default)] - pub namespace: Option<&'a str>, + pub(crate) namespace: Option<&'a str>, /// Alternative names for this fixed type #[serde(borrow, default)] - pub aliases: Vec<&'a str>, + pub(crate) aliases: Vec<&'a str>, /// The number of bytes in this fixed type - pub size: usize, + pub(crate) size: usize, /// Additional attributes for this fixed type #[serde(flatten)] - pub attributes: Attributes<'a>, + pub(crate) attributes: Attributes<'a>, +} + +#[derive(Debug, Copy, Clone, PartialEq, Default)] +pub(crate) struct AvroSchemaOptions { + pub(crate) null_order: Option, + pub(crate) strip_metadata: bool, +} + +/// A wrapper for an Avro schema in its JSON string representation. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct AvroSchema { + /// The Avro schema as a JSON string. + pub json_string: String, +} + +impl TryFrom<&ArrowSchema> for AvroSchema { + type Error = ArrowError; + + /// Converts an `ArrowSchema` to `AvroSchema`, delegating to + /// `AvroSchema::from_arrow_with_options` with `None` so that the + /// union null ordering is decided by `Nullability::default()`. + fn try_from(schema: &ArrowSchema) -> Result { + AvroSchema::from_arrow_with_options(schema, None) + } +} + +impl AvroSchema { + /// Creates a new `AvroSchema` from a JSON string. + pub fn new(json_string: String) -> Self { + Self { json_string } + } + + pub(crate) fn schema(&self) -> Result, ArrowError> { + serde_json::from_str(self.json_string.as_str()) + .map_err(|e| ArrowError::ParseError(format!("Invalid Avro schema JSON: {e}"))) + } + + /// Returns the fingerprint of the schema, computed using the specified [`FingerprintAlgorithm`]. + /// + /// The fingerprint is computed over the schema's Parsed Canonical Form + /// as defined by the Avro specification. Depending on `hash_type`, this + /// will return one of the supported [`Fingerprint`] variants: + /// - [`Fingerprint::Rabin`] for [`FingerprintAlgorithm::Rabin`] + /// - `Fingerprint::MD5` for `FingerprintAlgorithm::MD5` + /// - `Fingerprint::SHA256` for `FingerprintAlgorithm::SHA256` + /// + /// Note: [`FingerprintAlgorithm::Id`] or [`FingerprintAlgorithm::Id64`] cannot be used to generate a fingerprint + /// and will result in an error. If you intend to use a Schema Registry ID-based + /// wire format, either use [`SchemaStore::set`] or load the [`Fingerprint::Id`] directly via [`Fingerprint::load_fingerprint_id`] or for + /// [`Fingerprint::Id64`] via [`Fingerprint::load_fingerprint_id64`]. + /// + /// See also: + /// + /// # Errors + /// Returns an error if deserializing the schema fails, if generating the + /// canonical form of the schema fails, or if `hash_type` is [`FingerprintAlgorithm::Id`]. + /// + /// # Examples + /// ``` + /// use arrow_avro::schema::{AvroSchema, FingerprintAlgorithm}; + /// + /// let avro = AvroSchema::new("\"string\"".to_string()); + /// let fp = avro.fingerprint(FingerprintAlgorithm::Rabin).unwrap(); + /// ``` + pub fn fingerprint(&self, hash_type: FingerprintAlgorithm) -> Result { + Self::generate_fingerprint(&self.schema()?, hash_type) + } + + pub(crate) fn generate_fingerprint( + schema: &Schema, + hash_type: FingerprintAlgorithm, + ) -> Result { + let canonical = Self::generate_canonical_form(schema).map_err(|e| { + ArrowError::ComputeError(format!("Failed to generate canonical form for schema: {e}")) + })?; + match hash_type { + FingerprintAlgorithm::Rabin => { + Ok(Fingerprint::Rabin(compute_fingerprint_rabin(&canonical))) + } + FingerprintAlgorithm::Id | FingerprintAlgorithm::Id64 => Err(ArrowError::SchemaError( + "FingerprintAlgorithm of Id or Id64 cannot be used to generate a fingerprint; \ + if using Fingerprint::Id, pass the registry ID in instead using the set method." + .to_string(), + )), + #[cfg(feature = "md5")] + FingerprintAlgorithm::MD5 => Ok(Fingerprint::MD5(compute_fingerprint_md5(&canonical))), + #[cfg(feature = "sha256")] + FingerprintAlgorithm::SHA256 => { + Ok(Fingerprint::SHA256(compute_fingerprint_sha256(&canonical))) + } + } + } + + /// Generates the Parsed Canonical Form for the given `Schema`. + /// + /// The canonical form is a standardized JSON representation of the schema, + /// primarily used for generating a schema fingerprint for equality checking. + /// + /// This form strips attributes that do not affect the schema's identity, + /// such as `doc` fields, `aliases`, and any properties not defined in the + /// Avro specification. + /// + /// + pub(crate) fn generate_canonical_form(schema: &Schema) -> Result { + build_canonical(schema, None) + } + + /// Build Avro JSON from an Arrow [`ArrowSchema`], applying the given null‑union order and optionally stripping internal Arrow metadata. + /// + /// If the input Arrow schema already contains Avro JSON in + /// [`SCHEMA_METADATA_KEY`], that JSON is returned verbatim to preserve + /// the exact header encoding alignment; otherwise, a new JSON is generated + /// honoring `null_union_order` at **all nullable sites**. + pub(crate) fn from_arrow_with_options( + schema: &ArrowSchema, + options: Option, + ) -> Result { + let opts = options.unwrap_or_default(); + let order = opts.null_order.unwrap_or_default(); + let strip = opts.strip_metadata; + if !strip { + if let Some(json) = schema.metadata.get(SCHEMA_METADATA_KEY) { + return Ok(AvroSchema::new(json.clone())); + } + } + let mut name_gen = NameGenerator::default(); + let fields_json = schema + .fields() + .iter() + .map(|f| arrow_field_to_avro(f, &mut name_gen, order, strip)) + .collect::, _>>()?; + let record_name = schema + .metadata + .get(AVRO_NAME_METADATA_KEY) + .map_or(AVRO_ROOT_RECORD_DEFAULT_NAME, |s| s.as_str()); + let mut record = JsonMap::with_capacity(schema.metadata.len() + 4); + record.insert("type".into(), Value::String("record".into())); + record.insert( + "name".into(), + Value::String(sanitise_avro_name(record_name)), + ); + if let Some(ns) = schema.metadata.get(AVRO_NAMESPACE_METADATA_KEY) { + record.insert("namespace".into(), Value::String(ns.clone())); + } + if let Some(doc) = schema.metadata.get(AVRO_DOC_METADATA_KEY) { + record.insert("doc".into(), Value::String(doc.clone())); + } + record.insert("fields".into(), Value::Array(fields_json)); + extend_with_passthrough_metadata(&mut record, &schema.metadata); + let json_string = serde_json::to_string(&Value::Object(record)) + .map_err(|e| ArrowError::SchemaError(format!("Serializing Avro JSON failed: {e}")))?; + Ok(AvroSchema::new(json_string)) + } +} + +/// A stack-allocated, fixed-size buffer for the prefix. +#[derive(Debug, Copy, Clone)] +pub(crate) struct Prefix { + buf: [u8; MAX_PREFIX_LEN], + len: u8, +} + +impl Prefix { + #[inline] + pub(crate) fn as_slice(&self) -> &[u8] { + &self.buf[..self.len as usize] + } +} + +/// Defines the strategy for generating the per-record prefix for an Avro binary stream. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] +pub enum FingerprintStrategy { + /// Use the 64-bit Rabin fingerprint (default for single-object encoding). + #[default] + Rabin, + /// Use a Confluent Schema Registry 32-bit ID. + Id(u32), + /// Use an Apicurio Schema Registry 64-bit ID. + Id64(u64), + #[cfg(feature = "md5")] + /// Use the 128-bit MD5 fingerprint. + MD5, + #[cfg(feature = "sha256")] + /// Use the 256-bit SHA-256 fingerprint. + SHA256, +} + +impl From for FingerprintStrategy { + fn from(f: Fingerprint) -> Self { + Self::from(&f) + } +} + +impl From for FingerprintStrategy { + fn from(f: FingerprintAlgorithm) -> Self { + match f { + FingerprintAlgorithm::Rabin => FingerprintStrategy::Rabin, + FingerprintAlgorithm::Id => FingerprintStrategy::Id(0), + FingerprintAlgorithm::Id64 => FingerprintStrategy::Id64(0), + #[cfg(feature = "md5")] + FingerprintAlgorithm::MD5 => FingerprintStrategy::MD5, + #[cfg(feature = "sha256")] + FingerprintAlgorithm::SHA256 => FingerprintStrategy::SHA256, + } + } +} + +impl From<&Fingerprint> for FingerprintStrategy { + fn from(f: &Fingerprint) -> Self { + match f { + Fingerprint::Rabin(_) => FingerprintStrategy::Rabin, + Fingerprint::Id(_) => FingerprintStrategy::Id(0), + Fingerprint::Id64(_) => FingerprintStrategy::Id64(0), + #[cfg(feature = "md5")] + Fingerprint::MD5(_) => FingerprintStrategy::MD5, + #[cfg(feature = "sha256")] + Fingerprint::SHA256(_) => FingerprintStrategy::SHA256, + } + } +} + +/// Supported fingerprint algorithms for Avro schema identification. +/// For use with Confluent Schema Registry IDs, set to None. +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)] +pub enum FingerprintAlgorithm { + /// 64‑bit CRC‑64‑AVRO Rabin fingerprint. + #[default] + Rabin, + /// Represents a 32 bit fingerprint not based on a hash algorithm, (e.g., a 32-bit Schema Registry ID.) + Id, + /// Represents a 64 bit fingerprint not based on a hash algorithm, (e.g., a 64-bit Schema Registry ID.) + Id64, + #[cfg(feature = "md5")] + /// 128-bit MD5 message digest. + MD5, + #[cfg(feature = "sha256")] + /// 256-bit SHA-256 digest. + SHA256, +} + +/// Allow easy extraction of the algorithm used to create a fingerprint. +impl From<&Fingerprint> for FingerprintAlgorithm { + fn from(fp: &Fingerprint) -> Self { + match fp { + Fingerprint::Rabin(_) => FingerprintAlgorithm::Rabin, + Fingerprint::Id(_) => FingerprintAlgorithm::Id, + Fingerprint::Id64(_) => FingerprintAlgorithm::Id64, + #[cfg(feature = "md5")] + Fingerprint::MD5(_) => FingerprintAlgorithm::MD5, + #[cfg(feature = "sha256")] + Fingerprint::SHA256(_) => FingerprintAlgorithm::SHA256, + } + } +} + +impl From for FingerprintAlgorithm { + fn from(s: FingerprintStrategy) -> Self { + Self::from(&s) + } +} + +impl From<&FingerprintStrategy> for FingerprintAlgorithm { + fn from(s: &FingerprintStrategy) -> Self { + match s { + FingerprintStrategy::Rabin => FingerprintAlgorithm::Rabin, + FingerprintStrategy::Id(_) => FingerprintAlgorithm::Id, + FingerprintStrategy::Id64(_) => FingerprintAlgorithm::Id64, + #[cfg(feature = "md5")] + FingerprintStrategy::MD5 => FingerprintAlgorithm::MD5, + #[cfg(feature = "sha256")] + FingerprintStrategy::SHA256 => FingerprintAlgorithm::SHA256, + } + } +} + +/// A schema fingerprint in one of the supported formats. +/// +/// This is used as the key inside `SchemaStore` `HashMap`. Each `SchemaStore` +/// instance always stores only one variant, matching its configured +/// `FingerprintAlgorithm`, but the enum makes the API uniform. +/// +/// +/// +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum Fingerprint { + /// A 64-bit Rabin fingerprint. + Rabin(u64), + /// A 32-bit Schema Registry ID. + Id(u32), + /// A 64-bit Schema Registry ID. + Id64(u64), + #[cfg(feature = "md5")] + /// A 128-bit MD5 fingerprint. + MD5([u8; 16]), + #[cfg(feature = "sha256")] + /// A 256-bit SHA-256 fingerprint. + SHA256([u8; 32]), +} + +impl From for Fingerprint { + fn from(s: FingerprintStrategy) -> Self { + Self::from(&s) + } +} + +impl From<&FingerprintStrategy> for Fingerprint { + fn from(s: &FingerprintStrategy) -> Self { + match s { + FingerprintStrategy::Rabin => Fingerprint::Rabin(0), + FingerprintStrategy::Id(id) => Fingerprint::Id(*id), + FingerprintStrategy::Id64(id) => Fingerprint::Id64(*id), + #[cfg(feature = "md5")] + FingerprintStrategy::MD5 => Fingerprint::MD5([0; 16]), + #[cfg(feature = "sha256")] + FingerprintStrategy::SHA256 => Fingerprint::SHA256([0; 32]), + } + } +} + +impl From for Fingerprint { + fn from(s: FingerprintAlgorithm) -> Self { + match s { + FingerprintAlgorithm::Rabin => Fingerprint::Rabin(0), + FingerprintAlgorithm::Id => Fingerprint::Id(0), + FingerprintAlgorithm::Id64 => Fingerprint::Id64(0), + #[cfg(feature = "md5")] + FingerprintAlgorithm::MD5 => Fingerprint::MD5([0; 16]), + #[cfg(feature = "sha256")] + FingerprintAlgorithm::SHA256 => Fingerprint::SHA256([0; 32]), + } + } +} + +impl Fingerprint { + /// Loads the 32-bit Schema Registry fingerprint (Confluent Schema Registry ID). + /// + /// The provided `id` is in big-endian wire order; this converts it to host order + /// and returns `Fingerprint::Id`. + /// + /// # Returns + /// A `Fingerprint::Id` variant containing the 32-bit fingerprint. + pub fn load_fingerprint_id(id: u32) -> Self { + Fingerprint::Id(u32::from_be(id)) + } + + /// Loads the 64-bit Schema Registry fingerprint (Apicurio Schema Registry ID). + /// + /// The provided `id` is in big-endian wire order; this converts it to host order + /// and returns `Fingerprint::Id64`. + /// + /// # Returns + /// A `Fingerprint::Id64` variant containing the 64-bit fingerprint. + pub fn load_fingerprint_id64(id: u64) -> Self { + Fingerprint::Id64(u64::from_be(id)) + } + + /// Constructs a serialized prefix represented as a `Vec` based on the variant of the enum. + /// + /// This method serializes data in different formats depending on the variant of `self`: + /// - **`Id(id)`**: Uses the Confluent wire format, which includes a predefined magic header (`CONFLUENT_MAGIC`) + /// followed by the big-endian byte representation of the `id`. + /// - **`Id64(id)`**: Uses the Apicurio wire format, which includes a predefined magic header (`CONFLUENT_MAGIC`) + /// followed by the big-endian 8-byte representation of the `id`. + /// - **`Rabin(val)`**: Uses the Avro single-object specification format. This includes a different magic header + /// (`SINGLE_OBJECT_MAGIC`) followed by the little-endian byte representation of the `val`. + /// - **`MD5(bytes)`** (optional, `md5` feature enabled): A non-standard extension that adds the + /// `SINGLE_OBJECT_MAGIC` header followed by the provided `bytes`. + /// - **`SHA256(bytes)`** (optional, `sha256` feature enabled): Similar to the `MD5` variant, this is + /// a non-standard extension that attaches the `SINGLE_OBJECT_MAGIC` header followed by the given `bytes`. + /// + /// # Returns + /// + /// A `Prefix` containing the serialized prefix data. + /// + /// # Features + /// + /// - You can optionally enable the `md5` feature to include the `MD5` variant. + /// - You can optionally enable the `sha256` feature to include the `SHA256` variant. + /// + pub(crate) fn make_prefix(&self) -> Prefix { + let mut buf = [0u8; MAX_PREFIX_LEN]; + let len = match self { + Self::Id(val) => write_prefix(&mut buf, &CONFLUENT_MAGIC, &val.to_be_bytes()), + Self::Id64(val) => write_prefix(&mut buf, &CONFLUENT_MAGIC, &val.to_be_bytes()), + Self::Rabin(val) => write_prefix(&mut buf, &SINGLE_OBJECT_MAGIC, &val.to_le_bytes()), + #[cfg(feature = "md5")] + Self::MD5(val) => write_prefix(&mut buf, &SINGLE_OBJECT_MAGIC, val), + #[cfg(feature = "sha256")] + Self::SHA256(val) => write_prefix(&mut buf, &SINGLE_OBJECT_MAGIC, val), + }; + Prefix { buf, len } + } +} + +fn write_prefix( + buf: &mut [u8; MAX_PREFIX_LEN], + magic: &[u8; MAGIC_LEN], + payload: &[u8; PAYLOAD_LEN], +) -> u8 { + debug_assert!(MAGIC_LEN + PAYLOAD_LEN <= MAX_PREFIX_LEN); + let total = MAGIC_LEN + PAYLOAD_LEN; + let prefix_slice = &mut buf[..total]; + prefix_slice[..MAGIC_LEN].copy_from_slice(magic); + prefix_slice[MAGIC_LEN..total].copy_from_slice(payload); + total as u8 +} + +/// An in-memory cache of Avro schemas, indexed by their fingerprint. +/// +/// `SchemaStore` provides a mechanism to store and retrieve Avro schemas efficiently. +/// Each schema is associated with a unique [`Fingerprint`], which is generated based +/// on the schema's canonical form and a specific hashing algorithm. +/// +/// A `SchemaStore` instance is configured to use a single [`FingerprintAlgorithm`] such as Rabin, +/// MD5 (not yet supported), or SHA256 (not yet supported) for all its operations. +/// This ensures consistency when generating fingerprints and looking up schemas. +/// All schemas registered will have their fingerprint computed with this algorithm, and +/// lookups must use a matching fingerprint. +/// +/// # Examples +/// +/// ```no_run +/// // Create a new store with the default Rabin fingerprinting. +/// use arrow_avro::schema::{AvroSchema, SchemaStore}; +/// +/// let mut store = SchemaStore::new(); +/// let schema = AvroSchema::new("\"string\"".to_string()); +/// // Register the schema to get its fingerprint. +/// let fingerprint = store.register(schema.clone()).unwrap(); +/// // Use the fingerprint to look up the schema. +/// let retrieved_schema = store.lookup(&fingerprint).cloned(); +/// assert_eq!(retrieved_schema, Some(schema)); +/// ``` +#[derive(Debug, Clone, Default)] +pub struct SchemaStore { + /// The hashing algorithm used for generating fingerprints. + fingerprint_algorithm: FingerprintAlgorithm, + /// A map from a schema's fingerprint to the schema itself. + schemas: HashMap, +} + +impl TryFrom> for SchemaStore { + type Error = ArrowError; + + /// Creates a `SchemaStore` from a HashMap of schemas. + /// Each schema in the HashMap is registered with the new store. + fn try_from(schemas: HashMap) -> Result { + Ok(Self { + schemas, + ..Self::default() + }) + } +} + +impl SchemaStore { + /// Creates an empty `SchemaStore` using the default fingerprinting algorithm (64-bit Rabin). + pub fn new() -> Self { + Self::default() + } + + /// Creates an empty `SchemaStore` using the default fingerprinting algorithm (64-bit Rabin). + pub fn new_with_type(fingerprint_algorithm: FingerprintAlgorithm) -> Self { + Self { + fingerprint_algorithm, + ..Self::default() + } + } + + /// Registers a schema with the store and the provided fingerprint. + /// Note: Confluent wire format implementations should leverage this method. + /// + /// A schema is set in the store, using the provided fingerprint. If a schema + /// with the same fingerprint does not already exist in the store, the new schema + /// is inserted. If the fingerprint already exists, the existing schema is not overwritten. + /// + /// # Arguments + /// + /// * `fingerprint` - A reference to the `Fingerprint` of the schema to register. + /// * `schema` - The `AvroSchema` to register. + /// + /// # Returns + /// + /// A `Result` returning the provided `Fingerprint` of the schema if successful, + /// or an `ArrowError` on failure. + pub fn set( + &mut self, + fingerprint: Fingerprint, + schema: AvroSchema, + ) -> Result { + match self.schemas.entry(fingerprint) { + Entry::Occupied(entry) => { + if entry.get() != &schema { + return Err(ArrowError::ComputeError(format!( + "Schema fingerprint collision detected for fingerprint {fingerprint:?}" + ))); + } + } + Entry::Vacant(entry) => { + entry.insert(schema); + } + } + Ok(fingerprint) + } + + /// Registers a schema with the store and returns its fingerprint. + /// + /// A fingerprint is calculated for the given schema using the store's configured + /// hash type. If a schema with the same fingerprint does not already exist in the + /// store, the new schema is inserted. If the fingerprint already exists, the + /// existing schema is not overwritten. If FingerprintAlgorithm is set to Id or Id64, this + /// method will return an error. Confluent wire format implementations should leverage the + /// set method instead. + /// + /// # Arguments + /// + /// * `schema` - The `AvroSchema` to register. + /// + /// # Returns + /// + /// A `Result` containing the `Fingerprint` of the schema if successful, + /// or an `ArrowError` on failure. + pub fn register(&mut self, schema: AvroSchema) -> Result { + if self.fingerprint_algorithm == FingerprintAlgorithm::Id + || self.fingerprint_algorithm == FingerprintAlgorithm::Id64 + { + return Err(ArrowError::SchemaError( + "Invalid FingerprintAlgorithm; unable to generate fingerprint. \ + Use the set method directly instead, providing a valid fingerprint" + .to_string(), + )); + } + let fingerprint = + AvroSchema::generate_fingerprint(&schema.schema()?, self.fingerprint_algorithm)?; + self.set(fingerprint, schema)?; + Ok(fingerprint) + } + + /// Looks up a schema by its `Fingerprint`. + /// + /// # Arguments + /// + /// * `fingerprint` - A reference to the `Fingerprint` of the schema to look up. + /// + /// # Returns + /// + /// An `Option` containing a clone of the `AvroSchema` if found, otherwise `None`. + pub fn lookup(&self, fingerprint: &Fingerprint) -> Option<&AvroSchema> { + self.schemas.get(fingerprint) + } + + /// Returns a `Vec` containing **all unique [`Fingerprint`]s** currently + /// held by this [`SchemaStore`]. + /// + /// The order of the returned fingerprints is unspecified and should not be + /// relied upon. + pub fn fingerprints(&self) -> Vec { + self.schemas.keys().copied().collect() + } + + /// Returns the `FingerprintAlgorithm` used by the `SchemaStore` for fingerprinting. + pub(crate) fn fingerprint_algorithm(&self) -> FingerprintAlgorithm { + self.fingerprint_algorithm + } +} + +fn quote(s: &str) -> Result { + serde_json::to_string(s) + .map_err(|e| ArrowError::ComputeError(format!("Failed to quote string: {e}"))) +} + +// Avro names are defined by a `name` and an optional `namespace`. +// The full name is composed of the namespace and the name, separated by a dot. +// +// Avro specification defines two ways to specify a full name: +// 1. The `name` attribute contains the full name (e.g., "a.b.c.d"). +// In this case, the `namespace` attribute is ignored. +// 2. The `name` attribute contains the simple name (e.g., "d") and the +// `namespace` attribute contains the namespace (e.g., "a.b.c"). +// +// Each part of the name must match the regex `^[A-Za-z_][A-Za-z0-9_]*$`. +// Complex paths with quotes or backticks like `a."hi".b` are not supported. +// +// This function constructs the full name and extracts the namespace, +// handling both ways of specifying the name. It prioritizes a namespace +// defined within the `name` attribute itself, then the explicit `namespace_attr`, +// and finally the `enclosing_ns`. +pub(crate) fn make_full_name( + name: &str, + namespace_attr: Option<&str>, + enclosing_ns: Option<&str>, +) -> (String, Option) { + // `name` already contains a dot then treat as full-name, ignore namespace. + if let Some((ns, _)) = name.rsplit_once('.') { + return (name.to_string(), Some(ns.to_string())); + } + match namespace_attr.or(enclosing_ns) { + Some(ns) => (format!("{ns}.{name}"), Some(ns.to_string())), + None => (name.to_string(), None), + } +} + +fn build_canonical(schema: &Schema, enclosing_ns: Option<&str>) -> Result { + Ok(match schema { + Schema::TypeName(tn) | Schema::Type(Type { r#type: tn, .. }) => match tn { + TypeName::Primitive(pt) => quote(pt.as_ref())?, + TypeName::Ref(name) => { + let (full_name, _) = make_full_name(name, None, enclosing_ns); + quote(&full_name)? + } + }, + Schema::Union(branches) => format!( + "[{}]", + branches + .iter() + .map(|b| build_canonical(b, enclosing_ns)) + .collect::, _>>()? + .join(",") + ), + Schema::Complex(ct) => match ct { + ComplexType::Record(r) => { + let (full_name, child_ns) = make_full_name(r.name, r.namespace, enclosing_ns); + let fields = r + .fields + .iter() + .map(|f| { + // PCF [STRIP] per Avro spec: keep only attributes relevant to parsing + // ("name" and "type" for fields) and **strip others** such as doc, + // default, order, and **aliases**. This preserves canonicalization. See: + // https://avro.apache.org/docs/1.11.1/specification/#parsing-canonical-form-for-schemas + let field_type = + build_canonical(&f.r#type, child_ns.as_deref().or(enclosing_ns))?; + Ok(format!( + r#"{{"name":{},"type":{}}}"#, + quote(f.name)?, + field_type + )) + }) + .collect::, ArrowError>>()? + .join(","); + format!( + r#"{{"name":{},"type":"record","fields":[{fields}]}}"#, + quote(&full_name)?, + ) + } + ComplexType::Enum(e) => { + let (full_name, _) = make_full_name(e.name, e.namespace, enclosing_ns); + let symbols = e + .symbols + .iter() + .map(|s| quote(s)) + .collect::, _>>()? + .join(","); + format!( + r#"{{"name":{},"type":"enum","symbols":[{symbols}]}}"#, + quote(&full_name)? + ) + } + ComplexType::Array(arr) => format!( + r#"{{"type":"array","items":{}}}"#, + build_canonical(&arr.items, enclosing_ns)? + ), + ComplexType::Map(map) => format!( + r#"{{"type":"map","values":{}}}"#, + build_canonical(&map.values, enclosing_ns)? + ), + ComplexType::Fixed(f) => { + let (full_name, _) = make_full_name(f.name, f.namespace, enclosing_ns); + format!( + r#"{{"name":{},"type":"fixed","size":{}}}"#, + quote(&full_name)?, + f.size + ) + } + }, + }) +} + +/// 64‑bit Rabin fingerprint as described in the Avro spec. +const EMPTY: u64 = 0xc15d_213a_a4d7_a795; + +/// Build one entry of the polynomial‑division table. +/// +/// We cannot yet write `for _ in 0..8` here: `for` loops rely on +/// `Iterator::next`, which is not `const` on stable Rust. Until the +/// `const_for` feature (tracking issue #87575) is stabilized, a `while` +/// loop is the only option in a `const fn` +const fn one_entry(i: usize) -> u64 { + let mut fp = i as u64; + let mut j = 0; + while j < 8 { + fp = (fp >> 1) ^ (EMPTY & (0u64.wrapping_sub(fp & 1))); + j += 1; + } + fp +} + +/// Build the full 256‑entry table at compile time. +/// +/// We cannot yet write `for _ in 0..256` here: `for` loops rely on +/// `Iterator::next`, which is not `const` on stable Rust. Until the +/// `const_for` feature (tracking issue #87575) is stabilized, a `while` +/// loop is the only option in a `const fn` +const fn build_table() -> [u64; 256] { + let mut table = [0u64; 256]; + let mut i = 0; + while i < 256 { + table[i] = one_entry(i); + i += 1; + } + table +} + +/// The pre‑computed table. +static FINGERPRINT_TABLE: [u64; 256] = build_table(); + +/// Computes the 64-bit Rabin fingerprint for a given canonical schema string. +/// This implementation is based on the Avro specification for schema fingerprinting. +pub(crate) fn compute_fingerprint_rabin(canonical_form: &str) -> u64 { + let mut fp = EMPTY; + for &byte in canonical_form.as_bytes() { + let idx = ((fp as u8) ^ byte) as usize; + fp = (fp >> 8) ^ FINGERPRINT_TABLE[idx]; + } + fp +} + +#[cfg(feature = "md5")] +/// Compute the **128‑bit MD5** fingerprint of the canonical form. +/// +/// Returns a 16‑byte array (`[u8; 16]`) containing the full MD5 digest, +/// exactly as required by the Avro specification. +#[inline] +pub(crate) fn compute_fingerprint_md5(canonical_form: &str) -> [u8; 16] { + let digest = md5::compute(canonical_form.as_bytes()); + digest.0 +} + +#[cfg(feature = "sha256")] +/// Compute the **256‑bit SHA‑256** fingerprint of the canonical form. +/// +/// Returns a 32‑byte array (`[u8; 32]`) containing the full SHA‑256 digest. +#[inline] +pub(crate) fn compute_fingerprint_sha256(canonical_form: &str) -> [u8; 32] { + let mut hasher = Sha256::new(); + hasher.update(canonical_form.as_bytes()); + let digest = hasher.finalize(); + digest.into() +} + +#[inline] +fn is_internal_arrow_key(key: &str) -> bool { + key.starts_with("ARROW:") || key == SCHEMA_METADATA_KEY +} + +/// Copies Arrow schema metadata entries to the provided JSON map, +/// skipping keys that are Avro-reserved, internal Arrow keys, or +/// nested under the `avro.schema.` namespace. Values that parse as +/// JSON are inserted as JSON; otherwise the raw string is preserved. +fn extend_with_passthrough_metadata( + target: &mut JsonMap, + metadata: &HashMap, +) { + for (meta_key, meta_val) in metadata { + if meta_key.starts_with("avro.") || is_internal_arrow_key(meta_key) { + continue; + } + let json_val = + serde_json::from_str(meta_val).unwrap_or_else(|_| Value::String(meta_val.clone())); + target.insert(meta_key.clone(), json_val); + } +} + +// Sanitize an arbitrary string so it is a valid Avro field or type name +fn sanitise_avro_name(base_name: &str) -> String { + if base_name.is_empty() { + return "_".to_owned(); + } + let mut out: String = base_name + .chars() + .map(|char| { + if char.is_ascii_alphanumeric() || char == '_' { + char + } else { + '_' + } + }) + .collect(); + if out.as_bytes()[0].is_ascii_digit() { + out.insert(0, '_'); + } + out +} + +#[derive(Default)] +struct NameGenerator { + used: HashSet, + counters: HashMap, +} + +impl NameGenerator { + fn make_unique(&mut self, field_name: &str) -> String { + let field_name = sanitise_avro_name(field_name); + if self.used.insert(field_name.clone()) { + self.counters.insert(field_name.clone(), 1); + return field_name; + } + let counter = self.counters.entry(field_name.clone()).or_insert(1); + loop { + let candidate = format!("{field_name}_{}", *counter); + if self.used.insert(candidate.clone()) { + return candidate; + } + *counter += 1; + } + } +} + +fn merge_extras(schema: Value, extras: JsonMap) -> Value { + if extras.is_empty() { + return schema; + } + match schema { + Value::Object(mut map) => { + map.extend(extras); + Value::Object(map) + } + Value::Array(mut union) => { + // For unions, we cannot attach attributes to the array itself (per Avro spec). + // As a fallback for extension metadata, attach extras to the first non-null branch object. + if let Some(non_null) = union.iter_mut().find(|val| val.as_str() != Some("null")) { + let original = std::mem::take(non_null); + *non_null = merge_extras(original, extras); + } + Value::Array(union) + } + primitive => { + let mut map = JsonMap::with_capacity(extras.len() + 1); + map.insert("type".into(), primitive); + map.extend(extras); + Value::Object(map) + } + } +} + +#[inline] +fn is_avro_json_null(v: &Value) -> bool { + matches!(v, Value::String(s) if s == "null") +} + +fn wrap_nullable(inner: Value, null_order: Nullability) -> Value { + let null = Value::String("null".into()); + match inner { + Value::Array(mut union) => { + // If this site is already a union and already contains "null", + // preserve the branch order exactly. Reordering "null" breaks + // the correspondence between Arrow union child order (type_ids) + // and the Avro branch index written on the wire. + if union.iter().any(is_avro_json_null) { + return Value::Array(union); + } + // Otherwise, inject "null" without reordering existing branches. + match null_order { + Nullability::NullFirst => union.insert(0, null), + Nullability::NullSecond => union.push(null), + } + Value::Array(union) + } + other => match null_order { + Nullability::NullFirst => Value::Array(vec![null, other]), + Nullability::NullSecond => Value::Array(vec![other, null]), + }, + } +} + +fn min_fixed_bytes_for_precision(p: usize) -> usize { + // From the spec: max precision for n=1..=32 bytes: + // [2,4,6,9,11,14,16,18,21,23,26,28,31,33,35,38,40,43,45,47,50,52,55,57,59,62,64,67,69,71,74,76] + const MAX_P: [usize; 32] = [ + 2, 4, 6, 9, 11, 14, 16, 18, 21, 23, 26, 28, 31, 33, 35, 38, 40, 43, 45, 47, 50, 52, 55, 57, + 59, 62, 64, 67, 69, 71, 74, 76, + ]; + for (i, &max_p) in MAX_P.iter().enumerate() { + if p <= max_p { + return i + 1; + } + } + 32 // saturate at Decimal256 +} + +fn union_branch_signature(branch: &Value) -> Result { + match branch { + Value::String(t) => Ok(format!("P:{t}")), + Value::Object(map) => { + let t = map.get("type").and_then(|v| v.as_str()).ok_or_else(|| { + ArrowError::SchemaError("Union branch object missing string 'type'".into()) + })?; + match t { + "record" | "enum" | "fixed" => { + let name = map.get("name").and_then(|v| v.as_str()).ok_or_else(|| { + ArrowError::SchemaError(format!( + "Union branch '{t}' missing required 'name'" + )) + })?; + Ok(format!("N:{t}:{name}")) + } + "array" | "map" => Ok(format!("C:{t}")), + other => Ok(format!("P:{other}")), + } + } + Value::Array(_) => Err(ArrowError::SchemaError( + "Avro union may not immediately contain another union".into(), + )), + _ => Err(ArrowError::SchemaError( + "Invalid JSON for Avro union branch".into(), + )), + } +} + +fn datatype_to_avro( + dt: &DataType, + field_name: &str, + metadata: &HashMap, + name_gen: &mut NameGenerator, + null_order: Nullability, + strip: bool, +) -> Result<(Value, JsonMap), ArrowError> { + let mut extras = JsonMap::new(); + let mut handle_decimal = |precision: &u8, scale: &i8| -> Result { + if *scale < 0 { + return Err(ArrowError::SchemaError(format!( + "Invalid Avro decimal for field '{field_name}': scale ({scale}) must be >= 0" + ))); + } + if (*scale as usize) > (*precision as usize) { + return Err(ArrowError::SchemaError(format!( + "Invalid Avro decimal for field '{field_name}': scale ({scale}) \ + must be <= precision ({precision})" + ))); + } + let mut meta = JsonMap::from_iter([ + ("logicalType".into(), json!("decimal")), + ("precision".into(), json!(*precision)), + ("scale".into(), json!(*scale)), + ]); + let mut fixed_size = metadata.get("size").and_then(|v| v.parse::().ok()); + let carries_name = metadata.contains_key(AVRO_NAME_METADATA_KEY) + || metadata.contains_key(AVRO_NAMESPACE_METADATA_KEY); + if fixed_size.is_none() && carries_name { + fixed_size = Some(min_fixed_bytes_for_precision(*precision as usize)); + } + if let Some(size) = fixed_size { + meta.insert("type".into(), json!("fixed")); + meta.insert("size".into(), json!(size)); + let chosen_name = metadata + .get(AVRO_NAME_METADATA_KEY) + .map(|s| sanitise_avro_name(s)) + .unwrap_or_else(|| name_gen.make_unique(field_name)); + meta.insert("name".into(), json!(chosen_name)); + if let Some(ns) = metadata.get(AVRO_NAMESPACE_METADATA_KEY) { + meta.insert("namespace".into(), json!(ns)); + } + } else { + // default to bytes-backed decimal + meta.insert("type".into(), json!("bytes")); + } + Ok(Value::Object(meta)) + }; + let val = match dt { + DataType::Null => Value::String("null".into()), + DataType::Boolean => Value::String("boolean".into()), + DataType::Int8 | DataType::Int16 | DataType::UInt8 | DataType::UInt16 | DataType::Int32 => { + Value::String("int".into()) + } + DataType::UInt32 | DataType::Int64 | DataType::UInt64 => Value::String("long".into()), + DataType::Float16 | DataType::Float32 => Value::String("float".into()), + DataType::Float64 => Value::String("double".into()), + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View => Value::String("string".into()), + DataType::Binary | DataType::LargeBinary => Value::String("bytes".into()), + DataType::BinaryView => { + if !strip { + extras.insert("arrowBinaryView".into(), Value::Bool(true)); + } + Value::String("bytes".into()) + } + DataType::FixedSizeBinary(len) => { + let md_is_uuid = metadata + .get("logicalType") + .map(|s| s.trim_matches('"') == "uuid") + .unwrap_or(false); + #[cfg(feature = "canonical_extension_types")] + let ext_is_uuid = metadata + .get(arrow_schema::extension::EXTENSION_TYPE_NAME_KEY) + .map(|v| v == arrow_schema::extension::Uuid::NAME || v == "uuid") + .unwrap_or(false); + #[cfg(not(feature = "canonical_extension_types"))] + let ext_is_uuid = false; + let is_uuid = (*len == 16) && (md_is_uuid || ext_is_uuid); + if is_uuid { + json!({ "type": "string", "logicalType": "uuid" }) + } else { + let chosen_name = metadata + .get(AVRO_NAME_METADATA_KEY) + .map(|s| sanitise_avro_name(s)) + .unwrap_or_else(|| name_gen.make_unique(field_name)); + let mut obj = JsonMap::from_iter([ + ("type".into(), json!("fixed")), + ("name".into(), json!(chosen_name)), + ("size".into(), json!(len)), + ]); + if let Some(ns) = metadata.get(AVRO_NAMESPACE_METADATA_KEY) { + obj.insert("namespace".into(), json!(ns)); + } + Value::Object(obj) + } + } + #[cfg(feature = "small_decimals")] + DataType::Decimal32(precision, scale) | DataType::Decimal64(precision, scale) => { + handle_decimal(precision, scale)? + } + DataType::Decimal128(precision, scale) | DataType::Decimal256(precision, scale) => { + handle_decimal(precision, scale)? + } + DataType::Date32 => json!({ "type": "int", "logicalType": "date" }), + DataType::Date64 => json!({ "type": "long", "logicalType": "local-timestamp-millis" }), + DataType::Time32(unit) => match unit { + TimeUnit::Millisecond => json!({ "type": "int", "logicalType": "time-millis" }), + TimeUnit::Second => { + if !strip { + extras.insert("arrowTimeUnit".into(), Value::String("second".into())); + } + Value::String("int".into()) + } + _ => Value::String("int".into()), + }, + DataType::Time64(unit) => match unit { + TimeUnit::Microsecond => json!({ "type": "long", "logicalType": "time-micros" }), + TimeUnit::Nanosecond => { + if !strip { + extras.insert("arrowTimeUnit".into(), Value::String("nanosecond".into())); + } + Value::String("long".into()) + } + _ => Value::String("long".into()), + }, + DataType::Timestamp(unit, tz) => { + let logical_type = match (unit, tz.is_some()) { + (TimeUnit::Millisecond, true) => "timestamp-millis", + (TimeUnit::Millisecond, false) => "local-timestamp-millis", + (TimeUnit::Microsecond, true) => "timestamp-micros", + (TimeUnit::Microsecond, false) => "local-timestamp-micros", + (TimeUnit::Nanosecond, true) => "timestamp-nanos", + (TimeUnit::Nanosecond, false) => "local-timestamp-nanos", + (TimeUnit::Second, _) => { + if !strip { + extras.insert("arrowTimeUnit".into(), Value::String("second".into())); + } + return Ok((Value::String("long".into()), extras)); + } + }; + if !strip && matches!(unit, TimeUnit::Nanosecond) { + extras.insert("arrowTimeUnit".into(), Value::String("nanosecond".into())); + } + json!({ "type": "long", "logicalType": logical_type }) + } + #[cfg(not(feature = "avro_custom_types"))] + DataType::Duration(_unit) => Value::String("long".into()), + #[cfg(feature = "avro_custom_types")] + DataType::Duration(unit) => { + // When the feature is enabled, create an Avro schema object + // with the correct `logicalType` annotation. + let logical_type = match unit { + TimeUnit::Second => "arrow.duration-seconds", + TimeUnit::Millisecond => "arrow.duration-millis", + TimeUnit::Microsecond => "arrow.duration-micros", + TimeUnit::Nanosecond => "arrow.duration-nanos", + }; + json!({ "type": "long", "logicalType": logical_type }) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + // Avro duration logical type: fixed(12) with months/days/millis per spec. + let chosen_name = metadata + .get(AVRO_NAME_METADATA_KEY) + .map(|s| sanitise_avro_name(s)) + .unwrap_or_else(|| name_gen.make_unique(field_name)); + let mut obj = JsonMap::from_iter([ + ("type".into(), json!("fixed")), + ("name".into(), json!(chosen_name)), + ("size".into(), json!(12)), + ("logicalType".into(), json!("duration")), + ]); + if let Some(ns) = metadata.get(AVRO_NAMESPACE_METADATA_KEY) { + obj.insert("namespace".into(), json!(ns)); + } + json!(obj) + } + DataType::Interval(IntervalUnit::YearMonth) => { + if !strip { + extras.insert( + "arrowIntervalUnit".into(), + Value::String("yearmonth".into()), + ); + } + Value::String("long".into()) + } + DataType::Interval(IntervalUnit::DayTime) => { + if !strip { + extras.insert("arrowIntervalUnit".into(), Value::String("daytime".into())); + } + Value::String("long".into()) + } + DataType::List(child) | DataType::LargeList(child) => { + if matches!(dt, DataType::LargeList(_)) && !strip { + extras.insert("arrowLargeList".into(), Value::Bool(true)); + } + let items_schema = process_datatype( + child.data_type(), + child.name(), + child.metadata(), + name_gen, + null_order, + child.is_nullable(), + strip, + )?; + json!({ + "type": "array", + "items": items_schema + }) + } + DataType::ListView(child) | DataType::LargeListView(child) => { + if matches!(dt, DataType::LargeListView(_)) && !strip { + extras.insert("arrowLargeList".into(), Value::Bool(true)); + } + if !strip { + extras.insert("arrowListView".into(), Value::Bool(true)); + } + let items_schema = process_datatype( + child.data_type(), + child.name(), + child.metadata(), + name_gen, + null_order, + child.is_nullable(), + strip, + )?; + json!({ + "type": "array", + "items": items_schema + }) + } + DataType::FixedSizeList(child, len) => { + if !strip { + extras.insert("arrowFixedSize".into(), json!(len)); + } + let items_schema = process_datatype( + child.data_type(), + child.name(), + child.metadata(), + name_gen, + null_order, + child.is_nullable(), + strip, + )?; + json!({ + "type": "array", + "items": items_schema + }) + } + DataType::Map(entries, _) => { + let value_field = match entries.data_type() { + DataType::Struct(fs) => &fs[1], + _ => { + return Err(ArrowError::SchemaError( + "Map 'entries' field must be Struct(key,value)".into(), + )); + } + }; + let values_schema = process_datatype( + value_field.data_type(), + value_field.name(), + value_field.metadata(), + name_gen, + null_order, + value_field.is_nullable(), + strip, + )?; + json!({ + "type": "map", + "values": values_schema + }) + } + DataType::Struct(fields) => { + let avro_fields = fields + .iter() + .map(|field| arrow_field_to_avro(field, name_gen, null_order, strip)) + .collect::, _>>()?; + // Prefer avro.name/avro.namespace when provided on the struct field metadata + let chosen_name = metadata + .get(AVRO_NAME_METADATA_KEY) + .map(|s| sanitise_avro_name(s)) + .unwrap_or_else(|| name_gen.make_unique(field_name)); + let mut obj = JsonMap::from_iter([ + ("type".into(), json!("record")), + ("name".into(), json!(chosen_name)), + ("fields".into(), Value::Array(avro_fields)), + ]); + if let Some(ns) = metadata.get(AVRO_NAMESPACE_METADATA_KEY) { + obj.insert("namespace".into(), json!(ns)); + } + Value::Object(obj) + } + DataType::Dictionary(_, value) => { + if let Some(j) = metadata.get(AVRO_ENUM_SYMBOLS_METADATA_KEY) { + let symbols: Vec<&str> = + serde_json::from_str(j).map_err(|e| ArrowError::ParseError(e.to_string()))?; + // Prefer avro.name/namespace when provided for enums + let chosen_name = metadata + .get(AVRO_NAME_METADATA_KEY) + .map(|s| sanitise_avro_name(s)) + .unwrap_or_else(|| name_gen.make_unique(field_name)); + let mut obj = JsonMap::from_iter([ + ("type".into(), json!("enum")), + ("name".into(), json!(chosen_name)), + ("symbols".into(), json!(symbols)), + ]); + if let Some(ns) = metadata.get(AVRO_NAMESPACE_METADATA_KEY) { + obj.insert("namespace".into(), json!(ns)); + } + Value::Object(obj) + } else { + process_datatype( + value.as_ref(), + field_name, + metadata, + name_gen, + null_order, + false, + strip, + )? + } + } + #[cfg(feature = "avro_custom_types")] + DataType::RunEndEncoded(run_ends, values) => { + let bits = match run_ends.data_type() { + DataType::Int16 => 16, + DataType::Int32 => 32, + DataType::Int64 => 64, + other => { + return Err(ArrowError::SchemaError(format!( + "RunEndEncoded requires Int16/Int32/Int64 for run_ends, found: {other:?}" + ))); + } + }; + // Build the value site schema, preserving its own nullability + let (value_schema, value_extras) = datatype_to_avro( + values.data_type(), + values.name(), + values.metadata(), + name_gen, + null_order, + strip, + )?; + let mut merged = merge_extras(value_schema, value_extras); + if values.is_nullable() { + merged = wrap_nullable(merged, null_order); + } + let mut extras = JsonMap::new(); + extras.insert("logicalType".into(), json!("arrow.run-end-encoded")); + extras.insert("arrow.runEndIndexBits".into(), json!(bits)); + return Ok((merged, extras)); + } + #[cfg(not(feature = "avro_custom_types"))] + DataType::RunEndEncoded(_run_ends, values) => { + let (value_schema, _extras) = datatype_to_avro( + values.data_type(), + values.name(), + values.metadata(), + name_gen, + null_order, + strip, + )?; + return Ok((value_schema, JsonMap::new())); + } + DataType::Union(fields, mode) => { + let mut branches: Vec = Vec::with_capacity(fields.len()); + let mut type_ids: Vec = Vec::with_capacity(fields.len()); + for (type_id, field_ref) in fields.iter() { + // NOTE: `process_datatype` would wrap nullability; force is_nullable=false here. + let (branch_schema, _branch_extras) = datatype_to_avro( + field_ref.data_type(), + field_ref.name(), + field_ref.metadata(), + name_gen, + null_order, + strip, + )?; + // Avro unions cannot immediately contain another union + if matches!(branch_schema, Value::Array(_)) { + return Err(ArrowError::SchemaError( + "Avro union may not immediately contain another union".into(), + )); + } + branches.push(branch_schema); + type_ids.push(type_id as i32); + } + let mut seen: HashSet = HashSet::with_capacity(branches.len()); + for b in &branches { + let sig = union_branch_signature(b)?; + if !seen.insert(sig) { + return Err(ArrowError::SchemaError( + "Avro union contains duplicate branch types (disallowed by spec)".into(), + )); + } + } + if !strip { + extras.insert( + "arrowUnionMode".into(), + Value::String( + match mode { + UnionMode::Sparse => "sparse", + UnionMode::Dense => "dense", + } + .to_string(), + ), + ); + extras.insert( + "arrowUnionTypeIds".into(), + Value::Array(type_ids.into_iter().map(|id| json!(id)).collect()), + ); + } + Value::Array(branches) + } + #[cfg(not(feature = "small_decimals"))] + other => { + return Err(ArrowError::NotYetImplemented(format!( + "Arrow type {other:?} has no Avro representation" + ))); + } + }; + Ok((val, extras)) +} + +fn process_datatype( + dt: &DataType, + field_name: &str, + metadata: &HashMap, + name_gen: &mut NameGenerator, + null_order: Nullability, + is_nullable: bool, + strip: bool, +) -> Result { + let (schema, extras) = datatype_to_avro(dt, field_name, metadata, name_gen, null_order, strip)?; + let mut merged = merge_extras(schema, extras); + if is_nullable { + merged = wrap_nullable(merged, null_order) + } + Ok(merged) +} + +fn arrow_field_to_avro( + field: &ArrowField, + name_gen: &mut NameGenerator, + null_order: Nullability, + strip: bool, +) -> Result { + let avro_name = sanitise_avro_name(field.name()); + let schema_value = process_datatype( + field.data_type(), + &avro_name, + field.metadata(), + name_gen, + null_order, + field.is_nullable(), + strip, + )?; + // Build the field map + let mut map = JsonMap::with_capacity(field.metadata().len() + 3); + map.insert("name".into(), Value::String(avro_name)); + map.insert("type".into(), schema_value); + // Transfer selected metadata + for (meta_key, meta_val) in field.metadata() { + if is_internal_arrow_key(meta_key) { + continue; + } + match meta_key.as_str() { + AVRO_DOC_METADATA_KEY => { + map.insert("doc".into(), Value::String(meta_val.clone())); + } + AVRO_FIELD_DEFAULT_METADATA_KEY => { + let default_value = serde_json::from_str(meta_val) + .unwrap_or_else(|_| Value::String(meta_val.clone())); + map.insert("default".into(), default_value); + } + _ => { + let json_val = serde_json::from_str(meta_val) + .unwrap_or_else(|_| Value::String(meta_val.clone())); + map.insert(meta_key.clone(), json_val); + } + } + } + Ok(Value::Object(map)) } #[cfg(test)] mod tests { use super::*; - use crate::codec::{AvroDataType, AvroField}; - use arrow_schema::{DataType, Fields, TimeUnit}; + use crate::codec::{AvroField, AvroFieldBuilder}; + use arrow_schema::{DataType, Fields, SchemaBuilder, TimeUnit, UnionFields}; use serde_json::json; + use std::sync::Arc; + + fn int_schema() -> Schema<'static> { + Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)) + } + + fn record_schema() -> Schema<'static> { + Schema::Complex(ComplexType::Record(Record { + name: "record1", + namespace: Some("test.namespace"), + doc: Some(Cow::from("A test record")), + aliases: vec![], + fields: vec![ + Field { + name: "field1", + doc: Some(Cow::from("An integer field")), + r#type: int_schema(), + default: None, + aliases: vec![], + }, + Field { + name: "field2", + doc: None, + r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), + default: None, + aliases: vec![], + }, + ], + attributes: Attributes::default(), + })) + } + + fn single_field_schema(field: ArrowField) -> arrow_schema::Schema { + let mut sb = SchemaBuilder::new(); + sb.push(field); + sb.finish() + } + + fn assert_json_contains(avro_json: &str, needle: &str) { + assert!( + avro_json.contains(needle), + "JSON did not contain `{needle}` : {avro_json}" + ) + } #[test] fn test_deserialize() { @@ -370,6 +1883,7 @@ mod tests { Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), ]), default: None, + aliases: vec![], },], attributes: Default::default(), })) @@ -401,6 +1915,7 @@ mod tests { doc: None, r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)), default: None, + aliases: vec![], }, Field { name: "next", @@ -410,6 +1925,7 @@ mod tests { Schema::TypeName(TypeName::Ref("LongList")), ]), default: None, + aliases: vec![], } ], attributes: Attributes::default(), @@ -463,6 +1979,7 @@ mod tests { Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), ]), default: None, + aliases: vec![], }, Field { name: "timestamp_col", @@ -472,27 +1989,31 @@ mod tests { Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), ]), default: None, + aliases: vec![], } ], attributes: Default::default(), })) ); let codec = AvroField::try_from(&schema).unwrap(); - assert_eq!( - codec.field(), - arrow_schema::Field::new( - "topLevelRecord", - DataType::Struct(Fields::from(vec![ - arrow_schema::Field::new("id", DataType::Int32, true), - arrow_schema::Field::new( - "timestamp_col", - DataType::Timestamp(TimeUnit::Microsecond, Some("+00:00".into())), - true - ), - ])), - false - ) - ); + let expected_arrow_field = arrow_schema::Field::new( + "topLevelRecord", + DataType::Struct(Fields::from(vec![ + arrow_schema::Field::new("id", DataType::Int32, true), + arrow_schema::Field::new( + "timestamp_col", + DataType::Timestamp(TimeUnit::Microsecond, Some("+00:00".into())), + true, + ), + ])), + false, + ) + .with_metadata(std::collections::HashMap::from([( + AVRO_NAME_METADATA_KEY.to_string(), + "topLevelRecord".to_string(), + )])); + + assert_eq!(codec.field(), expected_arrow_field); let schema: Schema = serde_json::from_str( r#"{ @@ -527,6 +2048,7 @@ mod tests { attributes: Default::default(), })), default: None, + aliases: vec![], }, Field { name: "clientProtocol", @@ -536,12 +2058,14 @@ mod tests { Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), ]), default: None, + aliases: vec![], }, Field { name: "serverHash", doc: None, r#type: Schema::TypeName(TypeName::Ref("MD5")), default: None, + aliases: vec![], }, Field { name: "meta", @@ -556,10 +2080,1061 @@ mod tests { })), ]), default: None, + aliases: vec![], } ], attributes: Default::default(), })) ); } + + #[test] + fn test_canonical_form_generation_comprehensive_record() { + // NOTE: This schema is identical to the one used in test_deserialize_comprehensive. + let json_str = r#"{ + "type": "record", + "name": "E2eComprehensive", + "namespace": "org.apache.arrow.avrotests.v1", + "doc": "Comprehensive Avro writer schema to exercise arrow-avro Reader/Decoder paths.", + "fields": [ + {"name": "id", "type": "long", "doc": "Primary row id", "aliases": ["identifier"]}, + {"name": "flag", "type": "boolean", "default": true, "doc": "A sample boolean with default true"}, + {"name": "ratio_f32", "type": "float", "default": 0.0, "doc": "Float32 example"}, + {"name": "ratio_f64", "type": "double", "default": 0.0, "doc": "Float64 example"}, + {"name": "count_i32", "type": "int", "default": 0, "doc": "Int32 example"}, + {"name": "count_i64", "type": "long", "default": 0, "doc": "Int64 example"}, + {"name": "opt_i32_nullfirst", "type": ["null", "int"], "default": null, "doc": "Nullable int (null-first)"}, + {"name": "opt_str_nullsecond", "type": ["string", "null"], "default": "", "aliases": ["old_opt_str"], "doc": "Nullable string (null-second). Default is empty string."}, + {"name": "tri_union_prim", "type": ["int", "string", "boolean"], "default": 0, "doc": "Union[int, string, boolean] with default on first branch (int=0)."}, + {"name": "str_utf8", "type": "string", "default": "default", "doc": "Plain Utf8 string (Reader may use Utf8View)."}, + {"name": "raw_bytes", "type": "bytes", "default": "", "doc": "Raw bytes field"}, + {"name": "fx16_plain", "type": {"type": "fixed", "name": "Fx16", "namespace": "org.apache.arrow.avrotests.v1.types", "aliases": ["Fixed16Old"], "size": 16}, "doc": "Plain fixed(16)"}, + {"name": "dec_bytes_s10_2", "type": {"type": "bytes", "logicalType": "decimal", "precision": 10, "scale": 2}, "doc": "Decimal encoded on bytes, precision 10, scale 2"}, + {"name": "dec_fix_s20_4", "type": {"type": "fixed", "name": "DecFix20", "namespace": "org.apache.arrow.avrotests.v1.types", "size": 20, "logicalType": "decimal", "precision": 20, "scale": 4}, "doc": "Decimal encoded on fixed(20), precision 20, scale 4"}, + {"name": "uuid_str", "type": {"type": "string", "logicalType": "uuid"}, "doc": "UUID logical type on string"}, + {"name": "d_date", "type": {"type": "int", "logicalType": "date"}, "doc": "Date32: days since 1970-01-01"}, + {"name": "t_millis", "type": {"type": "int", "logicalType": "time-millis"}, "doc": "Time32-millis"}, + {"name": "t_micros", "type": {"type": "long", "logicalType": "time-micros"}, "doc": "Time64-micros"}, + {"name": "ts_millis_utc", "type": {"type": "long", "logicalType": "timestamp-millis"}, "doc": "Timestamp ms (UTC)"}, + {"name": "ts_micros_utc", "type": {"type": "long", "logicalType": "timestamp-micros"}, "doc": "Timestamp µs (UTC)"}, + {"name": "ts_millis_local", "type": {"type": "long", "logicalType": "local-timestamp-millis"}, "doc": "Local timestamp ms"}, + {"name": "ts_micros_local", "type": {"type": "long", "logicalType": "local-timestamp-micros"}, "doc": "Local timestamp µs"}, + {"name": "interval_mdn", "type": {"type": "fixed", "name": "Dur12", "namespace": "org.apache.arrow.avrotests.v1.types", "size": 12, "logicalType": "duration"}, "doc": "Duration: fixed(12) little-endian (months, days, millis)"}, + {"name": "status", "type": {"type": "enum", "name": "Status", "namespace": "org.apache.arrow.avrotests.v1.types", "symbols": ["UNKNOWN", "NEW", "PROCESSING", "DONE"], "aliases": ["State"], "doc": "Processing status enum with default"}, "default": "UNKNOWN", "doc": "Enum field using default when resolving"}, + {"name": "arr_union", "type": {"type": "array", "items": ["long", "string", "null"]}, "default": [], "doc": "Array whose items are a union[long,string,null]"}, + {"name": "map_union", "type": {"type": "map", "values": ["null", "double", "string"]}, "default": {}, "doc": "Map whose values are a union[null,double,string]"}, + {"name": "address", "type": {"type": "record", "name": "Address", "namespace": "org.apache.arrow.avrotests.v1.types", "doc": "Postal address with defaults and field alias", "fields": [ + {"name": "street", "type": "string", "default": "", "aliases": ["street_name"], "doc": "Street (field alias = street_name)"}, + {"name": "zip", "type": "int", "default": 0, "doc": "ZIP/postal code"}, + {"name": "country", "type": "string", "default": "US", "doc": "Country code"} + ]}, "doc": "Embedded Address record"}, + {"name": "maybe_auth", "type": {"type": "record", "name": "MaybeAuth", "namespace": "org.apache.arrow.avrotests.v1.types", "doc": "Optional auth token model", "fields": [ + {"name": "user", "type": "string", "doc": "Username"}, + {"name": "token", "type": ["null", "bytes"], "default": null, "doc": "Nullable auth token"} + ]}}, + {"name": "union_enum_record_array_map", "type": [ + {"type": "enum", "name": "Color", "namespace": "org.apache.arrow.avrotests.v1.types", "symbols": ["RED", "GREEN", "BLUE"], "doc": "Color enum"}, + {"type": "record", "name": "RecA", "namespace": "org.apache.arrow.avrotests.v1.types", "fields": [{"name": "a", "type": "int"}, {"name": "b", "type": "string"}]}, + {"type": "record", "name": "RecB", "namespace": "org.apache.arrow.avrotests.v1.types", "fields": [{"name": "x", "type": "long"}, {"name": "y", "type": "bytes"}]}, + {"type": "array", "items": "long"}, + {"type": "map", "values": "string"} + ], "doc": "Union of enum, two records, array, and map"}, + {"name": "union_date_or_fixed4", "type": [ + {"type": "int", "logicalType": "date"}, + {"type": "fixed", "name": "Fx4", "size": 4} + ], "doc": "Union of date(int) or fixed(4)"}, + {"name": "union_interval_or_string", "type": [ + {"type": "fixed", "name": "Dur12U", "size": 12, "logicalType": "duration"}, + "string" + ], "doc": "Union of duration(fixed12) or string"}, + {"name": "union_uuid_or_fixed10", "type": [ + {"type": "string", "logicalType": "uuid"}, + {"type": "fixed", "name": "Fx10", "size": 10} + ], "doc": "Union of UUID string or fixed(10)"}, + {"name": "array_records_with_union", "type": {"type": "array", "items": { + "type": "record", "name": "KV", "namespace": "org.apache.arrow.avrotests.v1.types", + "fields": [ + {"name": "key", "type": "string"}, + {"name": "val", "type": ["null", "int", "long"], "default": null} + ] + }}, "doc": "Array", "default": []}, + {"name": "union_map_or_array_int", "type": [ + {"type": "map", "values": "int"}, + {"type": "array", "items": "int"} + ], "doc": "Union[map, array]"}, + {"name": "renamed_with_default", "type": "int", "default": 42, "aliases": ["old_count"], "doc": "Field with alias and default"}, + {"name": "person", "type": {"type": "record", "name": "PersonV2", "namespace": "com.example.v2", "aliases": ["com.example.Person"], "doc": "Person record with alias pointing to previous namespace/name", "fields": [ + {"name": "name", "type": "string"}, + {"name": "age", "type": "int", "default": 0} + ]}, "doc": "Record using type alias for schema evolution tests"} + ] + }"#; + let avro = AvroSchema::new(json_str.to_string()); + let parsed = avro.schema().expect("schema should deserialize"); + let expected_canonical_form = r#"{"name":"org.apache.arrow.avrotests.v1.E2eComprehensive","type":"record","fields":[{"name":"id","type":"long"},{"name":"flag","type":"boolean"},{"name":"ratio_f32","type":"float"},{"name":"ratio_f64","type":"double"},{"name":"count_i32","type":"int"},{"name":"count_i64","type":"long"},{"name":"opt_i32_nullfirst","type":["null","int"]},{"name":"opt_str_nullsecond","type":["string","null"]},{"name":"tri_union_prim","type":["int","string","boolean"]},{"name":"str_utf8","type":"string"},{"name":"raw_bytes","type":"bytes"},{"name":"fx16_plain","type":{"name":"org.apache.arrow.avrotests.v1.types.Fx16","type":"fixed","size":16}},{"name":"dec_bytes_s10_2","type":"bytes"},{"name":"dec_fix_s20_4","type":{"name":"org.apache.arrow.avrotests.v1.types.DecFix20","type":"fixed","size":20}},{"name":"uuid_str","type":"string"},{"name":"d_date","type":"int"},{"name":"t_millis","type":"int"},{"name":"t_micros","type":"long"},{"name":"ts_millis_utc","type":"long"},{"name":"ts_micros_utc","type":"long"},{"name":"ts_millis_local","type":"long"},{"name":"ts_micros_local","type":"long"},{"name":"interval_mdn","type":{"name":"org.apache.arrow.avrotests.v1.types.Dur12","type":"fixed","size":12}},{"name":"status","type":{"name":"org.apache.arrow.avrotests.v1.types.Status","type":"enum","symbols":["UNKNOWN","NEW","PROCESSING","DONE"]}},{"name":"arr_union","type":{"type":"array","items":["long","string","null"]}},{"name":"map_union","type":{"type":"map","values":["null","double","string"]}},{"name":"address","type":{"name":"org.apache.arrow.avrotests.v1.types.Address","type":"record","fields":[{"name":"street","type":"string"},{"name":"zip","type":"int"},{"name":"country","type":"string"}]}},{"name":"maybe_auth","type":{"name":"org.apache.arrow.avrotests.v1.types.MaybeAuth","type":"record","fields":[{"name":"user","type":"string"},{"name":"token","type":["null","bytes"]}]}},{"name":"union_enum_record_array_map","type":[{"name":"org.apache.arrow.avrotests.v1.types.Color","type":"enum","symbols":["RED","GREEN","BLUE"]},{"name":"org.apache.arrow.avrotests.v1.types.RecA","type":"record","fields":[{"name":"a","type":"int"},{"name":"b","type":"string"}]},{"name":"org.apache.arrow.avrotests.v1.types.RecB","type":"record","fields":[{"name":"x","type":"long"},{"name":"y","type":"bytes"}]},{"type":"array","items":"long"},{"type":"map","values":"string"}]},{"name":"union_date_or_fixed4","type":["int",{"name":"org.apache.arrow.avrotests.v1.Fx4","type":"fixed","size":4}]},{"name":"union_interval_or_string","type":[{"name":"org.apache.arrow.avrotests.v1.Dur12U","type":"fixed","size":12},"string"]},{"name":"union_uuid_or_fixed10","type":["string",{"name":"org.apache.arrow.avrotests.v1.Fx10","type":"fixed","size":10}]},{"name":"array_records_with_union","type":{"type":"array","items":{"name":"org.apache.arrow.avrotests.v1.types.KV","type":"record","fields":[{"name":"key","type":"string"},{"name":"val","type":["null","int","long"]}]}}},{"name":"union_map_or_array_int","type":[{"type":"map","values":"int"},{"type":"array","items":"int"}]},{"name":"renamed_with_default","type":"int"},{"name":"person","type":{"name":"com.example.v2.PersonV2","type":"record","fields":[{"name":"name","type":"string"},{"name":"age","type":"int"}]}}]}"#; + let canonical_form = + AvroSchema::generate_canonical_form(&parsed).expect("canonical form should be built"); + assert_eq!( + canonical_form, expected_canonical_form, + "Canonical form must match Avro spec PCF exactly" + ); + } + + #[test] + fn test_new_schema_store() { + let store = SchemaStore::new(); + assert!(store.schemas.is_empty()); + } + + #[test] + fn test_try_from_schemas_rabin() { + let int_avro_schema = AvroSchema::new(serde_json::to_string(&int_schema()).unwrap()); + let record_avro_schema = AvroSchema::new(serde_json::to_string(&record_schema()).unwrap()); + let mut schemas: HashMap = HashMap::new(); + schemas.insert( + int_avro_schema + .fingerprint(FingerprintAlgorithm::Rabin) + .unwrap(), + int_avro_schema.clone(), + ); + schemas.insert( + record_avro_schema + .fingerprint(FingerprintAlgorithm::Rabin) + .unwrap(), + record_avro_schema.clone(), + ); + let store = SchemaStore::try_from(schemas).unwrap(); + let int_fp = int_avro_schema + .fingerprint(FingerprintAlgorithm::Rabin) + .unwrap(); + assert_eq!(store.lookup(&int_fp).cloned(), Some(int_avro_schema)); + let rec_fp = record_avro_schema + .fingerprint(FingerprintAlgorithm::Rabin) + .unwrap(); + assert_eq!(store.lookup(&rec_fp).cloned(), Some(record_avro_schema)); + } + + #[test] + fn test_try_from_with_duplicates() { + let int_avro_schema = AvroSchema::new(serde_json::to_string(&int_schema()).unwrap()); + let record_avro_schema = AvroSchema::new(serde_json::to_string(&record_schema()).unwrap()); + let mut schemas: HashMap = HashMap::new(); + schemas.insert( + int_avro_schema + .fingerprint(FingerprintAlgorithm::Rabin) + .unwrap(), + int_avro_schema.clone(), + ); + schemas.insert( + record_avro_schema + .fingerprint(FingerprintAlgorithm::Rabin) + .unwrap(), + record_avro_schema.clone(), + ); + // Insert duplicate of int schema + schemas.insert( + int_avro_schema + .fingerprint(FingerprintAlgorithm::Rabin) + .unwrap(), + int_avro_schema.clone(), + ); + let store = SchemaStore::try_from(schemas).unwrap(); + assert_eq!(store.schemas.len(), 2); + let int_fp = int_avro_schema + .fingerprint(FingerprintAlgorithm::Rabin) + .unwrap(); + assert_eq!(store.lookup(&int_fp).cloned(), Some(int_avro_schema)); + } + + #[test] + fn test_register_and_lookup_rabin() { + let mut store = SchemaStore::new(); + let schema = AvroSchema::new(serde_json::to_string(&int_schema()).unwrap()); + let fp_enum = store.register(schema.clone()).unwrap(); + match fp_enum { + Fingerprint::Rabin(fp_val) => { + assert_eq!( + store.lookup(&Fingerprint::Rabin(fp_val)).cloned(), + Some(schema.clone()) + ); + assert!( + store + .lookup(&Fingerprint::Rabin(fp_val.wrapping_add(1))) + .is_none() + ); + } + Fingerprint::Id(_id) => { + unreachable!("This test should only generate Rabin fingerprints") + } + Fingerprint::Id64(_id) => { + unreachable!("This test should only generate Rabin fingerprints") + } + #[cfg(feature = "md5")] + Fingerprint::MD5(_id) => { + unreachable!("This test should only generate Rabin fingerprints") + } + #[cfg(feature = "sha256")] + Fingerprint::SHA256(_id) => { + unreachable!("This test should only generate Rabin fingerprints") + } + } + } + + #[test] + fn test_set_and_lookup_id() { + let mut store = SchemaStore::new(); + let schema = AvroSchema::new(serde_json::to_string(&int_schema()).unwrap()); + let id = 42u32; + let fp = Fingerprint::Id(id); + let out_fp = store.set(fp, schema.clone()).unwrap(); + assert_eq!(out_fp, fp); + assert_eq!(store.lookup(&fp).cloned(), Some(schema.clone())); + assert!(store.lookup(&Fingerprint::Id(id.wrapping_add(1))).is_none()); + } + + #[test] + fn test_set_and_lookup_id64() { + let mut store = SchemaStore::new(); + let schema = AvroSchema::new(serde_json::to_string(&int_schema()).unwrap()); + let id64: u64 = 0xDEAD_BEEF_DEAD_BEEF; + let fp = Fingerprint::Id64(id64); + let out_fp = store.set(fp, schema.clone()).unwrap(); + assert_eq!(out_fp, fp, "set should return the same Id64 fingerprint"); + assert_eq!( + store.lookup(&fp).cloned(), + Some(schema.clone()), + "lookup should find the schema by Id64" + ); + assert!( + store + .lookup(&Fingerprint::Id64(id64.wrapping_add(1))) + .is_none(), + "lookup with a different Id64 must return None" + ); + } + + #[test] + fn test_fingerprint_id64_conversions() { + let algo_from_fp = FingerprintAlgorithm::from(&Fingerprint::Id64(123)); + assert_eq!(algo_from_fp, FingerprintAlgorithm::Id64); + let fp_from_algo = Fingerprint::from(FingerprintAlgorithm::Id64); + assert!(matches!(fp_from_algo, Fingerprint::Id64(0))); + let strategy_from_fp = FingerprintStrategy::from(Fingerprint::Id64(5)); + assert!(matches!(strategy_from_fp, FingerprintStrategy::Id64(0))); + let algo_from_strategy = FingerprintAlgorithm::from(strategy_from_fp); + assert_eq!(algo_from_strategy, FingerprintAlgorithm::Id64); + } + + #[test] + fn test_register_duplicate_schema() { + let mut store = SchemaStore::new(); + let schema1 = AvroSchema::new(serde_json::to_string(&int_schema()).unwrap()); + let schema2 = AvroSchema::new(serde_json::to_string(&int_schema()).unwrap()); + let fingerprint1 = store.register(schema1).unwrap(); + let fingerprint2 = store.register(schema2).unwrap(); + assert_eq!(fingerprint1, fingerprint2); + assert_eq!(store.schemas.len(), 1); + } + + #[test] + fn test_set_and_lookup_with_provided_fingerprint() { + let mut store = SchemaStore::new(); + let schema = AvroSchema::new(serde_json::to_string(&int_schema()).unwrap()); + let fp = schema.fingerprint(FingerprintAlgorithm::Rabin).unwrap(); + let out_fp = store.set(fp, schema.clone()).unwrap(); + assert_eq!(out_fp, fp); + assert_eq!(store.lookup(&fp).cloned(), Some(schema)); + } + + #[test] + fn test_set_duplicate_same_schema_ok() { + let mut store = SchemaStore::new(); + let schema = AvroSchema::new(serde_json::to_string(&int_schema()).unwrap()); + let fp = schema.fingerprint(FingerprintAlgorithm::Rabin).unwrap(); + let _ = store.set(fp, schema.clone()).unwrap(); + let _ = store.set(fp, schema.clone()).unwrap(); + assert_eq!(store.schemas.len(), 1); + } + + #[test] + fn test_set_duplicate_different_schema_collision_error() { + let mut store = SchemaStore::new(); + let schema1 = AvroSchema::new(serde_json::to_string(&int_schema()).unwrap()); + let schema2 = AvroSchema::new(serde_json::to_string(&record_schema()).unwrap()); + // Use the same Fingerprint::Id to simulate a collision across different schemas + let fp = Fingerprint::Id(123); + let _ = store.set(fp, schema1).unwrap(); + let err = store.set(fp, schema2).unwrap_err(); + let msg = format!("{err}"); + assert!(msg.contains("Schema fingerprint collision")); + } + + #[test] + fn test_canonical_form_generation_primitive() { + let schema = int_schema(); + let canonical_form = AvroSchema::generate_canonical_form(&schema).unwrap(); + assert_eq!(canonical_form, r#""int""#); + } + + #[test] + fn test_canonical_form_generation_record() { + let schema = record_schema(); + let expected_canonical_form = r#"{"name":"test.namespace.record1","type":"record","fields":[{"name":"field1","type":"int"},{"name":"field2","type":"string"}]}"#; + let canonical_form = AvroSchema::generate_canonical_form(&schema).unwrap(); + assert_eq!(canonical_form, expected_canonical_form); + } + + #[test] + fn test_fingerprint_calculation() { + let canonical_form = r#"{"fields":[{"name":"a","type":"long"},{"name":"b","type":"string"}],"name":"test","type":"record"}"#; + let expected_fingerprint = 10505236152925314060; + let fingerprint = compute_fingerprint_rabin(canonical_form); + assert_eq!(fingerprint, expected_fingerprint); + } + + #[test] + fn test_register_and_lookup_complex_schema() { + let mut store = SchemaStore::new(); + let schema = AvroSchema::new(serde_json::to_string(&record_schema()).unwrap()); + let canonical_form = r#"{"name":"test.namespace.record1","type":"record","fields":[{"name":"field1","type":"int"},{"name":"field2","type":"string"}]}"#; + let expected_fingerprint = Fingerprint::Rabin(compute_fingerprint_rabin(canonical_form)); + let fingerprint = store.register(schema.clone()).unwrap(); + assert_eq!(fingerprint, expected_fingerprint); + let looked_up = store.lookup(&fingerprint).cloned(); + assert_eq!(looked_up, Some(schema)); + } + + #[test] + fn test_fingerprints_returns_all_keys() { + let mut store = SchemaStore::new(); + let fp_int = store + .register(AvroSchema::new( + serde_json::to_string(&int_schema()).unwrap(), + )) + .unwrap(); + let fp_record = store + .register(AvroSchema::new( + serde_json::to_string(&record_schema()).unwrap(), + )) + .unwrap(); + let fps = store.fingerprints(); + assert_eq!(fps.len(), 2); + assert!(fps.contains(&fp_int)); + assert!(fps.contains(&fp_record)); + } + + #[test] + fn test_canonical_form_strips_attributes() { + let schema_with_attrs = Schema::Complex(ComplexType::Record(Record { + name: "record_with_attrs", + namespace: None, + doc: Some(Cow::from("This doc should be stripped")), + aliases: vec!["alias1", "alias2"], + fields: vec![Field { + name: "f1", + doc: Some(Cow::from("field doc")), + r#type: Schema::Type(Type { + r#type: TypeName::Primitive(PrimitiveType::Bytes), + attributes: Attributes { + logical_type: None, + additional: HashMap::from([("precision", json!(4))]), + }, + }), + default: None, + aliases: vec![], + }], + attributes: Attributes { + logical_type: None, + additional: HashMap::from([("custom_attr", json!("value"))]), + }, + })); + let expected_canonical_form = r#"{"name":"record_with_attrs","type":"record","fields":[{"name":"f1","type":"bytes"}]}"#; + let canonical_form = AvroSchema::generate_canonical_form(&schema_with_attrs).unwrap(); + assert_eq!(canonical_form, expected_canonical_form); + } + + #[test] + fn test_primitive_mappings() { + let cases = vec![ + (DataType::Boolean, "\"boolean\""), + (DataType::Int8, "\"int\""), + (DataType::Int16, "\"int\""), + (DataType::Int32, "\"int\""), + (DataType::Int64, "\"long\""), + (DataType::UInt8, "\"int\""), + (DataType::UInt16, "\"int\""), + (DataType::UInt32, "\"long\""), + (DataType::UInt64, "\"long\""), + (DataType::Float16, "\"float\""), + (DataType::Float32, "\"float\""), + (DataType::Float64, "\"double\""), + (DataType::Utf8, "\"string\""), + (DataType::Binary, "\"bytes\""), + ]; + for (dt, avro_token) in cases { + let field = ArrowField::new("col", dt.clone(), false); + let arrow_schema = single_field_schema(field); + let avro = AvroSchema::try_from(&arrow_schema).unwrap(); + assert_json_contains(&avro.json_string, avro_token); + } + } + + #[test] + fn test_temporal_mappings() { + let cases = vec![ + (DataType::Date32, "\"logicalType\":\"date\""), + ( + DataType::Time32(TimeUnit::Millisecond), + "\"logicalType\":\"time-millis\"", + ), + ( + DataType::Time64(TimeUnit::Microsecond), + "\"logicalType\":\"time-micros\"", + ), + ( + DataType::Timestamp(TimeUnit::Millisecond, None), + "\"logicalType\":\"local-timestamp-millis\"", + ), + ( + DataType::Timestamp(TimeUnit::Microsecond, Some("+00:00".into())), + "\"logicalType\":\"timestamp-micros\"", + ), + ]; + for (dt, needle) in cases { + let field = ArrowField::new("ts", dt.clone(), true); + let arrow_schema = single_field_schema(field); + let avro = AvroSchema::try_from(&arrow_schema).unwrap(); + assert_json_contains(&avro.json_string, needle); + } + } + + #[test] + fn test_decimal_and_uuid() { + let decimal_field = ArrowField::new("amount", DataType::Decimal128(25, 2), false); + let dec_schema = single_field_schema(decimal_field); + let avro_dec = AvroSchema::try_from(&dec_schema).unwrap(); + assert_json_contains(&avro_dec.json_string, "\"logicalType\":\"decimal\""); + assert_json_contains(&avro_dec.json_string, "\"precision\":25"); + assert_json_contains(&avro_dec.json_string, "\"scale\":2"); + let mut md = HashMap::new(); + md.insert("logicalType".into(), "uuid".into()); + let uuid_field = + ArrowField::new("id", DataType::FixedSizeBinary(16), false).with_metadata(md); + let uuid_schema = single_field_schema(uuid_field); + let avro_uuid = AvroSchema::try_from(&uuid_schema).unwrap(); + assert_json_contains(&avro_uuid.json_string, "\"logicalType\":\"uuid\""); + } + + #[cfg(feature = "avro_custom_types")] + #[test] + fn test_interval_duration() { + let interval_field = ArrowField::new( + "span", + DataType::Interval(IntervalUnit::MonthDayNano), + false, + ); + let s = single_field_schema(interval_field); + let avro = AvroSchema::try_from(&s).unwrap(); + assert_json_contains(&avro.json_string, "\"logicalType\":\"duration\""); + assert_json_contains(&avro.json_string, "\"size\":12"); + let dur_field = ArrowField::new("latency", DataType::Duration(TimeUnit::Nanosecond), false); + let s2 = single_field_schema(dur_field); + let avro2 = AvroSchema::try_from(&s2).unwrap(); + assert_json_contains( + &avro2.json_string, + "\"logicalType\":\"arrow.duration-nanos\"", + ); + } + + #[test] + fn test_complex_types() { + let list_dt = DataType::List(Arc::new(ArrowField::new("item", DataType::Int32, true))); + let list_schema = single_field_schema(ArrowField::new("numbers", list_dt, false)); + let avro_list = AvroSchema::try_from(&list_schema).unwrap(); + assert_json_contains(&avro_list.json_string, "\"type\":\"array\""); + assert_json_contains(&avro_list.json_string, "\"items\""); + let value_field = ArrowField::new("value", DataType::Boolean, true); + let entries_struct = ArrowField::new( + "entries", + DataType::Struct(Fields::from(vec![ + ArrowField::new("key", DataType::Utf8, false), + value_field.clone(), + ])), + false, + ); + let map_dt = DataType::Map(Arc::new(entries_struct), false); + let map_schema = single_field_schema(ArrowField::new("props", map_dt, false)); + let avro_map = AvroSchema::try_from(&map_schema).unwrap(); + assert_json_contains(&avro_map.json_string, "\"type\":\"map\""); + assert_json_contains(&avro_map.json_string, "\"values\""); + let struct_dt = DataType::Struct(Fields::from(vec![ + ArrowField::new("f1", DataType::Int64, false), + ArrowField::new("f2", DataType::Utf8, true), + ])); + let struct_schema = single_field_schema(ArrowField::new("person", struct_dt, true)); + let avro_struct = AvroSchema::try_from(&struct_schema).unwrap(); + assert_json_contains(&avro_struct.json_string, "\"type\":\"record\""); + assert_json_contains(&avro_struct.json_string, "\"null\""); + } + + #[test] + fn test_enum_dictionary() { + let mut md = HashMap::new(); + md.insert( + AVRO_ENUM_SYMBOLS_METADATA_KEY.into(), + "[\"OPEN\",\"CLOSED\"]".into(), + ); + let enum_dt = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let field = ArrowField::new("status", enum_dt, false).with_metadata(md); + let schema = single_field_schema(field); + let avro = AvroSchema::try_from(&schema).unwrap(); + assert_json_contains(&avro.json_string, "\"type\":\"enum\""); + assert_json_contains(&avro.json_string, "\"symbols\":[\"OPEN\",\"CLOSED\"]"); + } + + #[test] + fn test_run_end_encoded() { + let ree_dt = DataType::RunEndEncoded( + Arc::new(ArrowField::new("run_ends", DataType::Int32, false)), + Arc::new(ArrowField::new("values", DataType::Utf8, false)), + ); + let s = single_field_schema(ArrowField::new("text", ree_dt, false)); + let avro = AvroSchema::try_from(&s).unwrap(); + assert_json_contains(&avro.json_string, "\"string\""); + } + + #[test] + fn test_dense_union() { + let uf: UnionFields = vec![ + (2i8, Arc::new(ArrowField::new("a", DataType::Int32, false))), + (7i8, Arc::new(ArrowField::new("b", DataType::Utf8, true))), + ] + .into_iter() + .collect(); + let union_dt = DataType::Union(uf, UnionMode::Dense); + let s = single_field_schema(ArrowField::new("u", union_dt, false)); + let avro = + AvroSchema::try_from(&s).expect("Arrow Union -> Avro union conversion should succeed"); + let v: serde_json::Value = serde_json::from_str(&avro.json_string).unwrap(); + let fields = v + .get("fields") + .and_then(|x| x.as_array()) + .expect("fields array"); + let u_field = fields + .iter() + .find(|f| f.get("name").and_then(|n| n.as_str()) == Some("u")) + .expect("field 'u'"); + let union = u_field.get("type").expect("u.type"); + let arr = union.as_array().expect("u.type must be Avro union array"); + assert_eq!(arr.len(), 2, "expected two union branches"); + let first = &arr[0]; + let obj = first + .as_object() + .expect("first branch should be an object with metadata"); + assert_eq!(obj.get("type").and_then(|t| t.as_str()), Some("int")); + assert_eq!( + obj.get("arrowUnionMode").and_then(|m| m.as_str()), + Some("dense") + ); + let type_ids: Vec = obj + .get("arrowUnionTypeIds") + .and_then(|a| a.as_array()) + .expect("arrowUnionTypeIds array") + .iter() + .map(|n| n.as_i64().expect("i64")) + .collect(); + assert_eq!(type_ids, vec![2, 7], "type id ordering should be preserved"); + assert_eq!(arr[1], Value::String("string".into())); + } + + #[test] + fn round_trip_primitive() { + let arrow_schema = ArrowSchema::new(vec![ArrowField::new("f1", DataType::Int32, false)]); + let avro_schema = AvroSchema::try_from(&arrow_schema).unwrap(); + let decoded = avro_schema.schema().unwrap(); + assert!(matches!(decoded, Schema::Complex(_))); + } + + #[test] + fn test_name_generator_sanitization_and_uniqueness() { + let f1 = ArrowField::new("weird-name", DataType::FixedSizeBinary(8), false); + let f2 = ArrowField::new("weird name", DataType::FixedSizeBinary(8), false); + let f3 = ArrowField::new("123bad", DataType::FixedSizeBinary(8), false); + let arrow_schema = ArrowSchema::new(vec![f1, f2, f3]); + let avro = AvroSchema::try_from(&arrow_schema).unwrap(); + assert_json_contains(&avro.json_string, "\"name\":\"weird_name\""); + assert_json_contains(&avro.json_string, "\"name\":\"weird_name_1\""); + assert_json_contains(&avro.json_string, "\"name\":\"_123bad\""); + } + + #[test] + fn test_date64_logical_type_mapping() { + let field = ArrowField::new("d", DataType::Date64, true); + let schema = single_field_schema(field); + let avro = AvroSchema::try_from(&schema).unwrap(); + assert_json_contains( + &avro.json_string, + "\"logicalType\":\"local-timestamp-millis\"", + ); + } + + #[cfg(feature = "avro_custom_types")] + #[test] + fn test_duration_list_extras_propagated() { + let child = ArrowField::new("lat", DataType::Duration(TimeUnit::Microsecond), false); + let list_dt = DataType::List(Arc::new(child)); + let arrow_schema = single_field_schema(ArrowField::new("durations", list_dt, false)); + let avro = AvroSchema::try_from(&arrow_schema).unwrap(); + assert_json_contains( + &avro.json_string, + "\"logicalType\":\"arrow.duration-micros\"", + ); + } + + #[test] + fn test_interval_yearmonth_extra() { + let field = ArrowField::new("iv", DataType::Interval(IntervalUnit::YearMonth), false); + let schema = single_field_schema(field); + let avro = AvroSchema::try_from(&schema).unwrap(); + assert_json_contains(&avro.json_string, "\"arrowIntervalUnit\":\"yearmonth\""); + } + + #[test] + fn test_interval_daytime_extra() { + let field = ArrowField::new("iv_dt", DataType::Interval(IntervalUnit::DayTime), false); + let schema = single_field_schema(field); + let avro = AvroSchema::try_from(&schema).unwrap(); + assert_json_contains(&avro.json_string, "\"arrowIntervalUnit\":\"daytime\""); + } + + #[test] + fn test_fixed_size_list_extra() { + let child = ArrowField::new("item", DataType::Int32, false); + let dt = DataType::FixedSizeList(Arc::new(child), 3); + let schema = single_field_schema(ArrowField::new("triples", dt, false)); + let avro = AvroSchema::try_from(&schema).unwrap(); + assert_json_contains(&avro.json_string, "\"arrowFixedSize\":3"); + } + + #[cfg(feature = "avro_custom_types")] + #[test] + fn test_map_duration_value_extra() { + let val_field = ArrowField::new("value", DataType::Duration(TimeUnit::Second), true); + let entries_struct = ArrowField::new( + "entries", + DataType::Struct(Fields::from(vec![ + ArrowField::new("key", DataType::Utf8, false), + val_field, + ])), + false, + ); + let map_dt = DataType::Map(Arc::new(entries_struct), false); + let schema = single_field_schema(ArrowField::new("metrics", map_dt, false)); + let avro = AvroSchema::try_from(&schema).unwrap(); + assert_json_contains( + &avro.json_string, + "\"logicalType\":\"arrow.duration-seconds\"", + ); + } + + #[test] + fn test_schema_with_non_string_defaults_decodes_successfully() { + let schema_json = r#"{ + "type": "record", + "name": "R", + "fields": [ + {"name": "a", "type": "int", "default": 0}, + {"name": "b", "type": {"type": "array", "items": "long"}, "default": [1, 2, 3]}, + {"name": "c", "type": {"type": "map", "values": "double"}, "default": {"x": 1.5, "y": 2.5}}, + {"name": "inner", "type": {"type": "record", "name": "Inner", "fields": [ + {"name": "flag", "type": "boolean", "default": true}, + {"name": "name", "type": "string", "default": "hi"} + ]}, "default": {"flag": false, "name": "d"}}, + {"name": "u", "type": ["int", "null"], "default": 42} + ] + }"#; + let schema: Schema = serde_json::from_str(schema_json).expect("schema should parse"); + match &schema { + Schema::Complex(ComplexType::Record(_)) => {} + other => panic!("expected record schema, got: {:?}", other), + } + // Avro to Arrow conversion + let field = crate::codec::AvroField::try_from(&schema) + .expect("Avro->Arrow conversion should succeed"); + let arrow_field = field.field(); + // Build expected Arrow field + let expected_list_item = ArrowField::new( + arrow_schema::Field::LIST_FIELD_DEFAULT_NAME, + DataType::Int64, + false, + ); + let expected_b = ArrowField::new("b", DataType::List(Arc::new(expected_list_item)), false); + + let expected_map_value = ArrowField::new("value", DataType::Float64, false); + let expected_entries = ArrowField::new( + "entries", + DataType::Struct(Fields::from(vec![ + ArrowField::new("key", DataType::Utf8, false), + expected_map_value, + ])), + false, + ); + let expected_c = + ArrowField::new("c", DataType::Map(Arc::new(expected_entries), false), false); + let mut inner_md = std::collections::HashMap::new(); + inner_md.insert(AVRO_NAME_METADATA_KEY.to_string(), "Inner".to_string()); + let expected_inner = ArrowField::new( + "inner", + DataType::Struct(Fields::from(vec![ + ArrowField::new("flag", DataType::Boolean, false), + ArrowField::new("name", DataType::Utf8, false), + ])), + false, + ) + .with_metadata(inner_md); + let mut root_md = std::collections::HashMap::new(); + root_md.insert(AVRO_NAME_METADATA_KEY.to_string(), "R".to_string()); + let expected = ArrowField::new( + "R", + DataType::Struct(Fields::from(vec![ + ArrowField::new("a", DataType::Int32, false), + expected_b, + expected_c, + expected_inner, + ArrowField::new("u", DataType::Int32, true), + ])), + false, + ) + .with_metadata(root_md); + assert_eq!(arrow_field, expected); + } + + #[test] + fn default_order_is_consistent() { + let arrow_schema = ArrowSchema::new(vec![ArrowField::new("s", DataType::Utf8, true)]); + let a = AvroSchema::try_from(&arrow_schema).unwrap().json_string; + let b = AvroSchema::from_arrow_with_options(&arrow_schema, None); + assert_eq!(a, b.unwrap().json_string); + } + + #[test] + fn test_union_branch_missing_name_errors() { + for t in ["record", "enum", "fixed"] { + let branch = json!({ "type": t }); + let err = union_branch_signature(&branch).unwrap_err().to_string(); + assert!( + err.contains(&format!("Union branch '{t}' missing required 'name'")), + "expected missing-name error for {t}, got: {err}" + ); + } + } + + #[test] + fn test_union_branch_named_type_signature_includes_name() { + let rec = json!({ "type": "record", "name": "Foo" }); + assert_eq!(union_branch_signature(&rec).unwrap(), "N:record:Foo"); + let en = json!({ "type": "enum", "name": "Color", "symbols": ["R", "G", "B"] }); + assert_eq!(union_branch_signature(&en).unwrap(), "N:enum:Color"); + let fx = json!({ "type": "fixed", "name": "Bytes16", "size": 16 }); + assert_eq!(union_branch_signature(&fx).unwrap(), "N:fixed:Bytes16"); + } + + #[test] + fn test_record_field_alias_resolution_without_default() { + let writer_json = r#"{ + "type":"record", + "name":"R", + "fields":[{"name":"old","type":"int"}] + }"#; + let reader_json = r#"{ + "type":"record", + "name":"R", + "fields":[{"name":"new","aliases":["old"],"type":"int"}] + }"#; + let writer: Schema = serde_json::from_str(writer_json).unwrap(); + let reader: Schema = serde_json::from_str(reader_json).unwrap(); + let resolved = AvroFieldBuilder::new(&writer) + .with_reader_schema(&reader) + .with_utf8view(false) + .with_strict_mode(false) + .build() + .unwrap(); + let expected = ArrowField::new( + "R", + DataType::Struct(Fields::from(vec![ArrowField::new( + "new", + DataType::Int32, + false, + )])), + false, + ); + assert_eq!(resolved.field(), expected); + } + + #[test] + fn test_record_field_alias_ambiguous_in_strict_mode_errors() { + let writer_json = r#"{ + "type":"record", + "name":"R", + "fields":[ + {"name":"a","type":"int","aliases":["old"]}, + {"name":"b","type":"int","aliases":["old"]} + ] + }"#; + let reader_json = r#"{ + "type":"record", + "name":"R", + "fields":[{"name":"target","type":"int","aliases":["old"]}] + }"#; + let writer: Schema = serde_json::from_str(writer_json).unwrap(); + let reader: Schema = serde_json::from_str(reader_json).unwrap(); + let err = AvroFieldBuilder::new(&writer) + .with_reader_schema(&reader) + .with_utf8view(false) + .with_strict_mode(true) + .build() + .unwrap_err() + .to_string(); + assert!( + err.contains("Ambiguous alias 'old'"), + "expected ambiguous-alias error, got: {err}" + ); + } + + #[test] + fn test_pragmatic_writer_field_alias_mapping_non_strict() { + let writer_json = r#"{ + "type":"record", + "name":"R", + "fields":[{"name":"before","type":"int","aliases":["now"]}] + }"#; + let reader_json = r#"{ + "type":"record", + "name":"R", + "fields":[{"name":"now","type":"int"}] + }"#; + let writer: Schema = serde_json::from_str(writer_json).unwrap(); + let reader: Schema = serde_json::from_str(reader_json).unwrap(); + let resolved = AvroFieldBuilder::new(&writer) + .with_reader_schema(&reader) + .with_utf8view(false) + .with_strict_mode(false) + .build() + .unwrap(); + let expected = ArrowField::new( + "R", + DataType::Struct(Fields::from(vec![ArrowField::new( + "now", + DataType::Int32, + false, + )])), + false, + ); + assert_eq!(resolved.field(), expected); + } + + #[test] + fn test_missing_reader_field_null_first_no_default_is_ok() { + let writer_json = r#"{ + "type":"record", + "name":"R", + "fields":[{"name":"a","type":"int"}] + }"#; + let reader_json = r#"{ + "type":"record", + "name":"R", + "fields":[ + {"name":"a","type":"int"}, + {"name":"b","type":["null","int"]} + ] + }"#; + let writer: Schema = serde_json::from_str(writer_json).unwrap(); + let reader: Schema = serde_json::from_str(reader_json).unwrap(); + let resolved = AvroFieldBuilder::new(&writer) + .with_reader_schema(&reader) + .with_utf8view(false) + .with_strict_mode(false) + .build() + .unwrap(); + let expected = ArrowField::new( + "R", + DataType::Struct(Fields::from(vec![ + ArrowField::new("a", DataType::Int32, false), + ArrowField::new("b", DataType::Int32, true).with_metadata(HashMap::from([( + AVRO_FIELD_DEFAULT_METADATA_KEY.to_string(), + "null".to_string(), + )])), + ])), + false, + ); + assert_eq!(resolved.field(), expected); + } + + #[test] + fn test_missing_reader_field_null_second_without_default_errors() { + let writer_json = r#"{ + "type":"record", + "name":"R", + "fields":[{"name":"a","type":"int"}] + }"#; + let reader_json = r#"{ + "type":"record", + "name":"R", + "fields":[ + {"name":"a","type":"int"}, + {"name":"b","type":["int","null"]} + ] + }"#; + let writer: Schema = serde_json::from_str(writer_json).unwrap(); + let reader: Schema = serde_json::from_str(reader_json).unwrap(); + let err = AvroFieldBuilder::new(&writer) + .with_reader_schema(&reader) + .with_utf8view(false) + .with_strict_mode(false) + .build() + .unwrap_err() + .to_string(); + assert!( + err.contains("must have a default value"), + "expected missing-default error, got: {err}" + ); + } + + #[test] + fn test_from_arrow_with_options_respects_schema_metadata_when_not_stripping() { + let field = ArrowField::new("x", DataType::Int32, true); + let injected_json = + r#"{"type":"record","name":"Injected","fields":[{"name":"ignored","type":"int"}]}"# + .to_string(); + let mut md = HashMap::new(); + md.insert(SCHEMA_METADATA_KEY.to_string(), injected_json.clone()); + md.insert("custom".to_string(), "123".to_string()); + let arrow_schema = ArrowSchema::new_with_metadata(vec![field], md); + let opts = AvroSchemaOptions { + null_order: Some(Nullability::NullSecond), + strip_metadata: false, + }; + let out = AvroSchema::from_arrow_with_options(&arrow_schema, Some(opts)).unwrap(); + assert_eq!( + out.json_string, injected_json, + "When strip_metadata=false and avro.schema is present, return the embedded JSON verbatim" + ); + let v: Value = serde_json::from_str(&out.json_string).unwrap(); + assert_eq!(v.get("type").and_then(|t| t.as_str()), Some("record")); + assert_eq!(v.get("name").and_then(|n| n.as_str()), Some("Injected")); + } + + #[test] + fn test_from_arrow_with_options_ignores_schema_metadata_when_stripping_and_keeps_passthrough() { + let field = ArrowField::new("x", DataType::Int32, true); + let injected_json = + r#"{"type":"record","name":"Injected","fields":[{"name":"ignored","type":"int"}]}"# + .to_string(); + let mut md = HashMap::new(); + md.insert(SCHEMA_METADATA_KEY.to_string(), injected_json); + md.insert("custom_meta".to_string(), "7".to_string()); + let arrow_schema = ArrowSchema::new_with_metadata(vec![field], md); + let opts = AvroSchemaOptions { + null_order: Some(Nullability::NullFirst), + strip_metadata: true, + }; + let out = AvroSchema::from_arrow_with_options(&arrow_schema, Some(opts)).unwrap(); + assert_json_contains(&out.json_string, "\"type\":\"record\""); + assert_json_contains(&out.json_string, "\"name\":\"topLevelRecord\""); + assert_json_contains(&out.json_string, "\"custom_meta\":7"); + } + + #[test] + fn test_from_arrow_with_options_null_first_for_nullable_primitive() { + let field = ArrowField::new("s", DataType::Utf8, true); + let arrow_schema = single_field_schema(field); + let opts = AvroSchemaOptions { + null_order: Some(Nullability::NullFirst), + strip_metadata: true, + }; + let out = AvroSchema::from_arrow_with_options(&arrow_schema, Some(opts)).unwrap(); + let v: Value = serde_json::from_str(&out.json_string).unwrap(); + let arr = v["fields"][0]["type"] + .as_array() + .expect("nullable primitive should be Avro union array"); + assert_eq!(arr[0], Value::String("null".into())); + assert_eq!(arr[1], Value::String("string".into())); + } + + #[test] + fn test_from_arrow_with_options_null_second_for_nullable_primitive() { + let field = ArrowField::new("s", DataType::Utf8, true); + let arrow_schema = single_field_schema(field); + let opts = AvroSchemaOptions { + null_order: Some(Nullability::NullSecond), + strip_metadata: true, + }; + let out = AvroSchema::from_arrow_with_options(&arrow_schema, Some(opts)).unwrap(); + let v: Value = serde_json::from_str(&out.json_string).unwrap(); + let arr = v["fields"][0]["type"] + .as_array() + .expect("nullable primitive should be Avro union array"); + assert_eq!(arr[0], Value::String("string".into())); + assert_eq!(arr[1], Value::String("null".into())); + } + + #[test] + fn test_from_arrow_with_options_union_extras_respected_by_strip_metadata() { + let uf: UnionFields = vec![ + (2i8, Arc::new(ArrowField::new("a", DataType::Int32, false))), + (7i8, Arc::new(ArrowField::new("b", DataType::Utf8, false))), + ] + .into_iter() + .collect(); + let union_dt = DataType::Union(uf, UnionMode::Dense); + let arrow_schema = single_field_schema(ArrowField::new("u", union_dt, true)); + let with_extras = AvroSchema::from_arrow_with_options( + &arrow_schema, + Some(AvroSchemaOptions { + null_order: Some(Nullability::NullFirst), + strip_metadata: false, + }), + ) + .unwrap(); + let v_with: Value = serde_json::from_str(&with_extras.json_string).unwrap(); + let union_arr = v_with["fields"][0]["type"].as_array().expect("union array"); + let first_obj = union_arr + .iter() + .find(|b| b.is_object()) + .expect("expected an object branch with extras"); + let obj = first_obj.as_object().unwrap(); + assert_eq!(obj.get("type").and_then(|t| t.as_str()), Some("int")); + assert_eq!( + obj.get("arrowUnionMode").and_then(|m| m.as_str()), + Some("dense") + ); + let type_ids: Vec = obj["arrowUnionTypeIds"] + .as_array() + .expect("arrowUnionTypeIds array") + .iter() + .map(|n| n.as_i64().expect("i64")) + .collect(); + assert_eq!(type_ids, vec![2, 7]); + let stripped = AvroSchema::from_arrow_with_options( + &arrow_schema, + Some(AvroSchemaOptions { + null_order: Some(Nullability::NullFirst), + strip_metadata: true, + }), + ) + .unwrap(); + let v_stripped: Value = serde_json::from_str(&stripped.json_string).unwrap(); + let union_arr2 = v_stripped["fields"][0]["type"] + .as_array() + .expect("union array"); + assert!( + !union_arr2.iter().any(|b| b + .as_object() + .is_some_and(|m| m.contains_key("arrowUnionMode"))), + "extras must be removed when strip_metadata=true" + ); + assert_eq!(union_arr2[0], Value::String("null".into())); + assert_eq!(union_arr2[1], Value::String("int".into())); + assert_eq!(union_arr2[2], Value::String("string".into())); + } } diff --git a/arrow-avro/src/writer/encoder.rs b/arrow-avro/src/writer/encoder.rs new file mode 100644 index 000000000000..ef9e02c8faf1 --- /dev/null +++ b/arrow-avro/src/writer/encoder.rs @@ -0,0 +1,3048 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Avro Encoder for Arrow types. + +use crate::codec::{AvroDataType, AvroField, Codec}; +use crate::schema::{Fingerprint, Nullability, Prefix}; +use arrow_array::cast::AsArray; +use arrow_array::types::{ + ArrowPrimitiveType, Date32Type, DurationMicrosecondType, DurationMillisecondType, + DurationNanosecondType, DurationSecondType, Float32Type, Float64Type, Int16Type, Int32Type, + Int64Type, IntervalDayTimeType, IntervalMonthDayNanoType, IntervalYearMonthType, + Time32MillisecondType, Time64MicrosecondType, TimestampMicrosecondType, + TimestampMillisecondType, +}; +use arrow_array::types::{ + RunEndIndexType, Time32SecondType, TimestampNanosecondType, TimestampSecondType, +}; +use arrow_array::{ + Array, BinaryViewArray, Decimal128Array, Decimal256Array, DictionaryArray, + FixedSizeBinaryArray, FixedSizeListArray, GenericBinaryArray, GenericListArray, + GenericListViewArray, GenericStringArray, LargeListArray, LargeListViewArray, ListArray, + ListViewArray, MapArray, OffsetSizeTrait, PrimitiveArray, RecordBatch, RunArray, StringArray, + StringViewArray, StructArray, UnionArray, +}; +#[cfg(feature = "small_decimals")] +use arrow_array::{Decimal32Array, Decimal64Array}; +use arrow_buffer::{ArrowNativeType, NullBuffer}; +use arrow_schema::{ + ArrowError, DataType, Field, IntervalUnit, Schema as ArrowSchema, TimeUnit, UnionMode, +}; +use std::io::Write; +use std::sync::Arc; +use uuid::Uuid; + +/// Encode a single Avro-`long` using ZigZag + variable length, buffered. +/// +/// Spec: +#[inline] +pub(crate) fn write_long(out: &mut W, value: i64) -> Result<(), ArrowError> { + let mut zz = ((value << 1) ^ (value >> 63)) as u64; + // At most 10 bytes for 64-bit varint + let mut buf = [0u8; 10]; + let mut i = 0; + while (zz & !0x7F) != 0 { + buf[i] = ((zz & 0x7F) as u8) | 0x80; + i += 1; + zz >>= 7; + } + buf[i] = (zz & 0x7F) as u8; + i += 1; + out.write_all(&buf[..i]) + .map_err(|e| ArrowError::IoError(format!("write long: {e}"), e)) +} + +#[inline] +fn write_int(out: &mut W, value: i32) -> Result<(), ArrowError> { + write_long(out, value as i64) +} + +#[inline] +fn write_len_prefixed(out: &mut W, bytes: &[u8]) -> Result<(), ArrowError> { + write_long(out, bytes.len() as i64)?; + out.write_all(bytes) + .map_err(|e| ArrowError::IoError(format!("write bytes: {e}"), e)) +} + +#[inline] +fn write_bool(out: &mut W, v: bool) -> Result<(), ArrowError> { + out.write_all(&[if v { 1 } else { 0 }]) + .map_err(|e| ArrowError::IoError(format!("write bool: {e}"), e)) +} + +/// Minimal two's-complement big-endian representation helper for Avro decimal (bytes). +/// +/// For positive numbers, trim leading 0x00 until an essential byte is reached. +/// For negative numbers, trim leading 0xFF until an essential byte is reached. +/// The resulting slice still encodes the same signed value. +/// +/// See Avro spec: decimal over `bytes` uses two's-complement big-endian +/// representation of the unscaled integer value. 1.11.1 specification. +#[inline] +fn minimal_twos_complement(be: &[u8]) -> &[u8] { + if be.is_empty() { + return be; + } + let sign_byte = if (be[0] & 0x80) != 0 { 0xFF } else { 0x00 }; + let mut k = 0usize; + while k < be.len() && be[k] == sign_byte { + k += 1; + } + if k == 0 { + return be; + } + if k == be.len() { + return &be[be.len() - 1..]; + } + let drop = if ((be[k] ^ sign_byte) & 0x80) == 0 { + k + } else { + k - 1 + }; + &be[drop..] +} + +/// Sign-extend (or validate/truncate) big-endian integer bytes to exactly `n` bytes. +/// +/// +/// - If shorter than `n`, the slice is sign-extended by left-padding with the +/// sign byte (`0x00` for positive, `0xFF` for negative). +/// - If longer than `n`, the slice is truncated from the left. An overflow error +/// is returned if any of the truncated bytes are not redundant sign bytes, +/// or if the resulting value's sign bit would differ from the original. +/// - If the slice is already `n` bytes long, it is copied. +/// +/// Used for encoding Avro decimal values into `fixed(N)` fields. +#[inline] +fn write_sign_extended( + out: &mut W, + src_be: &[u8], + n: usize, +) -> Result<(), ArrowError> { + let len = src_be.len(); + if len == n { + return out + .write_all(src_be) + .map_err(|e| ArrowError::IoError(format!("write decimal fixed: {e}"), e)); + } + let sign_byte = if len > 0 && (src_be[0] & 0x80) != 0 { + 0xFF + } else { + 0x00 + }; + if len > n { + let extra = len - n; + if n == 0 && src_be.iter().all(|&b| b == sign_byte) { + return Ok(()); + } + // All truncated bytes must equal the sign byte, and the MSB of the first + // retained byte must match the sign (otherwise overflow). + if src_be[..extra].iter().any(|&b| b != sign_byte) + || ((src_be[extra] ^ sign_byte) & 0x80) != 0 + { + return Err(ArrowError::InvalidArgumentError(format!( + "Decimal value with {len} bytes cannot be represented in {n} bytes without overflow", + ))); + } + return out + .write_all(&src_be[extra..]) + .map_err(|e| ArrowError::IoError(format!("write decimal fixed: {e}"), e)); + } + // len < n: prepend sign bytes (sign extension) then the payload + let pad_len = n - len; + // Fixed-size stack pads to avoid heap allocation on the hot path + const ZPAD: [u8; 64] = [0x00; 64]; + const FPAD: [u8; 64] = [0xFF; 64]; + let pad = if sign_byte == 0x00 { + &ZPAD[..] + } else { + &FPAD[..] + }; + // Emit padding in 64‑byte chunks (minimizes write calls without allocating), + // then write the original bytes. + let mut rem = pad_len; + while rem >= pad.len() { + out.write_all(pad) + .map_err(|e| ArrowError::IoError(format!("write decimal fixed: {e}"), e))?; + rem -= pad.len(); + } + if rem > 0 { + out.write_all(&pad[..rem]) + .map_err(|e| ArrowError::IoError(format!("write decimal fixed: {e}"), e))?; + } + out.write_all(src_be) + .map_err(|e| ArrowError::IoError(format!("write decimal fixed: {e}"), e)) +} + +/// Write the union branch index for an optional field. +/// +/// Branch index is 0-based per Avro unions: +/// - Null-first (default): null => 0, value => 1 +/// - Null-second (Impala): value => 0, null => 1 +fn write_optional_index( + out: &mut W, + is_null: bool, + null_order: Nullability, +) -> Result<(), ArrowError> { + let byte = union_value_branch_byte(null_order, is_null); + out.write_all(&[byte]) + .map_err(|e| ArrowError::IoError(format!("write union branch: {e}"), e)) +} + +#[derive(Debug, Clone)] +enum NullState<'a> { + NonNullable, + NullableNoNulls { + union_value_byte: u8, + }, + Nullable { + nulls: &'a NullBuffer, + null_order: Nullability, + }, +} + +/// Arrow to Avro FieldEncoder: +/// - Holds the inner `Encoder` (by value) +/// - Carries the per-site nullability **state** as a single enum that enforces invariants +pub(crate) struct FieldEncoder<'a> { + encoder: Encoder<'a>, + null_state: NullState<'a>, +} + +impl<'a> FieldEncoder<'a> { + fn make_encoder( + array: &'a dyn Array, + plan: &FieldPlan, + nullability: Option, + ) -> Result { + let encoder = match plan { + FieldPlan::Scalar => match array.data_type() { + DataType::Null => Encoder::Null, + DataType::Boolean => Encoder::Boolean(BooleanEncoder(array.as_boolean())), + DataType::Utf8 => { + Encoder::Utf8(Utf8GenericEncoder::(array.as_string::())) + } + DataType::LargeUtf8 => { + Encoder::Utf8Large(Utf8GenericEncoder::(array.as_string::())) + } + DataType::Utf8View => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::SchemaError("Expected StringViewArray".into()) + })?; + Encoder::Utf8View(Utf8ViewEncoder(arr)) + } + DataType::BinaryView => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::SchemaError("Expected BinaryViewArray".into()) + })?; + Encoder::BinaryView(BinaryViewEncoder(arr)) + } + DataType::Int32 => Encoder::Int(IntEncoder(array.as_primitive::())), + DataType::Int64 => Encoder::Long(LongEncoder(array.as_primitive::())), + DataType::Date32 => Encoder::Date32(IntEncoder(array.as_primitive::())), + DataType::Date64 => { + return Err(ArrowError::NotYetImplemented( + "Avro logical type 'date' is days since epoch (int). Arrow Date64 (ms) has no direct Avro logical type; cast to Date32 or to a Timestamp." + .into(), + )); + } + DataType::Time32(TimeUnit::Second) => Encoder::Time32SecsToMillis( + Time32SecondsToMillisEncoder(array.as_primitive::()), + ), + DataType::Time32(TimeUnit::Millisecond) => { + Encoder::Time32Millis(IntEncoder(array.as_primitive::())) + } + DataType::Time32(TimeUnit::Microsecond) => { + return Err(ArrowError::InvalidArgumentError( + "Arrow Time32 only supports Second or Millisecond. Use Time64 for microseconds." + .into(), + )); + } + DataType::Time32(TimeUnit::Nanosecond) => { + return Err(ArrowError::InvalidArgumentError( + "Arrow Time32 only supports Second or Millisecond. Use Time64 for nanoseconds." + .into(), + )); + } + DataType::Time64(TimeUnit::Microsecond) => Encoder::Time64Micros(LongEncoder( + array.as_primitive::(), + )), + DataType::Time64(TimeUnit::Nanosecond) => { + return Err(ArrowError::NotYetImplemented( + "Avro writer does not support time-nanos; cast to Time64(Microsecond)." + .into(), + )); + } + DataType::Time64(TimeUnit::Millisecond) => { + return Err(ArrowError::InvalidArgumentError( + "Arrow Time64 with millisecond unit is not a valid Arrow type (use Time32 for millis)." + .into(), + )); + } + DataType::Time64(TimeUnit::Second) => { + return Err(ArrowError::InvalidArgumentError( + "Arrow Time64 with second unit is not a valid Arrow type (use Time32 for seconds)." + .into(), + )); + } + DataType::Float32 => { + Encoder::Float32(F32Encoder(array.as_primitive::())) + } + DataType::Float64 => { + Encoder::Float64(F64Encoder(array.as_primitive::())) + } + DataType::Binary => Encoder::Binary(BinaryEncoder(array.as_binary::())), + DataType::LargeBinary => { + Encoder::LargeBinary(BinaryEncoder(array.as_binary::())) + } + DataType::FixedSizeBinary(_len) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::SchemaError("Expected FixedSizeBinaryArray".into()) + })?; + Encoder::Fixed(FixedEncoder(arr)) + } + DataType::Timestamp(unit, _) => match unit { + TimeUnit::Second => { + Encoder::TimestampSecsToMillis(TimestampSecondsToMillisEncoder( + array.as_primitive::(), + )) + } + TimeUnit::Millisecond => Encoder::TimestampMillis(LongEncoder( + array.as_primitive::(), + )), + TimeUnit::Microsecond => Encoder::TimestampMicros(LongEncoder( + array.as_primitive::(), + )), + TimeUnit::Nanosecond => Encoder::TimestampNanos(LongEncoder( + array.as_primitive::(), + )), + }, + DataType::Interval(unit) => match unit { + IntervalUnit::MonthDayNano => Encoder::IntervalMonthDayNano(DurationEncoder( + array.as_primitive::(), + )), + IntervalUnit::YearMonth => Encoder::IntervalYearMonth(DurationEncoder( + array.as_primitive::(), + )), + IntervalUnit::DayTime => Encoder::IntervalDayTime(DurationEncoder( + array.as_primitive::(), + )), + }, + DataType::Duration(tu) => match tu { + TimeUnit::Second => Encoder::DurationSeconds(LongEncoder( + array.as_primitive::(), + )), + TimeUnit::Millisecond => Encoder::DurationMillis(LongEncoder( + array.as_primitive::(), + )), + TimeUnit::Microsecond => Encoder::DurationMicros(LongEncoder( + array.as_primitive::(), + )), + TimeUnit::Nanosecond => Encoder::DurationNanos(LongEncoder( + array.as_primitive::(), + )), + }, + other => { + return Err(ArrowError::NotYetImplemented(format!( + "Avro scalar type not yet supported: {other:?}" + ))); + } + }, + FieldPlan::Struct { bindings } => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::SchemaError("Expected StructArray".into()))?; + Encoder::Struct(Box::new(StructEncoder::try_new(arr, bindings)?)) + } + FieldPlan::List { + items_nullability, + item_plan, + } => match array.data_type() { + DataType::List(_) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::SchemaError("Expected ListArray".into()))?; + Encoder::List(Box::new(ListEncoder32::try_new( + arr, + *items_nullability, + item_plan.as_ref(), + )?)) + } + DataType::LargeList(_) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::SchemaError("Expected LargeListArray".into()))?; + Encoder::LargeList(Box::new(ListEncoder64::try_new( + arr, + *items_nullability, + item_plan.as_ref(), + )?)) + } + DataType::ListView(_) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::SchemaError("Expected ListViewArray".into()))?; + Encoder::ListView(Box::new(ListViewEncoder32::try_new( + arr, + *items_nullability, + item_plan.as_ref(), + )?)) + } + DataType::LargeListView(_) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::SchemaError("Expected LargeListViewArray".into()) + })?; + Encoder::LargeListView(Box::new(ListViewEncoder64::try_new( + arr, + *items_nullability, + item_plan.as_ref(), + )?)) + } + DataType::FixedSizeList(_, _) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::SchemaError("Expected FixedSizeListArray".into()) + })?; + Encoder::FixedSizeList(Box::new(FixedSizeListEncoder::try_new( + arr, + *items_nullability, + item_plan.as_ref(), + )?)) + } + other => { + return Err(ArrowError::SchemaError(format!( + "Avro array site requires Arrow List/LargeList/ListView/LargeListView/FixedSizeList, found: {other:?}" + ))); + } + }, + FieldPlan::Decimal { size } => match array.data_type() { + #[cfg(feature = "small_decimals")] + DataType::Decimal32(_, _) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::SchemaError("Expected Decimal32Array".into()))?; + Encoder::Decimal32(DecimalEncoder::<4, Decimal32Array>::new(arr, *size)) + } + #[cfg(feature = "small_decimals")] + DataType::Decimal64(_, _) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::SchemaError("Expected Decimal64Array".into()))?; + Encoder::Decimal64(DecimalEncoder::<8, Decimal64Array>::new(arr, *size)) + } + DataType::Decimal128(_, _) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::SchemaError("Expected Decimal128Array".into()) + })?; + Encoder::Decimal128(DecimalEncoder::<16, Decimal128Array>::new(arr, *size)) + } + DataType::Decimal256(_, _) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::SchemaError("Expected Decimal256Array".into()) + })?; + Encoder::Decimal256(DecimalEncoder::<32, Decimal256Array>::new(arr, *size)) + } + other => { + return Err(ArrowError::SchemaError(format!( + "Avro decimal site requires Arrow Decimal 32, 64, 128, or 256, found: {other:?}" + ))); + } + }, + FieldPlan::Uuid => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::SchemaError("Expected FixedSizeBinaryArray".into()) + })?; + Encoder::Uuid(UuidEncoder(arr)) + } + FieldPlan::Map { + values_nullability, + value_plan, + } => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::SchemaError("Expected MapArray".into()))?; + Encoder::Map(Box::new(MapEncoder::try_new( + arr, + *values_nullability, + value_plan.as_ref(), + )?)) + } + FieldPlan::Enum { symbols } => match array.data_type() { + DataType::Dictionary(key_dt, value_dt) => { + if **key_dt != DataType::Int32 || **value_dt != DataType::Utf8 { + return Err(ArrowError::SchemaError( + "Avro enum requires Dictionary".into(), + )); + } + let dict = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + ArrowError::SchemaError("Expected DictionaryArray".into()) + })?; + let values = dict + .values() + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::SchemaError("Dictionary values must be Utf8".into()) + })?; + if values.len() != symbols.len() { + return Err(ArrowError::SchemaError(format!( + "Enum symbol length {} != dictionary size {}", + symbols.len(), + values.len() + ))); + } + for i in 0..values.len() { + if values.value(i) != symbols[i].as_str() { + return Err(ArrowError::SchemaError(format!( + "Enum symbol mismatch at {i}: schema='{}' dict='{}'", + symbols[i], + values.value(i) + ))); + } + } + let keys = dict.keys(); + Encoder::Enum(EnumEncoder { keys }) + } + other => { + return Err(ArrowError::SchemaError(format!( + "Avro enum site requires DataType::Dictionary, found: {other:?}" + ))); + } + }, + FieldPlan::Union { bindings } => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| ArrowError::SchemaError("Expected UnionArray".into()))?; + Encoder::Union(Box::new(UnionEncoder::try_new(arr, bindings)?)) + } + FieldPlan::RunEndEncoded { + values_nullability, + value_plan, + } => { + // Helper closure to build a typed RunEncodedEncoder + let build = |run_arr_any: &'a dyn Array| -> Result, ArrowError> { + if let Some(arr) = run_arr_any.as_any().downcast_ref::>() { + return Ok(Encoder::RunEncoded16(Box::new(RunEncodedEncoder::< + Int16Type, + >::new( + arr, + FieldEncoder::make_encoder( + arr.values().as_ref(), + value_plan.as_ref(), + *values_nullability, + )?, + )))); + } + if let Some(arr) = run_arr_any.as_any().downcast_ref::>() { + return Ok(Encoder::RunEncoded32(Box::new(RunEncodedEncoder::< + Int32Type, + >::new( + arr, + FieldEncoder::make_encoder( + arr.values().as_ref(), + value_plan.as_ref(), + *values_nullability, + )?, + )))); + } + if let Some(arr) = run_arr_any.as_any().downcast_ref::>() { + return Ok(Encoder::RunEncoded64(Box::new(RunEncodedEncoder::< + Int64Type, + >::new( + arr, + FieldEncoder::make_encoder( + arr.values().as_ref(), + value_plan.as_ref(), + *values_nullability, + )?, + )))); + } + Err(ArrowError::SchemaError( + "Unsupported run-ends index type for RunEndEncoded; expected Int16/Int32/Int64" + .into(), + )) + }; + build(array)? + } + }; + // Compute the effective null state from writer-declared nullability and data nulls. + let null_state = match nullability { + None => NullState::NonNullable, + Some(null_order) => { + match array.nulls() { + Some(nulls) if array.null_count() > 0 => { + NullState::Nullable { nulls, null_order } + } + _ => NullState::NullableNoNulls { + // Nullable site with no null buffer for this view + union_value_byte: union_value_branch_byte(null_order, false), + }, + } + } + }; + Ok(Self { + encoder, + null_state, + }) + } + + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + match &self.null_state { + NullState::NonNullable => {} + NullState::NullableNoNulls { union_value_byte } => out + .write_all(&[*union_value_byte]) + .map_err(|e| ArrowError::IoError(format!("write union value branch: {e}"), e))?, + NullState::Nullable { nulls, null_order } if nulls.is_null(idx) => { + return write_optional_index(out, true, *null_order); // no value to write + } + NullState::Nullable { null_order, .. } => { + write_optional_index(out, false, *null_order)?; + } + } + self.encoder.encode(out, idx) + } +} + +fn union_value_branch_byte(null_order: Nullability, is_null: bool) -> u8 { + let nulls_first = null_order == Nullability::default(); + if nulls_first == is_null { 0x00 } else { 0x02 } +} + +/// Per‑site encoder plan for a field. This mirrors the Avro structure, so nested +/// optional branch order can be honored exactly as declared by the schema. +#[derive(Debug, Clone)] +enum FieldPlan { + /// Non-nested scalar/logical type + Scalar, + /// Record/Struct with Avro‑ordered children + Struct { bindings: Vec }, + /// Array with item‑site nullability and nested plan + List { + items_nullability: Option, + item_plan: Box, + }, + /// Avro decimal logical type (bytes or fixed). `size=None` => bytes(decimal), `Some(n)` => fixed(n) + Decimal { size: Option }, + /// Avro UUID logical type (fixed) + Uuid, + /// Avro map with value‑site nullability and nested plan + Map { + values_nullability: Option, + value_plan: Box, + }, + /// Avro enum; maps to Arrow Dictionary with dictionary values + /// exactly equal and ordered as the Avro enum `symbols`. + Enum { symbols: Arc<[String]> }, + /// Avro union, maps to Arrow Union. + Union { bindings: Vec }, + /// Avro RunEndEncoded site. Values are encoded per logical row by mapping the + /// row index to its containing run and emitting that run's value with `value_plan`. + RunEndEncoded { + values_nullability: Option, + value_plan: Box, + }, +} + +#[derive(Debug, Clone)] +struct FieldBinding { + /// Index of the Arrow field/column associated with this Avro field site + arrow_index: usize, + /// Nullability/order for this site (None for required fields) + nullability: Option, + /// Nested plan for this site + plan: FieldPlan, +} + +/// Builder for `RecordEncoder` write plan +#[derive(Debug)] +pub(crate) struct RecordEncoderBuilder<'a> { + avro_root: &'a AvroField, + arrow_schema: &'a ArrowSchema, + fingerprint: Option, +} + +impl<'a> RecordEncoderBuilder<'a> { + /// Create a new builder from the Avro root and Arrow schema. + pub(crate) fn new(avro_root: &'a AvroField, arrow_schema: &'a ArrowSchema) -> Self { + Self { + avro_root, + arrow_schema, + fingerprint: None, + } + } + + pub(crate) fn with_fingerprint(mut self, fingerprint: Option) -> Self { + self.fingerprint = fingerprint; + self + } + + /// Build the `RecordEncoder` by walking the Avro **record** root in Avro order, + /// resolving each field to an Arrow index by name. + pub(crate) fn build(self) -> Result { + let avro_root_dt = self.avro_root.data_type(); + let Codec::Struct(root_fields) = avro_root_dt.codec() else { + return Err(ArrowError::SchemaError( + "Top-level Avro schema must be a record/struct".into(), + )); + }; + let mut columns = Vec::with_capacity(root_fields.len()); + for root_field in root_fields.as_ref() { + let name = root_field.name(); + let arrow_index = self.arrow_schema.index_of(name).map_err(|e| { + ArrowError::SchemaError(format!("Schema mismatch for field '{name}': {e}")) + })?; + columns.push(FieldBinding { + arrow_index, + nullability: root_field.data_type().nullability(), + plan: FieldPlan::build( + root_field.data_type(), + self.arrow_schema.field(arrow_index), + )?, + }); + } + Ok(RecordEncoder { + columns, + prefix: self.fingerprint.map(|fp| fp.make_prefix()), + }) + } +} + +/// A pre-computed plan for encoding a `RecordBatch` to Avro. +/// +/// Derived from an Avro schema and an Arrow schema. It maps +/// top-level Avro fields to Arrow columns and contains a nested encoding plan +/// for each column. +#[derive(Debug, Clone)] +pub(crate) struct RecordEncoder { + columns: Vec, + /// Optional pre-built, variable-length prefix written before each record. + prefix: Option, +} + +impl RecordEncoder { + fn prepare_for_batch<'a>( + &'a self, + batch: &'a RecordBatch, + ) -> Result>, ArrowError> { + let arrays = batch.columns(); + let mut out = Vec::with_capacity(self.columns.len()); + for col_plan in self.columns.iter() { + let arrow_index = col_plan.arrow_index; + let array = arrays.get(arrow_index).ok_or_else(|| { + ArrowError::SchemaError(format!("Column index {arrow_index} out of range")) + })?; + #[cfg(not(feature = "avro_custom_types"))] + let site_nullability = match &col_plan.plan { + FieldPlan::RunEndEncoded { .. } => None, + _ => col_plan.nullability, + }; + #[cfg(feature = "avro_custom_types")] + let site_nullability = col_plan.nullability; + out.push(FieldEncoder::make_encoder( + array.as_ref(), + &col_plan.plan, + site_nullability, + )?); + } + Ok(out) + } + + /// Encode a `RecordBatch` using this encoder plan. + /// + /// Tip: Wrap `out` in a `std::io::BufWriter` to reduce the overhead of many small writes. + pub(crate) fn encode( + &self, + out: &mut W, + batch: &RecordBatch, + ) -> Result<(), ArrowError> { + let mut column_encoders = self.prepare_for_batch(batch)?; + let n = batch.num_rows(); + match self.prefix { + Some(prefix) => { + for row in 0..n { + out.write_all(prefix.as_slice()) + .map_err(|e| ArrowError::IoError(format!("write prefix: {e}"), e))?; + for enc in column_encoders.iter_mut() { + enc.encode(out, row)?; + } + } + } + None => { + for row in 0..n { + for enc in column_encoders.iter_mut() { + enc.encode(out, row)?; + } + } + } + } + Ok(()) + } +} + +fn find_struct_child_index(fields: &arrow_schema::Fields, name: &str) -> Option { + fields.iter().position(|f| f.name() == name) +} + +fn find_map_value_field_index(fields: &arrow_schema::Fields) -> Option { + // Prefer common Arrow field names; fall back to second child if exactly two + find_struct_child_index(fields, "value") + .or_else(|| find_struct_child_index(fields, "values")) + .or_else(|| if fields.len() == 2 { Some(1) } else { None }) +} + +impl FieldPlan { + fn build(avro_dt: &AvroDataType, arrow_field: &Field) -> Result { + #[cfg(not(feature = "avro_custom_types"))] + if let DataType::RunEndEncoded(_re_field, values_field) = arrow_field.data_type() { + let values_nullability = avro_dt.nullability(); + let value_site_dt: &AvroDataType = match avro_dt.codec() { + Codec::Union(branches, _, _) => branches + .iter() + .find(|b| !matches!(b.codec(), Codec::Null)) + .ok_or_else(|| { + ArrowError::SchemaError( + "Avro union at RunEndEncoded site has no non-null branch".into(), + ) + })?, + _ => avro_dt, + }; + return Ok(FieldPlan::RunEndEncoded { + values_nullability, + value_plan: Box::new(FieldPlan::build(value_site_dt, values_field.as_ref())?), + }); + } + if let DataType::FixedSizeBinary(len) = arrow_field.data_type() { + // Extension-based detection (only when the feature is enabled) + let ext_is_uuid = { + #[cfg(feature = "canonical_extension_types")] + { + matches!( + arrow_field.extension_type_name(), + Some("arrow.uuid") | Some("uuid") + ) + } + #[cfg(not(feature = "canonical_extension_types"))] + { + false + } + }; + let md_is_uuid = arrow_field + .metadata() + .get("logicalType") + .map(|s| s.as_str()) + == Some("uuid"); + if ext_is_uuid || md_is_uuid { + if *len != 16 { + return Err(ArrowError::InvalidArgumentError( + "logicalType=uuid requires FixedSizeBinary(16)".into(), + )); + } + return Ok(FieldPlan::Uuid); + } + } + match avro_dt.codec() { + Codec::Struct(avro_fields) => { + let fields = match arrow_field.data_type() { + DataType::Struct(struct_fields) => struct_fields, + other => { + return Err(ArrowError::SchemaError(format!( + "Avro struct maps to Arrow Struct, found: {other:?}" + ))); + } + }; + let mut bindings = Vec::with_capacity(avro_fields.len()); + for avro_field in avro_fields.iter() { + let name = avro_field.name().to_string(); + let idx = find_struct_child_index(fields, &name).ok_or_else(|| { + ArrowError::SchemaError(format!( + "Struct field '{name}' not present in Arrow field '{}'", + arrow_field.name() + )) + })?; + bindings.push(FieldBinding { + arrow_index: idx, + nullability: avro_field.data_type().nullability(), + plan: FieldPlan::build(avro_field.data_type(), fields[idx].as_ref())?, + }); + } + Ok(FieldPlan::Struct { bindings }) + } + Codec::List(items_dt) => match arrow_field.data_type() { + DataType::List(field_ref) + | DataType::LargeList(field_ref) + | DataType::ListView(field_ref) + | DataType::LargeListView(field_ref) => Ok(FieldPlan::List { + items_nullability: items_dt.nullability(), + item_plan: Box::new(FieldPlan::build(items_dt.as_ref(), field_ref.as_ref())?), + }), + DataType::FixedSizeList(field_ref, _len) => Ok(FieldPlan::List { + items_nullability: items_dt.nullability(), + item_plan: Box::new(FieldPlan::build(items_dt.as_ref(), field_ref.as_ref())?), + }), + other => Err(ArrowError::SchemaError(format!( + "Avro array maps to Arrow List/LargeList/ListView/LargeListView/FixedSizeList, found: {other:?}" + ))), + }, + Codec::Map(values_dt) => { + let entries_field = match arrow_field.data_type() { + DataType::Map(entries, _sorted) => entries.as_ref(), + other => { + return Err(ArrowError::SchemaError(format!( + "Avro map maps to Arrow DataType::Map, found: {other:?}" + ))); + } + }; + let entries_struct_fields = match entries_field.data_type() { + DataType::Struct(fs) => fs, + other => { + return Err(ArrowError::SchemaError(format!( + "Arrow Map entries must be Struct, found: {other:?}" + ))); + } + }; + let value_idx = + find_map_value_field_index(entries_struct_fields).ok_or_else(|| { + ArrowError::SchemaError("Map entries struct missing value field".into()) + })?; + let value_field = entries_struct_fields[value_idx].as_ref(); + let value_plan = FieldPlan::build(values_dt.as_ref(), value_field)?; + Ok(FieldPlan::Map { + values_nullability: values_dt.nullability(), + value_plan: Box::new(value_plan), + }) + } + Codec::Enum(symbols) => match arrow_field.data_type() { + DataType::Dictionary(key_dt, value_dt) => { + if **key_dt != DataType::Int32 { + return Err(ArrowError::SchemaError( + "Avro enum requires Dictionary".into(), + )); + } + if **value_dt != DataType::Utf8 { + return Err(ArrowError::SchemaError( + "Avro enum requires Dictionary".into(), + )); + } + Ok(FieldPlan::Enum { + symbols: symbols.clone(), + }) + } + other => Err(ArrowError::SchemaError(format!( + "Avro enum maps to Arrow Dictionary, found: {other:?}" + ))), + }, + // decimal site (bytes or fixed(N)) with precision/scale validation + Codec::Decimal(precision, scale_opt, fixed_size_opt) => { + let (ap, as_) = match arrow_field.data_type() { + #[cfg(feature = "small_decimals")] + DataType::Decimal32(p, s) => (*p as usize, *s as i32), + #[cfg(feature = "small_decimals")] + DataType::Decimal64(p, s) => (*p as usize, *s as i32), + DataType::Decimal128(p, s) => (*p as usize, *s as i32), + DataType::Decimal256(p, s) => (*p as usize, *s as i32), + other => { + return Err(ArrowError::SchemaError(format!( + "Avro decimal requires Arrow decimal, got {other:?} for field '{}'", + arrow_field.name() + ))); + } + }; + let sc = scale_opt.unwrap_or(0) as i32; // Avro scale defaults to 0 if absent + if ap != *precision || as_ != sc { + return Err(ArrowError::SchemaError(format!( + "Decimal precision/scale mismatch for field '{}': Avro({precision},{sc}) vs Arrow({ap},{as_})", + arrow_field.name() + ))); + } + Ok(FieldPlan::Decimal { + size: *fixed_size_opt, + }) + } + Codec::Interval => match arrow_field.data_type() { + DataType::Interval( + IntervalUnit::MonthDayNano | IntervalUnit::YearMonth | IntervalUnit::DayTime, + ) => Ok(FieldPlan::Scalar), + other => Err(ArrowError::SchemaError(format!( + "Avro duration logical type requires Arrow Interval(MonthDayNano), found: {other:?}" + ))), + }, + Codec::Union(avro_branches, _, UnionMode::Dense) => { + let arrow_union_fields = match arrow_field.data_type() { + DataType::Union(fields, UnionMode::Dense) => fields, + DataType::Union(_, UnionMode::Sparse) => { + return Err(ArrowError::NotYetImplemented( + "Sparse Arrow unions are not yet supported".to_string(), + )); + } + other => { + return Err(ArrowError::SchemaError(format!( + "Avro union maps to Arrow Union, found: {other:?}" + ))); + } + }; + if avro_branches.len() != arrow_union_fields.len() { + return Err(ArrowError::SchemaError(format!( + "Mismatched number of branches between Avro union ({}) and Arrow union ({}) for field '{}'", + avro_branches.len(), + arrow_union_fields.len(), + arrow_field.name() + ))); + } + let bindings = avro_branches + .iter() + .zip(arrow_union_fields.iter()) + .enumerate() + .map(|(i, (avro_branch, (_, arrow_child_field)))| { + Ok(FieldBinding { + arrow_index: i, + nullability: avro_branch.nullability(), + plan: FieldPlan::build(avro_branch, arrow_child_field)?, + }) + }) + .collect::, ArrowError>>()?; + Ok(FieldPlan::Union { bindings }) + } + Codec::Union(_, _, UnionMode::Sparse) => Err(ArrowError::NotYetImplemented( + "Sparse Arrow unions are not yet supported".to_string(), + )), + #[cfg(feature = "avro_custom_types")] + Codec::RunEndEncoded(values_dt, _width_code) => { + let values_field = match arrow_field.data_type() { + DataType::RunEndEncoded(_run_ends_field, values_field) => values_field.as_ref(), + other => { + return Err(ArrowError::SchemaError(format!( + "Avro RunEndEncoded maps to Arrow DataType::RunEndEncoded, found: {other:?}" + ))); + } + }; + Ok(FieldPlan::RunEndEncoded { + values_nullability: values_dt.nullability(), + value_plan: Box::new(FieldPlan::build(values_dt.as_ref(), values_field)?), + }) + } + _ => Ok(FieldPlan::Scalar), + } + } +} + +enum Encoder<'a> { + Boolean(BooleanEncoder<'a>), + Int(IntEncoder<'a, Int32Type>), + Long(LongEncoder<'a, Int64Type>), + TimestampMicros(LongEncoder<'a, TimestampMicrosecondType>), + TimestampMillis(LongEncoder<'a, TimestampMillisecondType>), + TimestampNanos(LongEncoder<'a, TimestampNanosecondType>), + TimestampSecsToMillis(TimestampSecondsToMillisEncoder<'a>), + Date32(IntEncoder<'a, Date32Type>), + Time32SecsToMillis(Time32SecondsToMillisEncoder<'a>), + Time32Millis(IntEncoder<'a, Time32MillisecondType>), + Time64Micros(LongEncoder<'a, Time64MicrosecondType>), + DurationSeconds(LongEncoder<'a, DurationSecondType>), + DurationMillis(LongEncoder<'a, DurationMillisecondType>), + DurationMicros(LongEncoder<'a, DurationMicrosecondType>), + DurationNanos(LongEncoder<'a, DurationNanosecondType>), + Float32(F32Encoder<'a>), + Float64(F64Encoder<'a>), + Binary(BinaryEncoder<'a, i32>), + LargeBinary(BinaryEncoder<'a, i64>), + Utf8(Utf8Encoder<'a>), + Utf8Large(Utf8LargeEncoder<'a>), + Utf8View(Utf8ViewEncoder<'a>), + BinaryView(BinaryViewEncoder<'a>), + List(Box>), + LargeList(Box>), + ListView(Box>), + LargeListView(Box>), + FixedSizeList(Box>), + Struct(Box>), + /// Avro `fixed` encoder (raw bytes, no length) + Fixed(FixedEncoder<'a>), + /// Avro `uuid` logical type encoder (string with RFC‑4122 hyphenated text) + Uuid(UuidEncoder<'a>), + /// Avro `duration` logical type (Arrow Interval(MonthDayNano)) encoder + IntervalMonthDayNano(DurationEncoder<'a, IntervalMonthDayNanoType>), + /// Avro `duration` logical type (Arrow Interval(YearMonth)) encoder + IntervalYearMonth(DurationEncoder<'a, IntervalYearMonthType>), + /// Avro `duration` logical type (Arrow Interval(DayTime)) encoder + IntervalDayTime(DurationEncoder<'a, IntervalDayTimeType>), + #[cfg(feature = "small_decimals")] + Decimal32(Decimal32Encoder<'a>), + #[cfg(feature = "small_decimals")] + Decimal64(Decimal64Encoder<'a>), + Decimal128(Decimal128Encoder<'a>), + Decimal256(Decimal256Encoder<'a>), + /// Avro `enum` encoder: writes the key (int) as the enum index. + Enum(EnumEncoder<'a>), + Map(Box>), + Union(Box>), + /// Run-end encoded values with specific run-end index widths + RunEncoded16(Box>), + RunEncoded32(Box>), + RunEncoded64(Box>), + Null, +} + +impl<'a> Encoder<'a> { + /// Encode the value at `idx`. + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + match self { + Encoder::Boolean(e) => e.encode(out, idx), + Encoder::Int(e) => e.encode(out, idx), + Encoder::Long(e) => e.encode(out, idx), + Encoder::TimestampMicros(e) => e.encode(out, idx), + Encoder::TimestampMillis(e) => e.encode(out, idx), + Encoder::TimestampNanos(e) => e.encode(out, idx), + Encoder::TimestampSecsToMillis(e) => e.encode(out, idx), + Encoder::Date32(e) => e.encode(out, idx), + Encoder::Time32SecsToMillis(e) => e.encode(out, idx), + Encoder::Time32Millis(e) => e.encode(out, idx), + Encoder::Time64Micros(e) => e.encode(out, idx), + Encoder::DurationSeconds(e) => e.encode(out, idx), + Encoder::DurationMicros(e) => e.encode(out, idx), + Encoder::DurationMillis(e) => e.encode(out, idx), + Encoder::DurationNanos(e) => e.encode(out, idx), + Encoder::Float32(e) => e.encode(out, idx), + Encoder::Float64(e) => e.encode(out, idx), + Encoder::Binary(e) => e.encode(out, idx), + Encoder::LargeBinary(e) => e.encode(out, idx), + Encoder::Utf8(e) => e.encode(out, idx), + Encoder::Utf8Large(e) => e.encode(out, idx), + Encoder::Utf8View(e) => e.encode(out, idx), + Encoder::BinaryView(e) => e.encode(out, idx), + Encoder::List(e) => e.encode(out, idx), + Encoder::LargeList(e) => e.encode(out, idx), + Encoder::ListView(e) => e.encode(out, idx), + Encoder::LargeListView(e) => e.encode(out, idx), + Encoder::FixedSizeList(e) => e.encode(out, idx), + Encoder::Struct(e) => e.encode(out, idx), + Encoder::Fixed(e) => (e).encode(out, idx), + Encoder::Uuid(e) => (e).encode(out, idx), + Encoder::IntervalMonthDayNano(e) => (e).encode(out, idx), + Encoder::IntervalYearMonth(e) => (e).encode(out, idx), + Encoder::IntervalDayTime(e) => (e).encode(out, idx), + #[cfg(feature = "small_decimals")] + Encoder::Decimal32(e) => (e).encode(out, idx), + #[cfg(feature = "small_decimals")] + Encoder::Decimal64(e) => (e).encode(out, idx), + Encoder::Decimal128(e) => (e).encode(out, idx), + Encoder::Decimal256(e) => (e).encode(out, idx), + Encoder::Map(e) => (e).encode(out, idx), + Encoder::Enum(e) => (e).encode(out, idx), + Encoder::Union(e) => (e).encode(out, idx), + Encoder::RunEncoded16(e) => (e).encode(out, idx), + Encoder::RunEncoded32(e) => (e).encode(out, idx), + Encoder::RunEncoded64(e) => (e).encode(out, idx), + Encoder::Null => Ok(()), + } + } +} + +struct BooleanEncoder<'a>(&'a arrow_array::BooleanArray); +impl BooleanEncoder<'_> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + write_bool(out, self.0.value(idx)) + } +} + +/// Generic Avro `int` encoder for primitive arrays with `i32` native values. +struct IntEncoder<'a, P: ArrowPrimitiveType>(&'a PrimitiveArray

); +impl<'a, P: ArrowPrimitiveType> IntEncoder<'a, P> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + write_int(out, self.0.value(idx)) + } +} + +/// Generic Avro `long` encoder for primitive arrays with `i64` native values. +struct LongEncoder<'a, P: ArrowPrimitiveType>(&'a PrimitiveArray

); +impl<'a, P: ArrowPrimitiveType> LongEncoder<'a, P> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + write_long(out, self.0.value(idx)) + } +} + +/// Time32(Second) to Avro time-millis (int), via safe scaling by 1000 +struct Time32SecondsToMillisEncoder<'a>(&'a PrimitiveArray); +impl<'a> Time32SecondsToMillisEncoder<'a> { + #[inline] + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + let secs = self.0.value(idx); + let millis = secs.checked_mul(1000).ok_or_else(|| { + ArrowError::InvalidArgumentError("time32(secs) * 1000 overflowed".into()) + })?; + write_int(out, millis) + } +} + +/// Timestamp(Second) to Avro timestamp-millis (long), via safe scaling by 1000 +struct TimestampSecondsToMillisEncoder<'a>(&'a PrimitiveArray); +impl<'a> TimestampSecondsToMillisEncoder<'a> { + #[inline] + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + let secs = self.0.value(idx); + let millis = secs.checked_mul(1000).ok_or_else(|| { + ArrowError::InvalidArgumentError("timestamp(secs) * 1000 overflowed".into()) + })?; + write_long(out, millis) + } +} + +/// Unified binary encoder generic over offset size (i32/i64). +struct BinaryEncoder<'a, O: OffsetSizeTrait>(&'a GenericBinaryArray); +impl<'a, O: OffsetSizeTrait> BinaryEncoder<'a, O> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + write_len_prefixed(out, self.0.value(idx)) + } +} + +/// BinaryView (byte view) encoder. +struct BinaryViewEncoder<'a>(&'a BinaryViewArray); +impl BinaryViewEncoder<'_> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + write_len_prefixed(out, self.0.value(idx)) + } +} + +/// StringView encoder. +struct Utf8ViewEncoder<'a>(&'a StringViewArray); +impl Utf8ViewEncoder<'_> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + write_len_prefixed(out, self.0.value(idx).as_bytes()) + } +} + +struct F32Encoder<'a>(&'a arrow_array::Float32Array); +impl F32Encoder<'_> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + // Avro float: 4 bytes, IEEE-754 little-endian + let bits = self.0.value(idx).to_bits(); + out.write_all(&bits.to_le_bytes()) + .map_err(|e| ArrowError::IoError(format!("write f32: {e}"), e)) + } +} + +struct F64Encoder<'a>(&'a arrow_array::Float64Array); +impl F64Encoder<'_> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + // Avro double: 8 bytes, IEEE-754 little-endian + let bits = self.0.value(idx).to_bits(); + out.write_all(&bits.to_le_bytes()) + .map_err(|e| ArrowError::IoError(format!("write f64: {e}"), e)) + } +} + +struct Utf8GenericEncoder<'a, O: OffsetSizeTrait>(&'a GenericStringArray); + +impl<'a, O: OffsetSizeTrait> Utf8GenericEncoder<'a, O> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + write_len_prefixed(out, self.0.value(idx).as_bytes()) + } +} + +type Utf8Encoder<'a> = Utf8GenericEncoder<'a, i32>; +type Utf8LargeEncoder<'a> = Utf8GenericEncoder<'a, i64>; + +/// Internal key array kind used by Map encoder. +enum KeyKind<'a> { + Utf8(&'a GenericStringArray), + LargeUtf8(&'a GenericStringArray), +} +struct MapEncoder<'a> { + map: &'a MapArray, + keys: KeyKind<'a>, + values: FieldEncoder<'a>, + keys_offset: usize, + values_offset: usize, +} + +impl<'a> MapEncoder<'a> { + fn try_new( + map: &'a MapArray, + values_nullability: Option, + value_plan: &FieldPlan, + ) -> Result { + let keys_arr = map.keys(); + let keys_kind = match keys_arr.data_type() { + DataType::Utf8 => KeyKind::Utf8(keys_arr.as_string::()), + DataType::LargeUtf8 => KeyKind::LargeUtf8(keys_arr.as_string::()), + other => { + return Err(ArrowError::SchemaError(format!( + "Avro map requires string keys; Arrow key type must be Utf8/LargeUtf8, found: {other:?}" + ))); + } + }; + Ok(Self { + map, + keys: keys_kind, + values: FieldEncoder::make_encoder( + map.values().as_ref(), + value_plan, + values_nullability, + )?, + keys_offset: keys_arr.offset(), + values_offset: map.values().offset(), + }) + } + + fn encode_map_entries( + out: &mut W, + keys: &GenericStringArray, + keys_offset: usize, + start: usize, + end: usize, + mut write_item: impl FnMut(&mut W, usize) -> Result<(), ArrowError>, + ) -> Result<(), ArrowError> + where + W: Write + ?Sized, + O: OffsetSizeTrait, + { + encode_blocked_range(out, start, end, |out, j| { + let j_key = j.saturating_sub(keys_offset); + write_len_prefixed(out, keys.value(j_key).as_bytes())?; + write_item(out, j) + }) + } + + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + let offsets = self.map.offsets(); + let start = offsets[idx] as usize; + let end = offsets[idx + 1] as usize; + let write_item = |out: &mut W, j: usize| { + let j_val = j.saturating_sub(self.values_offset); + self.values.encode(out, j_val) + }; + match self.keys { + KeyKind::Utf8(arr) => MapEncoder::<'a>::encode_map_entries( + out, + arr, + self.keys_offset, + start, + end, + write_item, + ), + KeyKind::LargeUtf8(arr) => MapEncoder::<'a>::encode_map_entries( + out, + arr, + self.keys_offset, + start, + end, + write_item, + ), + } + } +} + +/// Avro `enum` encoder for Arrow `DictionaryArray`. +/// +/// Per Avro spec, an enum is encoded as an **int** equal to the +/// zero-based position of the symbol in the schema’s `symbols` list. +/// We validate at construction that the dictionary values equal the symbols, +/// so we can directly write the key value here. +struct EnumEncoder<'a> { + keys: &'a PrimitiveArray, +} +impl EnumEncoder<'_> { + fn encode(&mut self, out: &mut W, row: usize) -> Result<(), ArrowError> { + write_int(out, self.keys.value(row)) + } +} + +struct UnionEncoder<'a> { + encoders: Vec>, + array: &'a UnionArray, + type_id_to_encoder_index: Vec>, +} + +impl<'a> UnionEncoder<'a> { + fn try_new(array: &'a UnionArray, field_bindings: &[FieldBinding]) -> Result { + let DataType::Union(fields, UnionMode::Dense) = array.data_type() else { + return Err(ArrowError::SchemaError("Expected Dense UnionArray".into())); + }; + if fields.len() != field_bindings.len() { + return Err(ArrowError::SchemaError(format!( + "Mismatched number of union branches between Arrow array ({}) and encoding plan ({})", + fields.len(), + field_bindings.len() + ))); + } + let max_type_id = fields.iter().map(|(tid, _)| tid).max().unwrap_or(0); + let mut type_id_to_encoder_index: Vec> = + vec![None; (max_type_id + 1) as usize]; + let mut encoders = Vec::with_capacity(fields.len()); + for (i, (type_id, _)) in fields.iter().enumerate() { + let binding = field_bindings + .get(i) + .ok_or_else(|| ArrowError::SchemaError("Binding and field mismatch".to_string()))?; + encoders.push(FieldEncoder::make_encoder( + array.child(type_id).as_ref(), + &binding.plan, + binding.nullability, + )?); + type_id_to_encoder_index[type_id as usize] = Some(i); + } + Ok(Self { + encoders, + array, + type_id_to_encoder_index, + }) + } + + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + // SAFETY: `idx` is always in bounds because: + // 1. The encoder is called from `RecordEncoder::encode,` which iterates over `0..batch.num_rows()` + // 2. `self.array` is a column from the same batch, so its length equals `batch.num_rows()` + // 3. `type_ids()` returns a buffer with exactly `self.array.len()` entries (one per logical element) + let type_id = self.array.type_ids()[idx]; + let encoder_index = self + .type_id_to_encoder_index + .get(type_id as usize) + .and_then(|opt| *opt) + .ok_or_else(|| ArrowError::SchemaError(format!("Invalid type_id {type_id}")))?; + write_int(out, encoder_index as i32)?; + let encoder = self.encoders.get_mut(encoder_index).ok_or_else(|| { + ArrowError::SchemaError(format!("Invalid encoder index {encoder_index}")) + })?; + encoder.encode(out, self.array.value_offset(idx)) + } +} + +struct StructEncoder<'a> { + encoders: Vec>, +} + +impl<'a> StructEncoder<'a> { + fn try_new( + array: &'a StructArray, + field_bindings: &[FieldBinding], + ) -> Result { + let mut encoders = Vec::with_capacity(field_bindings.len()); + for field_binding in field_bindings { + let idx = field_binding.arrow_index; + let column = array.columns().get(idx).ok_or_else(|| { + ArrowError::SchemaError(format!("Struct child index {idx} out of range")) + })?; + let encoder = FieldEncoder::make_encoder( + column.as_ref(), + &field_binding.plan, + field_binding.nullability, + )?; + encoders.push(encoder); + } + Ok(Self { encoders }) + } + + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + for encoder in self.encoders.iter_mut() { + encoder.encode(out, idx)?; + } + Ok(()) + } +} + +/// Encode a blocked range of items with Avro array block framing. +/// +/// `write_item` must take `(out, index)` to maintain the "out-first" convention. +fn encode_blocked_range( + out: &mut W, + start: usize, + end: usize, + mut write_item: F, +) -> Result<(), ArrowError> +where + F: FnMut(&mut W, usize) -> Result<(), ArrowError>, +{ + let len = end.saturating_sub(start); + if len == 0 { + // Zero-length terminator per Avro spec. + write_long(out, 0)?; + return Ok(()); + } + // Emit a single positive block for performance, then the end marker. + write_long(out, len as i64)?; + for row in start..end { + write_item(out, row)?; + } + write_long(out, 0)?; + Ok(()) +} + +struct ListEncoder<'a, O: OffsetSizeTrait> { + list: &'a GenericListArray, + values: FieldEncoder<'a>, + values_offset: usize, +} + +type ListEncoder32<'a> = ListEncoder<'a, i32>; +type ListEncoder64<'a> = ListEncoder<'a, i64>; + +impl<'a, O: OffsetSizeTrait> ListEncoder<'a, O> { + fn try_new( + list: &'a GenericListArray, + items_nullability: Option, + item_plan: &FieldPlan, + ) -> Result { + Ok(Self { + list, + values: FieldEncoder::make_encoder( + list.values().as_ref(), + item_plan, + items_nullability, + )?, + values_offset: list.values().offset(), + }) + } + + fn encode_list_range( + &mut self, + out: &mut W, + start: usize, + end: usize, + ) -> Result<(), ArrowError> { + encode_blocked_range(out, start, end, |out, row| { + self.values + .encode(out, row.saturating_sub(self.values_offset)) + }) + } + + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + let offsets = self.list.offsets(); + let start = offsets[idx].to_usize().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!("Error converting offset[{idx}] to usize")) + })?; + let end = offsets[idx + 1].to_usize().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Error converting offset[{}] to usize", + idx + 1 + )) + })?; + self.encode_list_range(out, start, end) + } +} + +/// ListView encoder using `(offset, size)` buffers. +struct ListViewEncoder<'a, O: OffsetSizeTrait> { + list: &'a GenericListViewArray, + values: FieldEncoder<'a>, + values_offset: usize, +} +type ListViewEncoder32<'a> = ListViewEncoder<'a, i32>; +type ListViewEncoder64<'a> = ListViewEncoder<'a, i64>; + +impl<'a, O: OffsetSizeTrait> ListViewEncoder<'a, O> { + fn try_new( + list: &'a GenericListViewArray, + items_nullability: Option, + item_plan: &FieldPlan, + ) -> Result { + Ok(Self { + list, + values: FieldEncoder::make_encoder( + list.values().as_ref(), + item_plan, + items_nullability, + )?, + values_offset: list.values().offset(), + }) + } + + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + let start = self.list.value_offset(idx).to_usize().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!( + "Error converting value_offset[{idx}] to usize" + )) + })?; + let len = self.list.value_size(idx).to_usize().ok_or_else(|| { + ArrowError::InvalidArgumentError(format!("Error converting value_size[{idx}] to usize")) + })?; + let start = start + self.values_offset; + let end = start + len; + encode_blocked_range(out, start, end, |out, row| { + self.values + .encode(out, row.saturating_sub(self.values_offset)) + }) + } +} + +/// FixedSizeList encoder. +struct FixedSizeListEncoder<'a> { + list: &'a FixedSizeListArray, + values: FieldEncoder<'a>, + values_offset: usize, + elem_len: usize, +} + +impl<'a> FixedSizeListEncoder<'a> { + fn try_new( + list: &'a FixedSizeListArray, + items_nullability: Option, + item_plan: &FieldPlan, + ) -> Result { + Ok(Self { + list, + values: FieldEncoder::make_encoder( + list.values().as_ref(), + item_plan, + items_nullability, + )?, + values_offset: list.values().offset(), + elem_len: list.value_length() as usize, + }) + } + + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + // Starting index is relative to values() start + let rel = self.list.value_offset(idx) as usize; + let start = self.values_offset + rel; + let end = start + self.elem_len; + encode_blocked_range(out, start, end, |out, row| { + self.values + .encode(out, row.saturating_sub(self.values_offset)) + }) + } +} + +/// Avro `fixed` encoder for Arrow `FixedSizeBinaryArray`. +/// Spec: a fixed is encoded as exactly `size` bytes, with no length prefix. +struct FixedEncoder<'a>(&'a FixedSizeBinaryArray); +impl FixedEncoder<'_> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + let v = self.0.value(idx); // &[u8] of fixed width + out.write_all(v) + .map_err(|e| ArrowError::IoError(format!("write fixed bytes: {e}"), e)) + } +} + +/// Avro UUID logical type encoder: Arrow FixedSizeBinary(16) to Avro string (UUID). +/// Spec: uuid is a logical type over string (RFC‑4122). We output hyphenated form. +struct UuidEncoder<'a>(&'a FixedSizeBinaryArray); +impl UuidEncoder<'_> { + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + let mut buf = [0u8; 1 + uuid::fmt::Hyphenated::LENGTH]; + buf[0] = 0x48; + let v = self.0.value(idx); + let u = Uuid::from_slice(v) + .map_err(|e| ArrowError::InvalidArgumentError(format!("Invalid UUID bytes: {e}")))?; + let _ = u.hyphenated().encode_lower(&mut buf[1..]); + out.write_all(&buf) + .map_err(|e| ArrowError::IoError(format!("write uuid: {e}"), e)) + } +} + +#[derive(Copy, Clone)] +struct DurationParts { + months: u32, + days: u32, + millis: u32, +} +/// Trait mapping an Arrow interval native value to Avro duration `(months, days, millis)`. +trait IntervalToDurationParts: ArrowPrimitiveType { + fn duration_parts(native: Self::Native) -> Result; +} +impl IntervalToDurationParts for IntervalMonthDayNanoType { + fn duration_parts(native: Self::Native) -> Result { + let (months, days, nanos) = IntervalMonthDayNanoType::to_parts(native); + if months < 0 || days < 0 || nanos < 0 { + return Err(ArrowError::InvalidArgumentError( + "Avro 'duration' cannot encode negative months/days/nanoseconds".into(), + )); + } + if nanos % 1_000_000 != 0 { + return Err(ArrowError::InvalidArgumentError( + "Avro 'duration' requires whole milliseconds; nanoseconds must be divisible by 1_000_000" + .into(), + )); + } + let millis = nanos / 1_000_000; + if millis > u32::MAX as i64 { + return Err(ArrowError::InvalidArgumentError( + "Avro 'duration' milliseconds exceed u32::MAX".into(), + )); + } + Ok(DurationParts { + months: months as u32, + days: days as u32, + millis: millis as u32, + }) + } +} +impl IntervalToDurationParts for IntervalYearMonthType { + fn duration_parts(native: Self::Native) -> Result { + if native < 0 { + return Err(ArrowError::InvalidArgumentError( + "Avro 'duration' cannot encode negative months".into(), + )); + } + Ok(DurationParts { + months: native as u32, + days: 0, + millis: 0, + }) + } +} +impl IntervalToDurationParts for IntervalDayTimeType { + fn duration_parts(native: Self::Native) -> Result { + let (days, millis) = IntervalDayTimeType::to_parts(native); + if days < 0 || millis < 0 { + return Err(ArrowError::InvalidArgumentError( + "Avro 'duration' cannot encode negative days or milliseconds".into(), + )); + } + Ok(DurationParts { + months: 0, + days: days as u32, + millis: millis as u32, + }) + } +} + +/// Single generic encoder used for all three interval units. +/// Writes Avro `fixed(12)` as three little-endian u32 values in one call. +struct DurationEncoder<'a, P: ArrowPrimitiveType + IntervalToDurationParts>(&'a PrimitiveArray

); +impl<'a, P: ArrowPrimitiveType + IntervalToDurationParts> DurationEncoder<'a, P> { + #[inline(always)] + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + let parts = P::duration_parts(self.0.value(idx))?; + let months = parts.months.to_le_bytes(); + let days = parts.days.to_le_bytes(); + let ms = parts.millis.to_le_bytes(); + // SAFETY + // - Endianness & layout: Avro's `duration` logical type is encoded as fixed(12) + // with three *little-endian* unsigned 32-bit integers in order: (months, days, millis). + // We explicitly materialize exactly those 12 bytes. + // - In-bounds indexing: `to_le_bytes()` on `u32` returns `[u8; 4]` by contract, + // therefore, the constant indices 0..=3 used below are *always* in-bounds. + // Rust will panic on out-of-bounds indexing, but there is no such path here; + // the compiler can also elide the bound checks for constant, provably in-range + // indices. [std docs; Rust Performance Book on bounds-check elimination] + // - Memory safety: The `[u8; 12]` array is built on the stack by value, with no + // aliasing and no uninitialized memory. There is no `unsafe`. + // - I/O: `write_all(&buf)` is fallible and its `Result` is propagated and mapped + // into `ArrowError`, so I/O errors are reported, not panicked. + // Consequently, constructing `buf` with the constant indices below is safe and + // panic-free under these validated preconditions. + let buf = [ + months[0], months[1], months[2], months[3], days[0], days[1], days[2], days[3], ms[0], + ms[1], ms[2], ms[3], + ]; + out.write_all(&buf) + .map_err(|e| ArrowError::IoError(format!("write duration: {e}"), e)) + } +} + +/// Minimal trait to obtain a big-endian fixed-size byte array for a decimal's +/// unscaled integer value at `idx`. +trait DecimalBeBytes { + fn value_be_bytes(&self, idx: usize) -> [u8; N]; +} +#[cfg(feature = "small_decimals")] +impl DecimalBeBytes<4> for Decimal32Array { + fn value_be_bytes(&self, idx: usize) -> [u8; 4] { + self.value(idx).to_be_bytes() + } +} +#[cfg(feature = "small_decimals")] +impl DecimalBeBytes<8> for Decimal64Array { + fn value_be_bytes(&self, idx: usize) -> [u8; 8] { + self.value(idx).to_be_bytes() + } +} +impl DecimalBeBytes<16> for Decimal128Array { + fn value_be_bytes(&self, idx: usize) -> [u8; 16] { + self.value(idx).to_be_bytes() + } +} +impl DecimalBeBytes<32> for Decimal256Array { + fn value_be_bytes(&self, idx: usize) -> [u8; 32] { + // Arrow i256 → [u8; 32] big-endian + self.value(idx).to_be_bytes() + } +} + +/// Generic Avro decimal encoder over Arrow decimal arrays. +/// - When `fixed_size` is `None` → Avro `bytes(decimal)`; writes the minimal +/// two's-complement representation with a length prefix. +/// - When `Some(n)` → Avro `fixed(n, decimal)`; sign-extends (or validates) +/// to exactly `n` bytes and writes them directly. +struct DecimalEncoder<'a, const N: usize, A: DecimalBeBytes> { + arr: &'a A, + fixed_size: Option, +} + +impl<'a, const N: usize, A: DecimalBeBytes> DecimalEncoder<'a, N, A> { + fn new(arr: &'a A, fixed_size: Option) -> Self { + Self { arr, fixed_size } + } + + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + let be = self.arr.value_be_bytes(idx); + match self.fixed_size { + Some(n) => write_sign_extended(out, &be, n), + None => write_len_prefixed(out, minimal_twos_complement(&be)), + } + } +} + +#[cfg(feature = "small_decimals")] +type Decimal32Encoder<'a> = DecimalEncoder<'a, 4, Decimal32Array>; +#[cfg(feature = "small_decimals")] +type Decimal64Encoder<'a> = DecimalEncoder<'a, 8, Decimal64Array>; +type Decimal128Encoder<'a> = DecimalEncoder<'a, 16, Decimal128Array>; +type Decimal256Encoder<'a> = DecimalEncoder<'a, 32, Decimal256Array>; + +/// Generic encoder for Arrow `RunArray`-based sites (run-end encoded). +/// Follows the pattern used by other generic encoders (i.e., `ListEncoder`), +/// avoiding runtime branching on run-end width. +struct RunEncodedEncoder<'a, R: RunEndIndexType> { + ends_slice: &'a [::Native], + base: usize, + len: usize, + values: FieldEncoder<'a>, + // Cached run index used for sequential scans of rows [0..n) + cur_run: usize, + // Cached end (logical index, 1-based per spec) for the current run. + cur_end: usize, +} + +type RunEncodedEncoder16<'a> = RunEncodedEncoder<'a, Int16Type>; +type RunEncodedEncoder32<'a> = RunEncodedEncoder<'a, Int32Type>; +type RunEncodedEncoder64<'a> = RunEncodedEncoder<'a, Int64Type>; + +impl<'a, R: RunEndIndexType> RunEncodedEncoder<'a, R> { + fn new(arr: &'a RunArray, values: FieldEncoder<'a>) -> Self { + let ends = arr.run_ends(); + let base = ends.get_start_physical_index(); + let slice = ends.values(); + let len = ends.len(); + let cur_end = if len == 0 { 0 } else { slice[base].as_usize() }; + Self { + ends_slice: slice, + base, + len, + values, + cur_run: 0, + cur_end, + } + } + + /// Advance `cur_run` so that `idx` is within the run ending at `cur_end`. + /// Uses the REE invariant: run ends are strictly increasing, positive, and 1-based. + #[inline(always)] + fn advance_to_row(&mut self, idx: usize) -> Result<(), ArrowError> { + if idx < self.cur_end { + return Ok(()); + } + // Move forward across run boundaries until idx falls within cur_end + while self.cur_run + 1 < self.len && idx >= self.cur_end { + self.cur_run += 1; + self.cur_end = self.ends_slice[self.base + self.cur_run].as_usize(); + } + if idx < self.cur_end { + Ok(()) + } else { + Err(ArrowError::InvalidArgumentError(format!( + "row index {idx} out of bounds for run-ends ({} runs)", + self.len + ))) + } + } + + #[inline(always)] + fn encode(&mut self, out: &mut W, idx: usize) -> Result<(), ArrowError> { + self.advance_to_row(idx)?; + // For REE values, the value for any logical row within a run is at + // the physical index of that run. + self.values.encode(out, self.cur_run) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::types::Int32Type; + use arrow_array::{ + Array, ArrayRef, BinaryArray, BooleanArray, Float32Array, Float64Array, Int32Array, + Int64Array, LargeBinaryArray, LargeListArray, LargeStringArray, ListArray, NullArray, + StringArray, + }; + use arrow_buffer::Buffer; + use arrow_schema::{DataType, Field, Fields, UnionFields}; + + fn zigzag_i64(v: i64) -> u64 { + ((v << 1) ^ (v >> 63)) as u64 + } + + fn varint(mut x: u64) -> Vec { + let mut out = Vec::new(); + while (x & !0x7f) != 0 { + out.push(((x & 0x7f) as u8) | 0x80); + x >>= 7; + } + out.push((x & 0x7f) as u8); + out + } + + fn avro_long_bytes(v: i64) -> Vec { + varint(zigzag_i64(v)) + } + + fn avro_len_prefixed_bytes(payload: &[u8]) -> Vec { + let mut out = avro_long_bytes(payload.len() as i64); + out.extend_from_slice(payload); + out + } + + fn duration_fixed12(months: u32, days: u32, millis: u32) -> [u8; 12] { + let m = months.to_le_bytes(); + let d = days.to_le_bytes(); + let ms = millis.to_le_bytes(); + [ + m[0], m[1], m[2], m[3], d[0], d[1], d[2], d[3], ms[0], ms[1], ms[2], ms[3], + ] + } + + fn encode_all( + array: &dyn Array, + plan: &FieldPlan, + nullability: Option, + ) -> Vec { + let mut enc = FieldEncoder::make_encoder(array, plan, nullability).unwrap(); + let mut out = Vec::new(); + for i in 0..array.len() { + enc.encode(&mut out, i).unwrap(); + } + out + } + + fn assert_bytes_eq(actual: &[u8], expected: &[u8]) { + if actual != expected { + let to_hex = |b: &[u8]| { + b.iter() + .map(|x| format!("{:02X}", x)) + .collect::>() + .join(" ") + }; + panic!( + "mismatch\n expected: [{}]\n actual: [{}]", + to_hex(expected), + to_hex(actual) + ); + } + } + + #[test] + fn binary_encoder() { + let values: Vec<&[u8]> = vec![b"", b"ab", b"\x00\xFF"]; + let arr = BinaryArray::from_vec(values); + let mut expected = Vec::new(); + for payload in [b"" as &[u8], b"ab", b"\x00\xFF"] { + expected.extend(avro_len_prefixed_bytes(payload)); + } + let got = encode_all(&arr, &FieldPlan::Scalar, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn large_binary_encoder() { + let values: Vec<&[u8]> = vec![b"xyz", b""]; + let arr = LargeBinaryArray::from_vec(values); + let mut expected = Vec::new(); + for payload in [b"xyz" as &[u8], b""] { + expected.extend(avro_len_prefixed_bytes(payload)); + } + let got = encode_all(&arr, &FieldPlan::Scalar, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn utf8_encoder() { + let arr = StringArray::from(vec!["", "A", "BC"]); + let mut expected = Vec::new(); + for s in ["", "A", "BC"] { + expected.extend(avro_len_prefixed_bytes(s.as_bytes())); + } + let got = encode_all(&arr, &FieldPlan::Scalar, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn large_utf8_encoder() { + let arr = LargeStringArray::from(vec!["hello", ""]); + let mut expected = Vec::new(); + for s in ["hello", ""] { + expected.extend(avro_len_prefixed_bytes(s.as_bytes())); + } + let got = encode_all(&arr, &FieldPlan::Scalar, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn list_encoder_int32() { + // Build ListArray [[1,2], [], [3]] + let values = Int32Array::from(vec![1, 2, 3]); + let offsets = vec![0, 2, 2, 3]; + let list = ListArray::new( + Field::new("item", DataType::Int32, true).into(), + arrow_buffer::OffsetBuffer::new(offsets.into()), + Arc::new(values) as ArrayRef, + None, + ); + // Avro array encoding per row + let mut expected = Vec::new(); + // row 0: block len 2, items 1,2 then 0 + expected.extend(avro_long_bytes(2)); + expected.extend(avro_long_bytes(1)); + expected.extend(avro_long_bytes(2)); + expected.extend(avro_long_bytes(0)); + // row 1: empty + expected.extend(avro_long_bytes(0)); + // row 2: one item 3 + expected.extend(avro_long_bytes(1)); + expected.extend(avro_long_bytes(3)); + expected.extend(avro_long_bytes(0)); + + let plan = FieldPlan::List { + items_nullability: None, + item_plan: Box::new(FieldPlan::Scalar), + }; + let got = encode_all(&list, &plan, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn struct_encoder_two_fields() { + // Struct { a: Int32, b: Utf8 } + let a = Int32Array::from(vec![1, 2]); + let b = StringArray::from(vec!["x", "y"]); + let fields = Fields::from(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ]); + let struct_arr = StructArray::new( + fields.clone(), + vec![Arc::new(a) as ArrayRef, Arc::new(b) as ArrayRef], + None, + ); + let plan = FieldPlan::Struct { + bindings: vec![ + FieldBinding { + arrow_index: 0, + nullability: None, + plan: FieldPlan::Scalar, + }, + FieldBinding { + arrow_index: 1, + nullability: None, + plan: FieldPlan::Scalar, + }, + ], + }; + let got = encode_all(&struct_arr, &plan, None); + // Expected: rows concatenated: a then b + let mut expected = Vec::new(); + expected.extend(avro_long_bytes(1)); // a=1 + expected.extend(avro_len_prefixed_bytes(b"x")); // b="x" + expected.extend(avro_long_bytes(2)); // a=2 + expected.extend(avro_len_prefixed_bytes(b"y")); // b="y" + assert_bytes_eq(&got, &expected); + } + + #[test] + fn enum_encoder_dictionary() { + // symbols: ["A","B","C"], keys [2,0,1] + let dict_values = StringArray::from(vec!["A", "B", "C"]); + let keys = Int32Array::from(vec![2, 0, 1]); + let dict = + DictionaryArray::::try_new(keys, Arc::new(dict_values) as ArrayRef).unwrap(); + let symbols = Arc::<[String]>::from( + vec!["A".to_string(), "B".to_string(), "C".to_string()].into_boxed_slice(), + ); + let plan = FieldPlan::Enum { symbols }; + let got = encode_all(&dict, &plan, None); + let mut expected = Vec::new(); + expected.extend(avro_long_bytes(2)); + expected.extend(avro_long_bytes(0)); + expected.extend(avro_long_bytes(1)); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn decimal_bytes_and_fixed() { + // Use Decimal128 with small positives and negatives + let dec = Decimal128Array::from(vec![1i128, -1i128, 0i128]) + .with_precision_and_scale(20, 0) + .unwrap(); + // bytes(decimal): minimal two's complement length-prefixed + let plan_bytes = FieldPlan::Decimal { size: None }; + let got_bytes = encode_all(&dec, &plan_bytes, None); + // 1 -> 0x01; -1 -> 0xFF; 0 -> 0x00 + let mut expected_bytes = Vec::new(); + expected_bytes.extend(avro_len_prefixed_bytes(&[0x01])); + expected_bytes.extend(avro_len_prefixed_bytes(&[0xFF])); + expected_bytes.extend(avro_len_prefixed_bytes(&[0x00])); + assert_bytes_eq(&got_bytes, &expected_bytes); + + let plan_fixed = FieldPlan::Decimal { size: Some(16) }; + let got_fixed = encode_all(&dec, &plan_fixed, None); + let mut expected_fixed = Vec::new(); + expected_fixed.extend_from_slice(&1i128.to_be_bytes()); + expected_fixed.extend_from_slice(&(-1i128).to_be_bytes()); + expected_fixed.extend_from_slice(&0i128.to_be_bytes()); + assert_bytes_eq(&got_fixed, &expected_fixed); + } + + #[test] + fn decimal_bytes_256() { + use arrow_buffer::i256; + // Use Decimal256 with small positives and negatives + let dec = Decimal256Array::from(vec![ + i256::from_i128(1), + i256::from_i128(-1), + i256::from_i128(0), + ]) + .with_precision_and_scale(76, 0) + .unwrap(); + // bytes(decimal): minimal two's complement length-prefixed + let plan_bytes = FieldPlan::Decimal { size: None }; + let got_bytes = encode_all(&dec, &plan_bytes, None); + // 1 -> 0x01; -1 -> 0xFF; 0 -> 0x00 + let mut expected_bytes = Vec::new(); + expected_bytes.extend(avro_len_prefixed_bytes(&[0x01])); + expected_bytes.extend(avro_len_prefixed_bytes(&[0xFF])); + expected_bytes.extend(avro_len_prefixed_bytes(&[0x00])); + assert_bytes_eq(&got_bytes, &expected_bytes); + + // fixed(32): 32-byte big-endian two's complement + let plan_fixed = FieldPlan::Decimal { size: Some(32) }; + let got_fixed = encode_all(&dec, &plan_fixed, None); + let mut expected_fixed = Vec::new(); + expected_fixed.extend_from_slice(&i256::from_i128(1).to_be_bytes()); + expected_fixed.extend_from_slice(&i256::from_i128(-1).to_be_bytes()); + expected_fixed.extend_from_slice(&i256::from_i128(0).to_be_bytes()); + assert_bytes_eq(&got_fixed, &expected_fixed); + } + + #[cfg(feature = "small_decimals")] + #[test] + fn decimal_bytes_and_fixed_32() { + // Use Decimal32 with small positives and negatives + let dec = Decimal32Array::from(vec![1i32, -1i32, 0i32]) + .with_precision_and_scale(9, 0) + .unwrap(); + // bytes(decimal) + let plan_bytes = FieldPlan::Decimal { size: None }; + let got_bytes = encode_all(&dec, &plan_bytes, None); + let mut expected_bytes = Vec::new(); + expected_bytes.extend(avro_len_prefixed_bytes(&[0x01])); + expected_bytes.extend(avro_len_prefixed_bytes(&[0xFF])); + expected_bytes.extend(avro_len_prefixed_bytes(&[0x00])); + assert_bytes_eq(&got_bytes, &expected_bytes); + // fixed(4) + let plan_fixed = FieldPlan::Decimal { size: Some(4) }; + let got_fixed = encode_all(&dec, &plan_fixed, None); + let mut expected_fixed = Vec::new(); + expected_fixed.extend_from_slice(&1i32.to_be_bytes()); + expected_fixed.extend_from_slice(&(-1i32).to_be_bytes()); + expected_fixed.extend_from_slice(&0i32.to_be_bytes()); + assert_bytes_eq(&got_fixed, &expected_fixed); + } + + #[cfg(feature = "small_decimals")] + #[test] + fn decimal_bytes_and_fixed_64() { + // Use Decimal64 with small positives and negatives + let dec = Decimal64Array::from(vec![1i64, -1i64, 0i64]) + .with_precision_and_scale(18, 0) + .unwrap(); + // bytes(decimal) + let plan_bytes = FieldPlan::Decimal { size: None }; + let got_bytes = encode_all(&dec, &plan_bytes, None); + let mut expected_bytes = Vec::new(); + expected_bytes.extend(avro_len_prefixed_bytes(&[0x01])); + expected_bytes.extend(avro_len_prefixed_bytes(&[0xFF])); + expected_bytes.extend(avro_len_prefixed_bytes(&[0x00])); + assert_bytes_eq(&got_bytes, &expected_bytes); + // fixed(8) + let plan_fixed = FieldPlan::Decimal { size: Some(8) }; + let got_fixed = encode_all(&dec, &plan_fixed, None); + let mut expected_fixed = Vec::new(); + expected_fixed.extend_from_slice(&1i64.to_be_bytes()); + expected_fixed.extend_from_slice(&(-1i64).to_be_bytes()); + expected_fixed.extend_from_slice(&0i64.to_be_bytes()); + assert_bytes_eq(&got_fixed, &expected_fixed); + } + + #[test] + fn float32_and_float64_encoders() { + let f32a = Float32Array::from(vec![0.0f32, -1.5f32, f32::from_bits(0x7fc00000)]); // includes a quiet NaN bit pattern + let f64a = Float64Array::from(vec![0.0f64, -2.25f64]); + // f32 expected + let mut expected32 = Vec::new(); + for v in [0.0f32, -1.5f32, f32::from_bits(0x7fc00000)] { + expected32.extend_from_slice(&v.to_bits().to_le_bytes()); + } + let got32 = encode_all(&f32a, &FieldPlan::Scalar, None); + assert_bytes_eq(&got32, &expected32); + // f64 expected + let mut expected64 = Vec::new(); + for v in [0.0f64, -2.25f64] { + expected64.extend_from_slice(&v.to_bits().to_le_bytes()); + } + let got64 = encode_all(&f64a, &FieldPlan::Scalar, None); + assert_bytes_eq(&got64, &expected64); + } + + #[test] + fn long_encoder_int64() { + let arr = Int64Array::from(vec![0i64, 1i64, -1i64, 2i64, -2i64, i64::MIN + 1]); + let mut expected = Vec::new(); + for v in [0, 1, -1, 2, -2, i64::MIN + 1] { + expected.extend(avro_long_bytes(v)); + } + let got = encode_all(&arr, &FieldPlan::Scalar, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn fixed_encoder_plain() { + // Two values of width 4 + let data = [[0xDE, 0xAD, 0xBE, 0xEF], [0x00, 0x01, 0x02, 0x03]]; + let values: Vec> = data.iter().map(|x| x.to_vec()).collect(); + let arr = FixedSizeBinaryArray::try_from_iter(values.into_iter()).unwrap(); + let got = encode_all(&arr, &FieldPlan::Scalar, None); + let mut expected = Vec::new(); + expected.extend_from_slice(&data[0]); + expected.extend_from_slice(&data[1]); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn uuid_encoder_test() { + // Happy path + let u = Uuid::parse_str("00112233-4455-6677-8899-aabbccddeeff").unwrap(); + let bytes = *u.as_bytes(); + let arr_ok = FixedSizeBinaryArray::try_from_iter(vec![bytes.to_vec()].into_iter()).unwrap(); + // Expected: length 36 (0x48) followed by hyphenated lowercase text + let mut expected = Vec::new(); + expected.push(0x48); + expected.extend_from_slice(u.hyphenated().to_string().as_bytes()); + let got = encode_all(&arr_ok, &FieldPlan::Uuid, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn uuid_encoder_error() { + // Invalid UUID bytes: wrong length + let arr = + FixedSizeBinaryArray::try_new(10, arrow_buffer::Buffer::from(vec![0u8; 10]), None) + .unwrap(); + let plan = FieldPlan::Uuid; + let mut enc = FieldEncoder::make_encoder(&arr, &plan, None).unwrap(); + let mut out = Vec::new(); + let err = enc.encode(&mut out, 0).unwrap_err(); + match err { + ArrowError::InvalidArgumentError(msg) => { + assert!(msg.contains("Invalid UUID bytes")) + } + other => panic!("expected InvalidArgumentError, got {other:?}"), + } + } + + fn test_scalar_primitive_encoding( + non_nullable_data: &[T::Native], + nullable_data: &[Option], + ) where + T: ArrowPrimitiveType, + T::Native: Into + Copy, + PrimitiveArray: From::Native>>, + { + let plan = FieldPlan::Scalar; + + let array = PrimitiveArray::::from(non_nullable_data.to_vec()); + let got = encode_all(&array, &plan, None); + + let mut expected = Vec::new(); + for &value in non_nullable_data { + expected.extend(avro_long_bytes(value.into())); + } + assert_bytes_eq(&got, &expected); + + let array_nullable: PrimitiveArray = nullable_data.iter().copied().collect(); + let got_nullable = encode_all(&array_nullable, &plan, Some(Nullability::NullFirst)); + + let mut expected_nullable = Vec::new(); + for &opt_value in nullable_data { + match opt_value { + Some(value) => { + // Union index 1 for the value, then the value itself + expected_nullable.extend(avro_long_bytes(1)); + expected_nullable.extend(avro_long_bytes(value.into())); + } + None => { + // Union index 0 for the null + expected_nullable.extend(avro_long_bytes(0)); + } + } + } + assert_bytes_eq(&got_nullable, &expected_nullable); + } + + #[test] + fn date32_encoder() { + test_scalar_primitive_encoding::( + &[ + 19345, // 2022-12-20 + 0, // 1970-01-01 (epoch) + -1, // 1969-12-31 (pre-epoch) + ], + &[Some(19345), None], + ); + } + + #[test] + fn time32_millis_encoder() { + test_scalar_primitive_encoding::( + &[ + 0, // Midnight + 49530123, // 13:45:30.123 + 86399999, // 23:59:59.999 + ], + &[None, Some(49530123)], + ); + } + + #[test] + fn time64_micros_encoder() { + test_scalar_primitive_encoding::( + &[ + 0, // Midnight + 86399999999, // 23:59:59.999999 + ], + &[Some(86399999999), None], + ); + } + + #[test] + fn timestamp_millis_encoder() { + test_scalar_primitive_encoding::( + &[ + 1704067200000, // 2024-01-01T00:00:00Z + 0, // 1970-01-01T00:00:00Z (epoch) + -123456789, // Pre-epoch timestamp + ], + &[None, Some(1704067200000)], + ); + } + + #[test] + fn map_encoder_string_keys_int_values() { + // Build MapArray with two rows + // Row0: {"k1":1, "k2":2} + // Row1: {} + let keys = StringArray::from(vec!["k1", "k2"]); + let values = Int32Array::from(vec![1, 2]); + let entries_fields = Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, true), + ]); + let entries = StructArray::new( + entries_fields, + vec![Arc::new(keys) as ArrayRef, Arc::new(values) as ArrayRef], + None, + ); + let offsets = arrow_buffer::OffsetBuffer::new(vec![0i32, 2, 2].into()); + let map = MapArray::new( + Field::new("entries", entries.data_type().clone(), false).into(), + offsets, + entries, + None, + false, + ); + let plan = FieldPlan::Map { + values_nullability: None, + value_plan: Box::new(FieldPlan::Scalar), + }; + let got = encode_all(&map, &plan, None); + let mut expected = Vec::new(); + // Row0: block 2 then pairs + expected.extend(avro_long_bytes(2)); + expected.extend(avro_len_prefixed_bytes(b"k1")); + expected.extend(avro_long_bytes(1)); + expected.extend(avro_len_prefixed_bytes(b"k2")); + expected.extend(avro_long_bytes(2)); + expected.extend(avro_long_bytes(0)); + // Row1: empty + expected.extend(avro_long_bytes(0)); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn union_encoder_string_int() { + let strings = StringArray::from(vec!["hello", "world"]); + let ints = Int32Array::from(vec![10, 20, 30]); + + let union_fields = UnionFields::try_new( + vec![0, 1], + vec![ + Field::new("v_str", DataType::Utf8, true), + Field::new("v_int", DataType::Int32, true), + ], + ) + .unwrap(); + + let type_ids = Buffer::from_slice_ref([0_i8, 1, 1, 0, 1]); + let offsets = Buffer::from_slice_ref([0_i32, 0, 1, 1, 2]); + + let union_array = UnionArray::try_new( + union_fields, + type_ids.into(), + Some(offsets.into()), + vec![Arc::new(strings), Arc::new(ints)], + ) + .unwrap(); + + let plan = FieldPlan::Union { + bindings: vec![ + FieldBinding { + arrow_index: 0, + nullability: None, + plan: FieldPlan::Scalar, + }, + FieldBinding { + arrow_index: 1, + nullability: None, + plan: FieldPlan::Scalar, + }, + ], + }; + + let got = encode_all(&union_array, &plan, None); + + let mut expected = Vec::new(); + expected.extend(avro_long_bytes(0)); + expected.extend(avro_len_prefixed_bytes(b"hello")); + expected.extend(avro_long_bytes(1)); + expected.extend(avro_long_bytes(10)); + expected.extend(avro_long_bytes(1)); + expected.extend(avro_long_bytes(20)); + expected.extend(avro_long_bytes(0)); + expected.extend(avro_len_prefixed_bytes(b"world")); + expected.extend(avro_long_bytes(1)); + expected.extend(avro_long_bytes(30)); + + assert_bytes_eq(&got, &expected); + } + + #[test] + fn union_encoder_null_string_int() { + let nulls = NullArray::new(1); + let strings = StringArray::from(vec!["hello"]); + let ints = Int32Array::from(vec![10]); + + let union_fields = UnionFields::try_new( + vec![0, 1, 2], + vec![ + Field::new("v_null", DataType::Null, true), + Field::new("v_str", DataType::Utf8, true), + Field::new("v_int", DataType::Int32, true), + ], + ) + .unwrap(); + + let type_ids = Buffer::from_slice_ref([0_i8, 1, 2]); + // For a null value in a dense union, no value is added to a child array. + // The offset points to the last value of that type. Since there's only one + // null, and one of each other type, all offsets are 0. + let offsets = Buffer::from_slice_ref([0_i32, 0, 0]); + + let union_array = UnionArray::try_new( + union_fields, + type_ids.into(), + Some(offsets.into()), + vec![Arc::new(nulls), Arc::new(strings), Arc::new(ints)], + ) + .unwrap(); + + let plan = FieldPlan::Union { + bindings: vec![ + FieldBinding { + arrow_index: 0, + nullability: None, + plan: FieldPlan::Scalar, + }, + FieldBinding { + arrow_index: 1, + nullability: None, + plan: FieldPlan::Scalar, + }, + FieldBinding { + arrow_index: 2, + nullability: None, + plan: FieldPlan::Scalar, + }, + ], + }; + + let got = encode_all(&union_array, &plan, None); + + let mut expected = Vec::new(); + expected.extend(avro_long_bytes(0)); + expected.extend(avro_long_bytes(1)); + expected.extend(avro_len_prefixed_bytes(b"hello")); + expected.extend(avro_long_bytes(2)); + expected.extend(avro_long_bytes(10)); + + assert_bytes_eq(&got, &expected); + } + + #[test] + fn list64_encoder_int32() { + // LargeList [[1,2,3], []] + let values = Int32Array::from(vec![1, 2, 3]); + let offsets: Vec = vec![0, 3, 3]; + let list = LargeListArray::new( + Field::new("item", DataType::Int32, true).into(), + arrow_buffer::OffsetBuffer::new(offsets.into()), + Arc::new(values) as ArrayRef, + None, + ); + let plan = FieldPlan::List { + items_nullability: None, + item_plan: Box::new(FieldPlan::Scalar), + }; + let got = encode_all(&list, &plan, None); + // Expected one block of 3 and then 0, then empty 0 + let mut expected = Vec::new(); + expected.extend(avro_long_bytes(3)); + expected.extend(avro_long_bytes(1)); + expected.extend(avro_long_bytes(2)); + expected.extend(avro_long_bytes(3)); + expected.extend(avro_long_bytes(0)); + expected.extend(avro_long_bytes(0)); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn int_encoder_test() { + let ints = Int32Array::from(vec![0, -1, 2]); + let mut expected_i = Vec::new(); + for v in [0i32, -1, 2] { + expected_i.extend(avro_long_bytes(v as i64)); + } + let got_i = encode_all(&ints, &FieldPlan::Scalar, None); + assert_bytes_eq(&got_i, &expected_i); + } + + #[test] + fn boolean_encoder_test() { + let bools = BooleanArray::from(vec![true, false]); + let mut expected_b = Vec::new(); + expected_b.extend_from_slice(&[1]); + expected_b.extend_from_slice(&[0]); + let got_b = encode_all(&bools, &FieldPlan::Scalar, None); + assert_bytes_eq(&got_b, &expected_b); + } + + #[test] + #[cfg(feature = "avro_custom_types")] + fn duration_encoding_seconds() { + let arr: PrimitiveArray = vec![0i64, -1, 2].into(); + let mut expected = Vec::new(); + for v in [0i64, -1, 2] { + expected.extend_from_slice(&avro_long_bytes(v)); + } + let got = encode_all(&arr, &FieldPlan::Scalar, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + #[cfg(feature = "avro_custom_types")] + fn duration_encoding_milliseconds() { + let arr: PrimitiveArray = vec![1i64, 0, -2].into(); + let mut expected = Vec::new(); + for v in [1i64, 0, -2] { + expected.extend_from_slice(&avro_long_bytes(v)); + } + let got = encode_all(&arr, &FieldPlan::Scalar, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + #[cfg(feature = "avro_custom_types")] + fn duration_encoding_microseconds() { + let arr: PrimitiveArray = vec![5i64, -6, 7].into(); + let mut expected = Vec::new(); + for v in [5i64, -6, 7] { + expected.extend_from_slice(&avro_long_bytes(v)); + } + let got = encode_all(&arr, &FieldPlan::Scalar, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + #[cfg(feature = "avro_custom_types")] + fn duration_encoding_nanoseconds() { + let arr: PrimitiveArray = vec![8i64, 9, -10].into(); + let mut expected = Vec::new(); + for v in [8i64, 9, -10] { + expected.extend_from_slice(&avro_long_bytes(v)); + } + let got = encode_all(&arr, &FieldPlan::Scalar, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn duration_encoder_year_month_happy_path() { + let arr: PrimitiveArray = vec![0i32, 1i32, 25i32].into(); + let mut expected = Vec::new(); + for m in [0u32, 1u32, 25u32] { + expected.extend_from_slice(&duration_fixed12(m, 0, 0)); + } + let got = encode_all(&arr, &FieldPlan::Scalar, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn duration_encoder_year_month_rejects_negative() { + let arr: PrimitiveArray = vec![-1i32].into(); + let mut enc = FieldEncoder::make_encoder(&arr, &FieldPlan::Scalar, None).unwrap(); + let mut out = Vec::new(); + let err = enc.encode(&mut out, 0).unwrap_err(); + match err { + ArrowError::InvalidArgumentError(msg) => { + assert!(msg.contains("cannot encode negative months")) + } + other => panic!("expected InvalidArgumentError, got {other:?}"), + } + } + + #[test] + fn duration_encoder_day_time_happy_path() { + let v0 = IntervalDayTimeType::make_value(2, 500); // days=2, millis=500 + let v1 = IntervalDayTimeType::make_value(0, 0); + let arr: PrimitiveArray = vec![v0, v1].into(); + let mut expected = Vec::new(); + expected.extend_from_slice(&duration_fixed12(0, 2, 500)); + expected.extend_from_slice(&duration_fixed12(0, 0, 0)); + let got = encode_all(&arr, &FieldPlan::Scalar, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn duration_encoder_day_time_rejects_negative() { + let bad = IntervalDayTimeType::make_value(-1, 0); + let arr: PrimitiveArray = vec![bad].into(); + let mut enc = FieldEncoder::make_encoder(&arr, &FieldPlan::Scalar, None).unwrap(); + let mut out = Vec::new(); + let err = enc.encode(&mut out, 0).unwrap_err(); + match err { + ArrowError::InvalidArgumentError(msg) => { + assert!(msg.contains("cannot encode negative days")) + } + other => panic!("expected InvalidArgumentError, got {other:?}"), + } + } + + #[test] + fn duration_encoder_month_day_nano_happy_path() { + let v0 = IntervalMonthDayNanoType::make_value(1, 2, 3_000_000); // -> millis = 3 + let v1 = IntervalMonthDayNanoType::make_value(0, 0, 0); + let arr: PrimitiveArray = vec![v0, v1].into(); + let mut expected = Vec::new(); + expected.extend_from_slice(&duration_fixed12(1, 2, 3)); + expected.extend_from_slice(&duration_fixed12(0, 0, 0)); + let got = encode_all(&arr, &FieldPlan::Scalar, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn duration_encoder_month_day_nano_rejects_non_ms_multiple() { + let bad = IntervalMonthDayNanoType::make_value(0, 0, 1); + let arr: PrimitiveArray = vec![bad].into(); + let mut enc = FieldEncoder::make_encoder(&arr, &FieldPlan::Scalar, None).unwrap(); + let mut out = Vec::new(); + let err = enc.encode(&mut out, 0).unwrap_err(); + match err { + ArrowError::InvalidArgumentError(msg) => { + assert!(msg.contains("requires whole milliseconds") || msg.contains("divisible")) + } + other => panic!("expected InvalidArgumentError, got {other:?}"), + } + } + + #[test] + fn minimal_twos_complement_test() { + let pos = [0x00, 0x00, 0x01]; + assert_eq!(minimal_twos_complement(&pos), &pos[2..]); + let neg = [0xFF, 0xFF, 0x80]; // negative minimal is 0x80 + assert_eq!(minimal_twos_complement(&neg), &neg[2..]); + let zero = [0x00, 0x00, 0x00]; + assert_eq!(minimal_twos_complement(&zero), &zero[2..]); + } + + #[test] + fn write_sign_extend_test() { + let mut out = Vec::new(); + write_sign_extended(&mut out, &[0x01], 4).unwrap(); + assert_eq!(out, vec![0x00, 0x00, 0x00, 0x01]); + out.clear(); + write_sign_extended(&mut out, &[0xFF], 4).unwrap(); + assert_eq!(out, vec![0xFF, 0xFF, 0xFF, 0xFF]); + out.clear(); + // truncation success (sign bytes only removed) + write_sign_extended(&mut out, &[0xFF, 0xFF, 0x80], 2).unwrap(); + assert_eq!(out, vec![0xFF, 0x80]); + out.clear(); + // truncation overflow + let err = write_sign_extended(&mut out, &[0x01, 0x00], 1).unwrap_err(); + match err { + ArrowError::InvalidArgumentError(_) => {} + _ => panic!("expected InvalidArgumentError"), + } + } + + #[test] + fn duration_month_day_nano_overflow_millis() { + // nanos leading to millis > u32::MAX + let nanos = ((u64::from(u32::MAX) + 1) * 1_000_000) as i64; + let v = IntervalMonthDayNanoType::make_value(0, 0, nanos); + let arr: PrimitiveArray = vec![v].into(); + let mut enc = FieldEncoder::make_encoder(&arr, &FieldPlan::Scalar, None).unwrap(); + let mut out = Vec::new(); + let err = enc.encode(&mut out, 0).unwrap_err(); + match err { + ArrowError::InvalidArgumentError(msg) => assert!(msg.contains("exceed u32::MAX")), + _ => panic!("expected InvalidArgumentError"), + } + } + + #[test] + fn fieldplan_decimal_precision_scale_mismatch_errors() { + // Avro expects (10,2), Arrow has (12,2) + use crate::codec::Codec; + use std::collections::HashMap; + let arrow_field = Field::new("d", DataType::Decimal128(12, 2), true); + let avro_dt = AvroDataType::new(Codec::Decimal(10, Some(2), None), HashMap::new(), None); + let err = FieldPlan::build(&avro_dt, &arrow_field).unwrap_err(); + match err { + ArrowError::SchemaError(msg) => { + assert!(msg.contains("Decimal precision/scale mismatch")) + } + _ => panic!("expected SchemaError"), + } + } + + #[test] + fn timestamp_micros_encoder() { + // Mirrors the style used by `timestamp_millis_encoder` + test_scalar_primitive_encoding::( + &[ + 1_704_067_200_000_000, // 2024-01-01T00:00:00Z in micros + 0, // epoch + -123_456_789, // pre-epoch + ], + &[None, Some(1_704_067_200_000_000)], + ); + } + + #[test] + fn list_encoder_nullable_items_null_first() { + // One List row with three elements: [Some(1), None, Some(2)] + let values = Int32Array::from(vec![Some(1), None, Some(2)]); + let offsets = arrow_buffer::OffsetBuffer::new(vec![0i32, 3].into()); + let list = ListArray::new( + Field::new("item", DataType::Int32, true).into(), + offsets, + Arc::new(values) as ArrayRef, + None, + ); + + let plan = FieldPlan::List { + items_nullability: Some(Nullability::NullFirst), + item_plan: Box::new(FieldPlan::Scalar), + }; + + // Avro array encoding per row: one positive block, then 0 terminator. + // For NullFirst: Some(v) => branch 1 (0x02) then the value; None => branch 0 (0x00) + let mut expected = Vec::new(); + expected.extend(avro_long_bytes(3)); // block of 3 + expected.extend(avro_long_bytes(1)); // union branch=1 (value) + expected.extend(avro_long_bytes(1)); // value 1 + expected.extend(avro_long_bytes(0)); // union branch=0 (null) + expected.extend(avro_long_bytes(1)); // union branch=1 (value) + expected.extend(avro_long_bytes(2)); // value 2 + expected.extend(avro_long_bytes(0)); // block terminator + + let got = encode_all(&list, &plan, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn large_list_encoder_nullable_items_null_first() { + // LargeList single row: [Some(10), None] + let values = Int32Array::from(vec![Some(10), None]); + let offsets = arrow_buffer::OffsetBuffer::new(vec![0i64, 2].into()); + let list = LargeListArray::new( + Field::new("item", DataType::Int32, true).into(), + offsets, + Arc::new(values) as ArrayRef, + None, + ); + + let plan = FieldPlan::List { + items_nullability: Some(Nullability::NullFirst), + item_plan: Box::new(FieldPlan::Scalar), + }; + + let mut expected = Vec::new(); + expected.extend(avro_long_bytes(2)); // block of 2 + expected.extend(avro_long_bytes(1)); // union branch=1 (value) + expected.extend(avro_long_bytes(10)); // value 10 + expected.extend(avro_long_bytes(0)); // union branch=0 (null) + expected.extend(avro_long_bytes(0)); // block terminator + + let got = encode_all(&list, &plan, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn map_encoder_string_keys_nullable_int_values_null_first() { + // One map row: {"k1": Some(7), "k2": None} + let keys = StringArray::from(vec!["k1", "k2"]); + let values = Int32Array::from(vec![Some(7), None]); + + let entries_fields = Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, true), + ]); + let entries = StructArray::new( + entries_fields, + vec![Arc::new(keys) as ArrayRef, Arc::new(values) as ArrayRef], + None, + ); + + // Single row -> offsets [0, 2] + let offsets = arrow_buffer::OffsetBuffer::new(vec![0i32, 2].into()); + let map = MapArray::new( + Field::new("entries", entries.data_type().clone(), false).into(), + offsets, + entries, + None, + false, + ); + + let plan = FieldPlan::Map { + values_nullability: Some(Nullability::NullFirst), + value_plan: Box::new(FieldPlan::Scalar), + }; + + // Expected: + // - one positive block (len=2) + // - "k1", branch=1 + value=7 + // - "k2", branch=0 (null) + // - end-of-block marker 0 + let mut expected = Vec::new(); + expected.extend(avro_long_bytes(2)); // block length 2 + expected.extend(avro_len_prefixed_bytes(b"k1")); // key "k1" + expected.extend(avro_long_bytes(1)); // union branch 1 (value) + expected.extend(avro_long_bytes(7)); // value 7 + expected.extend(avro_len_prefixed_bytes(b"k2")); // key "k2" + expected.extend(avro_long_bytes(0)); // union branch 0 (null) + expected.extend(avro_long_bytes(0)); // block terminator + + let got = encode_all(&map, &plan, None); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn time32_seconds_to_millis_encoder() { + // Time32(Second) must encode as Avro time-millis (ms since midnight). + let arr: arrow_array::PrimitiveArray = + vec![0i32, 1, -2, 12_345].into(); + let got = encode_all(&arr, &FieldPlan::Scalar, None); + let mut expected = Vec::new(); + for secs in [0i32, 1, -2, 12_345] { + let millis = (secs as i64) * 1000; + expected.extend_from_slice(&avro_long_bytes(millis)); + } + assert_bytes_eq(&got, &expected); + } + + #[test] + fn time32_seconds_to_millis_overflow() { + // Choose a value that will overflow i32 when multiplied by 1000. + let overflow_secs: i32 = i32::MAX / 1000 + 1; + let arr: PrimitiveArray = vec![overflow_secs].into(); + let mut enc = FieldEncoder::make_encoder(&arr, &FieldPlan::Scalar, None).unwrap(); + let mut out = Vec::new(); + let err = enc.encode(&mut out, 0).unwrap_err(); + match err { + arrow_schema::ArrowError::InvalidArgumentError(msg) => { + assert!( + msg.contains("overflowed") || msg.contains("overflow"), + "unexpected message: {msg}" + ) + } + other => panic!("expected InvalidArgumentError, got {other:?}"), + } + } + + #[test] + fn timestamp_seconds_to_millis_encoder() { + // Timestamp(Second) must encode as Avro timestamp-millis (ms since epoch). + let arr: PrimitiveArray = vec![0i64, 1, -1, 1_234_567_890].into(); + let got = encode_all(&arr, &FieldPlan::Scalar, None); + let mut expected = Vec::new(); + for secs in [0i64, 1, -1, 1_234_567_890] { + let millis = secs * 1000; + expected.extend_from_slice(&avro_long_bytes(millis)); + } + assert_bytes_eq(&got, &expected); + } + + #[test] + fn timestamp_seconds_to_millis_overflow() { + // Overflow i64 when multiplied by 1000. + let overflow_secs: i64 = i64::MAX / 1000 + 1; + let arr: PrimitiveArray = vec![overflow_secs].into(); + let mut enc = FieldEncoder::make_encoder(&arr, &FieldPlan::Scalar, None).unwrap(); + let mut out = Vec::new(); + let err = enc.encode(&mut out, 0).unwrap_err(); + match err { + arrow_schema::ArrowError::InvalidArgumentError(msg) => { + assert!( + msg.contains("overflowed") || msg.contains("overflow"), + "unexpected message: {msg}" + ) + } + other => panic!("expected InvalidArgumentError, got {other:?}"), + } + } + + #[test] + fn timestamp_nanos_encoder() { + let arr: PrimitiveArray = vec![0i64, 1, -1, 123].into(); + let got = encode_all(&arr, &FieldPlan::Scalar, None); + let mut expected = Vec::new(); + for ns in [0i64, 1, -1, 123] { + expected.extend_from_slice(&avro_long_bytes(ns)); + } + assert_bytes_eq(&got, &expected); + } + + #[test] + fn union_encoder_string_int_nonzero_type_ids() { + let strings = StringArray::from(vec!["hello", "world"]); + let ints = Int32Array::from(vec![10, 20, 30]); + let union_fields = UnionFields::try_new( + vec![2, 5], + vec![ + Field::new("v_str", DataType::Utf8, true), + Field::new("v_int", DataType::Int32, true), + ], + ) + .unwrap(); + let type_ids = Buffer::from_slice_ref([2_i8, 5, 5, 2, 5]); + let offsets = Buffer::from_slice_ref([0_i32, 0, 1, 1, 2]); + let union_array = UnionArray::try_new( + union_fields, + type_ids.into(), + Some(offsets.into()), + vec![Arc::new(strings), Arc::new(ints)], + ) + .unwrap(); + let plan = FieldPlan::Union { + bindings: vec![ + FieldBinding { + arrow_index: 0, + nullability: None, + plan: FieldPlan::Scalar, + }, + FieldBinding { + arrow_index: 1, + nullability: None, + plan: FieldPlan::Scalar, + }, + ], + }; + let got = encode_all(&union_array, &plan, None); + let mut expected = Vec::new(); + expected.extend(avro_long_bytes(0)); + expected.extend(avro_len_prefixed_bytes(b"hello")); + expected.extend(avro_long_bytes(1)); + expected.extend(avro_long_bytes(10)); + expected.extend(avro_long_bytes(1)); + expected.extend(avro_long_bytes(20)); + expected.extend(avro_long_bytes(0)); + expected.extend(avro_len_prefixed_bytes(b"world")); + expected.extend(avro_long_bytes(1)); + expected.extend(avro_long_bytes(30)); + assert_bytes_eq(&got, &expected); + } + + #[test] + fn nullable_state_with_null_buffer_and_zero_nulls() { + let values = vec![1i32, 2, 3]; + let arr = Int32Array::from_iter_values_with_nulls(values, Some(NullBuffer::new_valid(3))); + assert_eq!(arr.null_count(), 0); + assert!(arr.nulls().is_some()); + let plan = FieldPlan::Scalar; + let enc = FieldEncoder::make_encoder(&arr, &plan, Some(Nullability::NullFirst)).unwrap(); + match enc.null_state { + NullState::NullableNoNulls { union_value_byte } => { + assert_eq!( + union_value_byte, + union_value_branch_byte(Nullability::NullFirst, false) + ); + } + other => panic!("expected NullableNoNulls, got {other:?}"), + } + } +} diff --git a/arrow-avro/src/writer/format.rs b/arrow-avro/src/writer/format.rs new file mode 100644 index 000000000000..ba2a0b8564b2 --- /dev/null +++ b/arrow-avro/src/writer/format.rs @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Avro Writer Formats for Arrow. + +use crate::compression::{CODEC_METADATA_KEY, CompressionCodec}; +use crate::schema::{AvroSchema, AvroSchemaOptions, SCHEMA_METADATA_KEY}; +use crate::writer::encoder::write_long; +use arrow_schema::{ArrowError, Schema}; +use rand::RngCore; +use std::fmt::Debug; +use std::io::Write; + +/// Format abstraction implemented by each container‐level writer. +pub trait AvroFormat: Debug + Default { + /// If `true`, the writer for this format will query `single_object_prefix()` + /// and write the prefix before each record. If `false`, the writer can + /// skip this step. This is a performance hint for the writer. + const NEEDS_PREFIX: bool; + + /// Write any bytes required at the very beginning of the output stream + /// (file header, etc.). + /// Implementations **must not** write any record data. + fn start_stream( + &mut self, + writer: &mut W, + schema: &Schema, + compression: Option, + ) -> Result<(), ArrowError>; + + /// Return the 16‑byte sync marker (OCF) or `None` (binary stream). + fn sync_marker(&self) -> Option<&[u8; 16]>; +} + +/// Avro Object Container File (OCF) format writer. +#[derive(Debug, Default)] +pub struct AvroOcfFormat { + sync_marker: [u8; 16], +} + +impl AvroFormat for AvroOcfFormat { + const NEEDS_PREFIX: bool = false; + fn start_stream( + &mut self, + writer: &mut W, + schema: &Schema, + compression: Option, + ) -> Result<(), ArrowError> { + let mut rng = rand::rng(); + rng.fill_bytes(&mut self.sync_marker); + // Choose the Avro schema JSON that the file will advertise. + // If `schema.metadata[SCHEMA_METADATA_KEY]` exists, AvroSchema::try_from + // uses it verbatim; otherwise it is generated from the Arrow schema. + let avro_schema = AvroSchema::from_arrow_with_options( + schema, + Some(AvroSchemaOptions { + null_order: None, + strip_metadata: true, + }), + )?; + // Magic + writer + .write_all(b"Obj\x01") + .map_err(|e| ArrowError::IoError(format!("write OCF magic: {e}"), e))?; + // File metadata map: { "avro.schema": , "avro.codec": } + let codec_str = match compression { + Some(CompressionCodec::Deflate) => "deflate", + Some(CompressionCodec::Snappy) => "snappy", + Some(CompressionCodec::ZStandard) => "zstandard", + Some(CompressionCodec::Bzip2) => "bzip2", + Some(CompressionCodec::Xz) => "xz", + None => "null", + }; + // Map block: count=2, then key/value pairs, then terminating count=0 + write_long(writer, 2)?; + write_string(writer, SCHEMA_METADATA_KEY)?; + write_bytes(writer, avro_schema.json_string.as_bytes())?; + write_string(writer, CODEC_METADATA_KEY)?; + write_bytes(writer, codec_str.as_bytes())?; + write_long(writer, 0)?; + // Sync marker (16 bytes) + writer + .write_all(&self.sync_marker) + .map_err(|e| ArrowError::IoError(format!("write OCF sync marker: {e}"), e))?; + Ok(()) + } + + fn sync_marker(&self) -> Option<&[u8; 16]> { + Some(&self.sync_marker) + } +} + +/// Raw Avro binary streaming format using **Single-Object Encoding** per record. +/// +/// Each record written by the stream writer is framed with a prefix determined +/// by the schema fingerprinting algorithm. +/// +/// See: +/// See: +#[derive(Debug, Default)] +pub struct AvroSoeFormat {} + +impl AvroFormat for AvroSoeFormat { + const NEEDS_PREFIX: bool = true; + fn start_stream( + &mut self, + _writer: &mut W, + _schema: &Schema, + compression: Option, + ) -> Result<(), ArrowError> { + if compression.is_some() { + return Err(ArrowError::InvalidArgumentError( + "Compression not supported for Avro SOE streaming".to_string(), + )); + } + Ok(()) + } + + fn sync_marker(&self) -> Option<&[u8; 16]> { + None + } +} + +#[inline] +fn write_string(writer: &mut W, s: &str) -> Result<(), ArrowError> { + write_bytes(writer, s.as_bytes()) +} + +#[inline] +fn write_bytes(writer: &mut W, bytes: &[u8]) -> Result<(), ArrowError> { + write_long(writer, bytes.len() as i64)?; + writer + .write_all(bytes) + .map_err(|e| ArrowError::IoError(format!("write bytes: {e}"), e)) +} diff --git a/arrow-avro/src/writer/mod.rs b/arrow-avro/src/writer/mod.rs new file mode 100644 index 000000000000..f4a2e60ed57f --- /dev/null +++ b/arrow-avro/src/writer/mod.rs @@ -0,0 +1,2413 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Avro writer implementation for the `arrow-avro` crate. +//! +//! # Overview +//! +//! Use this module to serialize Arrow `RecordBatch` values into Avro. Two output +//! formats are supported: +//! +//! * **[`AvroWriter`](crate::writer::AvroWriter)** — writes an **Object Container File (OCF)**: a self‑describing +//! file with header (schema JSON + metadata), optional compression, data blocks, and +//! sync markers. See Avro 1.11.1 “Object Container Files.” +//! +//! * **[`AvroStreamWriter`](crate::writer::AvroStreamWriter)** — writes a **Single Object Encoding (SOE) Stream** (“datum” bytes) without +//! any container framing. This is useful when the schema is known out‑of‑band (i.e., +//! via a registry) and you want minimal overhead. +//! +//! ## Which format should you use? +//! +//! * Use **OCF** when you need a portable, self‑contained file. The schema travels with +//! the data, making it easy to read elsewhere. +//! * Use the **SOE stream** when your surrounding protocol supplies schema information +//! (i.e., a schema registry). The writer automatically adds the per‑record prefix: +//! - **SOE**: Each record is prefixed with the 2-byte header (`0xC3 0x01`) followed by +//! an 8‑byte little‑endian CRC‑64‑AVRO fingerprint, then the Avro body. +//! See Avro 1.11.1 "Single object encoding". +//! +//! - **Confluent wire format**: Each record is prefixed with magic byte `0x00` followed by +//! a **big‑endian** 4‑byte schema ID, then the Avro body. Use `FingerprintStrategy::Id(schema_id)`. +//! +//! - **Apicurio wire format**: Each record is prefixed with magic byte `0x00` followed by +//! a **big‑endian** 8‑byte schema ID, then the Avro body. Use `FingerprintStrategy::Id64(schema_id)`. +//! +//! +//! ## Choosing the Avro schema +//! +//! By default, the writer converts your Arrow schema to Avro (including a top‑level record +//! name). If you already have an Avro schema JSON you want to use verbatim, put it into the +//! Arrow schema metadata under the `avro.schema` key before constructing the writer. The +//! builder will use that schema instead of generating a new one (unless `strip_metadata` is +//! set to true in the options). +//! +//! ## Compression +//! +//! For OCF, you may enable a compression codec via `WriterBuilder::with_compression`. The +//! chosen codec is written into the file header and used for subsequent blocks. SOE stream +//! writing doesn’t apply container‑level compression. +//! +//! --- +use crate::codec::AvroFieldBuilder; +use crate::compression::CompressionCodec; +use crate::schema::{ + AvroSchema, Fingerprint, FingerprintAlgorithm, FingerprintStrategy, SCHEMA_METADATA_KEY, +}; +use crate::writer::encoder::{RecordEncoder, RecordEncoderBuilder, write_long}; +use crate::writer::format::{AvroFormat, AvroOcfFormat, AvroSoeFormat}; +use arrow_array::RecordBatch; +use arrow_schema::{ArrowError, Schema}; +use std::io::Write; +use std::sync::Arc; + +/// Encodes `RecordBatch` into the Avro binary format. +mod encoder; +/// Logic for different Avro container file formats. +pub mod format; + +/// Builder to configure and create a `Writer`. +#[derive(Debug, Clone)] +pub struct WriterBuilder { + schema: Schema, + codec: Option, + capacity: usize, + fingerprint_strategy: Option, +} + +impl WriterBuilder { + /// Create a new builder with default settings. + /// + /// The Avro schema used for writing is determined as follows: + /// 1) If the Arrow schema metadata contains `avro::schema` (see `SCHEMA_METADATA_KEY`), + /// that JSON is used verbatim. + /// 2) Otherwise, the Arrow schema is converted to an Avro record schema. + pub fn new(schema: Schema) -> Self { + Self { + schema, + codec: None, + capacity: 1024, + fingerprint_strategy: None, + } + } + + /// Set the fingerprinting strategy for the stream writer. + /// This determines the per-record prefix format. + pub fn with_fingerprint_strategy(mut self, strategy: FingerprintStrategy) -> Self { + self.fingerprint_strategy = Some(strategy); + self + } + + /// Change the compression codec. + pub fn with_compression(mut self, codec: Option) -> Self { + self.codec = codec; + self + } + + /// Sets the capacity for the given object and returns the modified instance. + pub fn with_capacity(mut self, capacity: usize) -> Self { + self.capacity = capacity; + self + } + + /// Create a new `Writer` with specified `AvroFormat` and builder options. + /// Performs one‑time startup (header/stream init, encoder plan). + pub fn build(self, mut writer: W) -> Result, ArrowError> + where + W: Write, + F: AvroFormat, + { + let mut format = F::default(); + let avro_schema = match self.schema.metadata.get(SCHEMA_METADATA_KEY) { + Some(json) => AvroSchema::new(json.clone()), + None => AvroSchema::try_from(&self.schema)?, + }; + let maybe_fingerprint = if F::NEEDS_PREFIX { + match self.fingerprint_strategy { + Some(FingerprintStrategy::Id(id)) => Some(Fingerprint::Id(id)), + Some(FingerprintStrategy::Id64(id)) => Some(Fingerprint::Id64(id)), + Some(strategy) => { + Some(avro_schema.fingerprint(FingerprintAlgorithm::from(strategy))?) + } + None => Some( + avro_schema + .fingerprint(FingerprintAlgorithm::from(FingerprintStrategy::Rabin))?, + ), + } + } else { + None + }; + let mut md = self.schema.metadata().clone(); + md.insert( + SCHEMA_METADATA_KEY.to_string(), + avro_schema.clone().json_string, + ); + let schema = Arc::new(Schema::new_with_metadata(self.schema.fields().clone(), md)); + format.start_stream(&mut writer, &schema, self.codec)?; + let avro_root = AvroFieldBuilder::new(&avro_schema.schema()?).build()?; + let encoder = RecordEncoderBuilder::new(&avro_root, schema.as_ref()) + .with_fingerprint(maybe_fingerprint) + .build()?; + Ok(Writer { + writer, + schema, + format, + compression: self.codec, + capacity: self.capacity, + encoder, + }) + } +} + +/// Generic Avro writer. +/// +/// This type is generic over the output Write sink (`W`) and the Avro format (`F`). +/// You’ll usually use the concrete aliases: +/// +/// * **[`AvroWriter`]** for **OCF** (self‑describing container file) +/// * **[`AvroStreamWriter`]** for **SOE** Avro streams +#[derive(Debug)] +pub struct Writer { + writer: W, + schema: Arc, + format: F, + compression: Option, + capacity: usize, + encoder: RecordEncoder, +} + +/// Alias for an Avro **Object Container File** writer. +/// +/// ### Quickstart (runnable) +/// +/// ``` +/// use std::io::Cursor; +/// use std::sync::Arc; +/// use arrow_array::{ArrayRef, Int64Array, StringArray, RecordBatch}; +/// use arrow_schema::{DataType, Field, Schema}; +/// use arrow_avro::writer::AvroWriter; +/// use arrow_avro::reader::ReaderBuilder; +/// +/// # fn main() -> Result<(), Box> { +/// // Writer schema: { id: long, name: string } +/// let writer_schema = Schema::new(vec![ +/// Field::new("id", DataType::Int64, false), +/// Field::new("name", DataType::Utf8, false), +/// ]); +/// +/// // Build a RecordBatch with two rows +/// let batch = RecordBatch::try_new( +/// Arc::new(writer_schema.clone()), +/// vec![ +/// Arc::new(Int64Array::from(vec![1, 2])) as ArrayRef, +/// Arc::new(StringArray::from(vec!["a", "b"])) as ArrayRef, +/// ], +/// )?; +/// +/// // Write an Avro **Object Container File** (OCF) to memory +/// let mut w = AvroWriter::new(Vec::::new(), writer_schema.clone())?; +/// w.write(&batch)?; +/// w.finish()?; +/// let bytes = w.into_inner(); +/// +/// // Build a Reader and decode the batch back +/// let mut r = ReaderBuilder::new().build(Cursor::new(bytes))?; +/// let out = r.next().unwrap()?; +/// assert_eq!(out.num_rows(), 2); +/// # Ok(()) } +/// ``` +pub type AvroWriter = Writer; + +/// Alias for an Avro **Single Object Encoding** stream writer. +/// +/// ### Example +/// +/// This writer automatically adds the appropriate per-record prefix (based on the +/// fingerprint strategy) before the Avro body of each record. The default is Single +/// Object Encoding (SOE) with a Rabin fingerprint. +/// +/// ``` +/// use std::sync::Arc; +/// use arrow_array::{ArrayRef, Int64Array, RecordBatch}; +/// use arrow_schema::{DataType, Field, Schema}; +/// use arrow_avro::writer::AvroStreamWriter; +/// +/// # fn main() -> Result<(), Box> { +/// // One‑column Arrow batch +/// let schema = Schema::new(vec![Field::new("x", DataType::Int64, false)]); +/// let batch = RecordBatch::try_new( +/// Arc::new(schema.clone()), +/// vec![Arc::new(Int64Array::from(vec![10, 20])) as ArrayRef], +/// )?; +/// +/// // Write an Avro Single Object Encoding stream to a Vec +/// let sink: Vec = Vec::new(); +/// let mut w = AvroStreamWriter::new(sink, schema)?; +/// w.write(&batch)?; +/// w.finish()?; +/// let bytes = w.into_inner(); +/// assert!(!bytes.is_empty()); +/// # Ok(()) } +/// ``` +pub type AvroStreamWriter = Writer; + +impl Writer { + /// Convenience constructor – same as [`WriterBuilder::build`] with `AvroOcfFormat`. + /// + /// ### Example + /// + /// ``` + /// use std::sync::Arc; + /// use arrow_array::{ArrayRef, Int32Array, RecordBatch}; + /// use arrow_schema::{DataType, Field, Schema}; + /// use arrow_avro::writer::AvroWriter; + /// + /// # fn main() -> Result<(), Box> { + /// let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]); + /// let batch = RecordBatch::try_new( + /// Arc::new(schema.clone()), + /// vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef], + /// )?; + /// + /// let buf: Vec = Vec::new(); + /// let mut w = AvroWriter::new(buf, schema)?; + /// w.write(&batch)?; + /// w.finish()?; + /// let bytes = w.into_inner(); + /// assert!(!bytes.is_empty()); + /// # Ok(()) } + /// ``` + pub fn new(writer: W, schema: Schema) -> Result { + WriterBuilder::new(schema).build::(writer) + } + + /// Return a reference to the 16‑byte sync marker generated for this file. + pub fn sync_marker(&self) -> Option<&[u8; 16]> { + self.format.sync_marker() + } +} + +impl Writer { + /// Convenience constructor to create a new [`AvroStreamWriter`]. + /// + /// The resulting stream contains **Single Object Encodings** (no OCF header/sync). + /// + /// ### Example + /// + /// ``` + /// use std::sync::Arc; + /// use arrow_array::{ArrayRef, Int64Array, RecordBatch}; + /// use arrow_schema::{DataType, Field, Schema}; + /// use arrow_avro::writer::AvroStreamWriter; + /// + /// # fn main() -> Result<(), Box> { + /// let schema = Schema::new(vec![Field::new("x", DataType::Int64, false)]); + /// let batch = RecordBatch::try_new( + /// Arc::new(schema.clone()), + /// vec![Arc::new(Int64Array::from(vec![10, 20])) as ArrayRef], + /// )?; + /// + /// let sink: Vec = Vec::new(); + /// let mut w = AvroStreamWriter::new(sink, schema)?; + /// w.write(&batch)?; + /// w.finish()?; + /// let bytes = w.into_inner(); + /// assert!(!bytes.is_empty()); + /// # Ok(()) } + /// ``` + pub fn new(writer: W, schema: Schema) -> Result { + WriterBuilder::new(schema).build::(writer) + } +} + +impl Writer { + /// Serialize one [`RecordBatch`] to the output. + pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { + if batch.schema().fields() != self.schema.fields() { + return Err(ArrowError::SchemaError( + "Schema of RecordBatch differs from Writer schema".to_string(), + )); + } + match self.format.sync_marker() { + Some(&sync) => self.write_ocf_block(batch, &sync), + None => self.write_stream(batch), + } + } + + /// A convenience method to write a slice of [`RecordBatch`]. + /// + /// This is equivalent to calling `write` for each batch in the slice. + pub fn write_batches(&mut self, batches: &[&RecordBatch]) -> Result<(), ArrowError> { + for b in batches { + self.write(b)?; + } + Ok(()) + } + + /// Flush remaining buffered data and (for OCF) ensure the header is present. + pub fn finish(&mut self) -> Result<(), ArrowError> { + self.writer + .flush() + .map_err(|e| ArrowError::IoError(format!("Error flushing writer: {e}"), e)) + } + + /// Consume the writer, returning the underlying output object. + pub fn into_inner(self) -> W { + self.writer + } + + fn write_ocf_block(&mut self, batch: &RecordBatch, sync: &[u8; 16]) -> Result<(), ArrowError> { + let mut buf = Vec::::with_capacity(self.capacity); + self.encoder.encode(&mut buf, batch)?; + let encoded = match self.compression { + Some(codec) => codec.compress(&buf)?, + None => buf, + }; + write_long(&mut self.writer, batch.num_rows() as i64)?; + write_long(&mut self.writer, encoded.len() as i64)?; + self.writer + .write_all(&encoded) + .map_err(|e| ArrowError::IoError(format!("Error writing Avro block: {e}"), e))?; + self.writer + .write_all(sync) + .map_err(|e| ArrowError::IoError(format!("Error writing Avro sync: {e}"), e))?; + Ok(()) + } + + fn write_stream(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { + self.encoder.encode(&mut self.writer, batch)?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::compression::CompressionCodec; + use crate::reader::ReaderBuilder; + use crate::schema::{AvroSchema, SchemaStore}; + use crate::test_util::arrow_test_data; + use arrow::datatypes::TimeUnit; + #[cfg(feature = "avro_custom_types")] + use arrow_array::types::{Int16Type, Int32Type, Int64Type}; + use arrow_array::types::{ + Time32MillisecondType, Time64MicrosecondType, TimestampMicrosecondType, + TimestampMillisecondType, TimestampNanosecondType, + }; + use arrow_array::{ + Array, ArrayRef, BinaryArray, Date32Array, Int32Array, PrimitiveArray, RecordBatch, + StringArray, StructArray, UnionArray, + }; + #[cfg(feature = "avro_custom_types")] + use arrow_array::{Int16Array, Int64Array, RunArray}; + use arrow_schema::UnionMode; + #[cfg(not(feature = "avro_custom_types"))] + use arrow_schema::{DataType, Field, Schema}; + #[cfg(feature = "avro_custom_types")] + use arrow_schema::{DataType, Field, Schema}; + use std::collections::HashMap; + use std::collections::HashSet; + use std::fs::File; + use std::io::{BufReader, Cursor}; + use std::path::PathBuf; + use std::sync::Arc; + use tempfile::NamedTempFile; + + fn files() -> impl Iterator { + [ + // TODO: avoid requiring snappy for this file + #[cfg(feature = "snappy")] + "avro/alltypes_plain.avro", + #[cfg(feature = "snappy")] + "avro/alltypes_plain.snappy.avro", + #[cfg(feature = "zstd")] + "avro/alltypes_plain.zstandard.avro", + #[cfg(feature = "bzip2")] + "avro/alltypes_plain.bzip2.avro", + #[cfg(feature = "xz")] + "avro/alltypes_plain.xz.avro", + ] + .into_iter() + } + + fn make_schema() -> Schema { + Schema::new(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Binary, false), + ]) + } + + fn make_batch() -> RecordBatch { + let ids = Int32Array::from(vec![1, 2, 3]); + let names = BinaryArray::from_vec(vec![b"a".as_ref(), b"b".as_ref(), b"c".as_ref()]); + RecordBatch::try_new( + Arc::new(make_schema()), + vec![Arc::new(ids) as ArrayRef, Arc::new(names) as ArrayRef], + ) + .expect("failed to build test RecordBatch") + } + + #[test] + fn test_stream_writer_writes_prefix_per_row_rt() -> Result<(), ArrowError> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef], + )?; + let buf: Vec = Vec::new(); + let mut writer = AvroStreamWriter::new(buf, schema.clone())?; + writer.write(&batch)?; + let encoded = writer.into_inner(); + let mut store = SchemaStore::new(); // Rabin by default + let avro_schema = AvroSchema::try_from(&schema)?; + let _fp = store.register(avro_schema)?; + let mut decoder = ReaderBuilder::new() + .with_writer_schema_store(store) + .build_decoder()?; + let _consumed = decoder.decode(&encoded)?; + let decoded = decoder + .flush()? + .expect("expected at least one batch from decoder"); + assert_eq!(decoded.num_columns(), 1); + assert_eq!(decoded.num_rows(), 2); + let col = decoded + .column(0) + .as_any() + .downcast_ref::() + .expect("int column"); + assert_eq!(col, &Int32Array::from(vec![10, 20])); + Ok(()) + } + + #[test] + fn test_nullable_struct_with_nonnullable_field_sliced_encoding() { + use arrow_array::{ArrayRef, Int32Array, StringArray, StructArray}; + use arrow_buffer::NullBuffer; + use arrow_schema::{DataType, Field, Fields, Schema}; + use std::sync::Arc; + let inner_fields = Fields::from(vec![ + Field::new("id", DataType::Int32, false), // non-nullable + Field::new("name", DataType::Utf8, true), // nullable + ]); + let inner_struct_type = DataType::Struct(inner_fields.clone()); + let schema = Schema::new(vec![ + Field::new("before", inner_struct_type.clone(), true), // nullable struct + Field::new("after", inner_struct_type.clone(), true), // nullable struct + Field::new("op", DataType::Utf8, false), // non-nullable + ]); + let before_ids = Int32Array::from(vec![None, None]); + let before_names = StringArray::from(vec![None::<&str>, None]); + let before_struct = StructArray::new( + inner_fields.clone(), + vec![ + Arc::new(before_ids) as ArrayRef, + Arc::new(before_names) as ArrayRef, + ], + Some(NullBuffer::from(vec![false, false])), + ); + let after_ids = Int32Array::from(vec![1, 2]); // non-nullable, no nulls + let after_names = StringArray::from(vec![Some("Alice"), Some("Bob")]); + let after_struct = StructArray::new( + inner_fields.clone(), + vec![ + Arc::new(after_ids) as ArrayRef, + Arc::new(after_names) as ArrayRef, + ], + Some(NullBuffer::from(vec![true, true])), + ); + let op_col = StringArray::from(vec!["r", "r"]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(before_struct) as ArrayRef, + Arc::new(after_struct) as ArrayRef, + Arc::new(op_col) as ArrayRef, + ], + ) + .expect("failed to create test batch"); + let mut sink = Vec::new(); + let mut writer = WriterBuilder::new(schema) + .with_fingerprint_strategy(FingerprintStrategy::Id(1)) + .build::<_, AvroSoeFormat>(&mut sink) + .expect("failed to create writer"); + for row_idx in 0..batch.num_rows() { + let single_row = batch.slice(row_idx, 1); + let after_col = single_row.column(1); + assert_eq!( + after_col.null_count(), + 0, + "after column should have no nulls in sliced row" + ); + writer + .write(&single_row) + .unwrap_or_else(|e| panic!("Failed to encode row {row_idx}: {e}")); + } + writer.finish().expect("failed to finish writer"); + assert!(!sink.is_empty(), "encoded output should not be empty"); + } + + #[test] + fn test_nullable_struct_with_decimal_and_timestamp_sliced() { + use arrow_array::{ + ArrayRef, Decimal128Array, Int32Array, StringArray, StructArray, + TimestampMicrosecondArray, + }; + use arrow_buffer::NullBuffer; + use arrow_schema::{DataType, Field, Fields, Schema}; + use std::sync::Arc; + let row_fields = Fields::from(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + Field::new("category", DataType::Utf8, true), + Field::new("price", DataType::Decimal128(10, 2), true), + Field::new("stock_quantity", DataType::Int32, true), + Field::new( + "created_at", + DataType::Timestamp(TimeUnit::Microsecond, None), + true, + ), + ]); + let row_struct_type = DataType::Struct(row_fields.clone()); + let schema = Schema::new(vec![ + Field::new("before", row_struct_type.clone(), true), + Field::new("after", row_struct_type.clone(), true), + Field::new("op", DataType::Utf8, false), + ]); + let before_struct = StructArray::new_null(row_fields.clone(), 2); + let ids = Int32Array::from(vec![1, 2]); + let names = StringArray::from(vec![Some("Widget"), Some("Gadget")]); + let categories = StringArray::from(vec![Some("Electronics"), Some("Electronics")]); + let prices = Decimal128Array::from(vec![Some(1999), Some(2999)]) + .with_precision_and_scale(10, 2) + .unwrap(); + let quantities = Int32Array::from(vec![Some(100), Some(50)]); + let timestamps = TimestampMicrosecondArray::from(vec![ + Some(1700000000000000i64), + Some(1700000001000000i64), + ]); + let after_struct = StructArray::new( + row_fields.clone(), + vec![ + Arc::new(ids) as ArrayRef, + Arc::new(names) as ArrayRef, + Arc::new(categories) as ArrayRef, + Arc::new(prices) as ArrayRef, + Arc::new(quantities) as ArrayRef, + Arc::new(timestamps) as ArrayRef, + ], + Some(NullBuffer::from(vec![true, true])), + ); + let op_col = StringArray::from(vec!["r", "r"]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(before_struct) as ArrayRef, + Arc::new(after_struct) as ArrayRef, + Arc::new(op_col) as ArrayRef, + ], + ) + .expect("failed to create products batch"); + let mut sink = Vec::new(); + let mut writer = WriterBuilder::new(schema) + .with_fingerprint_strategy(FingerprintStrategy::Id(1)) + .build::<_, AvroSoeFormat>(&mut sink) + .expect("failed to create writer"); + // Encode row by row + for row_idx in 0..batch.num_rows() { + let single_row = batch.slice(row_idx, 1); + writer + .write(&single_row) + .unwrap_or_else(|e| panic!("Failed to encode product row {row_idx}: {e}")); + } + writer.finish().expect("failed to finish writer"); + assert!(!sink.is_empty()); + } + + #[test] + fn non_nullable_child_in_nullable_struct_should_encode_per_row() { + use arrow_array::{ + ArrayRef, Int32Array, Int64Array, RecordBatch, StringArray, StructArray, + }; + use arrow_schema::{DataType, Field, Fields, Schema}; + use std::sync::Arc; + let row_fields = Fields::from(vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ]); + let row_struct_dt = DataType::Struct(row_fields.clone()); + let before: ArrayRef = Arc::new(StructArray::new_null(row_fields.clone(), 1)); + let id_col: ArrayRef = Arc::new(Int32Array::from(vec![1])); + let name_col: ArrayRef = Arc::new(StringArray::from(vec![None::<&str>])); + let after: ArrayRef = Arc::new(StructArray::new( + row_fields.clone(), + vec![id_col, name_col], + None, + )); + let schema = Arc::new(Schema::new(vec![ + Field::new("before", row_struct_dt.clone(), true), + Field::new("after", row_struct_dt, true), + Field::new("op", DataType::Utf8, false), + Field::new("ts_ms", DataType::Int64, false), + ])); + let op = Arc::new(StringArray::from(vec!["r"])) as ArrayRef; + let ts_ms = Arc::new(Int64Array::from(vec![1732900000000_i64])) as ArrayRef; + let batch = RecordBatch::try_new(schema.clone(), vec![before, after, op, ts_ms]).unwrap(); + let mut buf = Vec::new(); + let mut writer = WriterBuilder::new(schema.as_ref().clone()) + .build::<_, AvroSoeFormat>(&mut buf) + .unwrap(); + let single = batch.slice(0, 1); + let res = writer.write(&single); + assert!( + res.is_ok(), + "expected to encode successfully, got: {:?}", + res.err() + ); + } + + #[test] + fn test_union_nonzero_type_ids() -> Result<(), ArrowError> { + use arrow_array::UnionArray; + use arrow_buffer::Buffer; + use arrow_schema::UnionFields; + let union_fields = UnionFields::try_new( + vec![2, 5], + vec![ + Field::new("v_str", DataType::Utf8, true), + Field::new("v_int", DataType::Int32, true), + ], + ) + .unwrap(); + let strings = StringArray::from(vec!["hello", "world"]); + let ints = Int32Array::from(vec![10, 20, 30]); + let type_ids = Buffer::from_slice_ref([2_i8, 5, 5, 2, 5]); + let offsets = Buffer::from_slice_ref([0_i32, 0, 1, 1, 2]); + let union_array = UnionArray::try_new( + union_fields.clone(), + type_ids.into(), + Some(offsets.into()), + vec![Arc::new(strings) as ArrayRef, Arc::new(ints) as ArrayRef], + )?; + let schema = Schema::new(vec![Field::new( + "union_col", + DataType::Union(union_fields, UnionMode::Dense), + false, + )]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(union_array) as ArrayRef], + )?; + let mut writer = AvroWriter::new(Vec::::new(), schema.clone())?; + assert!( + writer.write(&batch).is_ok(), + "Expected no error from writing" + ); + writer.finish()?; + assert!( + writer.finish().is_ok(), + "Expected no error from finishing writer" + ); + Ok(()) + } + + #[test] + fn test_stream_writer_with_id_fingerprint_rt() -> Result<(), ArrowError> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef], + )?; + let schema_id: u32 = 42; + let mut writer = WriterBuilder::new(schema.clone()) + .with_fingerprint_strategy(FingerprintStrategy::Id(schema_id)) + .build::<_, AvroSoeFormat>(Vec::new())?; + writer.write(&batch)?; + let encoded = writer.into_inner(); + let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::Id); + let avro_schema = AvroSchema::try_from(&schema)?; + let _ = store.set(Fingerprint::Id(schema_id), avro_schema)?; + let mut decoder = ReaderBuilder::new() + .with_writer_schema_store(store) + .build_decoder()?; + let _ = decoder.decode(&encoded)?; + let decoded = decoder + .flush()? + .expect("expected at least one batch from decoder"); + assert_eq!(decoded.num_columns(), 1); + assert_eq!(decoded.num_rows(), 3); + let col = decoded + .column(0) + .as_any() + .downcast_ref::() + .expect("int column"); + assert_eq!(col, &Int32Array::from(vec![1, 2, 3])); + Ok(()) + } + + #[test] + fn test_stream_writer_with_id64_fingerprint_rt() -> Result<(), ArrowError> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef], + )?; + let schema_id: u64 = 42; + let mut writer = WriterBuilder::new(schema.clone()) + .with_fingerprint_strategy(FingerprintStrategy::Id64(schema_id)) + .build::<_, AvroSoeFormat>(Vec::new())?; + writer.write(&batch)?; + let encoded = writer.into_inner(); + let mut store = SchemaStore::new_with_type(FingerprintAlgorithm::Id64); + let avro_schema = AvroSchema::try_from(&schema)?; + let _ = store.set(Fingerprint::Id64(schema_id), avro_schema)?; + let mut decoder = ReaderBuilder::new() + .with_writer_schema_store(store) + .build_decoder()?; + let _ = decoder.decode(&encoded)?; + let decoded = decoder + .flush()? + .expect("expected at least one batch from decoder"); + assert_eq!(decoded.num_columns(), 1); + assert_eq!(decoded.num_rows(), 3); + let col = decoded + .column(0) + .as_any() + .downcast_ref::() + .expect("int column"); + assert_eq!(col, &Int32Array::from(vec![1, 2, 3])); + Ok(()) + } + + #[test] + fn test_ocf_writer_generates_header_and_sync() -> Result<(), ArrowError> { + let batch = make_batch(); + let buffer: Vec = Vec::new(); + let mut writer = AvroWriter::new(buffer, make_schema())?; + writer.write(&batch)?; + writer.finish()?; + let out = writer.into_inner(); + assert_eq!(&out[..4], b"Obj\x01", "OCF magic bytes missing/incorrect"); + let trailer = &out[out.len() - 16..]; + assert_eq!(trailer.len(), 16, "expected 16‑byte sync marker"); + Ok(()) + } + + #[test] + fn test_schema_mismatch_yields_error() { + let batch = make_batch(); + let alt_schema = Schema::new(vec![Field::new("x", DataType::Int32, false)]); + let buffer = Vec::::new(); + let mut writer = AvroWriter::new(buffer, alt_schema).unwrap(); + let err = writer.write(&batch).unwrap_err(); + assert!(matches!(err, ArrowError::SchemaError(_))); + } + + #[test] + fn test_write_batches_accumulates_multiple() -> Result<(), ArrowError> { + let batch1 = make_batch(); + let batch2 = make_batch(); + let buffer = Vec::::new(); + let mut writer = AvroWriter::new(buffer, make_schema())?; + writer.write_batches(&[&batch1, &batch2])?; + writer.finish()?; + let out = writer.into_inner(); + assert!(out.len() > 4, "combined batches produced tiny file"); + Ok(()) + } + + #[test] + fn test_finish_without_write_adds_header() -> Result<(), ArrowError> { + let buffer = Vec::::new(); + let mut writer = AvroWriter::new(buffer, make_schema())?; + writer.finish()?; + let out = writer.into_inner(); + assert_eq!(&out[..4], b"Obj\x01", "finish() should emit OCF header"); + Ok(()) + } + + #[test] + fn test_write_long_encodes_zigzag_varint() -> Result<(), ArrowError> { + let mut buf = Vec::new(); + write_long(&mut buf, 0)?; + write_long(&mut buf, -1)?; + write_long(&mut buf, 1)?; + write_long(&mut buf, -2)?; + write_long(&mut buf, 2147483647)?; + assert!( + buf.starts_with(&[0x00, 0x01, 0x02, 0x03]), + "zig‑zag varint encodings incorrect: {buf:?}" + ); + Ok(()) + } + + #[test] + fn test_roundtrip_alltypes_roundtrip_writer() -> Result<(), ArrowError> { + for rel in files() { + let path = arrow_test_data(rel); + let rdr_file = File::open(&path).expect("open input avro"); + let reader = ReaderBuilder::new() + .build(BufReader::new(rdr_file)) + .expect("build reader"); + let schema = reader.schema(); + let input_batches = reader.collect::, _>>()?; + let original = + arrow::compute::concat_batches(&schema, &input_batches).expect("concat input"); + let tmp = NamedTempFile::new().expect("create temp file"); + let out_path = tmp.into_temp_path(); + let out_file = File::create(&out_path).expect("create temp avro"); + let codec = if rel.contains(".snappy.") { + Some(CompressionCodec::Snappy) + } else if rel.contains(".zstandard.") { + Some(CompressionCodec::ZStandard) + } else if rel.contains(".bzip2.") { + Some(CompressionCodec::Bzip2) + } else if rel.contains(".xz.") { + Some(CompressionCodec::Xz) + } else { + None + }; + let mut writer = WriterBuilder::new(original.schema().as_ref().clone()) + .with_compression(codec) + .build::<_, AvroOcfFormat>(out_file)?; + writer.write(&original)?; + writer.finish()?; + drop(writer); + let rt_file = File::open(&out_path).expect("open roundtrip avro"); + let rt_reader = ReaderBuilder::new() + .build(BufReader::new(rt_file)) + .expect("build roundtrip reader"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let roundtrip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip"); + assert_eq!( + roundtrip, original, + "Round-trip batch mismatch for file: {}", + rel + ); + } + Ok(()) + } + + #[test] + fn test_roundtrip_nested_records_writer() -> Result<(), ArrowError> { + let path = arrow_test_data("avro/nested_records.avro"); + let rdr_file = File::open(&path).expect("open nested_records.avro"); + let reader = ReaderBuilder::new() + .build(BufReader::new(rdr_file)) + .expect("build reader for nested_records.avro"); + let schema = reader.schema(); + let batches = reader.collect::, _>>()?; + let original = arrow::compute::concat_batches(&schema, &batches).expect("concat original"); + let tmp = NamedTempFile::new().expect("create temp file"); + let out_path = tmp.into_temp_path(); + { + let out_file = File::create(&out_path).expect("create output avro"); + let mut writer = AvroWriter::new(out_file, original.schema().as_ref().clone())?; + writer.write(&original)?; + writer.finish()?; + } + let rt_file = File::open(&out_path).expect("open round_trip avro"); + let rt_reader = ReaderBuilder::new() + .build(BufReader::new(rt_file)) + .expect("build round_trip reader"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let round_trip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip"); + assert_eq!( + round_trip, original, + "Round-trip batch mismatch for nested_records.avro" + ); + Ok(()) + } + + #[test] + #[cfg(feature = "snappy")] + fn test_roundtrip_nested_lists_writer() -> Result<(), ArrowError> { + let path = arrow_test_data("avro/nested_lists.snappy.avro"); + let rdr_file = File::open(&path).expect("open nested_lists.snappy.avro"); + let reader = ReaderBuilder::new() + .build(BufReader::new(rdr_file)) + .expect("build reader for nested_lists.snappy.avro"); + let schema = reader.schema(); + let batches = reader.collect::, _>>()?; + let original = arrow::compute::concat_batches(&schema, &batches).expect("concat original"); + let tmp = NamedTempFile::new().expect("create temp file"); + let out_path = tmp.into_temp_path(); + { + let out_file = File::create(&out_path).expect("create output avro"); + let mut writer = WriterBuilder::new(original.schema().as_ref().clone()) + .with_compression(Some(CompressionCodec::Snappy)) + .build::<_, AvroOcfFormat>(out_file)?; + writer.write(&original)?; + writer.finish()?; + } + let rt_file = File::open(&out_path).expect("open round_trip avro"); + let rt_reader = ReaderBuilder::new() + .build(BufReader::new(rt_file)) + .expect("build round_trip reader"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let round_trip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip"); + assert_eq!( + round_trip, original, + "Round-trip batch mismatch for nested_lists.snappy.avro" + ); + Ok(()) + } + + #[test] + fn test_round_trip_simple_fixed_ocf() -> Result<(), ArrowError> { + let path = arrow_test_data("avro/simple_fixed.avro"); + let rdr_file = File::open(&path).expect("open avro/simple_fixed.avro"); + let reader = ReaderBuilder::new() + .build(BufReader::new(rdr_file)) + .expect("build avro reader"); + let schema = reader.schema(); + let input_batches = reader.collect::, _>>()?; + let original = + arrow::compute::concat_batches(&schema, &input_batches).expect("concat input"); + let tmp = NamedTempFile::new().expect("create temp file"); + let out_file = File::create(tmp.path()).expect("create temp avro"); + let mut writer = AvroWriter::new(out_file, original.schema().as_ref().clone())?; + writer.write(&original)?; + writer.finish()?; + drop(writer); + let rt_file = File::open(tmp.path()).expect("open round_trip avro"); + let rt_reader = ReaderBuilder::new() + .build(BufReader::new(rt_file)) + .expect("build round_trip reader"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let round_trip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip"); + assert_eq!(round_trip, original); + Ok(()) + } + + // Strict equality (schema + values) only when canonical extension types are enabled + #[test] + #[cfg(feature = "canonical_extension_types")] + fn test_round_trip_duration_and_uuid_ocf() -> Result<(), ArrowError> { + use arrow_schema::{DataType, IntervalUnit}; + let in_file = + File::open("test/data/duration_uuid.avro").expect("open test/data/duration_uuid.avro"); + let reader = ReaderBuilder::new() + .build(BufReader::new(in_file)) + .expect("build reader for duration_uuid.avro"); + let in_schema = reader.schema(); + let has_mdn = in_schema.fields().iter().any(|f| { + matches!( + f.data_type(), + DataType::Interval(IntervalUnit::MonthDayNano) + ) + }); + assert!( + has_mdn, + "expected at least one Interval(MonthDayNano) field in duration_uuid.avro" + ); + let has_uuid_fixed = in_schema + .fields() + .iter() + .any(|f| matches!(f.data_type(), DataType::FixedSizeBinary(16))); + assert!( + has_uuid_fixed, + "expected at least one FixedSizeBinary(16) (uuid) field in duration_uuid.avro" + ); + let input_batches = reader.collect::, _>>()?; + let input = + arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input"); + // Write to an in‑memory OCF and read back + let mut writer = AvroWriter::new(Vec::::new(), in_schema.as_ref().clone())?; + writer.write(&input)?; + writer.finish()?; + let bytes = writer.into_inner(); + let rt_reader = ReaderBuilder::new() + .build(Cursor::new(bytes)) + .expect("build round_trip reader"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let round_trip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip"); + assert_eq!(round_trip, input); + Ok(()) + } + + // Feature OFF: only values are asserted equal; schema may legitimately differ (uuid as fixed(16)) + #[test] + #[cfg(not(feature = "canonical_extension_types"))] + fn test_duration_and_uuid_ocf_without_extensions_round_trips_values() -> Result<(), ArrowError> + { + use arrow::datatypes::{DataType, IntervalUnit}; + use std::io::BufReader; + + // Read input Avro (duration + uuid) + let in_file = + File::open("test/data/duration_uuid.avro").expect("open test/data/duration_uuid.avro"); + let reader = ReaderBuilder::new() + .build(BufReader::new(in_file)) + .expect("build reader for duration_uuid.avro"); + let in_schema = reader.schema(); + + // Sanity checks: has MonthDayNano and a FixedSizeBinary(16) + assert!( + in_schema.fields().iter().any(|f| { + matches!( + f.data_type(), + DataType::Interval(IntervalUnit::MonthDayNano) + ) + }), + "expected at least one Interval(MonthDayNano) field" + ); + assert!( + in_schema + .fields() + .iter() + .any(|f| matches!(f.data_type(), DataType::FixedSizeBinary(16))), + "expected a FixedSizeBinary(16) field (uuid)" + ); + + let input_batches = reader.collect::, _>>()?; + let input = + arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input"); + + // Write to a temp OCF and read back + let mut writer = AvroWriter::new(Vec::::new(), in_schema.as_ref().clone())?; + writer.write(&input)?; + writer.finish()?; + let bytes = writer.into_inner(); + let rt_reader = ReaderBuilder::new() + .build(Cursor::new(bytes)) + .expect("build round_trip reader"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let round_trip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip"); + + // 1) Values must round-trip for both columns + assert_eq!( + round_trip.column(0), + input.column(0), + "duration column values differ" + ); + assert_eq!(round_trip.column(1), input.column(1), "uuid bytes differ"); + + // 2) Schema expectation without extensions: + // uuid is written as named fixed(16), so reader attaches avro.name + let uuid_rt = rt_schema.field_with_name("uuid_field")?; + assert_eq!(uuid_rt.data_type(), &DataType::FixedSizeBinary(16)); + assert_eq!( + uuid_rt.metadata().get("logicalType").map(|s| s.as_str()), + Some("uuid"), + "expected `logicalType = \"uuid\"` on round-tripped field metadata" + ); + + // 3) Duration remains Interval(MonthDayNano) + let dur_rt = rt_schema.field_with_name("duration_field")?; + assert!(matches!( + dur_rt.data_type(), + DataType::Interval(IntervalUnit::MonthDayNano) + )); + + Ok(()) + } + + // This test reads the same 'nonnullable.impala.avro' used by the reader tests, + // writes it back out with the writer (hitting Map encoding paths), then reads it + // again and asserts exact Arrow equivalence. + #[test] + // TODO: avoid requiring snappy for this file + #[cfg(feature = "snappy")] + fn test_nonnullable_impala_roundtrip_writer() -> Result<(), ArrowError> { + // Load source Avro with Map fields + let path = arrow_test_data("avro/nonnullable.impala.avro"); + let rdr_file = File::open(&path).expect("open avro/nonnullable.impala.avro"); + let reader = ReaderBuilder::new() + .build(BufReader::new(rdr_file)) + .expect("build reader for nonnullable.impala.avro"); + // Collect all input batches and concatenate to a single RecordBatch + let in_schema = reader.schema(); + // Sanity: ensure the file actually contains at least one Map field + let has_map = in_schema + .fields() + .iter() + .any(|f| matches!(f.data_type(), DataType::Map(_, _))); + assert!( + has_map, + "expected at least one Map field in avro/nonnullable.impala.avro" + ); + + let input_batches = reader.collect::, _>>()?; + let original = + arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input"); + // Write out using the OCF writer into an in-memory Vec + let buffer = Vec::::new(); + let mut writer = AvroWriter::new(buffer, in_schema.as_ref().clone())?; + writer.write(&original)?; + writer.finish()?; + let out_bytes = writer.into_inner(); + // Read the produced bytes back with the Reader + let rt_reader = ReaderBuilder::new() + .build(Cursor::new(out_bytes)) + .expect("build reader for round-tripped in-memory OCF"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let roundtrip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip"); + // Exact value fidelity (schema + data) + assert_eq!( + roundtrip, original, + "Round-trip Avro map data mismatch for nonnullable.impala.avro" + ); + Ok(()) + } + + #[test] + // TODO: avoid requiring snappy for these files + #[cfg(feature = "snappy")] + fn test_roundtrip_decimals_via_writer() -> Result<(), ArrowError> { + // (file, resolve via ARROW_TEST_DATA?) + let files: [(&str, bool); 8] = [ + ("avro/fixed_length_decimal.avro", true), // fixed-backed -> Decimal128(25,2) + ("avro/fixed_length_decimal_legacy.avro", true), // legacy fixed[8] -> Decimal64(13,2) + ("avro/int32_decimal.avro", true), // bytes-backed -> Decimal32(4,2) + ("avro/int64_decimal.avro", true), // bytes-backed -> Decimal64(10,2) + ("test/data/int256_decimal.avro", false), // bytes-backed -> Decimal256(76,2) + ("test/data/fixed256_decimal.avro", false), // fixed[32]-backed -> Decimal256(76,10) + ("test/data/fixed_length_decimal_legacy_32.avro", false), // legacy fixed[4] -> Decimal32(9,2) + ("test/data/int128_decimal.avro", false), // bytes-backed -> Decimal128(38,2) + ]; + for (rel, in_test_data_dir) in files { + // Resolve path the same way as reader::test_decimal + let path: String = if in_test_data_dir { + arrow_test_data(rel) + } else { + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join(rel) + .to_string_lossy() + .into_owned() + }; + // Read original file into a single RecordBatch for comparison + let f_in = File::open(&path).expect("open input avro"); + let rdr = ReaderBuilder::new().build(BufReader::new(f_in))?; + let in_schema = rdr.schema(); + let in_batches = rdr.collect::, _>>()?; + let original = + arrow::compute::concat_batches(&in_schema, &in_batches).expect("concat input"); + // Write it out with the OCF writer (no special compression) + let tmp = NamedTempFile::new().expect("create temp file"); + let out_path = tmp.into_temp_path(); + let out_file = File::create(&out_path).expect("create temp avro"); + let mut writer = AvroWriter::new(out_file, original.schema().as_ref().clone())?; + writer.write(&original)?; + writer.finish()?; + // Read back the file we just wrote and compare equality (schema + data) + let f_rt = File::open(&out_path).expect("open roundtrip avro"); + let rt_rdr = ReaderBuilder::new().build(BufReader::new(f_rt))?; + let rt_schema = rt_rdr.schema(); + let rt_batches = rt_rdr.collect::, _>>()?; + let roundtrip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat rt"); + assert_eq!(roundtrip, original, "decimal round-trip mismatch for {rel}"); + } + Ok(()) + } + + #[test] + fn test_named_types_complex_roundtrip() -> Result<(), ArrowError> { + // 1. Read the new, more complex named references file. + let path = + PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test/data/named_types_complex.avro"); + let rdr_file = File::open(&path).expect("open avro/named_types_complex.avro"); + + let reader = ReaderBuilder::new() + .build(BufReader::new(rdr_file)) + .expect("build reader for named_types_complex.avro"); + + // 2. Concatenate all batches to one RecordBatch. + let in_schema = reader.schema(); + let input_batches = reader.collect::, _>>()?; + let original = + arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input"); + + // 3. Sanity Checks: Validate that all named types were reused correctly. + { + let arrow_schema = original.schema(); + + // --- A. Validate 'User' record reuse --- + let author_field = arrow_schema.field_with_name("author")?; + let author_type = author_field.data_type(); + let editors_field = arrow_schema.field_with_name("editors")?; + let editors_item_type = match editors_field.data_type() { + DataType::List(item_field) => item_field.data_type(), + other => panic!("Editors field should be a List, but was {:?}", other), + }; + assert_eq!( + author_type, editors_item_type, + "The DataType for the 'author' struct and the 'editors' list items must be identical" + ); + + // --- B. Validate 'PostStatus' enum reuse --- + let status_field = arrow_schema.field_with_name("status")?; + let status_type = status_field.data_type(); + assert!( + matches!(status_type, DataType::Dictionary(_, _)), + "Status field should be a Dictionary (Enum)" + ); + + let prev_status_field = arrow_schema.field_with_name("previous_status")?; + let prev_status_type = prev_status_field.data_type(); + assert_eq!( + status_type, prev_status_type, + "The DataType for 'status' and 'previous_status' enums must be identical" + ); + + // --- C. Validate 'MD5' fixed reuse --- + let content_hash_field = arrow_schema.field_with_name("content_hash")?; + let content_hash_type = content_hash_field.data_type(); + assert!( + matches!(content_hash_type, DataType::FixedSizeBinary(16)), + "Content hash should be FixedSizeBinary(16)" + ); + + let thumb_hash_field = arrow_schema.field_with_name("thumbnail_hash")?; + let thumb_hash_type = thumb_hash_field.data_type(); + assert_eq!( + content_hash_type, thumb_hash_type, + "The DataType for 'content_hash' and 'thumbnail_hash' fixed types must be identical" + ); + } + + // 4. Write the data to an in-memory buffer. + let buffer: Vec = Vec::new(); + let mut writer = AvroWriter::new(buffer, original.schema().as_ref().clone())?; + writer.write(&original)?; + writer.finish()?; + let bytes = writer.into_inner(); + + // 5. Read the data back and compare for exact equality. + let rt_reader = ReaderBuilder::new() + .build(Cursor::new(bytes)) + .expect("build reader for round-trip"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let roundtrip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip"); + + assert_eq!( + roundtrip, original, + "Avro complex named types round-trip mismatch" + ); + + Ok(()) + } + + // Union Roundtrip Test Helpers + + // Asserts that the `actual` schema is a semantically equivalent superset of the `expected` one. + // This allows the `actual` schema to contain additional metadata keys + // (`arrowUnionMode`, `arrowUnionTypeIds`, `avro.name`) that are added during an Arrow-to-Avro-to-Arrow + // roundtrip, while ensuring no other information was lost or changed. + fn assert_schema_is_semantically_equivalent(expected: &Schema, actual: &Schema) { + // Compare top-level schema metadata using the same superset logic. + assert_metadata_is_superset(expected.metadata(), actual.metadata(), "Schema"); + + // Compare fields. + assert_eq!( + expected.fields().len(), + actual.fields().len(), + "Schema must have the same number of fields" + ); + + for (expected_field, actual_field) in expected.fields().iter().zip(actual.fields().iter()) { + assert_field_is_semantically_equivalent(expected_field, actual_field); + } + } + + fn assert_field_is_semantically_equivalent(expected: &Field, actual: &Field) { + let context = format!("Field '{}'", expected.name()); + + assert_eq!( + expected.name(), + actual.name(), + "{context}: names must match" + ); + assert_eq!( + expected.is_nullable(), + actual.is_nullable(), + "{context}: nullability must match" + ); + + // Recursively check the data types. + assert_datatype_is_semantically_equivalent( + expected.data_type(), + actual.data_type(), + &context, + ); + + // Check that metadata is a valid superset. + assert_metadata_is_superset(expected.metadata(), actual.metadata(), &context); + } + + fn assert_datatype_is_semantically_equivalent( + expected: &DataType, + actual: &DataType, + context: &str, + ) { + match (expected, actual) { + (DataType::List(expected_field), DataType::List(actual_field)) + | (DataType::LargeList(expected_field), DataType::LargeList(actual_field)) + | (DataType::Map(expected_field, _), DataType::Map(actual_field, _)) => { + assert_field_is_semantically_equivalent(expected_field, actual_field); + } + (DataType::Struct(expected_fields), DataType::Struct(actual_fields)) => { + assert_eq!( + expected_fields.len(), + actual_fields.len(), + "{context}: struct must have same number of fields" + ); + for (ef, af) in expected_fields.iter().zip(actual_fields.iter()) { + assert_field_is_semantically_equivalent(ef, af); + } + } + ( + DataType::Union(expected_fields, expected_mode), + DataType::Union(actual_fields, actual_mode), + ) => { + assert_eq!( + expected_mode, actual_mode, + "{context}: union mode must match" + ); + assert_eq!( + expected_fields.len(), + actual_fields.len(), + "{context}: union must have same number of variants" + ); + for ((exp_id, exp_field), (act_id, act_field)) in + expected_fields.iter().zip(actual_fields.iter()) + { + assert_eq!(exp_id, act_id, "{context}: union type ids must match"); + assert_field_is_semantically_equivalent(exp_field, act_field); + } + } + _ => { + assert_eq!(expected, actual, "{context}: data types must be identical"); + } + } + } + + fn assert_batch_data_is_identical(expected: &RecordBatch, actual: &RecordBatch) { + assert_eq!( + expected.num_columns(), + actual.num_columns(), + "RecordBatches must have the same number of columns" + ); + assert_eq!( + expected.num_rows(), + actual.num_rows(), + "RecordBatches must have the same number of rows" + ); + + for i in 0..expected.num_columns() { + let context = format!("Column {i}"); + let expected_col = expected.column(i); + let actual_col = actual.column(i); + assert_array_data_is_identical(expected_col, actual_col, &context); + } + } + + /// Recursively asserts that the data content of two Arrays is identical. + fn assert_array_data_is_identical(expected: &dyn Array, actual: &dyn Array, context: &str) { + assert_eq!( + expected.nulls(), + actual.nulls(), + "{context}: null buffers must match" + ); + assert_eq!( + expected.len(), + actual.len(), + "{context}: array lengths must match" + ); + + match (expected.data_type(), actual.data_type()) { + (DataType::Union(expected_fields, _), DataType::Union(..)) => { + let expected_union = expected.as_any().downcast_ref::().unwrap(); + let actual_union = actual.as_any().downcast_ref::().unwrap(); + + // Compare the type_ids buffer (always the first buffer). + assert_eq!( + &expected.to_data().buffers()[0], + &actual.to_data().buffers()[0], + "{context}: union type_ids buffer mismatch" + ); + + // For dense unions, compare the value_offsets buffer (the second buffer). + if expected.to_data().buffers().len() > 1 { + assert_eq!( + &expected.to_data().buffers()[1], + &actual.to_data().buffers()[1], + "{context}: union value_offsets buffer mismatch" + ); + } + + // Recursively compare children based on the fields in the DataType. + for (type_id, _) in expected_fields.iter() { + let child_context = format!("{context} -> child variant {type_id}"); + assert_array_data_is_identical( + expected_union.child(type_id), + actual_union.child(type_id), + &child_context, + ); + } + } + (DataType::Struct(_), DataType::Struct(_)) => { + let expected_struct = expected.as_any().downcast_ref::().unwrap(); + let actual_struct = actual.as_any().downcast_ref::().unwrap(); + for i in 0..expected_struct.num_columns() { + let child_context = format!("{context} -> struct child {i}"); + assert_array_data_is_identical( + expected_struct.column(i), + actual_struct.column(i), + &child_context, + ); + } + } + // Fallback for primitive types and other types where buffer comparison is sufficient. + _ => { + assert_eq!( + expected.to_data().buffers(), + actual.to_data().buffers(), + "{context}: data buffers must match" + ); + } + } + } + + /// Checks that `actual_meta` contains all of `expected_meta`, and any additional + /// keys in `actual_meta` are from a permitted set. + fn assert_metadata_is_superset( + expected_meta: &HashMap, + actual_meta: &HashMap, + context: &str, + ) { + let allowed_additions: HashSet<&str> = + vec!["arrowUnionMode", "arrowUnionTypeIds", "avro.name"] + .into_iter() + .collect(); + for (key, expected_value) in expected_meta { + match actual_meta.get(key) { + Some(actual_value) => assert_eq!( + expected_value, actual_value, + "{context}: preserved metadata for key '{key}' must have the same value" + ), + None => panic!("{context}: metadata key '{key}' was lost during roundtrip"), + } + } + for key in actual_meta.keys() { + if !expected_meta.contains_key(key) && !allowed_additions.contains(key.as_str()) { + panic!("{context}: unexpected metadata key '{key}' was added during roundtrip"); + } + } + } + + #[test] + fn test_union_roundtrip() -> Result<(), ArrowError> { + let file_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("test/data/union_fields.avro") + .to_string_lossy() + .into_owned(); + let rdr_file = File::open(&file_path).expect("open avro/union_fields.avro"); + let reader = ReaderBuilder::new() + .build(BufReader::new(rdr_file)) + .expect("build reader for union_fields.avro"); + let schema = reader.schema(); + let input_batches = reader.collect::, _>>()?; + let original = + arrow::compute::concat_batches(&schema, &input_batches).expect("concat input"); + let mut writer = AvroWriter::new(Vec::::new(), original.schema().as_ref().clone())?; + writer.write(&original)?; + writer.finish()?; + let bytes = writer.into_inner(); + let rt_reader = ReaderBuilder::new() + .build(Cursor::new(bytes)) + .expect("build round_trip reader"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let round_trip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip"); + + // The nature of the crate is such that metadata gets appended during the roundtrip, + // so we can't compare the schemas directly. Instead, we semantically compare the schemas and data. + assert_schema_is_semantically_equivalent(&original.schema(), &round_trip.schema()); + + assert_batch_data_is_identical(&original, &round_trip); + Ok(()) + } + + #[test] + fn test_enum_roundtrip_uses_reader_fixture() -> Result<(), ArrowError> { + // Read the known-good enum file (same as reader::test_simple) + let path = arrow_test_data("avro/simple_enum.avro"); + let rdr_file = File::open(&path).expect("open avro/simple_enum.avro"); + let reader = ReaderBuilder::new() + .build(BufReader::new(rdr_file)) + .expect("build reader for simple_enum.avro"); + // Concatenate all batches to one RecordBatch for a clean equality check + let in_schema = reader.schema(); + let input_batches = reader.collect::, _>>()?; + let original = + arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input"); + // Sanity: expect at least one Dictionary(Int32, Utf8) column (enum) + let has_enum_dict = in_schema.fields().iter().any(|f| { + matches!( + f.data_type(), + DataType::Dictionary(k, v) if **k == DataType::Int32 && **v == DataType::Utf8 + ) + }); + assert!( + has_enum_dict, + "Expected at least one enum-mapped Dictionary field" + ); + // Write with OCF writer into memory using the reader-provided Arrow schema. + // The writer will embed the Avro JSON from `avro.schema` metadata if present. + let buffer: Vec = Vec::new(); + let mut writer = AvroWriter::new(buffer, in_schema.as_ref().clone())?; + writer.write(&original)?; + writer.finish()?; + let bytes = writer.into_inner(); + // Read back and compare for exact equality (schema + data) + let rt_reader = ReaderBuilder::new() + .build(Cursor::new(bytes)) + .expect("reader for round-trip"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let roundtrip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip"); + assert_eq!(roundtrip, original, "Avro enum round-trip mismatch"); + Ok(()) + } + + #[test] + fn test_builder_propagates_capacity_to_writer() -> Result<(), ArrowError> { + let cap = 64 * 1024; + let buffer = Vec::::new(); + let mut writer = WriterBuilder::new(make_schema()) + .with_capacity(cap) + .build::<_, AvroOcfFormat>(buffer)?; + assert_eq!(writer.capacity, cap, "builder capacity not propagated"); + let batch = make_batch(); + writer.write(&batch)?; + writer.finish()?; + let out = writer.into_inner(); + assert_eq!(&out[..4], b"Obj\x01", "OCF magic missing/incorrect"); + Ok(()) + } + + #[test] + fn test_stream_writer_stores_capacity_direct_writes() -> Result<(), ArrowError> { + use arrow_array::{ArrayRef, Int32Array}; + use arrow_schema::{DataType, Field, Schema}; + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef], + )?; + let cap = 8192; + let mut writer = WriterBuilder::new(schema) + .with_capacity(cap) + .build::<_, AvroSoeFormat>(Vec::new())?; + assert_eq!(writer.capacity, cap); + writer.write(&batch)?; + let _bytes = writer.into_inner(); + Ok(()) + } + + #[cfg(feature = "avro_custom_types")] + #[test] + fn test_roundtrip_duration_logical_types_ocf() -> Result<(), ArrowError> { + let file_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("test/data/duration_logical_types.avro") + .to_string_lossy() + .into_owned(); + + let in_file = File::open(&file_path) + .unwrap_or_else(|_| panic!("Failed to open test file: {}", file_path)); + + let reader = ReaderBuilder::new() + .build(BufReader::new(in_file)) + .expect("build reader for duration_logical_types.avro"); + let in_schema = reader.schema(); + + let expected_units: HashSet = [ + TimeUnit::Nanosecond, + TimeUnit::Microsecond, + TimeUnit::Millisecond, + TimeUnit::Second, + ] + .into_iter() + .collect(); + + let found_units: HashSet = in_schema + .fields() + .iter() + .filter_map(|f| match f.data_type() { + DataType::Duration(unit) => Some(*unit), + _ => None, + }) + .collect(); + + assert_eq!( + found_units, expected_units, + "Expected to find all four Duration TimeUnits in the schema from the initial read" + ); + + let input_batches = reader.collect::, _>>()?; + let input = + arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input"); + + let tmp = NamedTempFile::new().expect("create temp file"); + { + let out_file = File::create(tmp.path()).expect("create temp avro"); + let mut writer = AvroWriter::new(out_file, in_schema.as_ref().clone())?; + writer.write(&input)?; + writer.finish()?; + } + + let rt_file = File::open(tmp.path()).expect("open round_trip avro"); + let rt_reader = ReaderBuilder::new() + .build(BufReader::new(rt_file)) + .expect("build round_trip reader"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let round_trip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip"); + + assert_eq!(round_trip, input); + Ok(()) + } + + #[cfg(feature = "avro_custom_types")] + #[test] + fn test_run_end_encoded_roundtrip_writer() -> Result<(), ArrowError> { + let run_ends = Int32Array::from(vec![3, 5, 7, 8]); + let run_values = Int32Array::from(vec![Some(1), Some(2), None, Some(3)]); + let ree = RunArray::::try_new(&run_ends, &run_values)?; + let field = Field::new("x", ree.data_type().clone(), true); + let schema = Schema::new(vec![field]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(ree.clone()) as ArrayRef], + )?; + let mut writer = AvroWriter::new(Vec::::new(), schema.clone())?; + writer.write(&batch)?; + writer.finish()?; + let bytes = writer.into_inner(); + let reader = ReaderBuilder::new().build(Cursor::new(bytes))?; + let out_schema = reader.schema(); + let batches = reader.collect::, _>>()?; + let out = arrow::compute::concat_batches(&out_schema, &batches).expect("concat output"); + assert_eq!(out.num_columns(), 1); + assert_eq!(out.num_rows(), 8); + match out.schema().field(0).data_type() { + DataType::RunEndEncoded(run_ends_field, values_field) => { + assert_eq!(run_ends_field.name(), "run_ends"); + assert_eq!(run_ends_field.data_type(), &DataType::Int32); + assert_eq!(values_field.name(), "values"); + assert_eq!(values_field.data_type(), &DataType::Int32); + assert!(values_field.is_nullable()); + let got_ree = out + .column(0) + .as_any() + .downcast_ref::>() + .expect("RunArray"); + assert_eq!(got_ree, &ree); + } + other => panic!( + "Unexpected DataType for round-tripped RunEndEncoded column: {:?}", + other + ), + } + Ok(()) + } + + #[cfg(feature = "avro_custom_types")] + #[test] + fn test_run_end_encoded_string_values_int16_run_ends_roundtrip_writer() -> Result<(), ArrowError> + { + let run_ends = Int16Array::from(vec![2, 5, 7]); // end indices + let run_values = StringArray::from(vec![Some("a"), None, Some("c")]); + let ree = RunArray::::try_new(&run_ends, &run_values)?; + let field = Field::new("s", ree.data_type().clone(), true); + let schema = Schema::new(vec![field]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(ree.clone()) as ArrayRef], + )?; + let mut writer = AvroWriter::new(Vec::::new(), schema.clone())?; + writer.write(&batch)?; + writer.finish()?; + let bytes = writer.into_inner(); + let reader = ReaderBuilder::new().build(Cursor::new(bytes))?; + let out_schema = reader.schema(); + let batches = reader.collect::, _>>()?; + let out = arrow::compute::concat_batches(&out_schema, &batches).expect("concat output"); + assert_eq!(out.num_columns(), 1); + assert_eq!(out.num_rows(), 7); + match out.schema().field(0).data_type() { + DataType::RunEndEncoded(run_ends_field, values_field) => { + assert_eq!(run_ends_field.data_type(), &DataType::Int16); + assert_eq!(values_field.data_type(), &DataType::Utf8); + assert!( + values_field.is_nullable(), + "REE 'values' child should be nullable" + ); + let got = out + .column(0) + .as_any() + .downcast_ref::>() + .expect("RunArray"); + assert_eq!(got, &ree); + } + other => panic!("Unexpected DataType: {:?}", other), + } + Ok(()) + } + + #[cfg(feature = "avro_custom_types")] + #[test] + fn test_run_end_encoded_int64_run_ends_numeric_values_roundtrip_writer() + -> Result<(), ArrowError> { + let run_ends = Int64Array::from(vec![4_i64, 8_i64]); + let run_values = Int32Array::from(vec![Some(999), Some(-5)]); + let ree = RunArray::::try_new(&run_ends, &run_values)?; + let field = Field::new("y", ree.data_type().clone(), true); + let schema = Schema::new(vec![field]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(ree.clone()) as ArrayRef], + )?; + let mut writer = AvroWriter::new(Vec::::new(), schema.clone())?; + writer.write(&batch)?; + writer.finish()?; + let bytes = writer.into_inner(); + let reader = ReaderBuilder::new().build(Cursor::new(bytes))?; + let out_schema = reader.schema(); + let batches = reader.collect::, _>>()?; + let out = arrow::compute::concat_batches(&out_schema, &batches).expect("concat output"); + assert_eq!(out.num_columns(), 1); + assert_eq!(out.num_rows(), 8); + match out.schema().field(0).data_type() { + DataType::RunEndEncoded(run_ends_field, values_field) => { + assert_eq!(run_ends_field.data_type(), &DataType::Int64); + assert_eq!(values_field.data_type(), &DataType::Int32); + assert!(values_field.is_nullable()); + let got = out + .column(0) + .as_any() + .downcast_ref::>() + .expect("RunArray"); + assert_eq!(got, &ree); + } + other => panic!("Unexpected DataType for REE column: {:?}", other), + } + Ok(()) + } + + #[cfg(feature = "avro_custom_types")] + #[test] + fn test_run_end_encoded_sliced_roundtrip_writer() -> Result<(), ArrowError> { + let run_ends = Int32Array::from(vec![3, 5, 7, 8]); + let run_values = Int32Array::from(vec![Some(1), Some(2), None, Some(3)]); + let base = RunArray::::try_new(&run_ends, &run_values)?; + let offset = 1usize; + let length = 6usize; + let base_values = base + .values() + .as_any() + .downcast_ref::() + .expect("REE values as Int32Array"); + let mut logical_window: Vec> = Vec::with_capacity(length); + for i in offset..offset + length { + let phys = base.get_physical_index(i); + let v = if base_values.is_null(phys) { + None + } else { + Some(base_values.value(phys)) + }; + logical_window.push(v); + } + + fn compress_run_ends_i32(vals: &[Option]) -> (Int32Array, Int32Array) { + if vals.is_empty() { + return (Int32Array::new_null(0), Int32Array::new_null(0)); + } + let mut run_ends_out: Vec = Vec::new(); + let mut run_vals_out: Vec> = Vec::new(); + let mut cur = vals[0]; + let mut len = 1i32; + for v in &vals[1..] { + if *v == cur { + len += 1; + } else { + let last_end = run_ends_out.last().copied().unwrap_or(0); + run_ends_out.push(last_end + len); + run_vals_out.push(cur); + cur = *v; + len = 1; + } + } + let last_end = run_ends_out.last().copied().unwrap_or(0); + run_ends_out.push(last_end + len); + run_vals_out.push(cur); + ( + Int32Array::from(run_ends_out), + Int32Array::from(run_vals_out), + ) + } + let (owned_run_ends, owned_run_values) = compress_run_ends_i32(&logical_window); + let owned_slice = RunArray::::try_new(&owned_run_ends, &owned_run_values)?; + let field = Field::new("x", owned_slice.data_type().clone(), true); + let schema = Schema::new(vec![field]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(owned_slice.clone()) as ArrayRef], + )?; + let mut writer = AvroWriter::new(Vec::::new(), schema.clone())?; + writer.write(&batch)?; + writer.finish()?; + let bytes = writer.into_inner(); + let reader = ReaderBuilder::new().build(Cursor::new(bytes))?; + let out_schema = reader.schema(); + let batches = reader.collect::, _>>()?; + let out = arrow::compute::concat_batches(&out_schema, &batches).expect("concat output"); + assert_eq!(out.num_columns(), 1); + assert_eq!(out.num_rows(), length); + match out.schema().field(0).data_type() { + DataType::RunEndEncoded(run_ends_field, values_field) => { + assert_eq!(run_ends_field.data_type(), &DataType::Int32); + assert_eq!(values_field.data_type(), &DataType::Int32); + assert!(values_field.is_nullable()); + let got = out + .column(0) + .as_any() + .downcast_ref::>() + .expect("RunArray"); + fn expand_ree_to_int32(a: &RunArray) -> Int32Array { + let vals = a + .values() + .as_any() + .downcast_ref::() + .expect("REE values as Int32Array"); + let mut out: Vec> = Vec::with_capacity(a.len()); + for i in 0..a.len() { + let phys = a.get_physical_index(i); + out.push(if vals.is_null(phys) { + None + } else { + Some(vals.value(phys)) + }); + } + Int32Array::from(out) + } + let got_logical = expand_ree_to_int32(got); + let expected_logical = Int32Array::from(logical_window); + assert_eq!( + got_logical, expected_logical, + "Logical values differ after REE slice round-trip" + ); + } + other => panic!("Unexpected DataType for REE column: {:?}", other), + } + Ok(()) + } + + #[cfg(not(feature = "avro_custom_types"))] + #[test] + fn test_run_end_encoded_roundtrip_writer_feature_off() -> Result<(), ArrowError> { + use arrow_schema::{DataType, Field, Schema}; + let run_ends = arrow_array::Int32Array::from(vec![3, 5, 7, 8]); + let run_values = arrow_array::Int32Array::from(vec![Some(1), Some(2), None, Some(3)]); + let ree = arrow_array::RunArray::::try_new( + &run_ends, + &run_values, + )?; + let field = Field::new("x", ree.data_type().clone(), true); + let schema = Schema::new(vec![field]); + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(ree) as ArrayRef])?; + let mut writer = AvroWriter::new(Vec::::new(), schema.clone())?; + writer.write(&batch)?; + writer.finish()?; + let bytes = writer.into_inner(); + let reader = ReaderBuilder::new().build(Cursor::new(bytes))?; + let out_schema = reader.schema(); + let batches = reader.collect::, _>>()?; + let out = arrow::compute::concat_batches(&out_schema, &batches).expect("concat output"); + assert_eq!(out.num_columns(), 1); + assert_eq!(out.num_rows(), 8); + assert_eq!(out.schema().field(0).data_type(), &DataType::Int32); + let got = out + .column(0) + .as_any() + .downcast_ref::() + .expect("Int32Array"); + let expected = Int32Array::from(vec![ + Some(1), + Some(1), + Some(1), + Some(2), + Some(2), + None, + None, + Some(3), + ]); + assert_eq!(got, &expected); + Ok(()) + } + + #[cfg(not(feature = "avro_custom_types"))] + #[test] + fn test_run_end_encoded_string_values_int16_run_ends_roundtrip_writer_feature_off() + -> Result<(), ArrowError> { + use arrow_schema::{DataType, Field, Schema}; + let run_ends = arrow_array::Int16Array::from(vec![2, 5, 7]); + let run_values = arrow_array::StringArray::from(vec![Some("a"), None, Some("c")]); + let ree = arrow_array::RunArray::::try_new( + &run_ends, + &run_values, + )?; + let field = Field::new("s", ree.data_type().clone(), true); + let schema = Schema::new(vec![field]); + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(ree) as ArrayRef])?; + let mut writer = AvroWriter::new(Vec::::new(), schema.clone())?; + writer.write(&batch)?; + writer.finish()?; + let bytes = writer.into_inner(); + let reader = ReaderBuilder::new().build(Cursor::new(bytes))?; + let out_schema = reader.schema(); + let batches = reader.collect::, _>>()?; + let out = arrow::compute::concat_batches(&out_schema, &batches).expect("concat output"); + assert_eq!(out.num_columns(), 1); + assert_eq!(out.num_rows(), 7); + assert_eq!(out.schema().field(0).data_type(), &DataType::Utf8); + let got = out + .column(0) + .as_any() + .downcast_ref::() + .expect("StringArray"); + let expected = arrow_array::StringArray::from(vec![ + Some("a"), + Some("a"), + None, + None, + None, + Some("c"), + Some("c"), + ]); + assert_eq!(got, &expected); + Ok(()) + } + + #[cfg(not(feature = "avro_custom_types"))] + #[test] + fn test_run_end_encoded_int64_run_ends_numeric_values_roundtrip_writer_feature_off() + -> Result<(), ArrowError> { + use arrow_schema::{DataType, Field, Schema}; + let run_ends = arrow_array::Int64Array::from(vec![4_i64, 8_i64]); + let run_values = Int32Array::from(vec![Some(999), Some(-5)]); + let ree = arrow_array::RunArray::::try_new( + &run_ends, + &run_values, + )?; + let field = Field::new("y", ree.data_type().clone(), true); + let schema = Schema::new(vec![field]); + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(ree) as ArrayRef])?; + let mut writer = AvroWriter::new(Vec::::new(), schema.clone())?; + writer.write(&batch)?; + writer.finish()?; + let bytes = writer.into_inner(); + let reader = ReaderBuilder::new().build(Cursor::new(bytes))?; + let out_schema = reader.schema(); + let batches = reader.collect::, _>>()?; + let out = arrow::compute::concat_batches(&out_schema, &batches).expect("concat output"); + assert_eq!(out.num_columns(), 1); + assert_eq!(out.num_rows(), 8); + assert_eq!(out.schema().field(0).data_type(), &DataType::Int32); + let got = out + .column(0) + .as_any() + .downcast_ref::() + .expect("Int32Array"); + let expected = Int32Array::from(vec![ + Some(999), + Some(999), + Some(999), + Some(999), + Some(-5), + Some(-5), + Some(-5), + Some(-5), + ]); + assert_eq!(got, &expected); + Ok(()) + } + + #[cfg(not(feature = "avro_custom_types"))] + #[test] + fn test_run_end_encoded_sliced_roundtrip_writer_feature_off() -> Result<(), ArrowError> { + use arrow_schema::{DataType, Field, Schema}; + let run_ends = Int32Array::from(vec![2, 4, 6]); + let run_values = Int32Array::from(vec![Some(1), Some(2), None]); + let ree = arrow_array::RunArray::::try_new( + &run_ends, + &run_values, + )?; + let field = Field::new("x", ree.data_type().clone(), true); + let schema = Schema::new(vec![field]); + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(ree) as ArrayRef])?; + let mut writer = AvroWriter::new(Vec::::new(), schema.clone())?; + writer.write(&batch)?; + writer.finish()?; + let bytes = writer.into_inner(); + let reader = ReaderBuilder::new().build(Cursor::new(bytes))?; + let out_schema = reader.schema(); + let batches = reader.collect::, _>>()?; + let out = arrow::compute::concat_batches(&out_schema, &batches).expect("concat output"); + assert_eq!(out.num_columns(), 1); + assert_eq!(out.num_rows(), 6); + assert_eq!(out.schema().field(0).data_type(), &DataType::Int32); + let got = out + .column(0) + .as_any() + .downcast_ref::() + .expect("Int32Array"); + let expected = Int32Array::from(vec![Some(1), Some(1), Some(2), Some(2), None, None]); + assert_eq!(got, &expected); + Ok(()) + } + + #[test] + // TODO: avoid requiring snappy for this file + #[cfg(feature = "snappy")] + fn test_nullable_impala_roundtrip() -> Result<(), ArrowError> { + let path = arrow_test_data("avro/nullable.impala.avro"); + let rdr_file = File::open(&path).expect("open avro/nullable.impala.avro"); + let reader = ReaderBuilder::new() + .build(BufReader::new(rdr_file)) + .expect("build reader for nullable.impala.avro"); + let in_schema = reader.schema(); + assert!( + in_schema.fields().iter().any(|f| f.is_nullable()), + "expected at least one nullable field in avro/nullable.impala.avro" + ); + let input_batches = reader.collect::, _>>()?; + let original = + arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input"); + let buffer: Vec = Vec::new(); + let mut writer = AvroWriter::new(buffer, in_schema.as_ref().clone())?; + writer.write(&original)?; + writer.finish()?; + let out_bytes = writer.into_inner(); + let rt_reader = ReaderBuilder::new() + .build(Cursor::new(out_bytes)) + .expect("build reader for round-tripped in-memory OCF"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let roundtrip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip"); + assert_eq!( + roundtrip, original, + "Round-trip Avro data mismatch for nullable.impala.avro" + ); + Ok(()) + } + + #[test] + #[cfg(feature = "snappy")] + fn test_datapage_v2_roundtrip() -> Result<(), ArrowError> { + let path = arrow_test_data("avro/datapage_v2.snappy.avro"); + let rdr_file = File::open(&path).expect("open avro/datapage_v2.snappy.avro"); + let reader = ReaderBuilder::new() + .build(BufReader::new(rdr_file)) + .expect("build reader for datapage_v2.snappy.avro"); + let in_schema = reader.schema(); + let input_batches = reader.collect::, _>>()?; + let original = + arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input"); + let mut writer = AvroWriter::new(Vec::::new(), in_schema.as_ref().clone())?; + writer.write(&original)?; + writer.finish()?; + let bytes = writer.into_inner(); + let rt_reader = ReaderBuilder::new() + .build(Cursor::new(bytes)) + .expect("build round-trip reader"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let round_trip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip"); + assert_eq!( + round_trip, original, + "Round-trip batch mismatch for datapage_v2.snappy.avro" + ); + Ok(()) + } + + #[test] + #[cfg(feature = "snappy")] + fn test_single_nan_roundtrip() -> Result<(), ArrowError> { + let path = arrow_test_data("avro/single_nan.avro"); + let in_file = File::open(&path).expect("open avro/single_nan.avro"); + let reader = ReaderBuilder::new() + .build(BufReader::new(in_file)) + .expect("build reader for single_nan.avro"); + let in_schema = reader.schema(); + let in_batches = reader.collect::, _>>()?; + let original = + arrow::compute::concat_batches(&in_schema, &in_batches).expect("concat input"); + let mut writer = AvroWriter::new(Vec::::new(), original.schema().as_ref().clone())?; + writer.write(&original)?; + writer.finish()?; + let bytes = writer.into_inner(); + let rt_reader = ReaderBuilder::new() + .build(Cursor::new(bytes)) + .expect("build round_trip reader"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let round_trip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip"); + assert_eq!( + round_trip, original, + "Round-trip batch mismatch for avro/single_nan.avro" + ); + Ok(()) + } + #[test] + // TODO: avoid requiring snappy for this file + #[cfg(feature = "snappy")] + fn test_dict_pages_offset_zero_roundtrip() -> Result<(), ArrowError> { + let path = arrow_test_data("avro/dict-page-offset-zero.avro"); + let rdr_file = File::open(&path).expect("open avro/dict-page-offset-zero.avro"); + let reader = ReaderBuilder::new() + .build(BufReader::new(rdr_file)) + .expect("build reader for dict-page-offset-zero.avro"); + let in_schema = reader.schema(); + let input_batches = reader.collect::, _>>()?; + let original = + arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input"); + let buffer: Vec = Vec::new(); + let mut writer = AvroWriter::new(buffer, original.schema().as_ref().clone())?; + writer.write(&original)?; + writer.finish()?; + let bytes = writer.into_inner(); + let rt_reader = ReaderBuilder::new() + .build(Cursor::new(bytes)) + .expect("build reader for round-trip"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let roundtrip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip"); + assert_eq!( + roundtrip, original, + "Round-trip batch mismatch for avro/dict-page-offset-zero.avro" + ); + Ok(()) + } + + #[test] + #[cfg(feature = "snappy")] + fn test_repeated_no_annotation_roundtrip() -> Result<(), ArrowError> { + let path = arrow_test_data("avro/repeated_no_annotation.avro"); + let in_file = File::open(&path).expect("open avro/repeated_no_annotation.avro"); + let reader = ReaderBuilder::new() + .build(BufReader::new(in_file)) + .expect("build reader for repeated_no_annotation.avro"); + let in_schema = reader.schema(); + let in_batches = reader.collect::, _>>()?; + let original = + arrow::compute::concat_batches(&in_schema, &in_batches).expect("concat input"); + let mut writer = AvroWriter::new(Vec::::new(), original.schema().as_ref().clone())?; + writer.write(&original)?; + writer.finish()?; + let bytes = writer.into_inner(); + let rt_reader = ReaderBuilder::new() + .build(Cursor::new(bytes)) + .expect("build reader for round-trip buffer"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let round_trip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round-trip"); + assert_eq!( + round_trip, original, + "Round-trip batch mismatch for avro/repeated_no_annotation.avro" + ); + Ok(()) + } + + #[test] + fn test_nested_record_type_reuse_roundtrip() -> Result<(), ArrowError> { + let path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("test/data/nested_record_reuse.avro") + .to_string_lossy() + .into_owned(); + let in_file = File::open(&path).expect("open avro/nested_record_reuse.avro"); + let reader = ReaderBuilder::new() + .build(BufReader::new(in_file)) + .expect("build reader for nested_record_reuse.avro"); + let in_schema = reader.schema(); + let input_batches = reader.collect::, _>>()?; + let input = + arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input"); + let mut writer = AvroWriter::new(Vec::::new(), in_schema.as_ref().clone())?; + writer.write(&input)?; + writer.finish()?; + let bytes = writer.into_inner(); + let rt_reader = ReaderBuilder::new() + .build(Cursor::new(bytes)) + .expect("build round_trip reader"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let round_trip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip"); + assert_eq!( + round_trip, input, + "Round-trip batch mismatch for nested_record_reuse.avro" + ); + Ok(()) + } + + #[test] + fn test_enum_type_reuse_roundtrip() -> Result<(), ArrowError> { + let path = + std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("test/data/enum_reuse.avro"); + let rdr_file = std::fs::File::open(&path).expect("open test/data/enum_reuse.avro"); + let reader = ReaderBuilder::new() + .build(std::io::BufReader::new(rdr_file)) + .expect("build reader for enum_reuse.avro"); + let in_schema = reader.schema(); + let input_batches = reader.collect::, _>>()?; + let original = + arrow::compute::concat_batches(&in_schema, &input_batches).expect("concat input"); + let mut writer = AvroWriter::new(Vec::::new(), original.schema().as_ref().clone())?; + writer.write(&original)?; + writer.finish()?; + let bytes = writer.into_inner(); + let rt_reader = ReaderBuilder::new() + .build(std::io::Cursor::new(bytes)) + .expect("build round_trip reader"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let round_trip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat round_trip"); + assert_eq!( + round_trip, original, + "Avro enum type reuse round-trip mismatch" + ); + Ok(()) + } + + #[test] + fn comprehensive_e2e_test_roundtrip() -> Result<(), ArrowError> { + let path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("test/data/comprehensive_e2e.avro"); + let rdr_file = File::open(&path).expect("open test/data/comprehensive_e2e.avro"); + let reader = ReaderBuilder::new() + .build(BufReader::new(rdr_file)) + .expect("build reader for comprehensive_e2e.avro"); + let in_schema = reader.schema(); + let in_batches = reader.collect::, _>>()?; + let original = + arrow::compute::concat_batches(&in_schema, &in_batches).expect("concat input"); + let sink: Vec = Vec::new(); + let mut writer = AvroWriter::new(sink, original.schema().as_ref().clone())?; + writer.write(&original)?; + writer.finish()?; + let bytes = writer.into_inner(); + let rt_reader = ReaderBuilder::new() + .build(Cursor::new(bytes)) + .expect("build round-trip reader"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let roundtrip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip"); + assert_eq!( + roundtrip, original, + "Round-trip batch mismatch for comprehensive_e2e.avro" + ); + Ok(()) + } + + #[test] + fn test_roundtrip_new_time_encoders_writer() -> Result<(), ArrowError> { + let schema = Schema::new(vec![ + Field::new("d32", DataType::Date32, false), + Field::new("t32_ms", DataType::Time32(TimeUnit::Millisecond), false), + Field::new("t64_us", DataType::Time64(TimeUnit::Microsecond), false), + Field::new( + "ts_ms", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ), + Field::new( + "ts_us", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ), + Field::new( + "ts_ns", + DataType::Timestamp(TimeUnit::Nanosecond, None), + false, + ), + ]); + let d32 = Date32Array::from(vec![0, 1, -1]); + let t32_ms: PrimitiveArray = + vec![0_i32, 12_345_i32, 86_399_999_i32].into(); + let t64_us: PrimitiveArray = + vec![0_i64, 1_234_567_i64, 86_399_999_999_i64].into(); + let ts_ms: PrimitiveArray = + vec![0_i64, -1_i64, 1_700_000_000_000_i64].into(); + let ts_us: PrimitiveArray = vec![0_i64, 1_i64, -1_i64].into(); + let ts_ns: PrimitiveArray = vec![0_i64, 1_i64, -1_i64].into(); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![ + Arc::new(d32) as ArrayRef, + Arc::new(t32_ms) as ArrayRef, + Arc::new(t64_us) as ArrayRef, + Arc::new(ts_ms) as ArrayRef, + Arc::new(ts_us) as ArrayRef, + Arc::new(ts_ns) as ArrayRef, + ], + )?; + let mut writer = AvroWriter::new(Vec::::new(), schema.clone())?; + writer.write(&batch)?; + writer.finish()?; + let bytes = writer.into_inner(); + let rt_reader = ReaderBuilder::new() + .build(std::io::Cursor::new(bytes)) + .expect("build reader for round-trip of new time encoders"); + let rt_schema = rt_reader.schema(); + let rt_batches = rt_reader.collect::, _>>()?; + let roundtrip = + arrow::compute::concat_batches(&rt_schema, &rt_batches).expect("concat roundtrip"); + assert_eq!(roundtrip, batch); + Ok(()) + } +} diff --git a/arrow-avro/test/data/README.md b/arrow-avro/test/data/README.md new file mode 100644 index 000000000000..226e0700fb94 --- /dev/null +++ b/arrow-avro/test/data/README.md @@ -0,0 +1,359 @@ + + +# Avro test files for `arrow-avro` + +This directory contains small Avro Object Container Files (OCF) used by +`arrow-avro` tests to validate the `Reader` implementation. These files are generated from +a set of python scripts and will gradually be removed as they are merged into `arrow-testing`. + +## Decimal Files + +This directory contains OCF files used to exercise decoding of Avro’s `decimal` logical type +across both `bytes` and `fixed` encodings, and to cover Arrow decimal widths ranging +from `Decimal32` up through `Decimal256`. The files were generated from a +script (see **How these files were created** below). + +> **Avro decimal recap.** Avro’s `decimal` logical type annotates either a +> `bytes` or `fixed` primitive and stores the **two’s‑complement big‑endian +> representation of the unscaled integer** (value × 10^scale). Implementations +> should reject invalid combinations such as `scale > precision`. + +> **Arrow decimal recap.** Arrow defines `Decimal32`, `Decimal64`, `Decimal128`, +> and `Decimal256` types with maximum precisions of 9, 18, 38, and 76 digits, +> respectively. Tests here validate that the Avro reader selects compatible +> Arrow decimal widths given the Avro decimal’s precision and storage. + +--- + +All files are one‑column Avro OCFs with a field named `value`. Each contains 24 +rows with the sequence `1 … 24` rendered at the file’s declared `scale` +(i.e., at scale 10: `1.0000000000`, `2.0000000000`). + +| File | Avro storage | Decimal (precision, scale) | Intended Arrow width | +|---|---|---|---| +| `int256_decimal.avro` | `bytes` + `logicalType: decimal` | (76, 10) | `Decimal256` | +| `fixed256_decimal.avro` | `fixed[32]` + `logicalType: decimal` | (76, 10) | `Decimal256` | +| `fixed_length_decimal_legacy_32.avro` | `fixed[4]` + `logicalType: decimal` | (9, 2) | `Decimal32` (legacy fixed‑width path) | +| `int128_decimal.avro` | `bytes` + `logicalType: decimal` | (38, 2) | `Decimal128` | + +### Schemas (for reference) + +#### int256_decimal.avro + +```json +{ + "type": "record", + "name": "OneColDecimal256Bytes", + "fields": [{ + "name": "value", + "type": { "type": "bytes", "logicalType": "decimal", "precision": 76, "scale": 10 } + }] +} +``` + +#### fixed256_decimal.avro + +```json +{ + "type": "record", + "name": "OneColDecimal256Fixed", + "fields": [{ + "name": "value", + "type": { + "type": "fixed", "name": "Decimal256Fixed", "size": 32, + "logicalType": "decimal", "precision": 76, "scale": 10 + } + }] +} +``` + +#### fixed_length_decimal_legacy_32.avro + +```json +{ + "type": "record", + "name": "OneColDecimal32FixedLegacy", + "fields": [{ + "name": "value", + "type": { + "type": "fixed", "name": "Decimal32FixedLegacy", "size": 4, + "logicalType": "decimal", "precision": 9, "scale": 2 + } + }] +} +``` + +#### int128_decimal.avro + +```json +{ + "type": "record", + "name": "OneColDecimal128Bytes", + "fields": [{ + "name": "value", + "type": { "type": "bytes", "logicalType": "decimal", "precision": 38, "scale": 2 } + }] +} +``` + +### How these files were created + +All four files were generated by the Python script +`create_avro_decimal_files.py` authored for this purpose. The script uses +`fastavro` to write OCFs and encodes decimal values as required by the Avro +spec (two’s‑complement big‑endian of the unscaled integer). + +#### Re‑generation + +From the repository root (defaults write into arrow-avro/test/data): + +```bash +# 1) Ensure Python 3 is available, then install fastavro +python -m pip install --upgrade fastavro + +# 2) Fetch the script +curl -L -o create_avro_decimal_files.py \ +https://gist.githubusercontent.com/jecsand838/3890349bdb33082a3e8fdcae3257eef7/raw/create_avro_decimal_files.py + +# 3) Generate the files (prints a verification dump by default) +python create_avro_decimal_files.py -o arrow-avro/test/data +``` + +Options: +* --num-rows (default 24) — number of rows to emit per file +* --scale (default 10) — the decimal scale used for the 256 files +* --no-verify — skip reading the files back for printed verification + +## Duration Logical Types File + +This directory contains an OCF file used to test the decoding of Avro long types annotated with custom logicalType values. This is used to map directly to Arrow Duration types with different time units. + +#### duration_logical_types.avro + +```json +{ + "type": "record", + "name": "DurationLogicalTypes", + "fields": [ + { + "name": "duration_time_nanos", + "type": { + "type": "long", + "logicalType": "arrow.duration-nanos" + } + }, + { + "name": "duration_time_micros", + "type": { + "type": "long", + "logicalType": "arrow.duration-micros" + } + }, + { + "name": "duration_time_millis", + "type": { + "type": "long", + "logicalType": "arrow.duration-millis" + } + }, + { + "name": "duration_time_seconds", + "type": { + "type": "long", + "logicalType": "arrow.duration-seconds" + } + } + ] +} +``` + +This file contains 24 rows of random long values across four fields, each annotated with a different custom logical type corresponding to an Arrow Duration unit. + + +#### How this file was created + +The file was generated by the Python script generate_duration_avro.py. The script uses fastavro to write an OCF with the schema and random data described above. + +#### Re‑generation +From the repository root (defaults write into arrow-avro/test/data): + +```Bash + +# 1) Ensure Python 3 is available, then install fastavro +python3 -m pip install --upgrade fastavro + +# 2) Fetch the script +curl -L -o generate_duration_avro.py \ +https://gist.githubusercontent.com/nathaniel-d-ef/c253cb180b041023e3ccfe9df20ccef7/raw/06c8ca1321efcd8e1c8746fd65aa013e1a566944/generate_duration_avro.py + +# 3) Run the generation script +python3 generate_duration_avro.py -o arrow-avro/test/data +``` + +Options: + +* --num-rows (default 24) — number of rows to emit + +* --no-verify — skip reading the file back for printed verification + +## Union File + +**Purpose:** Exercise a wide variety of Avro **union** shapes (including nullable unions, unions of ambiguous scalar types, unions of named types, and unions inside arrays, maps, and nested records) to validate `arrow-avro` union decoding and schema‑resolution paths. + +**Format:** Avro Object Container File (OCF) written by `fastavro.writer` with embedded writer schema. + +**Record count:** four rows. Each row selects different branches across the unions to ensure coverage (i.e., toggling between bytes vs. string, fixed vs. duration vs. decimal, enum vs. record alternatives, etc.). + +**How this file was created:** + +1. Script: [`create_avro_union_file.py`](https://gist.github.com/jecsand838/f4bf85ad597ab34575219df515156444) + Runs with Python 3 and uses **fastavro** to emit `union_fields.avro` in the working directory. +2. Quick reproduce: + ```bash + pip install fastavro + python3 create_avro_union_file.py + # Outputs: ./union_fields.avro + ``` + +> Note: Avro OCF files include a *sync marker*; `fastavro.writer` generates a random one if not provided, so byte‑for‑byte output may vary between runs even with the same data. This does not affect the embedded schema or logical content. + +**Writer schema (overview):** The record is named `UnionTypesRecord` and defines the following fields: + +| Field | Union branches / details | +|-----------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| `nullable_int_nullfirst` | `["null","int"]` (tests null‑first ordering) | +| `nullable_string_nullsecond` | `["string","null"]` (tests null‑second ordering; in Avro, a union field’s default must match the *first* branch) | +| `union_prim` | `["boolean","int","long","float","double","bytes","string"]` | +| `union_bytes_vs_string` | `["bytes","string"]` (ambiguous scalar union; script uses fastavro’s tuple notation to disambiguate) | +| `union_fixed_dur_decfix` | `["Fx8","Dur12","DecFix16"]` where:
• `Fx8` = `fixed`(size=8)
• `Dur12` = `fixed`(size=12, `logicalType`=`duration`)
• `DecFix16` = `fixed`(size=16, `logicalType`=`decimal`, precision=10, scale=2)
**Notes:** Avro `duration` is a `fixed[12]` storing **months, days, millis** as three **little‑endian** 32‑bit integers; Avro `decimal` on `bytes`/`fixed` uses **two’s‑complement big‑endian** encoding of the unscaled integer. | +| `union_enum_records_array_map` | `[ColorU, RecA, RecB, array, map]` where:
• `ColorU` = `enum` {`RED`,`GREEN`,`BLUE`}
• `RecA` = `record` {`a:int`, `b:string`}
• `RecB` = `record` {`x:long`, `y:bytes`} | +| `union_date_or_fixed4` | `[int (logicalType=`date`), Fx4]` where `Fx4` = `fixed`(size=4) | +| `union_time_millis_or_enum` | `[int (logicalType=`time-millis`), OnOff]` where `OnOff` = `enum` {`ON`,`OFF`} | +| `union_time_micros_or_string` | `[long (logicalType=`time-micros`), string]` | +| `union_ts_millis_utc_or_array` | `[long (logicalType=`timestamp-millis`), array]` | +| `union_ts_micros_local_or_bytes` | `[long (logicalType=`local-timestamp-micros`), bytes]` | +| `union_uuid_or_fixed10` | `[string (logicalType=`uuid`), Fx10]` where `Fx10` = `fixed`(size=10) | +| `union_dec_bytes_or_dec_fixed` | `[bytes (decimal p=10 s=2), DecFix20]` where `DecFix20` = `fixed`(size=20, decimal p=20 s=4) — decimal encoding is big‑endian two’s‑complement. | +| `union_null_bytes_string` | `["null","bytes","string"]` | +| `array_of_union` | `array<["long","string"]>` | +| `map_of_union` | `map<["null","double"]>` | +| `record_with_union_field` | `HasUnion` = `record` {`id:int`, `u:["int","string"]`} | +| `union_ts_micros_utc_or_map` | `[long (logicalType=`timestamp-micros`), map]` | +| `union_ts_millis_local_or_string` | `[long (logicalType=`local-timestamp-millis`), string]` | +| `union_bool_or_string` | `["boolean","string"]` | + +**Implementation notes (generation):** + +* The script uses **fastavro’s tuple notation** `(branch_name, value)` to select branches in ambiguous unions (e.g., bytes vs. string, multiple named records). See *“Using the tuple notation to specify which branch of a union to take”* in the fastavro docs. +* Decimal values are pre‑encoded to the required **big‑endian two’s‑complement** byte sequence before writing (for both `bytes` and `fixed` decimal logical types). +* The `duration` logical type payloads are 12‑byte triples: **months / days / milliseconds**, little‑endian each. + +**Source / Repro script:** +`create_avro_union_file.py` (Gist): contains the full writer schema, record builders covering four rows, and the `fastavro.writer` call which emits `union_fields.avro`. + +## Comprehensive E2E Coverage File + +**Purpose:** A single OCF that exercises **all decoder paths** used by `arrow-avro` with both **nested and non‑nested** shapes, including **dense unions** (null‑first, null‑second, multi‑branch), **aliases** (type and field), **default values**, **docs** and **namespaces**, and combinations thereof. It’s intended to validate the final `Reader` implementation and to stress schema‑resolution behavior in the tests under `arrow-avro/src/reader/mod.rs`. + +**File:** `comprehensive_e2e.avro` +**Top‑level record (writer schema):** `org.apache.arrow.avrotests.v1.E2eComprehensive` +**Record count:** four rows (each row selects different union branches and nested shapes) + +**Coverage summary (by Arrow / Avro mapping):** + +* Primitives: **boolean, int, long, float, double** +* Binary / Text: **bytes**, **string (UTF‑8)** +* Logical types: **date**, **time‑millis**, **time‑micros**, **timestamp‑millis (UTC)**, **timestamp‑micros (UTC)**, **local‑timestamp‑millis**, **local‑timestamp‑micros**, **uuid (string)**, **decimal** on **bytes** and **fixed**, **duration** on **fixed(12)** +* Named types: **fixed**, **enum**, **record** +* Collections: **array**, **map** +* Unions: **nullable unions**, **ambiguous scalar unions**, **unions of named types**, and **unions nested inside arrays/maps/records** +* Schema‑evolution hooks: **type aliases**, **field aliases**, **defaults** (including union defaults on the first branch), **docs**, and **namespaces** + +**Writer schema (overview of fields):** + +| Field | Type / details | +|-------------------------------|---------------------------------------------------------------------------------------------------------| +| `id` | `long` | +| `flag` | `boolean` (default `true`) | +| `ratio_f32` | `float` (default `0.0`) | +| `ratio_f64` | `double` (default `0.0`) | +| `count_i32` | `int` (default `0`) | +| `count_i64` | `long` (default `0`) | +| `opt_i32_nullfirst` | `["null","int"]` (default `null`) | +| `opt_str_nullsecond` | `["string","null"]` (default `""`, alias: `old_opt_str`) | +| `tri_union_prim` | `["int","string","boolean"]` (default `0`) | +| `str_utf8` | `string` (default `"default"`) | +| `raw_bytes` | `bytes` (default `""`) | +| `fx16_plain` | `fixed` `types.Fx16` (size 16, alias `Fixed16Old`) | +| `dec_bytes_s10_2` | `bytes` + `logicalType: decimal` (precision 10, scale 2) | +| `dec_fix_s20_4` | `fixed` `types.DecFix20` (size 20) + `logicalType: decimal` (precision 20, scale 4) | +| `uuid_str` | `string` + `logicalType: uuid` | +| `d_date` | `int` + `logicalType: date` | +| `t_millis` | `int` + `logicalType: time-millis` | +| `t_micros` | `long` + `logicalType: time-micros` | +| `ts_millis_utc` | `long` + `logicalType: timestamp-millis` | +| `ts_micros_utc` | `long` + `logicalType: timestamp-micros` | +| `ts_millis_local` | `long` + `logicalType: local-timestamp-millis` | +| `ts_micros_local` | `long` + `logicalType: local-timestamp-micros` | +| `interval_mdn` | `fixed` `types.Dur12` (size 12) + `logicalType: duration` | +| `status` | `enum` `types.Status` = {`UNKNOWN`,`NEW`,`PROCESSING`,`DONE`} (alias: `State`) | +| `arr_union` | `array<["long","string","null"]>` | +| `map_union` | `map<["null","double","string"]>` | +| `address` | `record` `types.Address` {`street` (alias: `street_name`), `zip:int`, `country:string`} | +| `maybe_auth` | `record` `types.MaybeAuth` {`user:string`, `token:["null","bytes"]` (default `null`)} | +| `union_enum_record_array_map` | `[types.Color enum, types.RecA record, types.RecB record, array, map]` | +| `union_date_or_fixed4` | `[int (logicalType=date), fixed Fx4 size 4]` | +| `union_interval_or_string` | `[fixed Dur12U size 12 (logicalType=duration), string]` | +| `union_uuid_or_fixed10` | `[string (logicalType=uuid), fixed Fx10 size 10]` | +| `array_records_with_union` | `array` | +| `union_map_or_array_int` | `[map, array]` | +| `renamed_with_default` | `int` (default `42`, alias: `old_count`) | +| `person` | `record` `com.example.v2.PersonV2` (alias: `com.example.Person`) `{ name:string, age:int (default 0) }` | + +**How this file was created** + +* Script: [`create_comprehensive_avro_file.py`](https://gist.github.com/jecsand838/26f9666da8de22651027d485bd83f4a3) + Uses **fastavro** to write `comprehensive_e2e.avro` with the schema above and four records that intentionally vary union branches and nested shapes. + +**Re‑generation** + +From the repository root: + +```bash +# 1) Ensure Python 3 is available, then install fastavro +python -m pip install --upgrade fastavro + +# 2) Run the generator (writes ./comprehensive_e2e.avro by default) +python create_comprehensive_avro_file.py + +# 3) Move or copy the file into this directory if needed +mv comprehensive_e2e.avro arrow-avro/test/data/ +``` + +**Notes / tips for tests** + +* For **unions of named types** (record/enum/fixed), the generator uses fastavro’s **tuple notation** to select the union branch and, where needed, supplies the **fully‑qualified name (FQN)** to avoid ambiguity when namespaces apply. +* The file contains many **defaults** and **aliases** (type and field) to exercise **schema resolution** code paths. +* As with all OCFs, a random **sync marker** is embedded in the file header; byte‑for‑byte output may vary across runs without affecting the schema or logical content. + +## Other Files + +This directory contains other small OCF files used by `arrow-avro` tests. Details on these will be added in +follow-up PRs. \ No newline at end of file diff --git a/arrow-avro/test/data/comprehensive_e2e.avro b/arrow-avro/test/data/comprehensive_e2e.avro new file mode 100644 index 000000000000..a3e55716c325 Binary files /dev/null and b/arrow-avro/test/data/comprehensive_e2e.avro differ diff --git a/arrow-avro/test/data/duration_logical_types.avro b/arrow-avro/test/data/duration_logical_types.avro new file mode 100644 index 000000000000..4d514fa9ba59 Binary files /dev/null and b/arrow-avro/test/data/duration_logical_types.avro differ diff --git a/arrow-avro/test/data/duration_uuid.avro b/arrow-avro/test/data/duration_uuid.avro new file mode 100644 index 000000000000..09dd67b7807a Binary files /dev/null and b/arrow-avro/test/data/duration_uuid.avro differ diff --git a/arrow-avro/test/data/enum_reuse.avro b/arrow-avro/test/data/enum_reuse.avro new file mode 100644 index 000000000000..7891870df3c9 Binary files /dev/null and b/arrow-avro/test/data/enum_reuse.avro differ diff --git a/arrow-avro/test/data/fixed256_decimal.avro b/arrow-avro/test/data/fixed256_decimal.avro new file mode 100644 index 000000000000..d1fc97dd8c83 Binary files /dev/null and b/arrow-avro/test/data/fixed256_decimal.avro differ diff --git a/arrow-avro/test/data/fixed_length_decimal_legacy_32.avro b/arrow-avro/test/data/fixed_length_decimal_legacy_32.avro new file mode 100644 index 000000000000..b746df9619b5 Binary files /dev/null and b/arrow-avro/test/data/fixed_length_decimal_legacy_32.avro differ diff --git a/arrow-avro/test/data/int128_decimal.avro b/arrow-avro/test/data/int128_decimal.avro new file mode 100644 index 000000000000..bd54d20ba487 Binary files /dev/null and b/arrow-avro/test/data/int128_decimal.avro differ diff --git a/arrow-avro/test/data/int256_decimal.avro b/arrow-avro/test/data/int256_decimal.avro new file mode 100644 index 000000000000..62ad7ea4df08 Binary files /dev/null and b/arrow-avro/test/data/int256_decimal.avro differ diff --git a/arrow-avro/test/data/named_types_complex.avro b/arrow-avro/test/data/named_types_complex.avro new file mode 100644 index 000000000000..eae439317e5b Binary files /dev/null and b/arrow-avro/test/data/named_types_complex.avro differ diff --git a/arrow-avro/test/data/nested_record_reuse.avro b/arrow-avro/test/data/nested_record_reuse.avro new file mode 100644 index 000000000000..5e2a9e0328bc Binary files /dev/null and b/arrow-avro/test/data/nested_record_reuse.avro differ diff --git a/arrow-avro/test/data/skippable_types.avro b/arrow-avro/test/data/skippable_types.avro new file mode 100644 index 000000000000..b0518e0056b5 Binary files /dev/null and b/arrow-avro/test/data/skippable_types.avro differ diff --git a/arrow-avro/test/data/union_fields.avro b/arrow-avro/test/data/union_fields.avro new file mode 100644 index 000000000000..e0ffb82bd412 Binary files /dev/null and b/arrow-avro/test/data/union_fields.avro differ diff --git a/arrow-avro/test/data/zero_byte.avro b/arrow-avro/test/data/zero_byte.avro new file mode 100644 index 000000000000..f7ffd29b6890 Binary files /dev/null and b/arrow-avro/test/data/zero_byte.avro differ diff --git a/arrow-buffer/Cargo.toml b/arrow-buffer/Cargo.toml index d4fa0614e01a..02ea49c37c46 100644 --- a/arrow-buffer/Cargo.toml +++ b/arrow-buffer/Cargo.toml @@ -35,13 +35,17 @@ bench = false [package.metadata.docs.rs] all-features = true +[features] +pool = [] + [dependencies] bytes = { version = "1.4" } -num = { version = "0.4", default-features = false, features = ["std"] } +num-bigint = { version = "0.4.6", default-features = false, features = ["std"] } +num-traits = { version = "0.2.19", default-features = false, features = ["std"] } half = { version = "2.1", default-features = false } [dev-dependencies] -criterion = { version = "0.5", default-features = false } +criterion = { workspace = true, default-features = false } rand = { version = "0.9", default-features = false, features = ["std", "std_rng", "thread_rng"] } [[bench]] @@ -55,3 +59,8 @@ harness = false [[bench]] name = "offset" harness = false + +[[bench]] +name = "mutable_buffer_repeat_slice" +harness = false + diff --git a/arrow-buffer/benches/bit_mask.rs b/arrow-buffer/benches/bit_mask.rs index 545528724e5d..0384089e32c5 100644 --- a/arrow-buffer/benches/bit_mask.rs +++ b/arrow-buffer/benches/bit_mask.rs @@ -16,7 +16,7 @@ // under the License. use arrow_buffer::bit_mask::set_bits; -use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use criterion::{BenchmarkId, Criterion, criterion_group, criterion_main}; use std::hint; fn criterion_benchmark(c: &mut Criterion) { diff --git a/arrow-buffer/benches/i256.rs b/arrow-buffer/benches/i256.rs index 7dec226bbc08..2bbb5c0284c2 100644 --- a/arrow-buffer/benches/i256.rs +++ b/arrow-buffer/benches/i256.rs @@ -17,6 +17,7 @@ use arrow_buffer::i256; use criterion::*; +use num_traits::cast::ToPrimitive; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use std::{hint, str::FromStr}; @@ -36,13 +37,19 @@ fn criterion_benchmark(c: &mut Criterion) { i256::MAX, ]; - for number in numbers { + for number in numbers.iter() { let t = hint::black_box(number.to_string()); c.bench_function(&format!("i256_parse({t})"), |b| { b.iter(|| i256::from_str(&t).unwrap()); }); } + for number in numbers.iter() { + c.bench_function(&format!("i256_to_f64({number})"), |b| { + b.iter(|| (*number).to_f64().unwrap()) + }); + } + let mut rng = StdRng::seed_from_u64(42); let numerators: Vec<_> = (0..SIZE) diff --git a/arrow-buffer/benches/mutable_buffer_repeat_slice.rs b/arrow-buffer/benches/mutable_buffer_repeat_slice.rs new file mode 100644 index 000000000000..a59c24baef56 --- /dev/null +++ b/arrow-buffer/benches/mutable_buffer_repeat_slice.rs @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_buffer::Buffer; +use criterion::*; +use rand::distr::Alphanumeric; +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::hint; + +fn criterion_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("MutableBuffer repeat slice"); + let mut rng = StdRng::seed_from_u64(42); + + for slice_length in [3, 20, 100] { + let slice_to_repeat: Vec = hint::black_box( + (&mut rng) + .sample_iter(&Alphanumeric) + .take(slice_length) + .collect(), + ); + let slice_to_repeat: &[u8] = slice_to_repeat.as_ref(); + + for repeat_count in [3, 64, 1024, 8192] { + let parameter_string = format!("slice_len={slice_length} n={repeat_count}"); + + group.bench_with_input( + BenchmarkId::new("repeat_slice_n_times", ¶meter_string), + &(repeat_count), + |b, &repeat_count| { + b.iter(|| { + let mut mutable_buffer = arrow_buffer::MutableBuffer::with_capacity(0); + + mutable_buffer.repeat_slice_n_times(slice_to_repeat, repeat_count); + + Buffer::from(mutable_buffer) + }) + }, + ); + group.bench_with_input( + BenchmarkId::new("extend_from_slice loop", ¶meter_string), + &(repeat_count), + |b, &repeat_count| { + b.iter(|| { + let mut mutable_buffer = arrow_buffer::MutableBuffer::with_capacity( + size_of_val(slice_to_repeat) * repeat_count, + ); + + for _ in 0..repeat_count { + mutable_buffer.extend_from_slice(slice_to_repeat); + } + + Buffer::from(mutable_buffer) + }) + }, + ); + } + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/arrow-buffer/src/bigint/mod.rs b/arrow-buffer/src/bigint/mod.rs index 9868ab55cc11..15faed43a130 100644 --- a/arrow-buffer/src/bigint/mod.rs +++ b/arrow-buffer/src/bigint/mod.rs @@ -17,8 +17,12 @@ use crate::arith::derive_arith; use crate::bigint::div::div_rem; -use num::cast::AsPrimitive; -use num::{BigInt, FromPrimitive, ToPrimitive}; +use num_bigint::BigInt; +use num_traits::{ + Bounded, CheckedAdd, CheckedDiv, CheckedMul, CheckedNeg, CheckedRem, CheckedSub, FromPrimitive, + Num, One, Signed, ToPrimitive, WrappingAdd, WrappingMul, WrappingNeg, WrappingSub, Zero, + cast::AsPrimitive, +}; use std::cmp::Ordering; use std::num::ParseIntError; use std::ops::{BitAnd, BitOr, BitXor, Neg, Shl, Shr}; @@ -232,11 +236,7 @@ impl i256 { pub fn from_f64(v: f64) -> Option { BigInt::from_f64(v).and_then(|i| { let (integer, overflow) = i256::from_bigint_with_overflow(i); - if overflow { - None - } else { - Some(integer) - } + if overflow { None } else { Some(integer) } }) } @@ -304,7 +304,7 @@ impl i256 { let v_bytes = v.to_signed_bytes_le(); match v_bytes.len().cmp(&32) { Ordering::Less => { - let mut bytes = if num::Signed::is_negative(&v) { + let mut bytes = if num_traits::Signed::is_negative(&v) { [255_u8; 32] } else { [0; 32] @@ -586,6 +586,34 @@ impl i256 { pub const fn is_positive(self) -> bool { self.high.is_positive() || self.high == 0 && self.low != 0 } + + /// Returns the number of leading zeros in the binary representation of this [`i256`]. + pub const fn leading_zeros(&self) -> u32 { + match self.high { + 0 => u128::BITS + self.low.leading_zeros(), + _ => self.high.leading_zeros(), + } + } + + /// Returns the number of trailing zeros in the binary representation of this [`i256`]. + pub const fn trailing_zeros(&self) -> u32 { + match self.low { + 0 => u128::BITS + self.high.trailing_zeros(), + _ => self.low.trailing_zeros(), + } + } + + fn redundant_leading_sign_bits_i256(n: i256) -> u8 { + let mask = n >> 255; // all ones or all zeros + ((n ^ mask).leading_zeros() - 1) as u8 // we only need one sign bit + } + + fn i256_to_f64(input: i256) -> f64 { + let k = i256::redundant_leading_sign_bits_i256(input); + let n = input << k; // left-justify (no redundant sign bits) + let n = (n.high >> 64) as i64; // throw away the lower 192 bits + (n as f64) * f64::powi(2.0, 192 - (k as i32)) // convert to f64 and scale it, as we left-shift k bit previous, so we need to scale it by 2^(192-k) + } } /// Temporary workaround due to lack of stable const array slicing @@ -821,6 +849,15 @@ impl ToPrimitive for i256 { } } + fn to_f64(&self) -> Option { + match *self { + Self::MIN => Some(-2_f64.powi(255)), + Self::ZERO => Some(0f64), + Self::ONE => Some(1f64), + n => Some(Self::i256_to_f64(n)), + } + } + fn to_u64(&self) -> Option { let as_i128 = self.low as i128; @@ -836,11 +873,142 @@ impl ToPrimitive for i256 { } } +// num_traits checked implementations + +impl CheckedNeg for i256 { + fn checked_neg(&self) -> Option { + (*self).checked_neg() + } +} + +impl CheckedAdd for i256 { + fn checked_add(&self, v: &i256) -> Option { + (*self).checked_add(*v) + } +} + +impl CheckedSub for i256 { + fn checked_sub(&self, v: &i256) -> Option { + (*self).checked_sub(*v) + } +} + +impl CheckedDiv for i256 { + fn checked_div(&self, v: &i256) -> Option { + (*self).checked_div(*v) + } +} + +impl CheckedMul for i256 { + fn checked_mul(&self, v: &i256) -> Option { + (*self).checked_mul(*v) + } +} + +impl CheckedRem for i256 { + fn checked_rem(&self, v: &i256) -> Option { + (*self).checked_rem(*v) + } +} + +impl WrappingAdd for i256 { + fn wrapping_add(&self, v: &Self) -> Self { + (*self).wrapping_add(*v) + } +} + +impl WrappingSub for i256 { + fn wrapping_sub(&self, v: &Self) -> Self { + (*self).wrapping_sub(*v) + } +} + +impl WrappingMul for i256 { + fn wrapping_mul(&self, v: &Self) -> Self { + (*self).wrapping_mul(*v) + } +} + +impl WrappingNeg for i256 { + fn wrapping_neg(&self) -> Self { + (*self).wrapping_neg() + } +} + +impl Zero for i256 { + fn zero() -> Self { + i256::ZERO + } + + fn is_zero(&self) -> bool { + *self == i256::ZERO + } +} + +impl One for i256 { + fn one() -> Self { + i256::ONE + } + + fn is_one(&self) -> bool { + *self == i256::ONE + } +} + +impl Num for i256 { + type FromStrRadixErr = ParseI256Error; + + fn from_str_radix(str: &str, radix: u32) -> Result { + if radix == 10 { + str.parse() + } else { + // Parsing from non-10 baseseeÎ is not supported + Err(ParseI256Error {}) + } + } +} + +impl Signed for i256 { + fn abs(&self) -> Self { + self.wrapping_abs() + } + + fn abs_sub(&self, other: &Self) -> Self { + if self > other { + self.wrapping_sub(other) + } else { + i256::ZERO + } + } + + fn signum(&self) -> Self { + (*self).signum() + } + + fn is_positive(&self) -> bool { + (*self).is_positive() + } + + fn is_negative(&self) -> bool { + (*self).is_negative() + } +} + +impl Bounded for i256 { + fn min_value() -> Self { + i256::MIN + } + + fn max_value() -> Self { + i256::MAX + } +} + #[cfg(all(test, not(miri)))] // llvm.x86.subborrow.64 not supported by MIRI mod tests { use super::*; - use num::Signed; - use rand::{rng, Rng}; + use num_traits::Signed; + use rand::{Rng, rng}; #[test] fn test_signed_cmp() { @@ -1264,4 +1432,152 @@ mod tests { } } } + + #[test] + fn test_decimal256_to_f64_typical_values() { + let v = i256::from_i128(42_i128); + assert_eq!(v.to_f64().unwrap(), 42.0); + + let v = i256::from_i128(-123456789012345678i128); + assert_eq!(v.to_f64().unwrap(), -123456789012345678.0); + + let v = i256::from_string("0").unwrap(); + assert_eq!(v.to_f64().unwrap(), 0.0); + + let v = i256::from_string("1").unwrap(); + assert_eq!(v.to_f64().unwrap(), 1.0); + + let mut rng = rng(); + for _ in 0..10 { + let f64_value = + (rng.random_range(i128::MIN..i128::MAX) as f64) * rng.random_range(0.0..1.0); + let big = i256::from_f64(f64_value).unwrap(); + assert_eq!(big.to_f64().unwrap(), f64_value); + } + } + + #[test] + fn test_decimal256_to_f64_large_positive_value() { + let max_f = f64::MAX; + let big = i256::from_f64(max_f * 2.0).unwrap_or(i256::MAX); + let out = big.to_f64().unwrap(); + assert!(out.is_finite() && out.is_sign_positive()); + } + + #[test] + fn test_decimal256_to_f64_large_negative_value() { + let max_f = f64::MAX; + let big_neg = i256::from_f64(-(max_f * 2.0)).unwrap_or(i256::MIN); + let out = big_neg.to_f64().unwrap(); + assert!(out.is_finite() && out.is_sign_negative()); + } + + #[test] + fn test_num_traits() { + let value = i256::from_i128(-5); + assert_eq!( + ::checked_neg(&value), + Some(i256::from(5)) + ); + + assert_eq!( + ::checked_add(&value, &value), + Some(i256::from(-10)) + ); + + assert_eq!( + ::checked_sub(&value, &value), + Some(i256::from(0)) + ); + + assert_eq!( + ::checked_mul(&value, &value), + Some(i256::from(25)) + ); + + assert_eq!( + ::checked_div(&value, &value), + Some(i256::from(1)) + ); + + assert_eq!( + ::checked_rem(&value, &value), + Some(i256::from(0)) + ); + + assert_eq!( + ::wrapping_add(&value, &value), + i256::from(-10) + ); + + assert_eq!( + ::wrapping_sub(&value, &value), + i256::from(0) + ); + + assert_eq!( + ::wrapping_mul(&value, &value), + i256::from(25) + ); + + assert_eq!(::wrapping_neg(&value), i256::from(5)); + + // A single check for wrapping behavior, rely on trait implementation for others + let result = ::wrapping_add(&i256::MAX, &i256::ONE); + assert_eq!(result, i256::MIN); + + assert_eq!(::abs(&value), i256::from(5)); + + assert_eq!(::one(), i256::from(1)); + assert_eq!(::zero(), i256::from(0)); + + assert_eq!(::min_value(), i256::MIN); + assert_eq!(::max_value(), i256::MAX); + } + + #[test] + fn test_numtraits_from_str_radix() { + assert_eq!( + i256::from_str_radix("123456789", 10).expect("parsed"), + i256::from(123456789) + ); + assert_eq!( + i256::from_str_radix("0", 10).expect("parsed"), + i256::from(0) + ); + assert!(i256::from_str_radix("abc", 10).is_err()); + assert!(i256::from_str_radix("0", 16).is_err()); + } + + #[test] + fn test_leading_zeros() { + // Without high part + assert_eq!(i256::from(0).leading_zeros(), 256); + assert_eq!(i256::from(1).leading_zeros(), 256 - 1); + assert_eq!(i256::from(16).leading_zeros(), 256 - 5); + assert_eq!(i256::from(17).leading_zeros(), 256 - 5); + + // With high part + assert_eq!(i256::from_parts(2, 16).leading_zeros(), 128 - 5); + assert_eq!(i256::from_parts(2, i128::MAX).leading_zeros(), 1); + + assert_eq!(i256::MAX.leading_zeros(), 1); + assert_eq!(i256::from(-1).leading_zeros(), 0); + } + + #[test] + fn test_trailing_zeros() { + // Without high part + assert_eq!(i256::from(0).trailing_zeros(), 256); + assert_eq!(i256::from(2).trailing_zeros(), 1); + assert_eq!(i256::from(16).trailing_zeros(), 4); + assert_eq!(i256::from(17).trailing_zeros(), 0); + // With high part + assert_eq!(i256::from_parts(0, i128::MAX).trailing_zeros(), 128); + assert_eq!(i256::from_parts(0, 16).trailing_zeros(), 128 + 4); + assert_eq!(i256::from_parts(2, i128::MAX).trailing_zeros(), 1); + + assert_eq!(i256::MAX.trailing_zeros(), 0); + assert_eq!(i256::from(-1).trailing_zeros(), 0); + } } diff --git a/arrow-buffer/src/buffer/boolean.rs b/arrow-buffer/src/buffer/boolean.rs index c8e5144c14cb..ff836bf28729 100644 --- a/arrow-buffer/src/buffer/boolean.rs +++ b/arrow-buffer/src/buffer/boolean.rs @@ -16,33 +16,74 @@ // under the License. use crate::bit_chunk_iterator::BitChunks; -use crate::bit_iterator::{BitIndexIterator, BitIterator, BitSliceIterator}; +use crate::bit_iterator::{BitIndexIterator, BitIndexU32Iterator, BitIterator, BitSliceIterator}; use crate::{ - bit_util, buffer_bin_and, buffer_bin_or, buffer_bin_xor, buffer_unary_not, - BooleanBufferBuilder, Buffer, MutableBuffer, + BooleanBufferBuilder, Buffer, MutableBuffer, bit_util, buffer_bin_and, buffer_bin_or, + buffer_bin_xor, buffer_unary_not, }; use std::ops::{BitAnd, BitOr, BitXor, Not}; /// A slice-able [`Buffer`] containing bit-packed booleans /// -/// `BooleanBuffer`s can be creating using [`BooleanBufferBuilder`] +/// This structure represents a sequence of boolean values packed into a +/// byte-aligned [`Buffer`]. Both the offset and length are represented in bits. /// -/// # See Also +/// # Layout +/// +/// The values are represented as little endian bit-packed values, where the +/// least significant bit of each byte represents the first boolean value and +/// then proceeding to the most significant bit. +/// +/// For example, the 10 bit bitmask `0b0111001101` has length 10, and is +/// represented using 2 bytes with offset 0 like this: +/// +/// ```text +/// ┌─────────────────────────────────┐ ┌─────────────────────────────────┐ +/// │┌───┬───┬───┬───┬───┬───┬───┬───┐│ │┌───┬───┬───┬───┬───┬───┬───┬───┐│ +/// ││ 1 │ 0 │ 1 │ 1 │ 0 │ 0 │ 1 │ 1 ││ ││ 1 │ 0 │ ? │ ? │ ? │ ? │ ? │ ? ││ +/// │└───┴───┴───┴───┴───┴───┴───┴───┘│ │└───┴───┴───┴───┴───┴───┴───┴───┘│ +/// bit └─────────────────────────────────┘ └─────────────────────────────────┘ +/// offset 0 Byte 0 7 0 Byte 1 7 +/// +/// length = 10 bits, offset = 0 +/// ``` +/// +/// The same bitmask with length 10 and offset 3 would be represented using 2 +/// bytes like this: +/// +/// ```text +/// ┌─────────────────────────────────┐ ┌─────────────────────────────────┐ +/// │┌───┬───┬───┬───┬───┬───┬───┬───┐│ │┌───┬───┬───┬───┬───┬───┬───┬───┐│ +/// ││ ? │ ? │ ? │ 1 │ 0 │ 1 │ 1 │ 0 ││ ││ 0 │ 1 │ 1 │ 1 │ 0 │ ? │ ? │ ? ││ +/// │└───┴───┴───┴───┴───┴───┴───┴───┘│ │└───┴───┴───┴───┴───┴───┴───┴───┘│ +/// bit └─────────────────────────────────┘ └─────────────────────────────────┘ +/// offset 0 Byte 0 7 0 Byte 1 7 +/// +/// length = 10 bits, offset = 3 +/// ``` /// +/// Note that the bits marked `?` are not logically part of the mask and may +/// contain either `0` or `1` +/// +/// # See Also +/// * [`BooleanBufferBuilder`] for building [`BooleanBuffer`] instances /// * [`NullBuffer`] for representing null values in Arrow arrays /// /// [`NullBuffer`]: crate::NullBuffer #[derive(Debug, Clone, Eq)] pub struct BooleanBuffer { + /// Underlying buffer (byte aligned) buffer: Buffer, - offset: usize, - len: usize, + /// Offset in bits (not bytes) + bit_offset: usize, + /// Length in bits (not bytes) + bit_len: usize, } impl PartialEq for BooleanBuffer { fn eq(&self, other: &Self) -> bool { - if self.len != other.len { + if self.bit_len != other.bit_len { return false; } @@ -53,40 +94,40 @@ impl PartialEq for BooleanBuffer { } impl BooleanBuffer { - /// Create a new [`BooleanBuffer`] from a [`Buffer`], an `offset` and `length` in bits + /// Create a new [`BooleanBuffer`] from a [`Buffer`], `bit_offset` offset and `bit_len` length /// /// # Panics /// /// This method will panic if `buffer` is not large enough - pub fn new(buffer: Buffer, offset: usize, len: usize) -> Self { - let total_len = offset.saturating_add(len); + pub fn new(buffer: Buffer, bit_offset: usize, bit_len: usize) -> Self { + let total_len = bit_offset.saturating_add(bit_len); let buffer_len = buffer.len(); - let bit_len = buffer_len.saturating_mul(8); + let buffer_bit_len = buffer_len.saturating_mul(8); assert!( - total_len <= bit_len, - "buffer not large enough (offset: {offset}, len: {len}, buffer_len: {buffer_len})" + total_len <= buffer_bit_len, + "buffer not large enough (bit_offset: {bit_offset}, bit_len: {bit_len}, buffer_len: {buffer_len})" ); Self { buffer, - offset, - len, + bit_offset, + bit_len, } } - /// Create a new [`BooleanBuffer`] of `length` where all values are `true` + /// Create a new [`BooleanBuffer`] of `length` bits (not bytes) where all values are `true` pub fn new_set(length: usize) -> Self { let mut builder = BooleanBufferBuilder::new(length); builder.append_n(length, true); builder.finish() } - /// Create a new [`BooleanBuffer`] of `length` where all values are `false` + /// Create a new [`BooleanBuffer`] of `length` bits (not bytes) where all values are `false` pub fn new_unset(length: usize) -> Self { let buffer = MutableBuffer::new_null(length).into_buffer(); Self { buffer, - offset: 0, - len: length, + bit_offset: 0, + bit_len: length, } } @@ -96,34 +137,258 @@ impl BooleanBuffer { Self::new(buffer.into(), 0, len) } + /// Create a new [`BooleanBuffer`] by copying the relevant bits from an + /// input buffer. + /// + /// # Notes: + /// * The new `BooleanBuffer` has zero offset, even if `offset_in_bits` is non-zero + /// + /// # Example: Create a new [`BooleanBuffer`] copying a bit slice from in input slice + /// ``` + /// # use arrow_buffer::BooleanBuffer; + /// let input = [0b11001100u8, 0b10111010u8]; + /// // // Copy bits 4..16 from input + /// let result = BooleanBuffer::from_bits(&input, 4, 12); + /// assert_eq!(result.values(), &[0b10101100u8, 0b00001011u8]); + pub fn from_bits(src: impl AsRef<[u8]>, offset_in_bits: usize, len_in_bits: usize) -> Self { + Self::from_bitwise_unary_op(src, offset_in_bits, len_in_bits, |a| a) + } + + /// Create a new [`BooleanBuffer`] by applying the bitwise operation to `op` + /// to an input buffer. + /// + /// This function is faster than applying the operation bit by bit as + /// it processes input buffers in chunks of 64 bits (8 bytes) at a time + /// + /// # Notes: + /// * `op` takes a single `u64` inputs and produces one `u64` output. + /// * `op` must only apply bitwise operations + /// on the relevant bits; the input `u64` may contain irrelevant bits + /// and may be processed differently on different endian architectures. + /// * `op` may be called with input bits outside the requested range + /// * The output always has zero offset + /// + /// # See Also + /// - [`BooleanBuffer::from_bitwise_binary_op`] to create a new buffer from a binary operation + /// - [`apply_bitwise_unary_op`](bit_util::apply_bitwise_unary_op) for in-place unary bitwise operations + /// + /// # Example: Create new [`BooleanBuffer`] from bitwise `NOT` of a byte slice + /// ``` + /// # use arrow_buffer::BooleanBuffer; + /// let input = [0b11001100u8, 0b10111010u8]; // 2 bytes = 16 bits + /// // NOT of the first 12 bits + /// let result = BooleanBuffer::from_bitwise_unary_op( + /// &input, 0, 12, |a| !a + /// ); + /// assert_eq!(result.values(), &[0b00110011u8, 0b11110101u8]); + /// ``` + pub fn from_bitwise_unary_op( + src: impl AsRef<[u8]>, + offset_in_bits: usize, + len_in_bits: usize, + mut op: F, + ) -> Self + where + F: FnMut(u64) -> u64, + { + // try fast path for aligned input + if offset_in_bits & 0x7 == 0 { + // align to byte boundary + let aligned = &src.as_ref()[offset_in_bits / 8..]; + if let Some(result) = + Self::try_from_aligned_bitwise_unary_op(aligned, len_in_bits, &mut op) + { + return result; + } + } + + let chunks = BitChunks::new(src.as_ref(), offset_in_bits, len_in_bits); + let mut result = MutableBuffer::with_capacity(chunks.num_u64s() * 8); + for chunk in chunks.iter() { + // SAFETY: reserved enough capacity above, (exactly num_u64s() + // items) and we assume `BitChunks` correctly reports upper bound + unsafe { + result.push_unchecked(op(chunk)); + } + } + if chunks.remainder_len() > 0 { + debug_assert!(result.capacity() >= result.len() + 8); // should not reallocate + // SAFETY: reserved enough capacity above, (exactly num_u64s() + // items) and we assume `BitChunks` correctly reports upper bound + unsafe { + result.push_unchecked(op(chunks.remainder_bits())); + } + // Just pushed one u64, which may have trailing zeros + result.truncate(chunks.num_bytes()); + } + + BooleanBuffer { + buffer: Buffer::from(result), + bit_offset: 0, + bit_len: len_in_bits, + } + } + + /// Fast path for [`Self::from_bitwise_unary_op`] when input is aligned to + /// 8-byte (64-bit) boundaries + /// + /// Returns None if the fast path cannot be taken + fn try_from_aligned_bitwise_unary_op( + src: &[u8], + len_in_bits: usize, + op: &mut F, + ) -> Option + where + F: FnMut(u64) -> u64, + { + // Safety: all valid bytes are valid u64s + let (prefix, aligned_u6us, suffix) = unsafe { src.align_to::() }; + if !(prefix.is_empty() && suffix.is_empty()) { + // Couldn't make this case any faster than the default path, see + // https://github.com/apache/arrow-rs/pull/8996/changes#r2620022082 + return None; + } + // the buffer is word (64 bit) aligned, so use optimized Vec code. + let result_u64s: Vec = aligned_u6us.iter().map(|l| op(*l)).collect(); + let buffer = Buffer::from(result_u64s); + Some(BooleanBuffer::new(buffer, 0, len_in_bits)) + } + + /// Create a new [`BooleanBuffer`] by applying the bitwise operation `op` to + /// the relevant bits from two input buffers. + /// + /// This function is faster than applying the operation bit by bit as + /// it processes input buffers in chunks of 64 bits (8 bytes) at a time + /// + /// # Notes: + /// See notes on [Self::from_bitwise_unary_op] + /// + /// # See Also + /// - [`BooleanBuffer::from_bitwise_unary_op`] for unary operations on a single input buffer. + /// - [`apply_bitwise_binary_op`](bit_util::apply_bitwise_binary_op) for in-place binary bitwise operations + /// + /// # Example: Create new [`BooleanBuffer`] from bitwise `AND` of two [`Buffer`]s + /// ``` + /// # use arrow_buffer::{Buffer, BooleanBuffer}; + /// let left = Buffer::from(vec![0b11001100u8, 0b10111010u8]); // 2 bytes = 16 bits + /// let right = Buffer::from(vec![0b10101010u8, 0b11011100u8, 0b11110000u8]); // 3 bytes = 24 bits + /// // AND of the first 12 bits + /// let result = BooleanBuffer::from_bitwise_binary_op( + /// &left, 0, &right, 0, 12, |a, b| a & b + /// ); + /// assert_eq!(result.inner().as_slice(), &[0b10001000u8, 0b00001000u8]); + /// ``` + /// + /// # Example: Create new [`BooleanBuffer`] from bitwise `OR` of two byte slices + /// ``` + /// # use arrow_buffer::BooleanBuffer; + /// let left = [0b11001100u8, 0b10111010u8]; + /// let right = [0b10101010u8, 0b11011100u8]; + /// // OR of bits 4..16 from left and bits 0..12 from right + /// let result = BooleanBuffer::from_bitwise_binary_op( + /// &left, 4, &right, 0, 12, |a, b| a | b + /// ); + /// assert_eq!(result.inner().as_slice(), &[0b10101110u8, 0b00001111u8]); + /// ``` + pub fn from_bitwise_binary_op( + left: impl AsRef<[u8]>, + left_offset_in_bits: usize, + right: impl AsRef<[u8]>, + right_offset_in_bits: usize, + len_in_bits: usize, + mut op: F, + ) -> Self + where + F: FnMut(u64, u64) -> u64, + { + let left = left.as_ref(); + let right = right.as_ref(); + // try fast path for aligned input + // If the underlying buffers are aligned to u64 we can apply the operation directly on the u64 slices + // to improve performance. + if left_offset_in_bits & 0x7 == 0 && right_offset_in_bits & 0x7 == 0 { + // align to byte boundary + let left = &left[left_offset_in_bits / 8..]; + let right = &right[right_offset_in_bits / 8..]; + + unsafe { + let (left_prefix, left_u64s, left_suffix) = left.align_to::(); + let (right_prefix, right_u64s, right_suffix) = right.align_to::(); + // if there is no prefix or suffix, both buffers are aligned and + // we can do the operation directly on u64s. + // TODO: consider `slice::as_chunks` and `u64::from_le_bytes` when MSRV reaches 1.88. + // https://github.com/apache/arrow-rs/pull/9022#discussion_r2639949361 + if left_prefix.is_empty() + && right_prefix.is_empty() + && left_suffix.is_empty() + && right_suffix.is_empty() + { + let result_u64s = left_u64s + .iter() + .zip(right_u64s.iter()) + .map(|(l, r)| op(*l, *r)) + .collect::>(); + return BooleanBuffer { + buffer: Buffer::from(result_u64s), + bit_offset: 0, + bit_len: len_in_bits, + }; + } + } + } + let left_chunks = BitChunks::new(left, left_offset_in_bits, len_in_bits); + let right_chunks = BitChunks::new(right, right_offset_in_bits, len_in_bits); + + let chunks = left_chunks + .iter() + .zip(right_chunks.iter()) + .map(|(left, right)| op(left, right)); + // Soundness: `BitChunks` is a `BitChunks` trusted length iterator which + // correctly reports its upper bound + let mut buffer = unsafe { MutableBuffer::from_trusted_len_iter(chunks) }; + + let remainder_bytes = bit_util::ceil(left_chunks.remainder_len(), 8); + let rem = op(left_chunks.remainder_bits(), right_chunks.remainder_bits()); + // we are counting its starting from the least significant bit, to to_le_bytes should be correct + let rem = &rem.to_le_bytes()[0..remainder_bytes]; + buffer.extend_from_slice(rem); + + BooleanBuffer { + buffer: Buffer::from(buffer), + bit_offset: 0, + bit_len: len_in_bits, + } + } + /// Returns the number of set bits in this buffer pub fn count_set_bits(&self) -> usize { - self.buffer.count_set_bits_offset(self.offset, self.len) + self.buffer + .count_set_bits_offset(self.bit_offset, self.bit_len) } - /// Returns a `BitChunks` instance which can be used to iterate over + /// Returns a [`BitChunks`] instance which can be used to iterate over /// this buffer's bits in `u64` chunks #[inline] - pub fn bit_chunks(&self) -> BitChunks { - BitChunks::new(self.values(), self.offset, self.len) + pub fn bit_chunks(&self) -> BitChunks<'_> { + BitChunks::new(self.values(), self.bit_offset, self.bit_len) } - /// Returns the offset of this [`BooleanBuffer`] in bits + /// Returns the offset of this [`BooleanBuffer`] in bits (not bytes) #[inline] pub fn offset(&self) -> usize { - self.offset + self.bit_offset } - /// Returns the length of this [`BooleanBuffer`] in bits + /// Returns the length of this [`BooleanBuffer`] in bits (not bytes) #[inline] pub fn len(&self) -> usize { - self.len + self.bit_len } /// Returns true if this [`BooleanBuffer`] is empty #[inline] pub fn is_empty(&self) -> bool { - self.len == 0 + self.bit_len == 0 } /// Free up unused memory. @@ -139,7 +404,7 @@ impl BooleanBuffer { /// Panics if `i >= self.len()` #[inline] pub fn value(&self, idx: usize) -> bool { - assert!(idx < self.len); + assert!(idx < self.bit_len); unsafe { self.value_unchecked(idx) } } @@ -149,7 +414,7 @@ impl BooleanBuffer { /// This doesn't check bounds, the caller must ensure that index < self.len() #[inline] pub unsafe fn value_unchecked(&self, i: usize) -> bool { - unsafe { bit_util::get_bit_raw(self.buffer.as_ptr(), i + self.offset) } + unsafe { bit_util::get_bit_raw(self.buffer.as_ptr(), i + self.bit_offset) } } /// Returns the packed values of this [`BooleanBuffer`] not including any offset @@ -161,13 +426,13 @@ impl BooleanBuffer { /// Slices this [`BooleanBuffer`] by the provided `offset` and `length` pub fn slice(&self, offset: usize, len: usize) -> Self { assert!( - offset.saturating_add(len) <= self.len, + offset.saturating_add(len) <= self.bit_len, "the length + offset of the sliced BooleanBuffer cannot exceed the existing length" ); Self { buffer: self.buffer.clone(), - offset: self.offset + offset, - len, + bit_offset: self.bit_offset + offset, + bit_len: len, } } @@ -175,7 +440,7 @@ impl BooleanBuffer { /// /// Equivalent to `self.buffer.bit_slice(self.offset, self.len)` pub fn sliced(&self) -> Buffer { - self.buffer.bit_slice(self.offset, self.len) + self.buffer.bit_slice(self.bit_offset, self.bit_len) } /// Returns true if this [`BooleanBuffer`] is equal to `other`, using pointer comparisons @@ -183,17 +448,21 @@ impl BooleanBuffer { /// return false when the arrays are logically equal pub fn ptr_eq(&self, other: &Self) -> bool { self.buffer.as_ptr() == other.buffer.as_ptr() - && self.offset == other.offset - && self.len == other.len + && self.bit_offset == other.bit_offset + && self.bit_len == other.bit_len } /// Returns the inner [`Buffer`] + /// + /// Note: this does not account for offset and length of this [`BooleanBuffer`] #[inline] pub fn inner(&self) -> &Buffer { &self.buffer } /// Returns the inner [`Buffer`], consuming self + /// + /// Note: this does not account for offset and length of this [`BooleanBuffer`] pub fn into_inner(self) -> Buffer { self.buffer } @@ -205,12 +474,17 @@ impl BooleanBuffer { /// Returns an iterator over the set bit positions in this [`BooleanBuffer`] pub fn set_indices(&self) -> BitIndexIterator<'_> { - BitIndexIterator::new(self.values(), self.offset, self.len) + BitIndexIterator::new(self.values(), self.bit_offset, self.bit_len) + } + + /// Returns a `u32` iterator over set bit positions without any usize->u32 conversion + pub fn set_indices_u32(&self) -> BitIndexU32Iterator<'_> { + BitIndexU32Iterator::new(self.values(), self.bit_offset, self.bit_len) } /// Returns a [`BitSliceIterator`] yielding contiguous ranges of set bits pub fn set_slices(&self) -> BitSliceIterator<'_> { - BitSliceIterator::new(self.values(), self.offset, self.len) + BitSliceIterator::new(self.values(), self.bit_offset, self.bit_len) } } @@ -219,9 +493,9 @@ impl Not for &BooleanBuffer { fn not(self) -> Self::Output { BooleanBuffer { - buffer: buffer_unary_not(&self.buffer, self.offset, self.len), - offset: 0, - len: self.len, + buffer: buffer_unary_not(&self.buffer, self.bit_offset, self.bit_len), + bit_offset: 0, + bit_len: self.bit_len, } } } @@ -230,11 +504,17 @@ impl BitAnd<&BooleanBuffer> for &BooleanBuffer { type Output = BooleanBuffer; fn bitand(self, rhs: &BooleanBuffer) -> Self::Output { - assert_eq!(self.len, rhs.len); + assert_eq!(self.bit_len, rhs.bit_len); BooleanBuffer { - buffer: buffer_bin_and(&self.buffer, self.offset, &rhs.buffer, rhs.offset, self.len), - offset: 0, - len: self.len, + buffer: buffer_bin_and( + &self.buffer, + self.bit_offset, + &rhs.buffer, + rhs.bit_offset, + self.bit_len, + ), + bit_offset: 0, + bit_len: self.bit_len, } } } @@ -243,11 +523,17 @@ impl BitOr<&BooleanBuffer> for &BooleanBuffer { type Output = BooleanBuffer; fn bitor(self, rhs: &BooleanBuffer) -> Self::Output { - assert_eq!(self.len, rhs.len); + assert_eq!(self.bit_len, rhs.bit_len); BooleanBuffer { - buffer: buffer_bin_or(&self.buffer, self.offset, &rhs.buffer, rhs.offset, self.len), - offset: 0, - len: self.len, + buffer: buffer_bin_or( + &self.buffer, + self.bit_offset, + &rhs.buffer, + rhs.bit_offset, + self.bit_len, + ), + bit_offset: 0, + bit_len: self.bit_len, } } } @@ -256,11 +542,17 @@ impl BitXor<&BooleanBuffer> for &BooleanBuffer { type Output = BooleanBuffer; fn bitxor(self, rhs: &BooleanBuffer) -> Self::Output { - assert_eq!(self.len, rhs.len); + assert_eq!(self.bit_len, rhs.bit_len); BooleanBuffer { - buffer: buffer_bin_xor(&self.buffer, self.offset, &rhs.buffer, rhs.offset, self.len), - offset: 0, - len: self.len, + buffer: buffer_bin_xor( + &self.buffer, + self.bit_offset, + &rhs.buffer, + rhs.bit_offset, + self.bit_len, + ), + bit_offset: 0, + bit_len: self.bit_len, } } } @@ -270,7 +562,7 @@ impl<'a> IntoIterator for &'a BooleanBuffer { type IntoIter = BitIterator<'a>; fn into_iter(self) -> Self::IntoIter { - BitIterator::new(self.values(), self.offset, self.len) + BitIterator::new(self.values(), self.bit_offset, self.bit_len) } } @@ -358,12 +650,12 @@ mod tests { assert_eq!(boolean_slice1.values(), boolean_slice2.values()); assert_eq!(bytes, boolean_slice1.values()); - assert_eq!(16, boolean_slice1.offset); - assert_eq!(16, boolean_slice1.len); + assert_eq!(16, boolean_slice1.bit_offset); + assert_eq!(16, boolean_slice1.bit_len); assert_eq!(bytes, boolean_slice2.values()); - assert_eq!(0, boolean_slice2.offset); - assert_eq!(16, boolean_slice2.len); + assert_eq!(0, boolean_slice2.bit_offset); + assert_eq!(16, boolean_slice2.bit_len); } #[test] @@ -432,4 +724,103 @@ mod tests { assert_eq!(buf.values().len(), 1); assert!(buf.value(0)); } + + #[test] + fn test_from_bitwise_unary_op() { + // Use 1024 boolean values so that at least some of the tests cover multiple u64 chunks and + // perfect alignment + let input_bools = (0..1024) + .map(|_| rand::random::()) + .collect::>(); + let input_buffer = BooleanBuffer::from(&input_bools[..]); + + // Note ensure we test offsets over 100 to cover multiple u64 chunks + for offset in 0..1024 { + let result = BooleanBuffer::from_bitwise_unary_op( + input_buffer.values(), + offset, + input_buffer.len() - offset, + |a| !a, + ); + let expected = input_bools[offset..] + .iter() + .map(|b| !*b) + .collect::(); + assert_eq!(result, expected); + } + + // Also test when the input doesn't cover the entire buffer + for offset in 0..512 { + let len = 512 - offset; // fixed length less than total + let result = + BooleanBuffer::from_bitwise_unary_op(input_buffer.values(), offset, len, |a| !a); + let expected = input_bools[offset..] + .iter() + .take(len) + .map(|b| !*b) + .collect::(); + assert_eq!(result, expected); + } + } + + #[test] + fn test_from_bitwise_binary_op() { + // pick random boolean inputs + let input_bools_left = (0..1024) + .map(|_| rand::random::()) + .collect::>(); + let input_bools_right = (0..1024) + .map(|_| rand::random::()) + .collect::>(); + let input_buffer_left = BooleanBuffer::from(&input_bools_left[..]); + let input_buffer_right = BooleanBuffer::from(&input_bools_right[..]); + + for left_offset in 0..200 { + for right_offset in [0, 4, 5, 17, 33, 24, 45, 64, 65, 100, 200] { + for len_offset in [0, 1, 44, 100, 256, 300, 512] { + let len = 1024 - len_offset - left_offset.max(right_offset); // ensure we don't go out of bounds + // compute with AND + let result = BooleanBuffer::from_bitwise_binary_op( + input_buffer_left.values(), + left_offset, + input_buffer_right.values(), + right_offset, + len, + |a, b| a & b, + ); + // compute directly from bools + let expected = input_bools_left[left_offset..] + .iter() + .zip(&input_bools_right[right_offset..]) + .take(len) + .map(|(a, b)| *a & *b) + .collect::(); + assert_eq!(result, expected); + } + } + } + } + + #[test] + fn test_extend_trusted_len_sets_byte_len() { + // Ensures extend_trusted_len keeps the underlying byte length in sync with bit length. + let mut builder = BooleanBufferBuilder::new(0); + let bools: Vec<_> = (0..10).map(|i| i % 2 == 0).collect(); + unsafe { builder.extend_trusted_len(bools.into_iter()) }; + assert_eq!(builder.as_slice().len(), bit_util::ceil(builder.len(), 8)); + } + + #[test] + fn test_extend_trusted_len_then_append() { + // Exercises append after extend_trusted_len to validate byte length and values. + let mut builder = BooleanBufferBuilder::new(0); + let bools: Vec<_> = (0..9).map(|i| i % 3 == 0).collect(); + unsafe { builder.extend_trusted_len(bools.clone().into_iter()) }; + builder.append(true); + assert_eq!(builder.as_slice().len(), bit_util::ceil(builder.len(), 8)); + let finished = builder.finish(); + for (i, v) in bools.into_iter().chain(std::iter::once(true)).enumerate() { + assert_eq!(finished.value(i), v, "at index {}", i); + } + } } diff --git a/arrow-buffer/src/buffer/immutable.rs b/arrow-buffer/src/buffer/immutable.rs index 946299d0061b..7bf67503562d 100644 --- a/arrow-buffer/src/buffer/immutable.rs +++ b/arrow-buffer/src/buffer/immutable.rs @@ -22,10 +22,12 @@ use std::sync::Arc; use crate::alloc::{Allocation, Deallocation}; use crate::util::bit_chunk_iterator::{BitChunks, UnalignedBitChunk}; -use crate::BufferBuilder; +use crate::{BooleanBuffer, BufferBuilder}; use crate::{bit_util, bytes::Bytes, native::ArrowNativeType}; -use super::ops::bitwise_unary_op_helper; +#[cfg(feature = "pool")] +use crate::pool::MemoryPool; + use super::{MutableBuffer, ScalarBuffer}; /// A contiguous memory region that can be shared with other buffers and across @@ -169,7 +171,7 @@ impl Buffer { len: usize, owner: Arc, ) -> Self { - Buffer::build_with_arguments(ptr, len, Deallocation::Custom(owner, len)) + unsafe { Buffer::build_with_arguments(ptr, len, Deallocation::Custom(owner, len)) } } /// Auxiliary method to create a new Buffer @@ -178,7 +180,7 @@ impl Buffer { len: usize, deallocation: Deallocation, ) -> Self { - let bytes = Bytes::new(ptr, len, deallocation); + let bytes = unsafe { Bytes::new(ptr, len, deallocation) }; let ptr = bytes.as_ptr(); Buffer { ptr, @@ -341,13 +343,13 @@ impl Buffer { return self.slice_with_length(offset / 8, bit_util::ceil(len, 8)); } - bitwise_unary_op_helper(self, offset, len, |a| a) + BooleanBuffer::from_bits(self.as_slice(), offset, len).into_inner() } /// Returns a `BitChunks` instance which can be used to iterate over this buffers bits /// in larger chunks and starting at arbitrary bit offsets. /// Note that both `offset` and `length` are measured in bits. - pub fn bit_chunks(&self, offset: usize, len: usize) -> BitChunks { + pub fn bit_chunks(&self, offset: usize, len: usize) -> BitChunks<'_> { BitChunks::new(self.as_slice(), offset, len) } @@ -361,6 +363,23 @@ impl Buffer { /// Returns `Err` if this is shared or its allocation is from an external source or /// it is not allocated with alignment [`ALIGNMENT`] /// + /// # Example: Creating a [`MutableBuffer`] from a [`Buffer`] + /// ``` + /// # use arrow_buffer::buffer::{Buffer, MutableBuffer}; + /// let buffer: Buffer = Buffer::from(&[1u8, 2, 3, 4][..]); + /// // Only possible to convert a Buffer into a MutableBuffer if uniquely owned + /// // (i.e., there are no other references to it). + /// let mut mutable_buffer = match buffer.into_mutable() { + /// Ok(mutable) => mutable, + /// Err(orig_buffer) => { + /// panic!("buffer was not uniquely owned"); + /// } + /// }; + /// mutable_buffer.push(5u8); + /// let buffer = Buffer::from(mutable_buffer); + /// assert_eq!(buffer.as_slice(), &[1u8, 2, 3, 4, 5]) + /// ``` + /// /// [`ALIGNMENT`]: crate::alloc::ALIGNMENT pub fn into_mutable(self) -> Result { let ptr = self.ptr; @@ -385,8 +404,8 @@ impl Buffer { /// # Errors /// /// Returns `Err(self)` if - /// 1. this buffer does not have the same [`Layout`] as the destination Vec - /// 2. contains a non-zero offset + /// 1. The buffer does not have the same [`Layout`] as the destination Vec + /// 2. The buffer contains a non-zero offset /// 3. The buffer is shared pub fn into_vec(self) -> Result, Self> { let layout = match self.data.deallocation() { @@ -430,6 +449,17 @@ impl Buffer { pub fn ptr_eq(&self, other: &Self) -> bool { self.ptr == other.ptr && self.length == other.length } + + /// Register this [`Buffer`] with the provided [`MemoryPool`] + /// + /// This claims the memory used by this buffer in the pool, allowing for + /// accurate accounting of memory usage. Any prior reservation will be + /// released so this works well when the buffer is being shared among + /// multiple arrays. + #[cfg(feature = "pool")] + pub fn claim(&self, pool: &dyn MemoryPool) { + self.data.claim(pool) + } } /// Note that here we deliberately do not implement @@ -510,6 +540,12 @@ impl std::ops::Deref for Buffer { } } +impl AsRef<[u8]> for &Buffer { + fn as_ref(&self) -> &[u8] { + self.as_slice() + } +} + impl From for Buffer { #[inline] fn from(buffer: MutableBuffer) -> Self { @@ -547,7 +583,7 @@ impl Buffer { pub unsafe fn from_trusted_len_iter>( iterator: I, ) -> Self { - MutableBuffer::from_trusted_len_iter(iterator).into() + unsafe { MutableBuffer::from_trusted_len_iter(iterator).into() } } /// Creates a [`Buffer`] from an [`Iterator`] with a trusted (upper) length or errors @@ -564,7 +600,7 @@ impl Buffer { >( iterator: I, ) -> Result { - Ok(MutableBuffer::try_from_trusted_len_iter(iterator)?.into()) + unsafe { Ok(MutableBuffer::try_from_trusted_len_iter(iterator)?.into()) } } } @@ -983,13 +1019,13 @@ mod tests { #[should_panic(expected = "capacity overflow")] fn test_from_iter_overflow() { let iter_len = usize::MAX / std::mem::size_of::() + 1; - let _ = Buffer::from_iter(std::iter::repeat(0_u64).take(iter_len)); + let _ = Buffer::from_iter(std::iter::repeat_n(0_u64, iter_len)); } #[test] fn bit_slice_length_preserved() { // Create a boring buffer - let buf = Buffer::from_iter(std::iter::repeat(true).take(64)); + let buf = Buffer::from_iter(std::iter::repeat_n(true, 64)); let assert_preserved = |offset: usize, len: usize| { let new_buf = buf.bit_slice(offset, len); @@ -1021,7 +1057,7 @@ mod tests { #[test] fn test_strong_count() { - let buffer = Buffer::from_iter(std::iter::repeat(0_u8).take(100)); + let buffer = Buffer::from_iter(std::iter::repeat_n(0_u8, 100)); assert_eq!(buffer.strong_count(), 1); let buffer2 = buffer.clone(); diff --git a/arrow-buffer/src/buffer/mutable.rs b/arrow-buffer/src/buffer/mutable.rs index 19ca0fef1519..9fc860506194 100644 --- a/arrow-buffer/src/buffer/mutable.rs +++ b/arrow-buffer/src/buffer/mutable.rs @@ -15,41 +15,86 @@ // specific language governing permissions and limitations // under the License. -use std::alloc::{handle_alloc_error, Layout}; +use std::alloc::{Layout, handle_alloc_error}; use std::mem; use std::ptr::NonNull; -use crate::alloc::{Deallocation, ALIGNMENT}; +use crate::alloc::{ALIGNMENT, Deallocation}; use crate::{ bytes::Bytes, native::{ArrowNativeType, ToByteSlice}, util::bit_util, }; +#[cfg(feature = "pool")] +use crate::pool::{MemoryPool, MemoryReservation}; +#[cfg(feature = "pool")] +use std::sync::Mutex; + use super::Buffer; -/// A [`MutableBuffer`] is Arrow's interface to build a [`Buffer`] out of items or slices of items. +/// A [`MutableBuffer`] is a wrapper over memory regions, used to build +/// [`Buffer`]s out of items or slices of items. /// -/// [`Buffer`]s created from [`MutableBuffer`] (via `into`) are guaranteed to have its pointer aligned -/// along cache lines and in multiple of 64 bytes. +/// [`Buffer`]s created from [`MutableBuffer`] (via `into`) are guaranteed to be +/// aligned along cache lines and in multiples of 64 bytes. /// /// Use [MutableBuffer::push] to insert an item, [MutableBuffer::extend_from_slice] -/// to insert many items, and `into` to convert it to [`Buffer`]. -/// -/// For a safe, strongly typed API consider using [`Vec`] and [`ScalarBuffer`](crate::ScalarBuffer) +/// to insert many items, and `into` to convert it to [`Buffer`]. For typed data, +/// it is often more efficient to use [`Vec`] and convert it to [`Buffer`] rather +/// than using [`MutableBuffer`] (see examples below). /// -/// Note: this may be deprecated in a future release ([#1176](https://github.com/apache/arrow-rs/issues/1176)) +/// # See Also +/// * For a safe, strongly typed API consider using [`Vec`] and [`ScalarBuffer`](crate::ScalarBuffer) +/// * To apply bitwise operations, see [`apply_bitwise_binary_op`] and [`apply_bitwise_unary_op`] /// -/// # Example +/// [`apply_bitwise_binary_op`]: crate::bit_util::apply_bitwise_binary_op +/// [`apply_bitwise_unary_op`]: crate::bit_util::apply_bitwise_unary_op /// +/// # Example: Creating a [`Buffer`] from a [`MutableBuffer`] /// ``` /// # use arrow_buffer::buffer::{Buffer, MutableBuffer}; /// let mut buffer = MutableBuffer::new(0); /// buffer.push(256u32); /// buffer.extend_from_slice(&[1u32]); -/// let buffer: Buffer = buffer.into(); +/// let buffer = Buffer::from(buffer); /// assert_eq!(buffer.as_slice(), &[0u8, 1, 0, 0, 1, 0, 0, 0]) /// ``` +/// +/// The same can be achieved more efficiently by using a `Vec` +/// ``` +/// # use arrow_buffer::buffer::Buffer; +/// let mut vec = Vec::new(); +/// vec.push(256u32); +/// vec.extend_from_slice(&[1u32]); +/// let buffer = Buffer::from(vec); +/// assert_eq!(buffer.as_slice(), &[0u8, 1, 0, 0, 1, 0, 0, 0]); +/// ``` +/// +/// # Example: Creating a [`MutableBuffer`] from a `Vec` +/// ``` +/// # use arrow_buffer::buffer::MutableBuffer; +/// let vec = vec![1u32, 2, 3]; +/// let mutable_buffer = MutableBuffer::from(vec); // reuses the allocation from vec +/// assert_eq!(mutable_buffer.len(), 12); // 3 * 4 bytes +/// ``` +/// +/// # Example: Creating a [`MutableBuffer`] from a [`Buffer`] +/// ``` +/// # use arrow_buffer::buffer::{Buffer, MutableBuffer}; +/// let buffer: Buffer = Buffer::from(&[1u8, 2, 3, 4][..]); +/// // Only possible to convert a Buffer into a MutableBuffer if uniquely owned +/// // (i.e., there are no other references to it). +/// let mut mutable_buffer = match buffer.into_mutable() { +/// Ok(mutable) => mutable, +/// Err(orig_buffer) => { +/// panic!("buffer was not uniquely owned"); +/// } +/// }; +/// mutable_buffer.push(5u8); +/// let buffer = Buffer::from(mutable_buffer); +/// assert_eq!(buffer.as_slice(), &[1u8, 2, 3, 4, 5]) +/// ``` #[derive(Debug)] pub struct MutableBuffer { // dangling iff capacity = 0 @@ -57,6 +102,10 @@ pub struct MutableBuffer { // invariant: len <= capacity len: usize, layout: Layout, + + /// Memory reservation for tracking memory usage + #[cfg(feature = "pool")] + reservation: Mutex>>, } impl MutableBuffer { @@ -91,6 +140,8 @@ impl MutableBuffer { data, len: 0, layout, + #[cfg(feature = "pool")] + reservation: std::sync::Mutex::new(None), } } @@ -115,7 +166,13 @@ impl MutableBuffer { NonNull::new(raw_ptr).unwrap_or_else(|| handle_alloc_error(layout)) } }; - Self { data, len, layout } + Self { + data, + len, + layout, + #[cfg(feature = "pool")] + reservation: std::sync::Mutex::new(None), + } } /// Allocates a new [MutableBuffer] from given `Bytes`. @@ -127,9 +184,17 @@ impl MutableBuffer { let len = bytes.len(); let data = bytes.ptr(); + #[cfg(feature = "pool")] + let reservation = bytes.reservation.lock().unwrap().take(); mem::forget(bytes); - Ok(Self { data, len, layout }) + Ok(Self { + data, + len, + layout, + #[cfg(feature = "pool")] + reservation: Mutex::new(reservation), + }) } /// creates a new [MutableBuffer] with capacity and length capable of holding `len` bits. @@ -197,6 +262,75 @@ impl MutableBuffer { } } + /// Adding to this mutable buffer `slice_to_repeat` repeated `repeat_count` times. + /// + /// # Example + /// + /// ## Repeat the same string bytes multiple times + /// ``` + /// # use arrow_buffer::buffer::MutableBuffer; + /// let mut buffer = MutableBuffer::new(0); + /// let bytes_to_repeat = b"ab"; + /// buffer.repeat_slice_n_times(bytes_to_repeat, 3); + /// assert_eq!(buffer.as_slice(), b"ababab"); + /// ``` + pub fn repeat_slice_n_times( + &mut self, + slice_to_repeat: &[T], + repeat_count: usize, + ) { + if repeat_count == 0 || slice_to_repeat.is_empty() { + return; + } + + let bytes_to_repeat = size_of_val(slice_to_repeat); + + // Ensure capacity + self.reserve(repeat_count * bytes_to_repeat); + + // Save the length before we do all the copies to know where to start from + let length_before = self.len; + + // Copy the initial slice once so we can use doubling strategy on it + self.extend_from_slice(slice_to_repeat); + + // This tracks how much bytes we have added by repeating so far + let added_repeats_length = bytes_to_repeat; + assert_eq!( + self.len - length_before, + added_repeats_length, + "should copy exactly the same number of bytes" + ); + + // Number of times the slice was repeated + let mut already_repeated_times = 1; + + // We will use doubling strategy to fill the buffer in log(repeat_count) steps + while already_repeated_times < repeat_count { + // How many slices can we copy in this iteration + // (either double what we have, or just the remaining ones) + let number_of_slices_to_copy = + already_repeated_times.min(repeat_count - already_repeated_times); + let number_of_bytes_to_copy = number_of_slices_to_copy * bytes_to_repeat; + + unsafe { + // Get to the start of the data before we started copying anything + let src = self.data.as_ptr().add(length_before) as *const u8; + + // Go to the current location to copy to (end of current data) + let dst = self.data.as_ptr().add(self.len); + + // SAFETY: the pointers are not overlapping as there is `number_of_bytes_to_copy` or less between them + std::ptr::copy_nonoverlapping(src, dst, number_of_bytes_to_copy) + } + + // Advance the length by the amount of data we just copied (doubled) + self.len += number_of_bytes_to_copy; + + already_repeated_times += number_of_slices_to_copy; + } + } + #[cold] fn reallocate(&mut self, capacity: usize) { let new_layout = Layout::from_size_align(capacity, self.layout.align()).unwrap(); @@ -217,6 +351,12 @@ impl MutableBuffer { }; self.data = NonNull::new(data).unwrap_or_else(|| handle_alloc_error(new_layout)); self.layout = new_layout; + #[cfg(feature = "pool")] + { + if let Some(reservation) = self.reservation.lock().unwrap().as_mut() { + reservation.resize(self.layout.size()); + } + } } /// Truncates this buffer to `len` bytes @@ -228,6 +368,12 @@ impl MutableBuffer { return; } self.len = len; + #[cfg(feature = "pool")] + { + if let Some(reservation) = self.reservation.lock().unwrap().as_mut() { + reservation.resize(self.len); + } + } } /// Resizes the buffer, either truncating its contents (with no change in capacity), or @@ -251,6 +397,12 @@ impl MutableBuffer { } // this truncates the buffer when new_len < self.len self.len = new_len; + #[cfg(feature = "pool")] + { + if let Some(reservation) = self.reservation.lock().unwrap().as_mut() { + reservation.resize(self.len); + } + } } /// Shrinks the capacity of the buffer as much as possible. @@ -328,6 +480,11 @@ impl MutableBuffer { #[inline] pub(super) fn into_buffer(self) -> Buffer { let bytes = unsafe { Bytes::new(self.data, self.len, Deallocation::Standard(self.layout)) }; + #[cfg(feature = "pool")] + { + let reservation = self.reservation.lock().unwrap().take(); + *bytes.reservation.lock().unwrap() = reservation; + } std::mem::forget(self); Buffer::from(bytes) } @@ -412,8 +569,8 @@ impl MutableBuffer { pub unsafe fn push_unchecked(&mut self, item: T) { let additional = std::mem::size_of::(); let src = item.to_byte_slice().as_ptr(); - let dst = self.data.as_ptr().add(self.len); - std::ptr::copy_nonoverlapping(src, dst, additional); + let dst = unsafe { self.data.as_ptr().add(self.len) }; + unsafe { std::ptr::copy_nonoverlapping(src, dst, additional) }; self.len += additional; } @@ -437,20 +594,19 @@ impl MutableBuffer { /// as it eliminates the conditional `Iterator::next` #[inline] pub fn collect_bool bool>(len: usize, mut f: F) -> Self { - let mut buffer = Self::new(bit_util::ceil(len, 64) * 8); + let mut buffer: Vec = Vec::with_capacity(bit_util::ceil(len, 64)); let chunks = len / 64; let remainder = len % 64; - for chunk in 0..chunks { + buffer.extend((0..chunks).map(|chunk| { let mut packed = 0; for bit_idx in 0..64 { let i = bit_idx + chunk * 64; packed |= (f(i) as u64) << bit_idx; } - // SAFETY: Already allocated sufficient capacity - unsafe { buffer.push_unchecked(packed) } - } + packed + })); if remainder != 0 { let mut packed = 0; @@ -459,13 +615,152 @@ impl MutableBuffer { packed |= (f(i) as u64) << bit_idx; } - // SAFETY: Already allocated sufficient capacity - unsafe { buffer.push_unchecked(packed) } + buffer.push(packed) } + let mut buffer: MutableBuffer = buffer.into(); buffer.truncate(bit_util::ceil(len, 8)); buffer } + + /// Extends this buffer with boolean values. + /// + /// This requires `iter` to report an exact size via `size_hint`. + /// `offset` indicates the starting offset in bits in this buffer to begin writing to + /// and must be less than or equal to the current length of this buffer. + /// All bits not written to (but readable due to byte alignment) will be zeroed out. + /// # Safety + /// Callers must ensure that `iter` reports an exact size via `size_hint`. + #[inline] + pub unsafe fn extend_bool_trusted_len>( + &mut self, + mut iter: I, + offset: usize, + ) { + let (lower, upper) = iter.size_hint(); + let len = upper.expect("Iterator must have exact size_hint"); + assert_eq!(lower, len, "Iterator must have exact size_hint"); + debug_assert!( + offset <= self.len * 8, + "offset must be <= buffer length in bits" + ); + + if len == 0 { + return; + } + + let start_len = offset; + let end_bit = start_len + len; + + // SAFETY: we will initialize all newly exposed bytes before they are read + let new_len_bytes = bit_util::ceil(end_bit, 8); + if new_len_bytes > self.len { + self.reserve(new_len_bytes - self.len); + // SAFETY: caller will initialize all newly exposed bytes before they are read + unsafe { self.set_len(new_len_bytes) }; + } + + let slice = self.as_slice_mut(); + + let mut bit_idx = start_len; + + // ---- Unaligned prefix: advance to the next 64-bit boundary ---- + let misalignment = bit_idx & 63; + let prefix_bits = if misalignment == 0 { + 0 + } else { + (64 - misalignment).min(end_bit - bit_idx) + }; + + if prefix_bits != 0 { + let byte_start = bit_idx / 8; + let byte_end = bit_util::ceil(bit_idx + prefix_bits, 8); + let bit_offset = bit_idx % 8; + + // Clear any newly-visible bits in the existing partial byte + if bit_offset != 0 { + let keep_mask = (1u8 << bit_offset).wrapping_sub(1); + slice[byte_start] &= keep_mask; + } + + // Zero any new bytes we will partially fill in this prefix + let zero_from = if bit_offset == 0 { + byte_start + } else { + byte_start + 1 + }; + if byte_end > zero_from { + slice[zero_from..byte_end].fill(0); + } + + for _ in 0..prefix_bits { + let v = iter.next().unwrap(); + if v { + let byte_idx = bit_idx / 8; + let bit = bit_idx % 8; + slice[byte_idx] |= 1 << bit; + } + bit_idx += 1; + } + } + + if bit_idx < end_bit { + // ---- Aligned middle: write u64 chunks ---- + debug_assert_eq!(bit_idx & 63, 0); + let remaining_bits = end_bit - bit_idx; + let chunks = remaining_bits / 64; + + let words_start = bit_idx / 8; + let words_end = words_start + chunks * 8; + for dst in slice[words_start..words_end].chunks_exact_mut(8) { + let mut packed: u64 = 0; + for i in 0..64 { + packed |= (iter.next().unwrap() as u64) << i; + } + dst.copy_from_slice(&packed.to_le_bytes()); + bit_idx += 64; + } + + // ---- Unaligned suffix: remaining < 64 bits ---- + let suffix_bits = end_bit - bit_idx; + if suffix_bits != 0 { + debug_assert_eq!(bit_idx % 8, 0); + let byte_start = bit_idx / 8; + let byte_end = bit_util::ceil(end_bit, 8); + slice[byte_start..byte_end].fill(0); + + for _ in 0..suffix_bits { + let v = iter.next().unwrap(); + if v { + let byte_idx = bit_idx / 8; + let bit = bit_idx % 8; + slice[byte_idx] |= 1 << bit; + } + bit_idx += 1; + } + } + } + + // Clear any unused bits in the last byte + let remainder = end_bit % 8; + if remainder != 0 { + let mask = (1u8 << remainder).wrapping_sub(1); + slice[bit_util::ceil(end_bit, 8) - 1] &= mask; + } + + debug_assert_eq!(bit_idx, end_bit); + } + + /// Register this [`MutableBuffer`] with the provided [`MemoryPool`] + /// + /// This claims the memory used by this buffer in the pool, allowing for + /// accurate accounting of memory usage. Any prior reservation will be + /// released so this works well when the buffer is being shared among + /// multiple arrays. + #[cfg(feature = "pool")] + pub fn claim(&self, pool: &dyn MemoryPool) { + *self.reservation.lock().unwrap() = Some(pool.reserve(self.capacity())); + } } /// Creates a non-null pointer with alignment of [`ALIGNMENT`] @@ -506,7 +801,13 @@ impl From> for MutableBuffer { // This is based on `RawVec::current_memory` let layout = unsafe { Layout::array::(value.capacity()).unwrap_unchecked() }; mem::forget(value); - Self { data, len, layout } + Self { + data, + len, + layout, + #[cfg(feature = "pool")] + reservation: std::sync::Mutex::new(None), + } } } @@ -575,11 +876,11 @@ impl MutableBuffer { for item in iterator { // note how there is no reserve here (compared with `extend_from_iter`) let src = item.to_byte_slice().as_ptr(); - std::ptr::copy_nonoverlapping(src, dst, item_size); - dst = dst.add(item_size); + unsafe { std::ptr::copy_nonoverlapping(src, dst, item_size) }; + dst = unsafe { dst.add(item_size) }; } assert_eq!( - dst.offset_from(buffer.data.as_ptr()) as usize, + unsafe { dst.offset_from(buffer.data.as_ptr()) } as usize, len, "Trusted iterator length was not accurately reported" ); @@ -638,20 +939,22 @@ impl MutableBuffer { let item = item?; // note how there is no reserve here (compared with `extend_from_iter`) let src = item.to_byte_slice().as_ptr(); - std::ptr::copy_nonoverlapping(src, dst, item_size); - dst = dst.add(item_size); + unsafe { std::ptr::copy_nonoverlapping(src, dst, item_size) }; + dst = unsafe { dst.add(item_size) }; } // try_from_trusted_len_iter is instantiated a lot, so we extract part of it into a less // generic method to reduce compile time unsafe fn finalize_buffer(dst: *mut u8, buffer: &mut MutableBuffer, len: usize) { - assert_eq!( - dst.offset_from(buffer.data.as_ptr()) as usize, - len, - "Trusted iterator length was not accurately reported" - ); - buffer.len = len; - } - finalize_buffer(dst, &mut buffer, len); + unsafe { + assert_eq!( + dst.offset_from(buffer.data.as_ptr()) as usize, + len, + "Trusted iterator length was not accurately reported" + ); + buffer.len = len; + } + } + unsafe { finalize_buffer(dst, &mut buffer, len) }; Ok(buffer) } } @@ -676,6 +979,12 @@ impl std::ops::DerefMut for MutableBuffer { } } +impl AsRef<[u8]> for &MutableBuffer { + fn as_ref(&self) -> &[u8] { + self.as_slice() + } +} + impl Drop for MutableBuffer { fn drop(&mut self) { if self.layout.size() != 0 { @@ -1013,4 +1322,229 @@ mod tests { let max_capacity = isize::MAX as usize - (isize::MAX as usize % ALIGNMENT); let _ = MutableBuffer::with_capacity(max_capacity + 1); } + + #[cfg(feature = "pool")] + mod pool_tests { + use super::*; + use crate::pool::{MemoryPool, TrackingMemoryPool}; + + #[test] + fn test_reallocate_with_pool() { + let pool = TrackingMemoryPool::default(); + let mut buffer = MutableBuffer::with_capacity(100); + buffer.claim(&pool); + + // Initial capacity should be 128 (multiple of 64) + assert_eq!(buffer.capacity(), 128); + assert_eq!(pool.used(), 128); + + // Reallocate to a larger size + buffer.reallocate(200); + + // The capacity is exactly the requested size, not rounded up + assert_eq!(buffer.capacity(), 200); + assert_eq!(pool.used(), 200); + + // Reallocate to a smaller size + buffer.reallocate(50); + + // The capacity is exactly the requested size, not rounded up + assert_eq!(buffer.capacity(), 50); + assert_eq!(pool.used(), 50); + } + + #[test] + fn test_truncate_with_pool() { + let pool = TrackingMemoryPool::default(); + let mut buffer = MutableBuffer::with_capacity(100); + + // Fill buffer with some data + buffer.resize(80, 1); + assert_eq!(buffer.len(), 80); + + buffer.claim(&pool); + assert_eq!(pool.used(), 128); + + // Truncate buffer + buffer.truncate(40); + assert_eq!(buffer.len(), 40); + assert_eq!(pool.used(), 40); + + // Truncate to zero + buffer.truncate(0); + assert_eq!(buffer.len(), 0); + assert_eq!(pool.used(), 0); + } + + #[test] + fn test_resize_with_pool() { + let pool = TrackingMemoryPool::default(); + let mut buffer = MutableBuffer::with_capacity(100); + buffer.claim(&pool); + + // Initial state + assert_eq!(buffer.len(), 0); + assert_eq!(pool.used(), 128); + + // Resize to increase length + buffer.resize(50, 1); + assert_eq!(buffer.len(), 50); + assert_eq!(pool.used(), 50); + + // Resize to increase length beyond capacity + buffer.resize(150, 1); + assert_eq!(buffer.len(), 150); + assert_eq!(buffer.capacity(), 256); + assert_eq!(pool.used(), 150); + + // Resize to decrease length + buffer.resize(30, 1); + assert_eq!(buffer.len(), 30); + assert_eq!(pool.used(), 30); + } + + #[test] + fn test_buffer_lifecycle_with_pool() { + let pool = TrackingMemoryPool::default(); + + // Create a buffer with memory reservation + let mut mutable = MutableBuffer::with_capacity(100); + mutable.resize(80, 1); + mutable.claim(&pool); + + // Memory reservation is based on capacity when using claim() + assert_eq!(pool.used(), 128); + + // Convert to immutable Buffer + let buffer = mutable.into_buffer(); + + // Memory reservation should be preserved + assert_eq!(pool.used(), 128); + + // Drop the buffer and the reservation should be released + drop(buffer); + assert_eq!(pool.used(), 0); + } + } + + fn create_expected_repeated_slice( + slice_to_repeat: &[T], + repeat_count: usize, + ) -> Buffer { + let mut expected = MutableBuffer::new(size_of_val(slice_to_repeat) * repeat_count); + for _ in 0..repeat_count { + // Not using push_slice_repeated as this is the function under test + expected.extend_from_slice(slice_to_repeat); + } + expected.into() + } + + // Helper to test a specific repeat count with various slice sizes + fn test_repeat_count( + repeat_count: usize, + test_data: &[T], + ) { + let mut buffer = MutableBuffer::new(0); + buffer.repeat_slice_n_times(test_data, repeat_count); + + let expected = create_expected_repeated_slice(test_data, repeat_count); + let result: Buffer = buffer.into(); + + assert_eq!( + result, + expected, + "Failed for repeat_count={}, slice_len={}", + repeat_count, + test_data.len() + ); + } + + #[test] + fn test_repeat_slice_count_edge_cases() { + // Empty slice + test_repeat_count(100, &[] as &[i32]); + + // Zero repeats + test_repeat_count(0, &[1i32, 2, 3]); + } + + #[test] + fn test_small_repeats_counts() { + // test any special implementation for small repeat counts + let data = &[1u8, 2, 3, 4, 5]; + + for _ in 1..=10 { + test_repeat_count(2, data); + } + } + + #[test] + fn test_different_size_of_i32_repeat_slice() { + let data: &[i32] = &[1, 2, 3]; + let data_with_single_item: &[i32] = &[42]; + + for data in &[data, data_with_single_item] { + for item in 1..=9 { + let base_repeat_count = 2_usize.pow(item); + test_repeat_count(base_repeat_count - 1, data); + test_repeat_count(base_repeat_count, data); + test_repeat_count(base_repeat_count + 1, data); + } + } + } + + #[test] + fn test_different_size_of_u8_repeat_slice() { + let data: &[u8] = &[1, 2, 3]; + let data_with_single_item: &[u8] = &[10]; + + for data in &[data, data_with_single_item] { + for item in 1..=9 { + let base_repeat_count = 2_usize.pow(item); + test_repeat_count(base_repeat_count - 1, data); + test_repeat_count(base_repeat_count, data); + test_repeat_count(base_repeat_count + 1, data); + } + } + } + + #[test] + fn test_different_size_of_u16_repeat_slice() { + let data: &[u16] = &[1, 2, 3]; + let data_with_single_item: &[u16] = &[10]; + + for data in &[data, data_with_single_item] { + for item in 1..=9 { + let base_repeat_count = 2_usize.pow(item); + test_repeat_count(base_repeat_count - 1, data); + test_repeat_count(base_repeat_count, data); + test_repeat_count(base_repeat_count + 1, data); + } + } + } + + #[test] + fn test_various_slice_lengths() { + // Test different slice lengths with same repeat pattern + let repeat_count = 37; // Arbitrary non-power-of-2 + + // Single element + test_repeat_count(repeat_count, &[42i32]); + + // Small slices + test_repeat_count(repeat_count, &[1i32, 2]); + test_repeat_count(repeat_count, &[1i32, 2, 3]); + test_repeat_count(repeat_count, &[1i32, 2, 3, 4]); + test_repeat_count(repeat_count, &[1i32, 2, 3, 4, 5]); + + // Larger slices + let data_10: Vec = (0..10).collect(); + test_repeat_count(repeat_count, &data_10); + + let data_100: Vec = (0..100).collect(); + test_repeat_count(repeat_count, &data_100); + + let data_1000: Vec = (0..1000).collect(); + test_repeat_count(repeat_count, &data_1000); + } } diff --git a/arrow-buffer/src/buffer/offset.rs b/arrow-buffer/src/buffer/offset.rs index fe3a57a38248..66fa7dd22ec5 100644 --- a/arrow-buffer/src/buffer/offset.rs +++ b/arrow-buffer/src/buffer/offset.rs @@ -112,6 +112,9 @@ impl OffsetBuffer { /// assert_eq!(offsets.as_ref(), &[0, 1, 4, 9]); /// ``` /// + /// If you want to create an [`OffsetBuffer`] where all lengths are the same, + /// consider using the faster [`OffsetBuffer::from_repeated_length`] instead. + /// /// # Panics /// /// Panics on overflow @@ -133,6 +136,43 @@ impl OffsetBuffer { Self(out.into()) } + /// Create a new [`OffsetBuffer`] where each slice has the same length + /// `length`, repeated `n` times. + /// + /// + /// Example + /// ``` + /// # use arrow_buffer::OffsetBuffer; + /// let offsets = OffsetBuffer::::from_repeated_length(4, 3); + /// assert_eq!(offsets.as_ref(), &[0, 4, 8, 12]); + /// ``` + /// + /// # Panics + /// + /// Panics on overflow + pub fn from_repeated_length(length: usize, n: usize) -> Self { + if n == 0 { + return Self::new_empty(); + } + + if length == 0 { + return Self::new_zeroed(n); + } + + // Check for overflow + // Making sure we don't overflow usize or O when calculating the total length + length.checked_mul(n).expect("usize overflow"); + + // Check for overflow + O::from_usize(length * n).expect("offset overflow"); + + let offsets = (0..=n) + .map(|index| O::usize_as(index * length)) + .collect::>(); + + Self(ScalarBuffer::from(offsets)) + } + /// Get an Iterator over the lengths of this [`OffsetBuffer`] /// /// ``` @@ -283,6 +323,36 @@ mod tests { OffsetBuffer::::from_lengths([usize::MAX, 1]); } + #[test] + #[should_panic(expected = "offset overflow")] + fn from_repeated_lengths_offset_length_overflow() { + OffsetBuffer::::from_repeated_length(i32::MAX as usize / 4, 5); + } + + #[test] + #[should_panic(expected = "offset overflow")] + fn from_repeated_lengths_offset_repeat_overflow() { + OffsetBuffer::::from_repeated_length(1, i32::MAX as usize + 1); + } + + #[test] + #[should_panic(expected = "offset overflow")] + fn from_repeated_lengths_usize_length_overflow() { + OffsetBuffer::::from_repeated_length(usize::MAX, 1); + } + + #[test] + #[should_panic(expected = "usize overflow")] + fn from_repeated_lengths_usize_length_usize_overflow() { + OffsetBuffer::::from_repeated_length(usize::MAX, 2); + } + + #[test] + #[should_panic(expected = "offset overflow")] + fn from_repeated_lengths_usize_repeat_overflow() { + OffsetBuffer::::from_repeated_length(1, usize::MAX); + } + #[test] fn get_lengths() { let offsets = OffsetBuffer::::new(ScalarBuffer::::from(vec![0, 1, 4, 9])); @@ -323,4 +393,76 @@ mod tests { let default = OffsetBuffer::::default(); assert_eq!(default.as_ref(), &[0]); } + + #[test] + fn from_repeated_length_basic() { + // Basic case with length 4, repeated 3 times + let buffer = OffsetBuffer::::from_repeated_length(4, 3); + assert_eq!(buffer.as_ref(), &[0, 4, 8, 12]); + + // Verify the lengths are correct + let lengths: Vec = buffer.lengths().collect(); + assert_eq!(lengths, vec![4, 4, 4]); + } + + #[test] + fn from_repeated_length_single_repeat() { + // Length 5, repeated once + let buffer = OffsetBuffer::::from_repeated_length(5, 1); + assert_eq!(buffer.as_ref(), &[0, 5]); + + let lengths: Vec = buffer.lengths().collect(); + assert_eq!(lengths, vec![5]); + } + + #[test] + fn from_repeated_length_zero_repeats() { + let buffer = OffsetBuffer::::from_repeated_length(10, 0); + assert_eq!(buffer, OffsetBuffer::::new_empty()); + } + + #[test] + fn from_repeated_length_zero_length() { + // Zero length, repeated 5 times (all zeros) + let buffer = OffsetBuffer::::from_repeated_length(0, 5); + assert_eq!(buffer.as_ref(), &[0, 0, 0, 0, 0, 0]); + + // All lengths should be 0 + let lengths: Vec = buffer.lengths().collect(); + assert_eq!(lengths, vec![0, 0, 0, 0, 0]); + } + + #[test] + fn from_repeated_length_large_values() { + // Test with larger values that don't overflow + let buffer = OffsetBuffer::::from_repeated_length(1000, 100); + assert_eq!(buffer[0], 0); + + // Verify all lengths are 1000 + let lengths: Vec = buffer.lengths().collect(); + assert_eq!(lengths.len(), 100); + assert!(lengths.iter().all(|&len| len == 1000)); + } + + #[test] + fn from_repeated_length_unit_length() { + // Length 1, repeated multiple times + let buffer = OffsetBuffer::::from_repeated_length(1, 10); + assert_eq!(buffer.as_ref(), &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); + + let lengths: Vec = buffer.lengths().collect(); + assert_eq!(lengths, vec![1; 10]); + } + + #[test] + fn from_repeated_length_max_safe_values() { + // Test with maximum safe values for i32 + // i32::MAX / 3 ensures we don't overflow when repeated twice + let third_max = (i32::MAX / 3) as usize; + let buffer = OffsetBuffer::::from_repeated_length(third_max, 2); + assert_eq!( + buffer.as_ref(), + &[0, third_max as i32, (third_max * 2) as i32] + ); + } } diff --git a/arrow-buffer/src/buffer/ops.rs b/arrow-buffer/src/buffer/ops.rs index c69e5c6deb10..36efe876432d 100644 --- a/arrow-buffer/src/buffer/ops.rs +++ b/arrow-buffer/src/buffer/ops.rs @@ -16,10 +16,16 @@ // under the License. use super::{Buffer, MutableBuffer}; +use crate::BooleanBuffer; use crate::util::bit_util::ceil; /// Apply a bitwise operation `op` to four inputs and return the result as a Buffer. -/// The inputs are treated as bitmaps, meaning that offsets and length are specified in number of bits. +/// +/// The inputs are treated as bitmaps, meaning that offsets and length are +/// specified in number of bits. +/// +/// NOTE: The operation `op` is applied to chunks of 64 bits (u64) and any bits +/// outside the offsets and len are set to zero out before calling `op`. pub fn bitwise_quaternary_op_helper( buffers: [&Buffer; 4], offsets: [usize; 4], @@ -59,7 +65,12 @@ where } /// Apply a bitwise operation `op` to two inputs and return the result as a Buffer. -/// The inputs are treated as bitmaps, meaning that offsets and length are specified in number of bits. +/// +/// The inputs are treated as bitmaps, meaning that offsets and length are +/// specified in number of bits. +/// +/// NOTE: The operation `op` is applied to chunks of 64 bits (u64) and any bits +/// outside the offsets and len are set to zero out before calling `op`. pub fn bitwise_bin_op_helper( left: &Buffer, left_offset_in_bits: usize, @@ -92,7 +103,12 @@ where } /// Apply a bitwise operation `op` to one input and return the result as a Buffer. -/// The input is treated as a bitmap, meaning that offset and length are specified in number of bits. +/// +/// The input is treated as a bitmap, meaning that offset and length are +/// specified in number of bits. +/// +/// NOTE: The operation `op` is applied to chunks of 64 bits (u64) and any bits +/// outside the offsets and len are set to zero out before calling `op`. pub fn bitwise_unary_op_helper( left: &Buffer, offset_in_bits: usize, @@ -134,7 +150,7 @@ pub fn buffer_bin_and( right_offset_in_bits: usize, len_in_bits: usize, ) -> Buffer { - bitwise_bin_op_helper( + BooleanBuffer::from_bitwise_binary_op( left, left_offset_in_bits, right, @@ -142,6 +158,7 @@ pub fn buffer_bin_and( len_in_bits, |a, b| a & b, ) + .into_inner() } /// Apply a bitwise or to two inputs and return the result as a Buffer. @@ -153,7 +170,7 @@ pub fn buffer_bin_or( right_offset_in_bits: usize, len_in_bits: usize, ) -> Buffer { - bitwise_bin_op_helper( + BooleanBuffer::from_bitwise_binary_op( left, left_offset_in_bits, right, @@ -161,6 +178,7 @@ pub fn buffer_bin_or( len_in_bits, |a, b| a | b, ) + .into_inner() } /// Apply a bitwise xor to two inputs and return the result as a Buffer. @@ -172,7 +190,7 @@ pub fn buffer_bin_xor( right_offset_in_bits: usize, len_in_bits: usize, ) -> Buffer { - bitwise_bin_op_helper( + BooleanBuffer::from_bitwise_binary_op( left, left_offset_in_bits, right, @@ -180,6 +198,7 @@ pub fn buffer_bin_xor( len_in_bits, |a, b| a ^ b, ) + .into_inner() } /// Apply a bitwise and_not to two inputs and return the result as a Buffer. @@ -191,7 +210,7 @@ pub fn buffer_bin_and_not( right_offset_in_bits: usize, len_in_bits: usize, ) -> Buffer { - bitwise_bin_op_helper( + BooleanBuffer::from_bitwise_binary_op( left, left_offset_in_bits, right, @@ -199,10 +218,11 @@ pub fn buffer_bin_and_not( len_in_bits, |a, b| a & !b, ) + .into_inner() } /// Apply a bitwise not to one input and return the result as a Buffer. /// The input is treated as a bitmap, meaning that offset and length are specified in number of bits. pub fn buffer_unary_not(left: &Buffer, offset_in_bits: usize, len_in_bits: usize) -> Buffer { - bitwise_unary_op_helper(left, offset_in_bits, len_in_bits, |a| !a) + BooleanBuffer::from_bitwise_unary_op(left, offset_in_bits, len_in_bits, |a| !a).into_inner() } diff --git a/arrow-buffer/src/buffer/run.rs b/arrow-buffer/src/buffer/run.rs index cc6d19044feb..6603dec1bac1 100644 --- a/arrow-buffer/src/buffer/run.rs +++ b/arrow-buffer/src/buffer/run.rs @@ -15,78 +15,111 @@ // specific language governing permissions and limitations // under the License. -use crate::buffer::ScalarBuffer; use crate::ArrowNativeType; +use crate::buffer::ScalarBuffer; -/// A slice-able buffer of monotonically increasing, positive integers used to store run-ends -/// -/// # Logical vs Physical +/// A buffer of monotonically increasing, positive integers used to store run-ends. /// -/// A [`RunEndBuffer`] is used to encode runs of the same value, the index of each run is -/// called the physical index. The logical index is then the corresponding index in the logical -/// run-encoded array, i.e. a single run of length `3`, would have the logical indices `0..3`. +/// Used to compactly represent runs of the same value. Values being represented +/// are stored in a separate buffer from this struct. See [`RunArray`] for an example +/// of how this is used with a companion array to represent the values. /// -/// Each value in [`RunEndBuffer::values`] is the cumulative length of all runs in the -/// logical array, up to that physical index. +/// # Logical vs Physical /// -/// Consider a [`RunEndBuffer`] containing `[3, 4, 6]`. The maximum physical index is `2`, -/// as there are `3` values, and the maximum logical index is `5`, as the maximum run end -/// is `6`. The physical indices are therefore `[0, 0, 0, 1, 2, 2]` +/// Physically, each value in the `run_ends` buffer is the cumulative length of +/// all runs in the logical representation, up to that physical index. Consider +/// the following example: /// /// ```text -/// ┌─────────┐ ┌─────────┐ ┌─────────┐ -/// │ 3 │ │ 0 │ ─┬──────▶ │ 0 │ -/// ├─────────┤ ├─────────┤ │ ├─────────┤ -/// │ 4 │ │ 1 │ ─┤ ┌────▶ │ 1 │ -/// ├─────────┤ ├─────────┤ │ │ ├─────────┤ -/// │ 6 │ │ 2 │ ─┘ │ ┌──▶ │ 2 │ -/// └─────────┘ ├─────────┤ │ │ └─────────┘ -/// run ends │ 3 │ ───┘ │ physical indices -/// ├─────────┤ │ -/// │ 4 │ ─────┤ -/// ├─────────┤ │ -/// │ 5 │ ─────┘ -/// └─────────┘ -/// logical indices +/// physical logical +/// ┌─────────┬─────────┐ ┌─────────┬─────────┐ +/// │ 3 │ 0 │ ◄──────┬─ │ A │ 0 │ +/// ├─────────┼─────────┤ │ ├─────────┼─────────┤ +/// │ 4 │ 1 │ ◄────┐ ├─ │ A │ 1 │ +/// ├─────────┼─────────┤ │ │ ├─────────┼─────────┤ +/// │ 6 │ 2 │ ◄──┐ │ └─ │ A │ 2 │ +/// └─────────┴─────────┘ │ │ ├─────────┼─────────┤ +/// run-ends index │ └─── │ B │ 3 │ +/// │ ├─────────┼─────────┤ +/// logical_offset = 0 ├───── │ C │ 4 │ +/// logical_length = 6 │ ├─────────┼─────────┤ +/// └───── │ C │ 5 │ +/// └─────────┴─────────┘ +/// values index /// ``` /// +/// A [`RunEndBuffer`] is physically the buffer and offset with length on the left. +/// In this case, the offset and length represent the whole buffer, so it is essentially +/// unsliced. See the section below on slicing for more details on how this buffer +/// handles slicing. +/// +/// This means that multiple logical values are represented in the same physical index, +/// and multiple logical indices map to the same physical index. The [`RunEndBuffer`] +/// containing `[3, 4, 6]` is essentially the physical indices `[0, 0, 0, 1, 2, 2]`, +/// and having a separately stored buffer of values such as `[A, B, C]` can turn +/// this into a representation of `[A, A, A, B, C, C]`. +/// /// # Slicing /// -/// In order to provide zero-copy slicing, this container stores a separate offset and length +/// In order to provide zero-copy slicing, this struct stores a separate **logical** +/// offset and length. Consider the following example: /// -/// For example, a [`RunEndBuffer`] containing values `[3, 6, 8]` with offset and length `4` would -/// describe the physical indices `1, 1, 2, 2` +/// ```text +/// physical logical +/// ┌─────────┬─────────┐ ┌ ─ ─ ─ ─ ┬ ─ ─ ─ ─ ┐ +/// │ 3 │ 0 │ ◄──────┐ A 0 +/// ├─────────┼─────────┤ │ ├── ─ ─ ─ ┼ ─ ─ ─ ─ ┤ +/// │ 4 │ 1 │ ◄────┐ │ A 1 +/// ├─────────┼─────────┤ │ │ ├─────────┼─────────┤ +/// │ 6 │ 2 │ ◄──┐ │ └─ │ A │ 2 │◄─── logical_offset +/// └─────────┴─────────┘ │ │ ├─────────┼─────────┤ +/// run-ends index │ └─── │ B │ 3 │ +/// │ ├─────────┼─────────┤ +/// logical_offset = 2 └───── │ C │ 4 │ +/// logical_length = 3 ├─────────┼─────────┤ +/// C 5 ◄─── logical_offset + logical_length +/// └ ─ ─ ─ ─ ┴ ─ ─ ─ ─ ┘ +/// values index +/// ``` +/// +/// The physical `run_ends` [`ScalarBuffer`] remains unchanged, in order to facilitate +/// zero-copy. However, we now offset into the **logical** representation with an +/// accompanying length. This allows us to represent values `[A, B, C]` using physical +/// indices `0, 1, 2` with the same underlying physical buffer, at the cost of two +/// extra `usize`s to represent the logical slice that was taken. /// -/// For example, a [`RunEndBuffer`] containing values `[6, 8, 9]` with offset `2` and length `5` -/// would describe the physical indices `0, 0, 0, 0, 1` +/// (A [`RunEndBuffer`] is considered unsliced when `logical_offset` is `0` and +/// `logical_length` is equal to the last value in `run_ends`) /// +/// [`RunArray`]: https://docs.rs/arrow/latest/arrow/array/struct.RunArray.html /// [Run-End encoded layout]: https://arrow.apache.org/docs/format/Columnar.html#run-end-encoded-layout #[derive(Debug, Clone)] pub struct RunEndBuffer { run_ends: ScalarBuffer, - len: usize, - offset: usize, + logical_length: usize, + logical_offset: usize, } impl RunEndBuffer where E: ArrowNativeType, { - /// Create a new [`RunEndBuffer`] from a [`ScalarBuffer`], an `offset` and `len` + /// Create a new [`RunEndBuffer`] from a [`ScalarBuffer`], `logical_offset` + /// and `logical_length`. /// /// # Panics /// - /// - `buffer` does not contain strictly increasing values greater than zero - /// - the last value of `buffer` is less than `offset + len` - pub fn new(run_ends: ScalarBuffer, offset: usize, len: usize) -> Self { + /// - `run_ends` does not contain strictly increasing values greater than zero + /// - The last value of `run_ends` is less than `logical_offset + logical_length` + pub fn new(run_ends: ScalarBuffer, logical_offset: usize, logical_length: usize) -> Self { assert!( run_ends.windows(2).all(|w| w[0] < w[1]), "run-ends not strictly increasing" ); - if len != 0 { + if logical_length != 0 { assert!(!run_ends.is_empty(), "non-empty slice but empty run-ends"); - let end = E::from_usize(offset.saturating_add(len)).unwrap(); + let end = E::from_usize(logical_offset.saturating_add(logical_length)).unwrap(); assert!( *run_ends.first().unwrap() > E::usize_as(0), "run-ends not greater than 0" @@ -99,41 +132,46 @@ where Self { run_ends, - offset, - len, + logical_offset, + logical_length, } } - /// Create a new [`RunEndBuffer`] from an [`ScalarBuffer`], an `offset` and `len` + /// Create a new [`RunEndBuffer`] from a [`ScalarBuffer`], `logical_offset` + /// and `logical_length`. /// /// # Safety /// - /// - `buffer` must contain strictly increasing values greater than zero - /// - The last value of `buffer` must be greater than or equal to `offset + len` - pub unsafe fn new_unchecked(run_ends: ScalarBuffer, offset: usize, len: usize) -> Self { + /// - `run_ends` must contain strictly increasing values greater than zero + /// - The last value of `run_ends` must be greater than or equal to `logical_offset + logical_len` + pub unsafe fn new_unchecked( + run_ends: ScalarBuffer, + logical_offset: usize, + logical_length: usize, + ) -> Self { Self { run_ends, - offset, - len, + logical_offset, + logical_length, } } - /// Returns the logical offset into the run-ends stored by this buffer + /// Returns the logical offset into the run-ends stored by this buffer. #[inline] pub fn offset(&self) -> usize { - self.offset + self.logical_offset } - /// Returns the logical length of the run-ends stored by this buffer + /// Returns the logical length of the run-ends stored by this buffer. #[inline] pub fn len(&self) -> usize { - self.len + self.logical_length } - /// Returns true if this buffer is empty + /// Returns true if this buffer is logically empty. #[inline] pub fn is_empty(&self) -> bool { - self.len == 0 + self.logical_length == 0 } /// Free up unused memory. @@ -142,23 +180,50 @@ where self.run_ends.shrink_to_fit(); } - /// Returns the values of this [`RunEndBuffer`] not including any offset + /// Returns the physical (**unsliced**) run ends of this buffer. + /// + /// Take care when operating on these values as it doesn't take into account + /// any logical slicing that may have occurred. #[inline] pub fn values(&self) -> &[E] { &self.run_ends } - /// Returns the maximum run-end encoded in the underlying buffer + /// Returns an iterator yielding run ends adjusted for the logical slice. + /// + /// Each yielded value is subtracted by the [`logical_offset`] and capped + /// at the [`logical_length`]. + /// + /// [`logical_offset`]: Self::offset + /// [`logical_length`]: Self::len + pub fn sliced_values(&self) -> impl Iterator + '_ { + let offset = self.logical_offset; + let len = self.logical_length; + let start = self.get_start_physical_index(); + let end = self.get_end_physical_index(); + self.run_ends[start..=end].iter().map(move |&val| { + let val = val.as_usize().saturating_sub(offset).min(len); + E::from_usize(val).unwrap() + }) + } + + /// Returns the maximum run-end encoded in the underlying buffer; that is, the + /// last physical run of the buffer. This does not take into account any logical + /// slicing that may have occurred. #[inline] pub fn max_value(&self) -> usize { self.values().last().copied().unwrap_or_default().as_usize() } - /// Performs a binary search to find the physical index for the given logical index + /// Performs a binary search to find the physical index for the given logical + /// index. + /// + /// Useful for extracting the corresponding physical `run_ends` when this buffer + /// is logically sliced. /// - /// The result is arbitrary if `logical_index >= self.len()` + /// The result is arbitrary if `logical_index >= self.len()`. pub fn get_physical_index(&self, logical_index: usize) -> usize { - let logical_index = E::usize_as(self.offset + logical_index); + let logical_index = E::usize_as(self.logical_offset + logical_index); let cmp = |p: &E| p.partial_cmp(&logical_index).unwrap(); match self.run_ends.binary_search_by(cmp) { @@ -167,49 +232,137 @@ where } } - /// Returns the physical index at which the logical array starts + /// Returns the physical index at which the logical array starts. + /// + /// The same as calling `get_physical_index(0)` but with a fast path if the + /// buffer is not logically sliced, in which case it always returns `0`. pub fn get_start_physical_index(&self) -> usize { - if self.offset == 0 || self.len == 0 { + if self.logical_offset == 0 || self.logical_length == 0 { return 0; } // Fallback to binary search self.get_physical_index(0) } - /// Returns the physical index at which the logical array ends + /// Returns the physical index at which the logical array ends. + /// + /// The same as calling `get_physical_index(length - 1)` but with a fast path + /// if the buffer is not logically sliced, in which case it returns `length - 1`. pub fn get_end_physical_index(&self) -> usize { - if self.len == 0 { + if self.logical_length == 0 { return 0; } - if self.max_value() == self.offset + self.len { + if self.max_value() == self.logical_offset + self.logical_length { return self.values().len() - 1; } // Fallback to binary search - self.get_physical_index(self.len - 1) + self.get_physical_index(self.logical_length - 1) } - /// Slices this [`RunEndBuffer`] by the provided `offset` and `length` - pub fn slice(&self, offset: usize, len: usize) -> Self { + /// Slices this [`RunEndBuffer`] by the provided `logical_offset` and `logical_length`. + /// + /// # Panics + /// + /// - Specified slice (`logical_offset` + `logical_length`) exceeds existing + /// logical length + pub fn slice(&self, logical_offset: usize, logical_length: usize) -> Self { assert!( - offset.saturating_add(len) <= self.len, + logical_offset.saturating_add(logical_length) <= self.logical_length, "the length + offset of the sliced RunEndBuffer cannot exceed the existing length" ); Self { run_ends: self.run_ends.clone(), - offset: self.offset + offset, - len, + logical_offset: self.logical_offset + logical_offset, + logical_length, } } - /// Returns the inner [`ScalarBuffer`] + /// Returns the inner [`ScalarBuffer`]. pub fn inner(&self) -> &ScalarBuffer { &self.run_ends } - /// Returns the inner [`ScalarBuffer`], consuming self + /// Returns the inner [`ScalarBuffer`], consuming self. pub fn into_inner(self) -> ScalarBuffer { self.run_ends } + + /// Returns the physical indices corresponding to the provided logical indices. + /// + /// Given a slice of logical indices, this method returns a `Vec` containing the + /// corresponding physical indices into the run-ends buffer. + /// + /// This method operates by iterating the logical indices in sorted order, instead of + /// finding the physical index for each logical index using binary search via + /// the function [`RunEndBuffer::get_physical_index`]. + /// + /// Running benchmarks on both approaches showed that the approach used here + /// scaled well for larger inputs. + /// + /// See for more details. + /// + /// # Errors + /// + /// If any logical index is out of bounds (>= self.len()), returns an error containing the invalid index. + #[inline] + pub fn get_physical_indices(&self, logical_indices: &[I]) -> Result, I> + where + I: ArrowNativeType, + { + let len = self.len(); + let offset = self.offset(); + + let indices_len = logical_indices.len(); + + if indices_len == 0 { + return Ok(vec![]); + } + + // `ordered_indices` store index into `logical_indices` and can be used + // to iterate `logical_indices` in sorted order. + let mut ordered_indices: Vec = (0..indices_len).collect(); + + // Instead of sorting `logical_indices` directly, sort the `ordered_indices` + // whose values are index of `logical_indices` + ordered_indices.sort_unstable_by(|lhs, rhs| { + logical_indices[*lhs] + .partial_cmp(&logical_indices[*rhs]) + .unwrap() + }); + + // Return early if all the logical indices cannot be converted to physical indices. + let largest_logical_index = logical_indices[*ordered_indices.last().unwrap()].as_usize(); + if largest_logical_index >= len { + return Err(logical_indices[*ordered_indices.last().unwrap()]); + } + + // Skip some physical indices based on offset. + let skip_value = self.get_start_physical_index(); + + let mut physical_indices = vec![0; indices_len]; + + let mut ordered_index = 0_usize; + for (physical_index, run_end) in self.values().iter().enumerate().skip(skip_value) { + // Get the run end index (relative to offset) of current physical index + let run_end_value = run_end.as_usize() - offset; + + // All the `logical_indices` that are less than current run end index + // belongs to current physical index. + while ordered_index < indices_len + && logical_indices[ordered_indices[ordered_index]].as_usize() < run_end_value + { + physical_indices[ordered_indices[ordered_index]] = physical_index; + ordered_index += 1; + } + } + + // If there are input values >= run_ends.last_value then we'll not be able to convert + // all logical indices to physical indices. + if ordered_index < logical_indices.len() { + return Err(logical_indices[ordered_indices[ordered_index]]); + } + Ok(physical_indices) + } } #[cfg(test)] @@ -233,4 +386,26 @@ mod tests { assert_eq!(buffer.get_start_physical_index(), 0); assert_eq!(buffer.get_end_physical_index(), 0); } + + #[test] + fn test_sliced_values() { + // [0, 0, 1, 2, 2, 2] + let buffer = RunEndBuffer::new(vec![2i32, 3, 6].into(), 0, 6); + + // Slice: [0, 1, 2, 2] start: 1, len: 4 + // Logical indices: 1, 2, 3, 4 + // Original run ends: [2, 3, 6] + // Adjusted: [2-1, 3-1, 6-1] capped at 4 -> [1, 2, 4] + let sliced = buffer.slice(1, 4); + let sliced_values: Vec = sliced.sliced_values().collect(); + assert_eq!(sliced_values, &[1, 2, 4]); + + // Slice: [2, 2] start: 4, len: 2 + // Original run ends: [2, 3, 6] + // Slicing at 4 means we only have the last run (physical index 2, which ends at 6) + // Adjusted: [6-4] capped at 2 -> [2] + let sliced = buffer.slice(4, 2); + let sliced_values: Vec = sliced.sliced_values().collect(); + assert_eq!(sliced_values, &[2]); + } } diff --git a/arrow-buffer/src/buffer/scalar.rs b/arrow-buffer/src/buffer/scalar.rs index 6c66060fb95f..3c5334ca5118 100644 --- a/arrow-buffer/src/buffer/scalar.rs +++ b/arrow-buffer/src/buffer/scalar.rs @@ -29,17 +29,38 @@ use std::ops::Deref; /// with the following differences: /// /// - slicing and cloning is O(1). -/// - it supports external allocated memory +/// - support for external allocated memory (e.g. via FFI). /// +/// See [`Buffer`] for more low-level memory management details. +/// +/// # Example: Convert to/from Vec (without copies) +/// +/// (See [`Buffer::from_vec`] and [`Buffer::into_vec`] for a lower level API) /// ``` /// # use arrow_buffer::ScalarBuffer; /// // Zero-copy conversion from Vec /// let buffer = ScalarBuffer::from(vec![1, 2, 3]); /// assert_eq!(&buffer, &[1, 2, 3]); +/// // convert the buffer back to Vec without copy assuming: +/// // 1. the inner buffer is not sliced +/// // 2. the inner buffer uses standard allocation +/// // 3. there are no other references to the inner buffer +/// let vec: Vec = buffer.into(); +/// assert_eq!(&vec, &[1, 2, 3]); +/// ``` /// +/// # Example: Zero copy slicing +/// ``` +/// # use arrow_buffer::ScalarBuffer; +/// let buffer = ScalarBuffer::from(vec![1, 2, 3]); +/// assert_eq!(&buffer, &[1, 2, 3]); /// // Zero-copy slicing /// let sliced = buffer.slice(1, 2); /// assert_eq!(&sliced, &[2, 3]); +/// // Original buffer is unchanged +/// assert_eq!(&buffer, &[1, 2, 3]); +/// // converting the sliced buffer back to Vec incurs a copy +/// let vec: Vec = sliced.into(); /// ``` #[derive(Clone, Default)] pub struct ScalarBuffer { @@ -72,6 +93,19 @@ impl ScalarBuffer { buffer.slice_with_length(byte_offset, byte_len).into() } + /// Unsafe function to create a new [`ScalarBuffer`] from a [`Buffer`]. + /// Only use for testing purpose. + /// + /// # Safety + /// + /// This function is unsafe because it does not check if the `buffer` is aligned + pub unsafe fn new_unchecked(buffer: Buffer) -> Self { + Self { + buffer, + phantom: Default::default(), + } + } + /// Free up unused memory. pub fn shrink_to_fit(&mut self) { self.buffer.shrink_to_fit(); @@ -99,6 +133,16 @@ impl ScalarBuffer { pub fn ptr_eq(&self, other: &Self) -> bool { self.buffer.ptr_eq(&other.buffer) } + + /// Returns the number of elements in the buffer + pub fn len(&self) -> usize { + self.buffer.len() / std::mem::size_of::() + } + + /// Returns if the buffer is empty + pub fn is_empty(&self) -> bool { + self.len() == 0 + } } impl Deref for ScalarBuffer { @@ -139,8 +183,10 @@ impl From for ScalarBuffer { is_aligned, "Memory pointer is not aligned with the specified scalar type" ), - Deallocation::Custom(_, _) => - assert!(is_aligned, "Memory pointer from external source (e.g, FFI) is not aligned with the specified scalar type. Before importing buffer through FFI, please make sure the allocation is aligned."), + Deallocation::Custom(_, _) => assert!( + is_aligned, + "Memory pointer from external source (e.g, FFI) is not aligned with the specified scalar type. Before importing buffer through FFI, please make sure the allocation is aligned." + ), } Self { diff --git a/arrow-buffer/src/builder/boolean.rs b/arrow-buffer/src/builder/boolean.rs index bdcc3a55dbf2..7990be1e7cc9 100644 --- a/arrow-buffer/src/builder/boolean.rs +++ b/arrow-buffer/src/builder/boolean.rs @@ -15,7 +15,8 @@ // specific language governing permissions and limitations // under the License. -use crate::{bit_mask, bit_util, BooleanBuffer, Buffer, MutableBuffer}; +use crate::bit_util::apply_bitwise_binary_op; +use crate::{BooleanBuffer, Buffer, MutableBuffer, NullBuffer, bit_util}; use std::ops::Range; /// Builder for [`BooleanBuffer`] @@ -139,7 +140,6 @@ impl BooleanBufferBuilder { /// Reserve space to at least `additional` new bits. /// Capacity will be `>= self.len() + additional`. - /// New bytes are uninitialized and reading them is undefined behavior. #[inline] pub fn reserve(&mut self, additional: usize) { let capacity = self.len + additional; @@ -218,13 +218,16 @@ impl BooleanBufferBuilder { pub fn append_packed_range(&mut self, range: Range, to_set: &[u8]) { let offset_write = self.len; let len = range.end - range.start; + // allocate new bits as 0 self.advance(len); - bit_mask::set_bits( + // copy bits from to_set into self.buffer a word at a time + apply_bitwise_binary_op( self.buffer.as_slice_mut(), - to_set, offset_write, + to_set, range.start, len, + |_a, b| b, // copy bits from to_set ); } @@ -256,6 +259,20 @@ impl BooleanBufferBuilder { pub fn finish_cloned(&self) -> BooleanBuffer { BooleanBuffer::new(Buffer::from_slice_ref(self.as_slice()), 0, self.len) } + + /// Extends the builder from a trusted length iterator of booleans. + /// # Safety + /// Callers must ensure that `iter` reports an exact size via `size_hint`. + /// + #[inline] + pub unsafe fn extend_trusted_len(&mut self, iterator: I) + where + I: Iterator, + { + let len = iterator.size_hint().0; + unsafe { self.buffer.extend_bool_trusted_len(iterator, self.len) }; + self.len += len; + } } impl From for Buffer { @@ -272,6 +289,14 @@ impl From for BooleanBuffer { } } +impl From for NullBuffer { + #[inline] + fn from(builder: BooleanBufferBuilder) -> Self { + let boolean_buffer = BooleanBuffer::from(builder); + NullBuffer::new(boolean_buffer) + } +} + #[cfg(test)] mod tests { use super::*; @@ -523,4 +548,65 @@ mod tests { assert_eq!(buf.len(), buf2.inner().len()); assert_eq!(buf.as_slice(), buf2.values()); } + + #[test] + fn test_extend() { + let mut builder = BooleanBufferBuilder::new(0); + let bools = vec![true, false, true, true, false, true, true, true, false]; + unsafe { builder.extend_trusted_len(bools.clone().into_iter()) }; + assert_eq!(builder.len(), 9); + let finished = builder.finish(); + for (i, v) in bools.into_iter().enumerate() { + assert_eq!(finished.value(i), v); + } + + // Test > 64 bits + let mut builder = BooleanBufferBuilder::new(0); + let bools: Vec<_> = (0..100).map(|i| i % 3 == 0 || i % 7 == 0).collect(); + unsafe { builder.extend_trusted_len(bools.clone().into_iter()) }; + assert_eq!(builder.len(), 100); + let finished = builder.finish(); + for (i, v) in bools.into_iter().enumerate() { + assert_eq!(finished.value(i), v, "at index {}", i); + } + } + + #[test] + fn test_extend_misaligned() { + // Test misaligned start + for offset in 1..65 { + let mut builder = BooleanBufferBuilder::new(0); + builder.append_n(offset, false); + + let bools: Vec<_> = (0..100).map(|i| i % 3 == 0 || i % 7 == 0).collect(); + unsafe { builder.extend_trusted_len(bools.clone().into_iter()) }; + assert_eq!(builder.len(), offset + 100); + + let finished = builder.finish(); + for i in 0..offset { + assert!(!finished.value(i)); + } + for (i, v) in bools.into_iter().enumerate() { + assert_eq!(finished.value(offset + i), v, "at index {}", offset + i); + } + } + } + + #[test] + fn test_extend_misaligned_end() { + for len in 1..130 { + let mut builder = BooleanBufferBuilder::new(0); + let mut bools: Vec<_> = (0..len).map(|i| i % 2 == 0).collect(); + unsafe { builder.extend_trusted_len(bools.clone().into_iter()) }; + unsafe { builder.extend_trusted_len(bools.clone().into_iter()) }; + let copy = bools.clone(); + bools.extend(copy); + assert_eq!(builder.len(), 2 * len); + + let finished = builder.finish(); + for (i, &v) in bools.iter().enumerate() { + assert_eq!(finished.value(i), v, "at index {} for len {}", i, len); + } + } + } } diff --git a/arrow-buffer/src/builder/mod.rs b/arrow-buffer/src/builder/mod.rs index f7e0e29dace4..abe510bdabc6 100644 --- a/arrow-buffer/src/builder/mod.rs +++ b/arrow-buffer/src/builder/mod.rs @@ -26,7 +26,7 @@ pub use null::*; pub use offset::*; use crate::{ArrowNativeType, Buffer, MutableBuffer}; -use std::{iter, marker::PhantomData}; +use std::marker::PhantomData; /// Builder for creating a [Buffer] object. /// @@ -214,7 +214,7 @@ impl BufferBuilder { #[inline] pub fn append_n(&mut self, n: usize, v: T) { self.reserve(n); - self.extend(iter::repeat(v).take(n)) + self.extend(std::iter::repeat_n(v, n)) } /// Appends `n`, zero-initialized values diff --git a/arrow-buffer/src/bytes.rs b/arrow-buffer/src/bytes.rs index b811bd2c6b40..8f912b807da5 100644 --- a/arrow-buffer/src/bytes.rs +++ b/arrow-buffer/src/bytes.rs @@ -26,6 +26,11 @@ use std::{fmt::Debug, fmt::Formatter}; use crate::alloc::Deallocation; use crate::buffer::dangling_ptr; +#[cfg(feature = "pool")] +use crate::pool::{MemoryPool, MemoryReservation}; +#[cfg(feature = "pool")] +use std::sync::Mutex; + /// A continuous, fixed-size, immutable memory region that knows how to de-allocate itself. /// /// Note that this structure is an internal implementation detail of the @@ -49,6 +54,10 @@ pub struct Bytes { /// how to deallocate this region deallocation: Deallocation, + + /// Memory reservation for tracking memory usage + #[cfg(feature = "pool")] + pub(super) reservation: Mutex>>, } impl Bytes { @@ -70,6 +79,8 @@ impl Bytes { ptr, len, deallocation, + #[cfg(feature = "pool")] + reservation: Mutex::new(None), } } @@ -101,6 +112,27 @@ impl Bytes { } } + /// Register this [`Bytes`] with the provided [`MemoryPool`], replacing any prior reservation. + #[cfg(feature = "pool")] + pub fn claim(&self, pool: &dyn MemoryPool) { + *self.reservation.lock().unwrap() = Some(pool.reserve(self.capacity())); + } + + /// Resize the memory reservation of this buffer + /// + /// This is a no-op if this buffer doesn't have a reservation. + #[cfg(feature = "pool")] + fn resize_reservation(&self, new_size: usize) { + let mut guard = self.reservation.lock().unwrap(); + if let Some(mut reservation) = guard.take() { + // Resize the reservation + reservation.resize(new_size); + + // Put it back + *guard = Some(reservation); + } + } + /// Try to reallocate the underlying memory region to a new size (smaller or larger). /// /// Only works for bytes allocated with the standard allocator. @@ -135,6 +167,13 @@ impl Bytes { self.ptr = ptr; self.len = new_len; self.deallocation = Deallocation::Standard(new_layout); + + #[cfg(feature = "pool")] + { + // Resize reservation + self.resize_reservation(new_len); + } + return Ok(()); } } @@ -199,6 +238,8 @@ impl From for Bytes { len, ptr: NonNull::new(value.as_ptr() as _).unwrap(), deallocation: Deallocation::Custom(std::sync::Arc::new(value), len), + #[cfg(feature = "pool")] + reservation: Mutex::new(None), } } } @@ -209,14 +250,83 @@ mod tests { #[test] fn test_from_bytes() { - let bytes = bytes::Bytes::from(vec![1, 2, 3, 4]); - let arrow_bytes: Bytes = bytes.clone().into(); + let message = b"hello arrow"; - assert_eq!(bytes.as_ptr(), arrow_bytes.as_ptr()); + // we can create a Bytes from bytes::Bytes (created from slices) + let c_bytes: bytes::Bytes = message.as_ref().into(); + let a_bytes: Bytes = c_bytes.into(); + assert_eq!(a_bytes.as_slice(), message); - drop(bytes); - drop(arrow_bytes); + // we can create a Bytes from bytes::Bytes (created from Vec) + let c_bytes: bytes::Bytes = bytes::Bytes::from(message.to_vec()); + let a_bytes: Bytes = c_bytes.into(); + assert_eq!(a_bytes.as_slice(), message); + } + + #[cfg(feature = "pool")] + mod pool_tests { + use super::*; + + use crate::pool::TrackingMemoryPool; + + #[test] + fn test_bytes_with_pool() { + // Create a standard allocation + let buffer = unsafe { + let layout = + std::alloc::Layout::from_size_align(1024, crate::alloc::ALIGNMENT).unwrap(); + let ptr = std::alloc::alloc(layout); + assert!(!ptr.is_null()); + + Bytes::new( + NonNull::new(ptr).unwrap(), + 1024, + Deallocation::Standard(layout), + ) + }; + + // Create a memory pool + let pool = TrackingMemoryPool::default(); + assert_eq!(pool.used(), 0); + + // Reserve memory and assign to buffer. Claim twice. + buffer.claim(&pool); + assert_eq!(pool.used(), 1024); + buffer.claim(&pool); + assert_eq!(pool.used(), 1024); + + // Memory should be released when buffer is dropped + drop(buffer); + assert_eq!(pool.used(), 0); + } + + #[test] + fn test_bytes_drop_releases_pool() { + let pool = TrackingMemoryPool::default(); + + { + // Create a buffer with pool + let _buffer = unsafe { + let layout = + std::alloc::Layout::from_size_align(1024, crate::alloc::ALIGNMENT).unwrap(); + let ptr = std::alloc::alloc(layout); + assert!(!ptr.is_null()); + + let bytes = Bytes::new( + NonNull::new(ptr).unwrap(), + 1024, + Deallocation::Standard(layout), + ); + + bytes.claim(&pool); + bytes + }; - let _ = Bytes::from(bytes::Bytes::new()); + assert_eq!(pool.used(), 1024); + } + + // Buffer has been dropped, memory should be released + assert_eq!(pool.used(), 0); + } } } diff --git a/arrow-buffer/src/lib.rs b/arrow-buffer/src/lib.rs index 174cdc4d9c18..230747b8b84a 100644 --- a/arrow-buffer/src/lib.rs +++ b/arrow-buffer/src/lib.rs @@ -16,14 +16,27 @@ // under the License. //! Low-level buffer abstractions for [Apache Arrow Rust](https://docs.rs/arrow) +//! +//! # Byte Storage abstractions +//! - [`MutableBuffer`]: Raw memory buffer that can be mutated and grown +//! - [`Buffer`]: Immutable buffer that is shared across threads +//! +//! # Typed Abstractions +//! +//! There are also several wrappers over [`Buffer`] with methods for +//! easier manipulation: +//! +//! - [`BooleanBuffer`][]: Bitmasks (buffer of packed bits) +//! - [`NullBuffer`][]: Arrow null (validity) bitmaps ([`BooleanBuffer`] with extra utilities) +//! - [`ScalarBuffer`][]: Typed buffer for primitive types (e.g., `i32`, `f64`) +//! - [`OffsetBuffer`][]: Offsets used in variable-length types (e.g., strings, lists) +//! - [`RunEndBuffer`][]: Run-ends used in run-encoded encoded data #![doc( html_logo_url = "https://arrow.apache.org/img/arrow-logo_chevrons_black-txt_white-bg.svg", html_favicon_url = "https://arrow.apache.org/img/arrow-logo_chevrons_black-txt_transparent-bg.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] -// used by [`buffer::mutable::dangling_ptr`] -#![cfg_attr(miri, feature(strict_provenance))] +#![cfg_attr(docsrs, feature(doc_cfg))] #![warn(missing_docs)] pub mod alloc; @@ -48,3 +61,8 @@ mod interval; pub use interval::*; mod arith; + +#[cfg(feature = "pool")] +mod pool; +#[cfg(feature = "pool")] +pub use pool::*; diff --git a/arrow-buffer/src/native.rs b/arrow-buffer/src/native.rs index eb8e067db0be..68058a4eeccd 100644 --- a/arrow-buffer/src/native.rs +++ b/arrow-buffer/src/native.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::{i256, IntervalDayTime, IntervalMonthDayNano}; +use crate::{IntervalDayTime, IntervalMonthDayNano, i256}; use half::f16; mod private { diff --git a/arrow-buffer/src/pool.rs b/arrow-buffer/src/pool.rs new file mode 100644 index 000000000000..95bd308a35be --- /dev/null +++ b/arrow-buffer/src/pool.rs @@ -0,0 +1,189 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module contains traits for memory pool traits and an implementation +//! for tracking memory usage. +//! +//! The basic traits are [`MemoryPool`] and [`MemoryReservation`]. And default +//! implementation of [`MemoryPool`] is [`TrackingMemoryPool`]. Their relationship +//! is as follows: +//! +//! ```text +//! (pool tracker) (resizable) +//! ┌──────────────────┐ fn reserve() ┌─────────────────────────┐ +//! │ trait MemoryPool │─────────────►│ trait MemoryReservation │ +//! └──────────────────┘ └─────────────────────────┘ +//! ``` + +use std::fmt::Debug; +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; + +/// A memory reservation within a [`MemoryPool`] that is freed on drop +pub trait MemoryReservation: Debug + Send + Sync { + /// Returns the size of this reservation in bytes. + fn size(&self) -> usize; + + /// Resize this reservation to a new size in bytes. + fn resize(&mut self, new_size: usize); +} + +/// A pool of memory that can be reserved and released. +/// +/// This is used to accurately track memory usage when buffers are shared +/// between multiple arrays or other data structures. +/// +/// For example, assume we have two arrays that share underlying buffer. +/// It's hard to tell how much memory is used by them because we can't +/// tell if the buffer is shared or not. +/// +/// ```text +/// Array A Array B +/// ┌────────────┐ ┌────────────┐ +/// │ slices... │ │ slices... │ +/// │────────────│ │────────────│ +/// │ Arc │ │ Arc │ (shared buffer) +/// └─────▲──────┘ └───────▲────┘ +/// │ │ +/// │ Bytes │ +/// │ ┌─────────────┐ │ +/// │ │ data... │ │ +/// │ │─────────────│ │ +/// └──│ Memory │──┘ (tracked with a memory pool) +/// │ Reservation │ +/// └─────────────┘ +/// ``` +/// +/// With a memory pool, we can count the memory usage by the shared buffer +/// directly. +pub trait MemoryPool: Debug + Send + Sync { + /// Reserves memory from the pool. Infallible. + /// + /// Returns a reservation of the requested size. + fn reserve(&self, size: usize) -> Box; + + /// Returns the current available memory in the pool. + /// + /// The pool may be overfilled, so this method might return a negative value. + fn available(&self) -> isize; + + /// Returns the current used memory from the pool. + fn used(&self) -> usize; + + /// Returns the maximum memory that can be reserved from the pool. + fn capacity(&self) -> usize; +} + +/// A simple [`MemoryPool`] that reports the total memory usage +#[derive(Debug, Default)] +pub struct TrackingMemoryPool(Arc); + +impl TrackingMemoryPool { + /// Returns the total allocated size + pub fn allocated(&self) -> usize { + self.0.load(Ordering::Relaxed) + } +} + +impl MemoryPool for TrackingMemoryPool { + fn reserve(&self, size: usize) -> Box { + self.0.fetch_add(size, Ordering::Relaxed); + Box::new(Tracker { + size, + shared: Arc::clone(&self.0), + }) + } + + fn available(&self) -> isize { + isize::MAX - self.used() as isize + } + + fn used(&self) -> usize { + self.0.load(Ordering::Relaxed) + } + + fn capacity(&self) -> usize { + usize::MAX + } +} + +#[derive(Debug)] +struct Tracker { + size: usize, + shared: Arc, +} + +impl Drop for Tracker { + fn drop(&mut self) { + self.shared.fetch_sub(self.size, Ordering::Relaxed); + } +} + +impl MemoryReservation for Tracker { + fn size(&self) -> usize { + self.size + } + + fn resize(&mut self, new: usize) { + match self.size < new { + true => self.shared.fetch_add(new - self.size, Ordering::Relaxed), + false => self.shared.fetch_sub(self.size - new, Ordering::Relaxed), + }; + self.size = new; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tracking_memory_pool() { + let pool = TrackingMemoryPool::default(); + + // Reserve 512 bytes + let reservation = pool.reserve(512); + assert_eq!(reservation.size(), 512); + assert_eq!(pool.used(), 512); + assert_eq!(pool.available(), isize::MAX - 512); + + // Reserve another 256 bytes + let reservation2 = pool.reserve(256); + assert_eq!(reservation2.size(), 256); + assert_eq!(pool.used(), 768); + assert_eq!(pool.available(), isize::MAX - 768); + + // Test resize to increase + let mut reservation_mut = reservation; + reservation_mut.resize(600); + assert_eq!(reservation_mut.size(), 600); + assert_eq!(pool.used(), 856); // 600 + 256 + + // Test resize to decrease + reservation_mut.resize(400); + assert_eq!(reservation_mut.size(), 400); + assert_eq!(pool.used(), 656); // 400 + 256 + + // Drop the first reservation + drop(reservation_mut); + assert_eq!(pool.used(), 256); + + // Drop the second reservation + drop(reservation2); + assert_eq!(pool.used(), 0); + } +} diff --git a/arrow-buffer/src/util/bit_chunk_iterator.rs b/arrow-buffer/src/util/bit_chunk_iterator.rs index ea8e8f472ace..8c7ec5e9a8f6 100644 --- a/arrow-buffer/src/util/bit_chunk_iterator.rs +++ b/arrow-buffer/src/util/bit_chunk_iterator.rs @@ -202,11 +202,10 @@ fn compute_suffix_mask(len: usize, lead_padding: usize) -> (u64, usize) { (suffix_mask, trailing_padding) } -/// Iterates over an arbitrarily aligned byte buffer +/// Iterates over an arbitrarily aligned byte buffer 64 bits at a time /// -/// Yields an iterator of u64, and a remainder. The first byte in the buffer +/// [`Self::iter`] yields iterator of `u64`, and a remainder. The first byte in the buffer /// will be the least significant byte in output u64 -/// #[derive(Debug)] pub struct BitChunks<'a> { buffer: &'a [u8], @@ -221,7 +220,10 @@ pub struct BitChunks<'a> { impl<'a> BitChunks<'a> { /// Create a new [`BitChunks`] from a byte array, and an offset and length in bits pub fn new(buffer: &'a [u8], offset: usize, len: usize) -> Self { - assert!(ceil(offset + len, 8) <= buffer.len() * 8); + assert!( + ceil(offset + len, 8) <= buffer.len(), + "offset + len out of bounds" + ); let byte_offset = offset / 8; let bit_offset = offset % 8; @@ -256,7 +258,7 @@ impl<'a> BitChunks<'a> { self.remainder_len } - /// Returns the number of chunks + /// Returns the number of `u64` chunks #[inline] pub const fn chunk_len(&self) -> usize { self.chunk_len @@ -290,7 +292,28 @@ impl<'a> BitChunks<'a> { } } - /// Returns an iterator over chunks of 64 bits represented as an u64 + /// Return the number of `u64` that are needed to represent all bits + /// (including remainder). + /// + /// This is equal to `chunk_len + 1` if there is a remainder, + /// otherwise it is equal to `chunk_len`. + #[inline] + pub fn num_u64s(&self) -> usize { + if self.remainder_len == 0 { + self.chunk_len + } else { + self.chunk_len + 1 + } + } + + /// Return the number of *bytes* that are needed to represent all bits + /// (including remainder). + #[inline] + pub fn num_bytes(&self) -> usize { + ceil(self.chunk_len * 64 + self.remainder_len, 8) + } + + /// Returns an iterator over chunks of 64 bits represented as an `u64` #[inline] pub const fn iter(&self) -> BitChunkIterator<'a> { BitChunkIterator::<'a> { @@ -476,6 +499,57 @@ mod tests { assert_eq!(0x7F, bitchunks.remainder_bits()); } + #[test] + #[should_panic(expected = "offset + len out of bounds")] + fn test_out_of_bound_should_panic_length_is_more_than_buffer_length() { + const ALLOC_SIZE: usize = 4 * 1024; + let input = vec![0xFF_u8; ALLOC_SIZE]; + + let buffer: Buffer = Buffer::from_vec(input); + + // We are reading more than exists in the buffer + buffer.bit_chunks(0, (ALLOC_SIZE + 1) * 8); + } + + #[test] + #[should_panic(expected = "offset + len out of bounds")] + fn test_out_of_bound_should_panic_length_is_more_than_buffer_length_but_not_when_not_using_ceil() + { + const ALLOC_SIZE: usize = 4 * 1024; + let input = vec![0xFF_u8; ALLOC_SIZE]; + + let buffer: Buffer = Buffer::from_vec(input); + + // We are reading more than exists in the buffer + buffer.bit_chunks(0, (ALLOC_SIZE * 8) + 1); + } + + #[test] + #[should_panic(expected = "offset + len out of bounds")] + fn test_out_of_bound_should_panic_when_offset_is_not_zero_and_length_is_the_entire_buffer_length() + { + const ALLOC_SIZE: usize = 4 * 1024; + let input = vec![0xFF_u8; ALLOC_SIZE]; + + let buffer: Buffer = Buffer::from_vec(input); + + // We are reading more than exists in the buffer + buffer.bit_chunks(8, ALLOC_SIZE * 8); + } + + #[test] + #[should_panic(expected = "offset + len out of bounds")] + fn test_out_of_bound_should_panic_when_offset_is_not_zero_and_length_is_the_entire_buffer_length_with_ceil() + { + const ALLOC_SIZE: usize = 4 * 1024; + let input = vec![0xFF_u8; ALLOC_SIZE]; + + let buffer: Buffer = Buffer::from_vec(input); + + // We are reading more than exists in the buffer + buffer.bit_chunks(1, ALLOC_SIZE * 8); + } + #[test] #[allow(clippy::assertions_on_constants)] fn test_unaligned_bit_chunk_iterator() { diff --git a/arrow-buffer/src/util/bit_iterator.rs b/arrow-buffer/src/util/bit_iterator.rs index c3e72044bf87..0aa94a5d4dc1 100644 --- a/arrow-buffer/src/util/bit_iterator.rs +++ b/arrow-buffer/src/util/bit_iterator.rs @@ -23,6 +23,7 @@ use crate::bit_util::{ceil, get_bit_raw}; /// Iterator over the bits within a packed bitmask /// /// To efficiently iterate over just the set bits see [`BitIndexIterator`] and [`BitSliceIterator`] +#[derive(Clone)] pub struct BitIterator<'a> { buffer: &'a [u8], current_offset: usize, @@ -71,6 +72,71 @@ impl Iterator for BitIterator<'_> { let remaining_bits = self.end_offset - self.current_offset; (remaining_bits, Some(remaining_bits)) } + + fn count(self) -> usize + where + Self: Sized, + { + self.len() + } + + fn nth(&mut self, n: usize) -> Option { + // Check if we can advance to the desired offset. + // When n is 0 it means we want the next() value + // and when n is 1 we want the next().next() value + // so adding n to the current offset and not n - 1 + match self.current_offset.checked_add(n) { + // Yes, and still within bounds + Some(new_offset) if new_offset < self.end_offset => { + self.current_offset = new_offset; + } + + // Either overflow or would exceed end_offset + _ => { + self.current_offset = self.end_offset; + return None; + } + } + + self.next() + } + + fn last(mut self) -> Option { + // If already at the end, return None + if self.current_offset == self.end_offset { + return None; + } + + // Go to the one before the last bit + self.current_offset = self.end_offset - 1; + + // Return the last bit + self.next() + } + + fn max(self) -> Option + where + Self: Sized, + Self::Item: Ord, + { + if self.current_offset == self.end_offset { + return None; + } + + // true is greater than false so we only need to check if there's any true bit + let mut bit_index_iter = BitIndexIterator::new( + self.buffer, + self.current_offset, + self.end_offset - self.current_offset, + ); + + if bit_index_iter.next().is_some() { + return Some(true); + } + + // We know the iterator is not empty and there are no set bits so false is the max + Some(false) + } } impl ExactSizeIterator for BitIterator<'_> {} @@ -86,6 +152,27 @@ impl DoubleEndedIterator for BitIterator<'_> { let v = unsafe { get_bit_raw(self.buffer.as_ptr(), self.end_offset) }; Some(v) } + + fn nth_back(&mut self, n: usize) -> Option { + // Check if we can advance to the desired offset. + // When n is 0 it means we want the next_back() value + // and when n is 1 we want the next_back().next_back() value + // so subtracting n to the current offset and not n - 1 + match self.end_offset.checked_sub(n) { + // Yes, and still within bounds + Some(new_offset) if self.current_offset < new_offset => { + self.end_offset = new_offset; + } + + // Either underflow or would exceed current_offset + _ => { + self.current_offset = self.end_offset; + return None; + } + } + + self.next_back() + } } /// Iterator of contiguous ranges of set bits within a provided packed bitmask @@ -216,6 +303,7 @@ impl<'a> BitIndexIterator<'a> { impl Iterator for BitIndexIterator<'_> { type Item = usize; + #[inline] fn next(&mut self) -> Option { loop { if self.current_chunk != 0 { @@ -230,6 +318,63 @@ impl Iterator for BitIndexIterator<'_> { } } +/// An iterator of u32 whose index in a provided bitmask is true +/// Respects arbitrary offsets and slice lead/trail padding exactly like BitIndexIterator +#[derive(Debug)] +pub struct BitIndexU32Iterator<'a> { + curr: u64, + chunk_offset: i64, + iter: UnalignedBitChunkIterator<'a>, +} + +impl<'a> BitIndexU32Iterator<'a> { + /// Create a new [BitIndexU32Iterator] from the provided buffer, + /// offset and len in bits. + pub fn new(buffer: &'a [u8], offset: usize, len: usize) -> Self { + // Build the aligned chunks (including prefix/suffix masked) + let chunks = UnalignedBitChunk::new(buffer, offset, len); + let mut iter = chunks.iter(); + + // First 64-bit word (masked for lead padding), or 0 if empty + let curr = iter.next().unwrap_or(0); + // Negative lead padding ensures the first bit in curr maps to index 0 + let chunk_offset = -(chunks.lead_padding() as i64); + + Self { + curr, + chunk_offset, + iter, + } + } +} + +impl<'a> Iterator for BitIndexU32Iterator<'a> { + type Item = u32; + + #[inline(always)] + fn next(&mut self) -> Option { + loop { + if self.curr != 0 { + // Position of least-significant set bit + let tz = self.curr.trailing_zeros(); + // Clear that bit + self.curr &= self.curr - 1; + // Return global index = chunk_offset + tz + return Some((self.chunk_offset + tz as i64) as u32); + } + // Advance to next 64-bit chunk + match self.iter.next() { + Some(next_chunk) => { + // Move offset forward by 64 bits + self.chunk_offset += 64; + self.curr = next_chunk; + } + None => return None, + } + } + } +} + /// Calls the provided closure for each index in the provided null mask that is set, /// using an adaptive strategy based on the null count /// @@ -269,6 +414,12 @@ pub fn try_for_each_valid_idx Result<(), E>>( #[cfg(test)] mod tests { use super::*; + use crate::BooleanBuffer; + use rand::rngs::StdRng; + use rand::{Rng, SeedableRng}; + use std::fmt::Debug; + use std::iter::Copied; + use std::slice::Iter; #[test] fn test_bit_iterator_size_hint() { @@ -322,4 +473,533 @@ mod tests { let mask = &[223, 23]; BitIterator::new(mask, 17, 0); } + + #[test] + fn test_bit_index_u32_iterator_basic() { + let mask = &[0b00010010, 0b00100011]; + + let result: Vec = BitIndexU32Iterator::new(mask, 0, 16).collect(); + let expected: Vec = BitIndexIterator::new(mask, 0, 16) + .map(|i| i as u32) + .collect(); + assert_eq!(result, expected); + + let result: Vec = BitIndexU32Iterator::new(mask, 4, 8).collect(); + let expected: Vec = BitIndexIterator::new(mask, 4, 8) + .map(|i| i as u32) + .collect(); + assert_eq!(result, expected); + + let result: Vec = BitIndexU32Iterator::new(mask, 10, 4).collect(); + let expected: Vec = BitIndexIterator::new(mask, 10, 4) + .map(|i| i as u32) + .collect(); + assert_eq!(result, expected); + + let result: Vec = BitIndexU32Iterator::new(mask, 0, 0).collect(); + let expected: Vec = BitIndexIterator::new(mask, 0, 0) + .map(|i| i as u32) + .collect(); + assert_eq!(result, expected); + } + + #[test] + fn test_bit_index_u32_iterator_all_set() { + let mask = &[0xFF, 0xFF]; + let result: Vec = BitIndexU32Iterator::new(mask, 0, 16).collect(); + let expected: Vec = BitIndexIterator::new(mask, 0, 16) + .map(|i| i as u32) + .collect(); + assert_eq!(result, expected); + } + + #[test] + fn test_bit_index_u32_iterator_none_set() { + let mask = &[0x00, 0x00]; + let result: Vec = BitIndexU32Iterator::new(mask, 0, 16).collect(); + let expected: Vec = BitIndexIterator::new(mask, 0, 16) + .map(|i| i as u32) + .collect(); + assert_eq!(result, expected); + } + + #[test] + fn test_bit_index_u32_cross_chunk() { + let mut buf = vec![0u8; 16]; + for bit in 60..68 { + let byte = (bit / 8) as usize; + let bit_in_byte = bit % 8; + buf[byte] |= 1 << bit_in_byte; + } + let offset = 58; + let len = 10; + + let result: Vec = BitIndexU32Iterator::new(&buf, offset, len).collect(); + let expected: Vec = BitIndexIterator::new(&buf, offset, len) + .map(|i| i as u32) + .collect(); + assert_eq!(result, expected); + } + + #[test] + fn test_bit_index_u32_unaligned_offset() { + let mask = &[0b0110_1100, 0b1010_0000]; + let offset = 2; + let len = 12; + + let result: Vec = BitIndexU32Iterator::new(mask, offset, len).collect(); + let expected: Vec = BitIndexIterator::new(mask, offset, len) + .map(|i| i as u32) + .collect(); + assert_eq!(result, expected); + } + + #[test] + fn test_bit_index_u32_long_all_set() { + let len = 200; + let num_bytes = len / 8 + if len % 8 != 0 { 1 } else { 0 }; + let bytes = vec![0xFFu8; num_bytes]; + + let result: Vec = BitIndexU32Iterator::new(&bytes, 0, len).collect(); + let expected: Vec = BitIndexIterator::new(&bytes, 0, len) + .map(|i| i as u32) + .collect(); + assert_eq!(result, expected); + } + + #[test] + fn test_bit_index_u32_none_set() { + let len = 50; + let num_bytes = len / 8 + if len % 8 != 0 { 1 } else { 0 }; + let bytes = vec![0u8; num_bytes]; + + let result: Vec = BitIndexU32Iterator::new(&bytes, 0, len).collect(); + let expected: Vec = BitIndexIterator::new(&bytes, 0, len) + .map(|i| i as u32) + .collect(); + assert_eq!(result, expected); + } + + trait SharedBetweenBitIteratorAndSliceIter: + ExactSizeIterator + DoubleEndedIterator + { + } + impl + DoubleEndedIterator> + SharedBetweenBitIteratorAndSliceIter for T + { + } + + fn get_bit_iterator_cases() -> impl Iterator)> { + let mut rng = StdRng::seed_from_u64(42); + + [0, 1, 6, 8, 100, 164] + .map(|len| { + let source = (0..len).map(|_| rng.random_bool(0.5)).collect::>(); + + (BooleanBuffer::from(source.as_slice()), source) + }) + .into_iter() + } + + fn setup_and_assert( + setup_iters: impl Fn(&mut dyn SharedBetweenBitIteratorAndSliceIter), + assert_fn: impl Fn(BitIterator, Copied>), + ) { + for (boolean_buffer, source) in get_bit_iterator_cases() { + // Not using `boolean_buffer.iter()` in case the implementation change to not call BitIterator internally + // in which case the test would not test what it intends to test + let mut actual = BitIterator::new(boolean_buffer.values(), 0, boolean_buffer.len()); + let mut expected = source.iter().copied(); + + setup_iters(&mut actual); + setup_iters(&mut expected); + + assert_fn(actual, expected); + } + } + + /// Trait representing an operation on a BitIterator + /// that can be compared against a slice iterator + trait BitIteratorOp { + /// What the operation returns (e.g. Option for last/max, usize for count, etc) + type Output: PartialEq + Debug; + + /// The name of the operation, used for error messages + const NAME: &'static str; + + /// Get the value of the operation for the provided iterator + /// This will be either a BitIterator or a slice iterator to make sure they produce the same result + fn get_value(iter: T) -> Self::Output; + } + + /// Helper function that will assert that the provided operation + /// produces the same result for both BitIterator and slice iterator + /// under various consumption patterns (e.g. some calls to next/next_back/consume_all/etc) + fn assert_bit_iterator_cases() { + setup_and_assert( + |_iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| {}, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + iter.next(); + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming 1 element from the start (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + iter.next_back(); + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming 1 element from the end (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + iter.next(); + iter.next_back(); + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming 1 element from start and end (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + while iter.len() > 1 { + iter.next(); + } + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming all from the start but 1 (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + while iter.len() > 1 { + iter.next_back(); + } + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming all from the end but 1 (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + while iter.next().is_some() {} + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming all from the start (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + + setup_and_assert( + |iter: &mut dyn SharedBetweenBitIteratorAndSliceIter| { + while iter.next_back().is_some() {} + }, + |actual, expected| { + let current_iterator_values: Vec = expected.clone().collect(); + + assert_eq!( + O::get_value(actual), + O::get_value(expected), + "Failed on op {} for new iter after consuming all from the end (left actual, right expected) ({current_iterator_values:?})", + O::NAME + ); + }, + ); + } + + #[test] + fn assert_bit_iterator_count() { + struct CountOp; + + impl BitIteratorOp for CountOp { + type Output = usize; + const NAME: &'static str = "count"; + + fn get_value(iter: T) -> Self::Output { + iter.count() + } + } + + assert_bit_iterator_cases::() + } + + #[test] + fn assert_bit_iterator_last() { + struct LastOp; + + impl BitIteratorOp for LastOp { + type Output = Option; + const NAME: &'static str = "last"; + + fn get_value(iter: T) -> Self::Output { + iter.last() + } + } + + assert_bit_iterator_cases::() + } + + #[test] + fn assert_bit_iterator_max() { + struct MaxOp; + + impl BitIteratorOp for MaxOp { + type Output = Option; + const NAME: &'static str = "max"; + + fn get_value(iter: T) -> Self::Output { + iter.max() + } + } + + assert_bit_iterator_cases::() + } + + #[test] + fn assert_bit_iterator_nth_0() { + struct NthOp; + + impl BitIteratorOp for NthOp { + type Output = Option; + const NAME: &'static str = if BACK { "nth_back(0)" } else { "nth(0)" }; + + fn get_value(mut iter: T) -> Self::Output { + if BACK { iter.nth_back(0) } else { iter.nth(0) } + } + } + + assert_bit_iterator_cases::>(); + assert_bit_iterator_cases::>(); + } + + #[test] + fn assert_bit_iterator_nth_1() { + struct NthOp; + + impl BitIteratorOp for NthOp { + type Output = Option; + const NAME: &'static str = if BACK { "nth_back(1)" } else { "nth(1)" }; + + fn get_value(mut iter: T) -> Self::Output { + if BACK { iter.nth_back(1) } else { iter.nth(1) } + } + } + + assert_bit_iterator_cases::>(); + assert_bit_iterator_cases::>(); + } + + #[test] + fn assert_bit_iterator_nth_after_end() { + struct NthOp; + + impl BitIteratorOp for NthOp { + type Output = Option; + const NAME: &'static str = if BACK { + "nth_back(iter.len() + 1)" + } else { + "nth(iter.len() + 1)" + }; + + fn get_value(mut iter: T) -> Self::Output { + if BACK { + iter.nth_back(iter.len() + 1) + } else { + iter.nth(iter.len() + 1) + } + } + } + + assert_bit_iterator_cases::>(); + assert_bit_iterator_cases::>(); + } + + #[test] + fn assert_bit_iterator_nth_len() { + struct NthOp; + + impl BitIteratorOp for NthOp { + type Output = Option; + const NAME: &'static str = if BACK { + "nth_back(iter.len())" + } else { + "nth(iter.len())" + }; + + fn get_value(mut iter: T) -> Self::Output { + if BACK { + iter.nth_back(iter.len()) + } else { + iter.nth(iter.len()) + } + } + } + + assert_bit_iterator_cases::>(); + assert_bit_iterator_cases::>(); + } + + #[test] + fn assert_bit_iterator_nth_last() { + struct NthOp; + + impl BitIteratorOp for NthOp { + type Output = Option; + const NAME: &'static str = if BACK { + "nth_back(iter.len().saturating_sub(1))" + } else { + "nth(iter.len().saturating_sub(1))" + }; + + fn get_value(mut iter: T) -> Self::Output { + if BACK { + iter.nth_back(iter.len().saturating_sub(1)) + } else { + iter.nth(iter.len().saturating_sub(1)) + } + } + } + + assert_bit_iterator_cases::>(); + assert_bit_iterator_cases::>(); + } + + #[test] + fn assert_bit_iterator_nth_and_reuse() { + setup_and_assert( + |_| {}, + |actual, expected| { + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + #[allow(clippy::iter_nth_zero)] + let actual_val = actual.nth(0); + #[allow(clippy::iter_nth_zero)] + let expected_val = expected.nth(0); + assert_eq!(actual_val, expected_val, "Failed on nth(0)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth(1); + let expected_val = expected.nth(1); + assert_eq!(actual_val, expected_val, "Failed on nth(1)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth(2); + let expected_val = expected.nth(2); + assert_eq!(actual_val, expected_val, "Failed on nth(2)"); + } + } + }, + ); + } + + #[test] + fn assert_bit_iterator_nth_back_and_reuse() { + setup_and_assert( + |_| {}, + |actual, expected| { + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + #[allow(clippy::iter_nth_zero)] + let actual_val = actual.nth_back(0); + let expected_val = expected.nth_back(0); + assert_eq!(actual_val, expected_val, "Failed on nth_back(0)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth_back(1); + let expected_val = expected.nth_back(1); + assert_eq!(actual_val, expected_val, "Failed on nth_back(1)"); + } + } + + { + let mut actual = actual.clone(); + let mut expected = expected.clone(); + for _ in 0..expected.len() { + let actual_val = actual.nth_back(2); + let expected_val = expected.nth_back(2); + assert_eq!(actual_val, expected_val, "Failed on nth_back(2)"); + } + } + }, + ); + } } diff --git a/arrow-buffer/src/util/bit_mask.rs b/arrow-buffer/src/util/bit_mask.rs index 0d694d13ec75..a8ae1a765414 100644 --- a/arrow-buffer/src/util/bit_mask.rs +++ b/arrow-buffer/src/util/bit_mask.rs @@ -132,10 +132,8 @@ unsafe fn set_upto_64bits( unsafe fn read_bytes_to_u64(data: &[u8], offset: usize, count: usize) -> u64 { debug_assert!(count <= 8); let mut tmp: u64 = 0; - let src = data.as_ptr().add(offset); - unsafe { - std::ptr::copy_nonoverlapping(src, &mut tmp as *mut _ as *mut u8, count); - } + let src = unsafe { data.as_ptr().add(offset) }; + unsafe { std::ptr::copy_nonoverlapping(src, &mut tmp as *mut _ as *mut u8, count) }; tmp } @@ -143,8 +141,8 @@ unsafe fn read_bytes_to_u64(data: &[u8], offset: usize, count: usize) -> u64 { /// The caller must ensure `data` has `offset..(offset + 8)` range #[inline] unsafe fn write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) { - let ptr = data.as_mut_ptr().add(offset) as *mut u64; - ptr.write_unaligned(chunk); + let ptr = unsafe { data.as_mut_ptr().add(offset) } as *mut u64; + unsafe { ptr.write_unaligned(chunk) }; } /// Similar to `write_u64_bytes`, but this method ORs the offset addressed `data` and `chunk` @@ -154,9 +152,9 @@ unsafe fn write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) { /// The caller must ensure `data` has `offset..(offset + 8)` range #[inline] unsafe fn or_write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) { - let ptr = data.as_mut_ptr().add(offset); - let chunk = chunk | (*ptr) as u64; - (ptr as *mut u64).write_unaligned(chunk); + let ptr = unsafe { data.as_mut_ptr().add(offset) }; + let chunk = chunk | (unsafe { *ptr }) as u64; + unsafe { (ptr as *mut u64).write_unaligned(chunk) }; } #[cfg(test)] @@ -278,7 +276,7 @@ mod tests { impl Display for BinaryFormatter<'_> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { for byte in self.0 { - write!(f, "{:08b} ", byte)?; + write!(f, "{byte:08b} ")?; } write!(f, " ")?; Ok(()) @@ -389,8 +387,8 @@ mod tests { self.len, ); - assert_eq!(actual, self.expected_data, "self: {}", self); - assert_eq!(null_count, self.expected_null_count, "self: {}", self); + assert_eq!(actual, self.expected_data, "self: {self}"); + assert_eq!(null_count, self.expected_null_count, "self: {self}"); } } diff --git a/arrow-buffer/src/util/bit_util.rs b/arrow-buffer/src/util/bit_util.rs index c297321bdcf9..67c72fc08906 100644 --- a/arrow-buffer/src/util/bit_util.rs +++ b/arrow-buffer/src/util/bit_util.rs @@ -17,6 +17,8 @@ //! Utils for working with bits +use crate::bit_chunk_iterator::BitChunks; + /// Returns the nearest number that is `>=` than `num` and is a multiple of 64 #[inline] pub fn round_upto_multiple_of_64(num: usize) -> usize { @@ -47,7 +49,7 @@ pub fn get_bit(data: &[u8], i: usize) -> bool { /// responsible to guarantee that `i` is within bounds. #[inline] pub unsafe fn get_bit_raw(data: *const u8, i: usize) -> bool { - (*data.add(i / 8) & (1 << (i % 8))) != 0 + unsafe { (*data.add(i / 8) & (1 << (i % 8))) != 0 } } /// Sets bit at position `i` for `data` to 1 @@ -64,7 +66,9 @@ pub fn set_bit(data: &mut [u8], i: usize) { /// responsible to guarantee that `i` is within bounds. #[inline] pub unsafe fn set_bit_raw(data: *mut u8, i: usize) { - *data.add(i / 8) |= 1 << (i % 8); + unsafe { + *data.add(i / 8) |= 1 << (i % 8); + } } /// Sets bit at position `i` for `data` to 0 @@ -81,7 +85,9 @@ pub fn unset_bit(data: &mut [u8], i: usize) { /// responsible to guarantee that `i` is within bounds. #[inline] pub unsafe fn unset_bit_raw(data: *mut u8, i: usize) { - *data.add(i / 8) &= !(1 << (i % 8)); + unsafe { + *data.add(i / 8) &= !(1 << (i % 8)); + } } /// Returns the ceil of `value`/`divisor` @@ -90,11 +96,726 @@ pub fn ceil(value: usize, divisor: usize) -> usize { value.div_ceil(divisor) } +/// Read up to 8 bits from a byte slice starting at a given bit offset. +/// +/// # Arguments +/// +/// * `slice` - The byte slice to read from +/// * `number_of_bits_to_read` - Number of bits to read (must be < 8) +/// * `bit_offset` - Starting bit offset within the first byte (must be < 8) +/// +/// # Returns +/// +/// A `u8` containing the requested bits in the least significant positions +/// +/// # Panics +/// - Panics if `number_of_bits_to_read` is 0 or >= 8 +/// - Panics if `bit_offset` is >= 8 +/// - Panics if `slice` is empty or too small to read the requested bits +/// +#[inline] +pub(crate) fn read_up_to_byte_from_offset( + slice: &[u8], + number_of_bits_to_read: usize, + bit_offset: usize, +) -> u8 { + assert!(number_of_bits_to_read < 8, "can read up to 8 bits only"); + assert!(bit_offset < 8, "bit offset must be less than 8"); + assert_ne!( + number_of_bits_to_read, 0, + "number of bits to read must be greater than 0" + ); + assert_ne!(slice.len(), 0, "slice must not be empty"); + + let number_of_bytes_to_read = ceil(number_of_bits_to_read + bit_offset, 8); + + // number of bytes to read + assert!(slice.len() >= number_of_bytes_to_read, "slice is too small"); + + let mut bits = slice[0] >> bit_offset; + for (i, &byte) in slice + .iter() + .take(number_of_bytes_to_read) + .enumerate() + .skip(1) + { + bits |= byte << (i * 8 - bit_offset); + } + + bits & ((1 << number_of_bits_to_read) - 1) +} + +/// Applies a bitwise operation relative to another bit-packed byte slice +/// (right) in place +/// +/// Note: applies the operation 64-bits (u64) at a time. +/// +/// # Arguments +/// +/// * `left` - The mutable buffer to be modified in-place +/// * `offset_in_bits` - Starting bit offset in Self buffer +/// * `right` - slice of bit-packed bytes in LSB order +/// * `right_offset_in_bits` - Starting bit offset in the right buffer +/// * `len_in_bits` - Number of bits to process +/// * `op` - Binary operation to apply (e.g., `|a, b| a & b`). Applied a word at a time +/// +/// # Example: Modify entire buffer +/// ``` +/// # use arrow_buffer::MutableBuffer; +/// # use arrow_buffer::bit_util::apply_bitwise_binary_op; +/// let mut left = MutableBuffer::new(2); +/// left.extend_from_slice(&[0b11110000u8, 0b00110011u8]); +/// let right = &[0b10101010u8, 0b10101010u8]; +/// // apply bitwise AND between left and right buffers, updating left in place +/// apply_bitwise_binary_op(left.as_slice_mut(), 0, right, 0, 16, |a, b| a & b); +/// assert_eq!(left.as_slice(), &[0b10100000u8, 0b00100010u8]); +/// ``` +/// +/// # Example: Modify buffer with offsets +/// ``` +/// # use arrow_buffer::MutableBuffer; +/// # use arrow_buffer::bit_util::apply_bitwise_binary_op; +/// let mut left = MutableBuffer::new(2); +/// left.extend_from_slice(&[0b00000000u8, 0b00000000u8]); +/// let right = &[0b10110011u8, 0b11111110u8]; +/// // apply bitwise OR between left and right buffers, +/// // Apply only 8 bits starting from bit offset 3 in left and bit offset 2 in right +/// apply_bitwise_binary_op(left.as_slice_mut(), 3, right, 2, 8, |a, b| a | b); +/// assert_eq!(left.as_slice(), &[0b01100000, 0b00000101u8]); +/// ``` +/// +/// # Panics +/// +/// If the offset or lengths exceed the buffer or slice size. +pub fn apply_bitwise_binary_op( + left: &mut [u8], + left_offset_in_bits: usize, + right: impl AsRef<[u8]>, + right_offset_in_bits: usize, + len_in_bits: usize, + mut op: F, +) where + F: FnMut(u64, u64) -> u64, +{ + if len_in_bits == 0 { + return; + } + + // offset inside a byte + let bit_offset = left_offset_in_bits % 8; + + let is_mutable_buffer_byte_aligned = bit_offset == 0; + + if is_mutable_buffer_byte_aligned { + byte_aligned_bitwise_bin_op_helper( + left, + left_offset_in_bits, + right, + right_offset_in_bits, + len_in_bits, + op, + ); + } else { + // If we are not byte aligned, run `op` on the first few bits to reach byte alignment + let bits_to_next_byte = (8 - bit_offset) + // Minimum with the amount of bits we need to process + // to avoid reading out of bounds + .min(len_in_bits); + + { + let right_byte_offset = right_offset_in_bits / 8; + + // Read the same amount of bits from the right buffer + let right_first_byte: u8 = crate::util::bit_util::read_up_to_byte_from_offset( + &right.as_ref()[right_byte_offset..], + bits_to_next_byte, + // Right bit offset + right_offset_in_bits % 8, + ); + + align_to_byte( + left, + // Hope it gets inlined + &mut |left| op(left, right_first_byte as u64), + left_offset_in_bits, + ); + } + + let offset_in_bits = left_offset_in_bits + bits_to_next_byte; + let right_offset_in_bits = right_offset_in_bits + bits_to_next_byte; + let len_in_bits = len_in_bits.saturating_sub(bits_to_next_byte); + + if len_in_bits == 0 { + return; + } + + // We are now byte aligned + byte_aligned_bitwise_bin_op_helper( + left, + offset_in_bits, + right, + right_offset_in_bits, + len_in_bits, + op, + ); + } +} + +/// Apply a bitwise operation to a mutable buffer, updating it in place. +/// +/// Note: applies the operation 64-bits (u64) at a time. +/// +/// # Arguments +/// +/// * `offset_in_bits` - Starting bit offset for the current buffer +/// * `len_in_bits` - Number of bits to process +/// * `op` - Unary operation to apply (e.g., `|a| !a`). Applied a word at a time +/// +/// # Example: Modify entire buffer +/// ``` +/// # use arrow_buffer::MutableBuffer; +/// # use arrow_buffer::bit_util::apply_bitwise_unary_op; +/// let mut buffer = MutableBuffer::new(2); +/// buffer.extend_from_slice(&[0b11110000u8, 0b00110011u8]); +/// // apply bitwise NOT to the buffer in place +/// apply_bitwise_unary_op(buffer.as_slice_mut(), 0, 16, |a| !a); +/// assert_eq!(buffer.as_slice(), &[0b00001111u8, 0b11001100u8]); +/// ``` +/// +/// # Example: Modify buffer with offsets +/// ``` +/// # use arrow_buffer::MutableBuffer; +/// # use arrow_buffer::bit_util::apply_bitwise_unary_op; +/// let mut buffer = MutableBuffer::new(2); +/// buffer.extend_from_slice(&[0b00000000u8, 0b00000000u8]); +/// // apply bitwise NOT to 8 bits starting from bit offset 3 +/// apply_bitwise_unary_op(buffer.as_slice_mut(), 3, 8, |a| !a); +/// assert_eq!(buffer.as_slice(), &[0b11111000u8, 0b00000111u8]); +/// ``` +/// +/// # Panics +/// +/// If the offset and length exceed the buffer size. +pub fn apply_bitwise_unary_op( + buffer: &mut [u8], + offset_in_bits: usize, + len_in_bits: usize, + mut op: F, +) where + F: FnMut(u64) -> u64, +{ + if len_in_bits == 0 { + return; + } + + // offset inside a byte + let left_bit_offset = offset_in_bits % 8; + + let is_mutable_buffer_byte_aligned = left_bit_offset == 0; + + if is_mutable_buffer_byte_aligned { + byte_aligned_bitwise_unary_op_helper(buffer, offset_in_bits, len_in_bits, op); + } else { + align_to_byte(buffer, &mut op, offset_in_bits); + + // If we are not byte aligned we will read the first few bits + let bits_to_next_byte = 8 - left_bit_offset; + + let offset_in_bits = offset_in_bits + bits_to_next_byte; + let len_in_bits = len_in_bits.saturating_sub(bits_to_next_byte); + + if len_in_bits == 0 { + return; + } + + // We are now byte aligned + byte_aligned_bitwise_unary_op_helper(buffer, offset_in_bits, len_in_bits, op); + } +} + +/// Perform bitwise binary operation on byte-aligned buffers (i.e. not offsetting into a middle of a byte). +/// +/// This is the optimized path for byte-aligned operations. It processes data in +/// u64 chunks for maximum efficiency, then handles any remainder bits. +/// +/// # Arguments +/// +/// * `left` - The left mutable buffer (must be byte-aligned) +/// * `left_offset_in_bits` - Starting bit offset in the left buffer (must be multiple of 8) +/// * `right` - The right buffer as byte slice +/// * `right_offset_in_bits` - Starting bit offset in the right buffer +/// * `len_in_bits` - Number of bits to process +/// * `op` - Binary operation to apply +#[inline] +fn byte_aligned_bitwise_bin_op_helper( + left: &mut [u8], + left_offset_in_bits: usize, + right: impl AsRef<[u8]>, + right_offset_in_bits: usize, + len_in_bits: usize, + mut op: F, +) where + F: FnMut(u64, u64) -> u64, +{ + // Must not reach here if we not byte aligned + assert_eq!( + left_offset_in_bits % 8, + 0, + "offset_in_bits must be byte aligned" + ); + + // 1. Prepare the buffers + let (complete_u64_chunks, remainder_bytes) = + U64UnalignedSlice::split(left, left_offset_in_bits, len_in_bits); + + let right_chunks = BitChunks::new(right.as_ref(), right_offset_in_bits, len_in_bits); + assert_eq!( + self::ceil(right_chunks.remainder_len(), 8), + remainder_bytes.len() + ); + + let right_chunks_iter = right_chunks.iter(); + assert_eq!(right_chunks_iter.len(), complete_u64_chunks.len()); + + // 2. Process complete u64 chunks + complete_u64_chunks.zip_modify(right_chunks_iter, &mut op); + + // Handle remainder bits if any + if right_chunks.remainder_len() > 0 { + handle_mutable_buffer_remainder( + &mut op, + remainder_bytes, + right_chunks.remainder_bits(), + right_chunks.remainder_len(), + ) + } +} + +/// Perform bitwise unary operation on byte-aligned buffer. +/// +/// This is the optimized path for byte-aligned unary operations. It processes data in +/// u64 chunks for maximum efficiency, then handles any remainder bits. +/// +/// # Arguments +/// +/// * `buffer` - The mutable buffer (must be byte-aligned) +/// * `offset_in_bits` - Starting bit offset (must be multiple of 8) +/// * `len_in_bits` - Number of bits to process +/// * `op` - Unary operation to apply (e.g., `|a| !a`) +#[inline] +fn byte_aligned_bitwise_unary_op_helper( + buffer: &mut [u8], + offset_in_bits: usize, + len_in_bits: usize, + mut op: F, +) where + F: FnMut(u64) -> u64, +{ + // Must not reach here if we not byte aligned + assert_eq!(offset_in_bits % 8, 0, "offset_in_bits must be byte aligned"); + + let remainder_len = len_in_bits % 64; + + let (complete_u64_chunks, remainder_bytes) = + U64UnalignedSlice::split(buffer, offset_in_bits, len_in_bits); + + assert_eq!(self::ceil(remainder_len, 8), remainder_bytes.len()); + + // 2. Process complete u64 chunks + complete_u64_chunks.apply_unary_op(&mut op); + + // Handle remainder bits if any + if remainder_len > 0 { + handle_mutable_buffer_remainder_unary(&mut op, remainder_bytes, remainder_len) + } +} + +/// Align to byte boundary by applying operation to bits before the next byte boundary. +/// +/// This function handles non-byte-aligned operations by processing bits from the current +/// position up to the next byte boundary, while preserving all other bits in the byte. +/// +/// # Arguments +/// +/// * `op` - Unary operation to apply +/// * `buffer` - The mutable buffer to modify +/// * `offset_in_bits` - Starting bit offset (not byte-aligned) +fn align_to_byte(buffer: &mut [u8], op: &mut F, offset_in_bits: usize) +where + F: FnMut(u64) -> u64, +{ + let byte_offset = offset_in_bits / 8; + let bit_offset = offset_in_bits % 8; + + // 1. read the first byte from the buffer + let first_byte: u8 = buffer[byte_offset]; + + // 2. Shift byte by the bit offset, keeping only the relevant bits + let relevant_first_byte = first_byte >> bit_offset; + + // 3. run the op on the first byte only + let result_first_byte = op(relevant_first_byte as u64) as u8; + + // 4. Shift back the result to the original position + let result_first_byte = result_first_byte << bit_offset; + + // 5. Mask the bits that are outside the relevant bits in the byte + // so the bits until bit_offset are 1 and the rest are 0 + let mask_for_first_bit_offset = (1 << bit_offset) - 1; + + let result_first_byte = + (first_byte & mask_for_first_bit_offset) | (result_first_byte & !mask_for_first_bit_offset); + + // 6. write back the result to the buffer + buffer[byte_offset] = result_first_byte; +} + +/// Centralized structure to handle a mutable u8 slice as a mutable u64 pointer. +/// +/// Handle the following: +/// 1. the lifetime is correct +/// 2. we read/write within the bounds +/// 3. We read and write using unaligned +/// +/// This does not deallocate the underlying pointer when dropped +/// +/// This is the only place that uses unsafe code to read and write unaligned +/// +struct U64UnalignedSlice<'a> { + /// Pointer to the start of the u64 data + /// + /// We are using raw pointer as the data came from a u8 slice so we need to read and write unaligned + ptr: *mut u64, + + /// Number of u64 elements + len: usize, + + /// Marker to tie the lifetime of the pointer to the lifetime of the u8 slice + _marker: std::marker::PhantomData<&'a u8>, +} + +impl<'a> U64UnalignedSlice<'a> { + /// Create a new [`U64UnalignedSlice`] from a `&mut [u8]` buffer + /// + /// return the [`U64UnalignedSlice`] and slice of bytes that are not part of the u64 chunks (guaranteed to be less than 8 bytes) + /// + fn split( + buffer: &'a mut [u8], + offset_in_bits: usize, + len_in_bits: usize, + ) -> (Self, &'a mut [u8]) { + // 1. Prepare the buffers + let left_buffer_mut: &mut [u8] = { + let last_offset = self::ceil(offset_in_bits + len_in_bits, 8); + assert!(last_offset <= buffer.len()); + + let byte_offset = offset_in_bits / 8; + + &mut buffer[byte_offset..last_offset] + }; + + let number_of_u64_we_can_fit = len_in_bits / (u64::BITS as usize); + + // 2. Split + let u64_len_in_bytes = number_of_u64_we_can_fit * size_of::(); + + assert!(u64_len_in_bytes <= left_buffer_mut.len()); + let (bytes_for_u64, remainder) = left_buffer_mut.split_at_mut(u64_len_in_bytes); + + let ptr = bytes_for_u64.as_mut_ptr() as *mut u64; + + let this = Self { + ptr, + len: number_of_u64_we_can_fit, + _marker: std::marker::PhantomData, + }; + + (this, remainder) + } + + fn len(&self) -> usize { + self.len + } + + /// Modify the underlying u64 data in place using a binary operation + /// with another iterator. + fn zip_modify( + mut self, + mut zip_iter: impl ExactSizeIterator, + mut map: impl FnMut(u64, u64) -> u64, + ) { + assert_eq!(self.len, zip_iter.len()); + + // In order to avoid advancing the pointer at the end of the loop which will + // make the last pointer invalid, we handle the first element outside the loop + // and then advance the pointer at the start of the loop + // making sure that the iterator is not empty + if let Some(right) = zip_iter.next() { + // SAFETY: We asserted that the iterator length and the current length are the same + // and the iterator is not empty, so the pointer is valid + unsafe { + self.apply_bin_op(right, &mut map); + } + + // Because this consumes self we don't update the length + } + + for right in zip_iter { + // Advance the pointer + // + // SAFETY: We asserted that the iterator length and the current length are the same + self.ptr = unsafe { self.ptr.add(1) }; + + // SAFETY: the pointer is valid as we are within the length + unsafe { + self.apply_bin_op(right, &mut map); + } + + // Because this consumes self we don't update the length + } + } + + /// Centralized function to correctly read the current u64 value and write back the result + /// + /// # SAFETY + /// the caller must ensure that the pointer is valid for reads and writes + /// + #[inline] + unsafe fn apply_bin_op(&mut self, right: u64, mut map: impl FnMut(u64, u64) -> u64) { + // SAFETY: The constructor ensures the pointer is valid, + // and as to all modifications in U64UnalignedSlice + let current_input = unsafe { + self.ptr + // Reading unaligned as we came from u8 slice + .read_unaligned() + // bit-packed buffers are stored starting with the least-significant byte first + // so when reading as u64 on a big-endian machine, the bytes need to be swapped + .to_le() + }; + + let combined = map(current_input, right); + + // Write the result back + // + // The pointer came from mutable u8 slice so the pointer is valid for writes, + // and we need to write unaligned + unsafe { self.ptr.write_unaligned(combined) } + } + + /// Modify the underlying u64 data in place using a unary operation. + fn apply_unary_op(mut self, mut map: impl FnMut(u64) -> u64) { + if self.len == 0 { + return; + } + + // In order to avoid advancing the pointer at the end of the loop which will + // make the last pointer invalid, we handle the first element outside the loop + // and then advance the pointer at the start of the loop + // making sure that the iterator is not empty + unsafe { + // I hope the function get inlined and the compiler remove the dead right parameter + self.apply_bin_op(0, &mut |left, _| map(left)); + + // Because this consumes self we don't update the length + } + + for _ in 1..self.len { + // Advance the pointer + // + // SAFETY: we only advance the pointer within the length and not beyond + self.ptr = unsafe { self.ptr.add(1) }; + + // SAFETY: the pointer is valid as we are within the length + unsafe { + // I hope the function get inlined and the compiler remove the dead right parameter + self.apply_bin_op(0, &mut |left, _| map(left)); + } + + // Because this consumes self we don't update the length + } + } +} + +/// Handle remainder bits (< 64 bits) for binary operations. +/// +/// This function processes the bits that don't form a complete u64 chunk, +/// ensuring that bits outside the operation range are preserved. +/// +/// # Arguments +/// +/// * `op` - Binary operation to apply +/// * `start_remainder_mut_slice` - slice to the start of remainder bytes +/// the length must be equal to `ceil(remainder_len, 8)` +/// * `right_remainder_bits` - Right operand bits +/// * `remainder_len` - Number of remainder bits +#[inline] +fn handle_mutable_buffer_remainder( + op: &mut F, + start_remainder_mut_slice: &mut [u8], + right_remainder_bits: u64, + remainder_len: usize, +) where + F: FnMut(u64, u64) -> u64, +{ + // Only read from slice the number of remainder bits + let left_remainder_bits = get_remainder_bits(start_remainder_mut_slice, remainder_len); + + // Apply the operation + let rem = op(left_remainder_bits, right_remainder_bits); + + // Write only the relevant bits back the result to the mutable slice + set_remainder_bits(start_remainder_mut_slice, rem, remainder_len); +} + +/// Write remainder bits back to buffer while preserving bits outside the range. +/// +/// This function carefully updates only the specified bits, leaving all other +/// bits in the affected bytes unchanged. +/// +/// # Arguments +/// +/// * `start_remainder_mut_slice` - the slice of bytes to write the remainder bits to, +/// the length must be equal to `ceil(remainder_len, 8)` +/// * `rem` - The result bits to write +/// * `remainder_len` - Number of bits to write +#[inline] +fn set_remainder_bits(start_remainder_mut_slice: &mut [u8], rem: u64, remainder_len: usize) { + assert_ne!( + start_remainder_mut_slice.len(), + 0, + "start_remainder_mut_slice must not be empty" + ); + assert!(remainder_len < 64, "remainder_len must be less than 64"); + + // This assertion is to make sure that the last byte in the slice is the boundary byte + // (i.e., the byte that contains both remainder bits and bits outside the remainder) + assert_eq!( + start_remainder_mut_slice.len(), + self::ceil(remainder_len, 8), + "start_remainder_mut_slice length must be equal to ceil(remainder_len, 8)" + ); + + // Need to update the remainder bytes in the mutable buffer + // but not override the bits outside the remainder + + // Update `rem` end with the current bytes in the mutable buffer + // to preserve the bits outside the remainder + let rem = { + // 1. Read the byte that we will override + // we only read the last byte as we verified that start_remainder_mut_slice length is + // equal to ceil(remainder_len, 8), which means the last byte is the boundary byte + // containing both remainder bits and bits outside the remainder + let current = start_remainder_mut_slice + .last() + // Unwrap as we already validated the slice is not empty + .unwrap(); + + let current = *current as u64; + + // Mask where the bits that are inside the remainder are 1 + // and the bits outside the remainder are 0 + let inside_remainder_mask = (1 << remainder_len) - 1; + // Mask where the bits that are outside the remainder are 1 + // and the bits inside the remainder are 0 + let outside_remainder_mask = !inside_remainder_mask; + + // 2. Only keep the bits that are outside the remainder for the value from the mutable buffer + let current = current & outside_remainder_mask; + + // 3. Only keep the bits that are inside the remainder for the value from the operation + let rem = rem & inside_remainder_mask; + + // 4. Combine the two values + current | rem + }; + + // Write back the result to the mutable slice + { + let remainder_bytes = self::ceil(remainder_len, 8); + + // we are counting starting from the least significant bit, so to_le_bytes should be correct + let rem = &rem.to_le_bytes()[0..remainder_bytes]; + + // this assumes that `[ToByteSlice]` can be copied directly + // without calling `to_byte_slice` for each element, + // which is correct for all ArrowNativeType implementations including u64. + let src = rem.as_ptr(); + unsafe { + std::ptr::copy_nonoverlapping( + src, + start_remainder_mut_slice.as_mut_ptr(), + remainder_bytes, + ) + }; + } +} + +/// Read remainder bits from a slice. +/// +/// Reads the specified number of bits from slice and returns them as a u64. +/// +/// # Arguments +/// +/// * `remainder` - slice to the start of the bits +/// * `remainder_len` - Number of bits to read (must be < 64) +/// +/// # Returns +/// +/// A u64 containing the bits in the least significant positions +#[inline] +fn get_remainder_bits(remainder: &[u8], remainder_len: usize) -> u64 { + assert!(remainder.len() < 64, "remainder_len must be less than 64"); + assert_eq!( + remainder.len(), + self::ceil(remainder_len, 8), + "remainder and remainder len ceil must be the same" + ); + + let bits = remainder + .iter() + .enumerate() + .fold(0_u64, |acc, (index, &byte)| { + acc | (byte as u64) << (index * 8) + }); + + bits & ((1 << remainder_len) - 1) +} + +/// Handle remainder bits (< 64 bits) for unary operations. +/// +/// This function processes the bits that don't form a complete u64 chunk, +/// ensuring that bits outside the operation range are preserved. +/// +/// # Arguments +/// +/// * `op` - Unary operation to apply +/// * `start_remainder_mut` - Slice of bytes to write the remainder bits to +/// * `remainder_len` - Number of remainder bits +#[inline] +fn handle_mutable_buffer_remainder_unary( + op: &mut F, + start_remainder_mut: &mut [u8], + remainder_len: usize, +) where + F: FnMut(u64) -> u64, +{ + // Only read from the slice the number of remainder bits + let left_remainder_bits = get_remainder_bits(start_remainder_mut, remainder_len); + + // Apply the operation + let rem = op(left_remainder_bits); + + // Write only the relevant bits back the result to the slice + set_remainder_bits(start_remainder_mut, rem, remainder_len); +} + #[cfg(test)] mod tests { use std::collections::HashSet; use super::*; + use crate::bit_iterator::BitIterator; + use crate::{BooleanBuffer, BooleanBufferBuilder, MutableBuffer}; use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; @@ -275,4 +996,500 @@ mod tests { assert_eq!(ceil(10, 10000000000), 1); assert_eq!(ceil(10000000000, 1000000000), 10); } + + #[test] + fn test_read_up_to() { + let all_ones = &[0b10111001, 0b10001100]; + + for (bit_offset, expected) in [ + (0, 0b00000001), + (1, 0b00000000), + (2, 0b00000000), + (3, 0b00000001), + (4, 0b00000001), + (5, 0b00000001), + (6, 0b00000000), + (7, 0b00000001), + ] { + let result = read_up_to_byte_from_offset(all_ones, 1, bit_offset); + assert_eq!( + result, expected, + "failed at bit_offset {bit_offset}. result, expected:\n{result:08b}\n{expected:08b}" + ); + } + + for (bit_offset, expected) in [ + (0, 0b00000001), + (1, 0b00000000), + (2, 0b00000010), + (3, 0b00000011), + (4, 0b00000011), + (5, 0b00000001), + (6, 0b00000010), + (7, 0b00000001), + ] { + let result = read_up_to_byte_from_offset(all_ones, 2, bit_offset); + assert_eq!( + result, expected, + "failed at bit_offset {bit_offset}. result, expected:\n{result:08b}\n{expected:08b}" + ); + } + + for (bit_offset, expected) in [ + (0, 0b00111001), + (1, 0b00011100), + (2, 0b00101110), + (3, 0b00010111), + (4, 0b00001011), + (5, 0b00100101), + (6, 0b00110010), + (7, 0b00011001), + ] { + let result = read_up_to_byte_from_offset(all_ones, 6, bit_offset); + assert_eq!( + result, expected, + "failed at bit_offset {bit_offset}. result, expected:\n{result:08b}\n{expected:08b}" + ); + } + + for (bit_offset, expected) in [ + (0, 0b00111001), + (1, 0b01011100), + (2, 0b00101110), + (3, 0b00010111), + (4, 0b01001011), + (5, 0b01100101), + (6, 0b00110010), + (7, 0b00011001), + ] { + let result = read_up_to_byte_from_offset(all_ones, 7, bit_offset); + assert_eq!( + result, expected, + "failed at bit_offset {bit_offset}. result, expected:\n{result:08b}\n{expected:08b}" + ); + } + } + + /// Verifies that a unary operation applied to a buffer using u64 chunks + /// is the same as applying the operation bit by bit. + fn test_mutable_buffer_bin_op_helper( + left_data: &[bool], + right_data: &[bool], + left_offset_in_bits: usize, + right_offset_in_bits: usize, + len_in_bits: usize, + op: F, + mut expected_op: G, + ) where + F: FnMut(u64, u64) -> u64, + G: FnMut(bool, bool) -> bool, + { + let mut left_buffer = BooleanBufferBuilder::new(len_in_bits); + left_buffer.append_slice(left_data); + let right_buffer = BooleanBuffer::from(right_data); + + let expected: Vec = left_data + .iter() + .skip(left_offset_in_bits) + .zip(right_data.iter().skip(right_offset_in_bits)) + .take(len_in_bits) + .map(|(l, r)| expected_op(*l, *r)) + .collect(); + + apply_bitwise_binary_op( + left_buffer.as_slice_mut(), + left_offset_in_bits, + right_buffer.inner(), + right_offset_in_bits, + len_in_bits, + op, + ); + + let result: Vec = + BitIterator::new(left_buffer.as_slice(), left_offset_in_bits, len_in_bits).collect(); + + assert_eq!( + result, expected, + "Failed with left_offset={}, right_offset={}, len={}", + left_offset_in_bits, right_offset_in_bits, len_in_bits + ); + } + + /// Verifies that a unary operation applied to a buffer using u64 chunks + /// is the same as applying the operation bit by bit. + fn test_mutable_buffer_unary_op_helper( + data: &[bool], + offset_in_bits: usize, + len_in_bits: usize, + op: F, + mut expected_op: G, + ) where + F: FnMut(u64) -> u64, + G: FnMut(bool) -> bool, + { + let mut buffer = BooleanBufferBuilder::new(len_in_bits); + buffer.append_slice(data); + + let expected: Vec = data + .iter() + .skip(offset_in_bits) + .take(len_in_bits) + .map(|b| expected_op(*b)) + .collect(); + + apply_bitwise_unary_op(buffer.as_slice_mut(), offset_in_bits, len_in_bits, op); + + let result: Vec = + BitIterator::new(buffer.as_slice(), offset_in_bits, len_in_bits).collect(); + + assert_eq!( + result, expected, + "Failed with offset={}, len={}", + offset_in_bits, len_in_bits + ); + } + + // Helper to create test data of specific length + fn create_test_data(len: usize) -> (Vec, Vec) { + let mut rng = rand::rng(); + let left: Vec = (0..len).map(|_| rng.random_bool(0.5)).collect(); + let right: Vec = (0..len).map(|_| rng.random_bool(0.5)).collect(); + (left, right) + } + + /// Test all binary operations (AND, OR, XOR) with the given parameters + fn test_all_binary_ops( + left_data: &[bool], + right_data: &[bool], + left_offset_in_bits: usize, + right_offset_in_bits: usize, + len_in_bits: usize, + ) { + // Test AND + test_mutable_buffer_bin_op_helper( + left_data, + right_data, + left_offset_in_bits, + right_offset_in_bits, + len_in_bits, + |a, b| a & b, + |a, b| a & b, + ); + + // Test OR + test_mutable_buffer_bin_op_helper( + left_data, + right_data, + left_offset_in_bits, + right_offset_in_bits, + len_in_bits, + |a, b| a | b, + |a, b| a | b, + ); + + // Test XOR + test_mutable_buffer_bin_op_helper( + left_data, + right_data, + left_offset_in_bits, + right_offset_in_bits, + len_in_bits, + |a, b| a ^ b, + |a, b| a ^ b, + ); + } + + // ===== Combined Binary Operation Tests ===== + + #[test] + fn test_binary_ops_less_than_byte() { + let (left, right) = create_test_data(4); + test_all_binary_ops(&left, &right, 0, 0, 4); + } + + #[test] + fn test_binary_ops_less_than_byte_across_boundary() { + let (left, right) = create_test_data(16); + test_all_binary_ops(&left, &right, 6, 6, 4); + } + + #[test] + fn test_binary_ops_exactly_byte() { + let (left, right) = create_test_data(16); + test_all_binary_ops(&left, &right, 0, 0, 8); + } + + #[test] + fn test_binary_ops_more_than_byte_less_than_u64() { + let (left, right) = create_test_data(64); + test_all_binary_ops(&left, &right, 0, 0, 32); + } + + #[test] + fn test_binary_ops_exactly_u64() { + let (left, right) = create_test_data(180); + test_all_binary_ops(&left, &right, 0, 0, 64); + test_all_binary_ops(&left, &right, 64, 9, 64); + test_all_binary_ops(&left, &right, 8, 100, 64); + test_all_binary_ops(&left, &right, 1, 15, 64); + test_all_binary_ops(&left, &right, 12, 10, 64); + test_all_binary_ops(&left, &right, 180 - 64, 2, 64); + } + + #[test] + fn test_binary_ops_more_than_u64_not_multiple() { + let (left, right) = create_test_data(200); + test_all_binary_ops(&left, &right, 0, 0, 100); + } + + #[test] + fn test_binary_ops_exactly_multiple_u64() { + let (left, right) = create_test_data(256); + test_all_binary_ops(&left, &right, 0, 0, 128); + } + + #[test] + fn test_binary_ops_more_than_multiple_u64() { + let (left, right) = create_test_data(300); + test_all_binary_ops(&left, &right, 0, 0, 200); + } + + #[test] + fn test_binary_ops_byte_aligned_no_remainder() { + let (left, right) = create_test_data(200); + test_all_binary_ops(&left, &right, 0, 0, 128); + } + + #[test] + fn test_binary_ops_byte_aligned_with_remainder() { + let (left, right) = create_test_data(200); + test_all_binary_ops(&left, &right, 0, 0, 100); + } + + #[test] + fn test_binary_ops_not_byte_aligned_no_remainder() { + let (left, right) = create_test_data(200); + test_all_binary_ops(&left, &right, 3, 3, 128); + } + + #[test] + fn test_binary_ops_not_byte_aligned_with_remainder() { + let (left, right) = create_test_data(200); + test_all_binary_ops(&left, &right, 5, 5, 100); + } + + #[test] + fn test_binary_ops_different_offsets() { + let (left, right) = create_test_data(200); + test_all_binary_ops(&left, &right, 3, 7, 50); + } + + #[test] + fn test_binary_ops_offsets_greater_than_8_less_than_64() { + let (left, right) = create_test_data(200); + test_all_binary_ops(&left, &right, 13, 27, 100); + } + + // ===== NOT (Unary) Operation Tests ===== + + #[test] + fn test_not_less_than_byte() { + let data = vec![true, false, true, false]; + test_mutable_buffer_unary_op_helper(&data, 0, 4, |a| !a, |a| !a); + } + + #[test] + fn test_not_less_than_byte_across_boundary() { + let data: Vec = (0..16).map(|i| i % 2 == 0).collect(); + test_mutable_buffer_unary_op_helper(&data, 6, 4, |a| !a, |a| !a); + } + + #[test] + fn test_not_exactly_byte() { + let data: Vec = (0..16).map(|i| i % 2 == 0).collect(); + test_mutable_buffer_unary_op_helper(&data, 0, 8, |a| !a, |a| !a); + } + + #[test] + fn test_not_more_than_byte_less_than_u64() { + let data: Vec = (0..64).map(|i| i % 2 == 0).collect(); + test_mutable_buffer_unary_op_helper(&data, 0, 32, |a| !a, |a| !a); + } + + #[test] + fn test_not_exactly_u64() { + let data: Vec = (0..128).map(|i| i % 2 == 0).collect(); + test_mutable_buffer_unary_op_helper(&data, 0, 64, |a| !a, |a| !a); + } + + #[test] + fn test_not_more_than_u64_not_multiple() { + let data: Vec = (0..200).map(|i| i % 2 == 0).collect(); + test_mutable_buffer_unary_op_helper(&data, 0, 100, |a| !a, |a| !a); + } + + #[test] + fn test_not_exactly_multiple_u64() { + let data: Vec = (0..256).map(|i| i % 2 == 0).collect(); + test_mutable_buffer_unary_op_helper(&data, 0, 128, |a| !a, |a| !a); + } + + #[test] + fn test_not_more_than_multiple_u64() { + let data: Vec = (0..300).map(|i| i % 2 == 0).collect(); + test_mutable_buffer_unary_op_helper(&data, 0, 200, |a| !a, |a| !a); + } + + #[test] + fn test_not_byte_aligned_no_remainder() { + let data: Vec = (0..200).map(|i| i % 2 == 0).collect(); + test_mutable_buffer_unary_op_helper(&data, 0, 128, |a| !a, |a| !a); + } + + #[test] + fn test_not_byte_aligned_with_remainder() { + let data: Vec = (0..200).map(|i| i % 2 == 0).collect(); + test_mutable_buffer_unary_op_helper(&data, 0, 100, |a| !a, |a| !a); + } + + #[test] + fn test_not_not_byte_aligned_no_remainder() { + let data: Vec = (0..200).map(|i| i % 2 == 0).collect(); + test_mutable_buffer_unary_op_helper(&data, 3, 128, |a| !a, |a| !a); + } + + #[test] + fn test_not_not_byte_aligned_with_remainder() { + let data: Vec = (0..200).map(|i| i % 2 == 0).collect(); + test_mutable_buffer_unary_op_helper(&data, 5, 100, |a| !a, |a| !a); + } + + // ===== Edge Cases ===== + + #[test] + fn test_empty_length() { + let (left, right) = create_test_data(16); + test_all_binary_ops(&left, &right, 0, 0, 0); + } + + #[test] + fn test_single_bit() { + let (left, right) = create_test_data(16); + test_all_binary_ops(&left, &right, 0, 0, 1); + } + + #[test] + fn test_single_bit_at_offset() { + let (left, right) = create_test_data(16); + test_all_binary_ops(&left, &right, 7, 7, 1); + } + + #[test] + fn test_not_single_bit() { + let data = vec![true, false, true, false]; + test_mutable_buffer_unary_op_helper(&data, 0, 1, |a| !a, |a| !a); + } + + #[test] + fn test_not_empty_length() { + let data = vec![true, false, true, false]; + test_mutable_buffer_unary_op_helper(&data, 0, 0, |a| !a, |a| !a); + } + + #[test] + fn test_less_than_byte_unaligned_and_not_enough_bits() { + let left_offset_in_bits = 2; + let right_offset_in_bits = 4; + let len_in_bits = 1; + + // Single byte + let right = (0..8).map(|i| (i / 2) % 2 == 0).collect::>(); + // less than a byte + let left = (0..3).map(|i| i % 2 == 0).collect::>(); + test_all_binary_ops( + &left, + &right, + left_offset_in_bits, + right_offset_in_bits, + len_in_bits, + ); + } + + #[test] + fn test_bitwise_binary_op_offset_out_of_bounds() { + let input = vec![0b10101010u8, 0b01010101u8]; + let mut buffer = MutableBuffer::new(2); // space for 16 bits + buffer.extend_from_slice(&input); // only 2 bytes + apply_bitwise_binary_op( + buffer.as_slice_mut(), + 100, // exceeds buffer length, becomes a noop + [0b11110000u8, 0b00001111u8], + 0, + 0, + |a, b| a & b, + ); + assert_eq!(buffer.as_slice(), &input); + } + + #[test] + #[should_panic(expected = "assertion failed: last_offset <= buffer.len()")] + fn test_bitwise_binary_op_length_out_of_bounds() { + let mut buffer = MutableBuffer::new(2); // space for 16 bits + buffer.extend_from_slice(&[0b10101010u8, 0b01010101u8]); // only 2 bytes + apply_bitwise_binary_op( + buffer.as_slice_mut(), + 0, // exceeds buffer length + [0b11110000u8, 0b00001111u8], + 0, + 100, + |a, b| a & b, + ); + assert_eq!(buffer.as_slice(), &[0b10101010u8, 0b01010101u8]); + } + + #[test] + #[should_panic(expected = "offset + len out of bounds")] + fn test_bitwise_binary_op_right_len_out_of_bounds() { + let mut buffer = MutableBuffer::new(2); // space for 16 bits + buffer.extend_from_slice(&[0b10101010u8, 0b01010101u8]); // only 2 bytes + apply_bitwise_binary_op( + buffer.as_slice_mut(), + 0, // exceeds buffer length + [0b11110000u8, 0b00001111u8], + 1000, + 16, + |a, b| a & b, + ); + assert_eq!(buffer.as_slice(), &[0b10101010u8, 0b01010101u8]); + } + + #[test] + #[should_panic(expected = "the len is 2 but the index is 12")] + fn test_bitwise_unary_op_offset_out_of_bounds() { + let input = vec![0b10101010u8, 0b01010101u8]; + let mut buffer = MutableBuffer::new(2); // space for 16 bits + buffer.extend_from_slice(&input); // only 2 bytes + apply_bitwise_unary_op( + buffer.as_slice_mut(), + 100, // exceeds buffer length, becomes a noop + 8, + |a| !a, + ); + assert_eq!(buffer.as_slice(), &input); + } + + #[test] + #[should_panic(expected = "assertion failed: last_offset <= buffer.len()")] + fn test_bitwise_unary_op_length_out_of_bounds2() { + let input = vec![0b10101010u8, 0b01010101u8]; + let mut buffer = MutableBuffer::new(2); // space for 16 bits + buffer.extend_from_slice(&input); // only 2 bytes + apply_bitwise_unary_op( + buffer.as_slice_mut(), + 3, // start at bit 3, to exercise different path + 100, // exceeds buffer length + |a| !a, + ); + assert_eq!(buffer.as_slice(), &input); + } } diff --git a/arrow-cast/Cargo.toml b/arrow-cast/Cargo.toml index 49145cf987f9..536bc101a816 100644 --- a/arrow-cast/Cargo.toml +++ b/arrow-cast/Cargo.toml @@ -43,19 +43,20 @@ force_validate = [] arrow-array = { workspace = true } arrow-buffer = { workspace = true } arrow-data = { workspace = true } +arrow-ord = { workspace = true } arrow-schema = { workspace = true } arrow-select = { workspace = true } chrono = { workspace = true } half = { version = "2.1", default-features = false } -num = { version = "0.4", default-features = false, features = ["std"] } +num-traits = { version = "0.2.19", default-features = false, features = ["std"] } lexical-core = { version = "1.0", default-features = false, features = ["write-integers", "write-floats", "parse-integers", "parse-floats"] } atoi = "2.0.0" -comfy-table = { version = "7.0", optional = true, default-features = false } +comfy-table = { version = "7", optional = true, default-features = false } base64 = "0.22" ryu = "1.0.16" [dev-dependencies] -criterion = { version = "0.5", default-features = false } +criterion = { workspace = true, default-features = false } half = { version = "2.1", default-features = false } rand = "0.9" @@ -74,3 +75,4 @@ harness = false [[bench]] name = "parse_decimal" harness = false + diff --git a/arrow-cast/src/base64.rs b/arrow-cast/src/base64.rs index e7bb84ebe24c..5637bdc689d9 100644 --- a/arrow-cast/src/base64.rs +++ b/arrow-cast/src/base64.rs @@ -79,18 +79,14 @@ pub fn b64_decode( // Safety: offsets monotonically increasing by construction let offsets = unsafe { OffsetBuffer::new_unchecked(offsets.into()) }; - Ok(GenericBinaryArray::new( - offsets, - Buffer::from_vec(buffer), - array.nulls().cloned(), - )) + GenericBinaryArray::try_new(offsets, Buffer::from_vec(buffer), array.nulls().cloned()) } #[cfg(test)] mod tests { use super::*; use arrow_array::BinaryArray; - use rand::{rng, Rng}; + use rand::{Rng, rng}; fn test_engine(e: &E, a: &BinaryArray) { let encoded = b64_encode(e, a); diff --git a/arrow-cast/src/cast/decimal.rs b/arrow-cast/src/cast/decimal.rs index b86d93bc81a7..71338a6921e9 100644 --- a/arrow-cast/src/cast/decimal.rs +++ b/arrow-cast/src/cast/decimal.rs @@ -19,17 +19,89 @@ use crate::cast::*; /// A utility trait that provides checked conversions between /// decimal types inspired by [`NumCast`] -pub(crate) trait DecimalCast: Sized { +pub trait DecimalCast: Sized { + /// Convert the decimal to an i32 + fn to_i32(self) -> Option; + + /// Convert the decimal to an i64 + fn to_i64(self) -> Option; + + /// Convert the decimal to an i128 fn to_i128(self) -> Option; + /// Convert the decimal to an i256 fn to_i256(self) -> Option; + /// Convert a decimal from a decimal fn from_decimal(n: T) -> Option; + /// Convert a decimal from a f64 fn from_f64(n: f64) -> Option; } +impl DecimalCast for i32 { + fn to_i32(self) -> Option { + Some(self) + } + + fn to_i64(self) -> Option { + Some(self as i64) + } + + fn to_i128(self) -> Option { + Some(self as i128) + } + + fn to_i256(self) -> Option { + Some(i256::from_i128(self as i128)) + } + + fn from_decimal(n: T) -> Option { + n.to_i32() + } + + fn from_f64(n: f64) -> Option { + n.to_i32() + } +} + +impl DecimalCast for i64 { + fn to_i32(self) -> Option { + i32::try_from(self).ok() + } + + fn to_i64(self) -> Option { + Some(self) + } + + fn to_i128(self) -> Option { + Some(self as i128) + } + + fn to_i256(self) -> Option { + Some(i256::from_i128(self as i128)) + } + + fn from_decimal(n: T) -> Option { + n.to_i64() + } + + fn from_f64(n: f64) -> Option { + // Call implementation explicitly otherwise this resolves to `to_i64` + // in arrow-buffer that behaves differently. + num_traits::ToPrimitive::to_i64(&n) + } +} + impl DecimalCast for i128 { + fn to_i32(self) -> Option { + i32::try_from(self).ok() + } + + fn to_i64(self) -> Option { + i64::try_from(self).ok() + } + fn to_i128(self) -> Option { Some(self) } @@ -48,6 +120,14 @@ impl DecimalCast for i128 { } impl DecimalCast for i256 { + fn to_i32(self) -> Option { + self.to_i128().map(|x| i32::try_from(x).ok())? + } + + fn to_i64(self) -> Option { + self.to_i128().map(|x| i64::try_from(x).ok())? + } + fn to_i128(self) -> Option { self.to_i128() } @@ -65,63 +145,96 @@ impl DecimalCast for i256 { } } -pub(crate) fn cast_decimal_to_decimal_error( +/// Construct closures to upscale decimals from `(input_precision, input_scale)` to +/// `(output_precision, output_scale)`. +/// +/// Returns `(f_fallible, f_infallible)` where: +/// * `f_fallible` yields `None` when the requested cast would overflow +/// * `f_infallible` is present only when every input is guaranteed to succeed; otherwise it is `None` +/// and callers must fall back to `f_fallible` +/// +/// Returns `None` if the required scale increase `delta_scale = output_scale - input_scale` +/// exceeds the supported precomputed precision table `O::MAX_FOR_EACH_PRECISION`. +/// In that case, the caller should treat this as an overflow for the output scale +/// and handle it accordingly (e.g., return a cast error). +#[allow(clippy::type_complexity)] +fn make_upscaler( + input_precision: u8, + input_scale: i8, output_precision: u8, output_scale: i8, -) -> impl Fn(::Native) -> ArrowError +) -> Option<( + impl Fn(I::Native) -> Option, + Option O::Native>, +)> where - I: DecimalType, - O: DecimalType, I::Native: DecimalCast + ArrowNativeTypeOp, O::Native: DecimalCast + ArrowNativeTypeOp, { - move |x: I::Native| { - ArrowError::CastError(format!( - "Cannot cast to {}({}, {}). Overflowing on {:?}", - O::PREFIX, - output_precision, - output_scale, - x - )) - } + let delta_scale = output_scale - input_scale; + + // O::MAX_FOR_EACH_PRECISION[k] stores 10^k - 1 (e.g., 9, 99, 999, ...). + // Adding 1 yields exactly 10^k without computing a power at runtime. + // Using the precomputed table avoids pow(10, k) and its checked/overflow + // handling, which is faster and simpler for scaling by 10^delta_scale. + let max = O::MAX_FOR_EACH_PRECISION.get(delta_scale as usize)?; + let mul = max.add_wrapping(O::Native::ONE); + let f_fallible = move |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok()); + + // if the gain in precision (digits) is greater than the multiplication due to scaling + // every number will fit into the output type + // Example: If we are starting with any number of precision 5 [xxxxx], + // then an increase of scale by 3 will have the following effect on the representation: + // [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type + // needs to provide at least 8 digits precision + let is_infallible_cast = (input_precision as i8) + delta_scale <= (output_precision as i8); + let f_infallible = is_infallible_cast + .then_some(move |x| O::Native::from_decimal(x).unwrap().mul_wrapping(mul)); + Some((f_fallible, f_infallible)) } -pub(crate) fn convert_to_smaller_scale_decimal( - array: &PrimitiveArray, +/// Construct closures to downscale decimals from `(input_precision, input_scale)` to +/// `(output_precision, output_scale)`. +/// +/// Returns `(f_fallible, f_infallible)` where: +/// * `f_fallible` yields `None` when the requested cast would overflow +/// * `f_infallible` is present only when every input is guaranteed to succeed; otherwise it is `None` +/// and callers must fall back to `f_fallible` +/// +/// Returns `None` if the required scale reduction `delta_scale = input_scale - output_scale` +/// exceeds the supported precomputed precision table `I::MAX_FOR_EACH_PRECISION`. +/// In this scenario, any value would round to zero (e.g., dividing by 10^k where k exceeds the +/// available precision). Callers should therefore produce zero values (preserving nulls) rather +/// than returning an error. +#[allow(clippy::type_complexity)] +fn make_downscaler( input_precision: u8, input_scale: i8, output_precision: u8, output_scale: i8, - cast_options: &CastOptions, -) -> Result, ArrowError> +) -> Option<( + impl Fn(I::Native) -> Option, + Option O::Native>, +)> where - I: DecimalType, - O: DecimalType, I::Native: DecimalCast + ArrowNativeTypeOp, O::Native: DecimalCast + ArrowNativeTypeOp, { - let error = cast_decimal_to_decimal_error::(output_precision, output_scale); let delta_scale = input_scale - output_scale; - // if the reduction of the input number through scaling (dividing) is greater - // than a possible precision loss (plus potential increase via rounding) - // every input number will fit into the output type - // Example: If we are starting with any number of precision 5 [xxxxx], - // then and decrease the scale by 3 will have the following effect on the representation: - // [xxxxx] -> [xx] (+ 1 possibly, due to rounding). - // The rounding may add an additional digit, so the cast to be infallible, - // the output type needs to have at least 3 digits of precision. - // e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100: - // [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible - let is_infallible_cast = (input_precision as i8) - delta_scale < (output_precision as i8); - let div = I::Native::from_decimal(10_i128) - .unwrap() - .pow_checked(delta_scale as u32)?; + // delta_scale is guaranteed to be > 0, but may also be larger than I::MAX_PRECISION. If so, the + // scale change divides out more digits than the input has precision and the result of the cast + // is always zero. For example, if we try to apply delta_scale=10 a decimal32 value, the largest + // possible result is 999999999/10000000000 = 0.0999999999, which rounds to zero. Smaller values + // (e.g. 1/10000000000) or larger delta_scale (e.g. 999999999/10000000000000) produce even + // smaller results, which also round to zero. In that case, just return an array of zeros. + let max = I::MAX_FOR_EACH_PRECISION.get(delta_scale as usize)?; - let half = div.div_wrapping(I::Native::from_usize(2).unwrap()); + let div = max.add_wrapping(I::Native::ONE); + let half = div.div_wrapping(I::Native::ONE.add_wrapping(I::Native::ONE)); let half_neg = half.neg_wrapping(); - let f = |x: I::Native| { + let f_fallible = move |x: I::Native| { // div is >= 10 and so this cannot overflow let d = x.div_wrapping(div); let r = x.mod_wrapping(div); @@ -135,23 +248,136 @@ where O::Native::from_decimal(adjusted) }; - Ok(if is_infallible_cast { - // make sure we don't perform calculations that don't make sense w/o validation - validate_decimal_precision_and_scale::(output_precision, output_scale)?; - let g = |x: I::Native| f(x).unwrap(); // unwrapping is safe since the result is guaranteed - // to fit into the target type - array.unary(g) + // if the reduction of the input number through scaling (dividing) is greater + // than a possible precision loss (plus potential increase via rounding) + // every input number will fit into the output type + // Example: If we are starting with any number of precision 5 [xxxxx], + // then and decrease the scale by 3 will have the following effect on the representation: + // [xxxxx] -> [xx] (+ 1 possibly, due to rounding). + // The rounding may add a digit, so the cast to be infallible, + // the output type needs to have at least 3 digits of precision. + // e.g. Decimal(5, 3) 99.999 to Decimal(3, 0) will result in 100: + // [99999] -> [99] + 1 = [100], a cast to Decimal(2, 0) would not be possible + let is_infallible_cast = (input_precision as i8) - delta_scale < (output_precision as i8); + let f_infallible = is_infallible_cast.then_some(move |x| f_fallible(x).unwrap()); + Some((f_fallible, f_infallible)) +} + +/// Apply the rescaler function to the value. +/// If the rescaler is infallible, use the infallible function. +/// Otherwise, use the fallible function and validate the precision. +fn apply_rescaler( + value: I::Native, + output_precision: u8, + f: impl Fn(I::Native) -> Option, + f_infallible: Option O::Native>, +) -> Option +where + I::Native: DecimalCast, + O::Native: DecimalCast, +{ + if let Some(f_infallible) = f_infallible { + Some(f_infallible(value)) + } else { + f(value).filter(|v| O::is_valid_decimal_precision(*v, output_precision)) + } +} + +/// Rescales a decimal value from `(input_precision, input_scale)` to +/// `(output_precision, output_scale)` and returns the converted number when it fits +/// within the output precision. +/// +/// The function first validates that the requested precision and scale are supported for +/// both the source and destination decimal types. It then either upscales (multiplying +/// by an appropriate power of ten) or downscales (dividing with rounding) the input value. +/// When the scaling factor exceeds the precision table of the destination type, the value +/// is treated as an overflow for upscaling, or rounded to zero for downscaling (as any +/// possible result would be zero at the requested scale). +/// +/// This mirrors the column-oriented helpers of decimal casting but operates on a single value +/// (row-level) instead of an entire array. +/// +/// Returns `None` if the value cannot be represented with the requested precision. +pub fn rescale_decimal( + value: I::Native, + input_precision: u8, + input_scale: i8, + output_precision: u8, + output_scale: i8, +) -> Option +where + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + validate_decimal_precision_and_scale::(input_precision, input_scale).ok()?; + validate_decimal_precision_and_scale::(output_precision, output_scale).ok()?; + + if input_scale <= output_scale { + let (f, f_infallible) = + make_upscaler::(input_precision, input_scale, output_precision, output_scale)?; + apply_rescaler::(value, output_precision, f, f_infallible) + } else { + let Some((f, f_infallible)) = + make_downscaler::(input_precision, input_scale, output_precision, output_scale) + else { + // Scale reduction exceeds supported precision; result mathematically rounds to zero + return Some(O::Native::ZERO); + }; + apply_rescaler::(value, output_precision, f, f_infallible) + } +} + +fn cast_decimal_to_decimal_error( + output_precision: u8, + output_scale: i8, +) -> impl Fn(::Native) -> ArrowError +where + I: DecimalType, + O: DecimalType, + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + move |x: I::Native| { + ArrowError::CastError(format!( + "Cannot cast to {}({}, {}). Overflowing on {:?}", + O::PREFIX, + output_precision, + output_scale, + x + )) + } +} + +fn apply_decimal_cast( + array: &PrimitiveArray, + output_precision: u8, + output_scale: i8, + f_fallible: impl Fn(I::Native) -> Option, + f_infallible: Option O::Native>, + cast_options: &CastOptions, +) -> Result, ArrowError> +where + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + let array = if let Some(f_infallible) = f_infallible { + array.unary(f_infallible) } else if cast_options.safe { - array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision))) + array.unary_opt(|x| { + f_fallible(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision)) + }) } else { + let error = cast_decimal_to_decimal_error::(output_precision, output_scale); array.try_unary(|x| { - f(x).ok_or_else(|| error(x)) - .and_then(|v| O::validate_decimal_precision(v, output_precision).map(|_| v)) + f_fallible(x).ok_or_else(|| error(x)).and_then(|v| { + O::validate_decimal_precision(v, output_precision, output_scale).map(|_| v) + }) })? - }) + }; + Ok(array) } -pub(crate) fn convert_to_bigger_or_equal_scale_decimal( +fn convert_to_smaller_scale_decimal( array: &PrimitiveArray, input_precision: u8, input_scale: i8, @@ -165,35 +391,58 @@ where I::Native: DecimalCast + ArrowNativeTypeOp, O::Native: DecimalCast + ArrowNativeTypeOp, { - let error = cast_decimal_to_decimal_error::(output_precision, output_scale); - let delta_scale = output_scale - input_scale; - let mul = O::Native::from_decimal(10_i128) - .unwrap() - .pow_checked(delta_scale as u32)?; + if let Some((f_fallible, f_infallible)) = + make_downscaler::(input_precision, input_scale, output_precision, output_scale) + { + apply_decimal_cast( + array, + output_precision, + output_scale, + f_fallible, + f_infallible, + cast_options, + ) + } else { + // Scale reduction exceeds supported precision; result mathematically rounds to zero + let zeros = vec![O::Native::ZERO; array.len()]; + Ok(PrimitiveArray::new(zeros.into(), array.nulls().cloned())) + } +} - // if the gain in precision (digits) is greater than the multiplication due to scaling - // every number will fit into the output type - // Example: If we are starting with any number of precision 5 [xxxxx], - // then an increase of scale by 3 will have the following effect on the representation: - // [xxxxx] -> [xxxxx000], so for the cast to be infallible, the output type - // needs to provide at least 8 digits precision - let is_infallible_cast = (input_precision as i8) + delta_scale <= (output_precision as i8); - let f = |x| O::Native::from_decimal(x).and_then(|x| x.mul_checked(mul).ok()); - - Ok(if is_infallible_cast { - // make sure we don't perform calculations that don't make sense w/o validation - validate_decimal_precision_and_scale::(output_precision, output_scale)?; - // unwrapping is safe since the result is guaranteed to fit into the target type - let f = |x| O::Native::from_decimal(x).unwrap().mul_wrapping(mul); - array.unary(f) - } else if cast_options.safe { - array.unary_opt(|x| f(x).filter(|v| O::is_valid_decimal_precision(*v, output_precision))) +fn convert_to_bigger_or_equal_scale_decimal( + array: &PrimitiveArray, + input_precision: u8, + input_scale: i8, + output_precision: u8, + output_scale: i8, + cast_options: &CastOptions, +) -> Result, ArrowError> +where + I: DecimalType, + O: DecimalType, + I::Native: DecimalCast + ArrowNativeTypeOp, + O::Native: DecimalCast + ArrowNativeTypeOp, +{ + if let Some((f, f_infallible)) = + make_upscaler::(input_precision, input_scale, output_precision, output_scale) + { + apply_decimal_cast( + array, + output_precision, + output_scale, + f, + f_infallible, + cast_options, + ) } else { - array.try_unary(|x| { - f(x).ok_or_else(|| error(x)) - .and_then(|v| O::validate_decimal_precision(v, output_precision).map(|_| v)) - })? - }) + // Scale increase exceeds supported precision; return overflow error + Err(ArrowError::CastError(format!( + "Cannot cast to {}({}, {}). Value overflows for output scale", + O::PREFIX, + output_precision, + output_scale + ))) + } } // Only support one type of decimal cast operations @@ -412,12 +661,11 @@ where parse_string_to_decimal_native::(v, scale as usize) .map_err(|_| { ArrowError::CastError(format!( - "Cannot cast string '{}' to value of {:?} type", - v, + "Cannot cast string '{v}' to value of {} type", T::DATA_TYPE, )) }) - .and_then(|v| T::validate_decimal_precision(v, precision).map(|_| v)) + .and_then(|v| T::validate_decimal_precision(v, precision, scale).map(|_| v)) }) .transpose() }) @@ -505,9 +753,8 @@ where )?, other => { return Err(ArrowError::ComputeError(format!( - "Cannot cast {:?} to decimal", - other - ))) + "Cannot cast {other:?} to decimal", + ))); } }; @@ -548,7 +795,7 @@ where v )) }) - .and_then(|v| D::validate_decimal_precision(v, precision).map(|_| v)) + .and_then(|v| D::validate_decimal_precision(v, precision, scale).map(|_| v)) })? .with_precision_and_scale(precision, scale) .map(|a| Arc::new(a) as ArrayRef) @@ -615,7 +862,11 @@ where Ok(Arc::new(value_builder.finish())) } -// Cast the decimal array to floating-point array +/// Cast a decimal array to a floating point array. +/// +/// Conversion is lossy and follows standard floating point semantics. Values +/// that exceed the representable range become `INFINITY` or `-INFINITY` without +/// returning an error. pub(crate) fn cast_decimal_to_float( array: &dyn Array, op: F, @@ -671,4 +922,58 @@ mod tests { ); Ok(()) } + + #[test] + fn test_rescale_decimal_upscale_within_precision() { + let result = rescale_decimal::( + 12_345_i128, // 123.45 with scale 2 + 5, + 2, + 8, + 5, + ); + assert_eq!(result, Some(12_345_000_i128)); + } + + #[test] + fn test_rescale_decimal_downscale_rounds_half_away_from_zero() { + let positive = rescale_decimal::( + 1_050_i128, // 1.050 with scale 3 + 5, 3, 5, 1, + ); + assert_eq!(positive, Some(11_i128)); // 1.1 with scale 1 + + let negative = rescale_decimal::( + -1_050_i128, // -1.050 with scale 3 + 5, + 3, + 5, + 1, + ); + assert_eq!(negative, Some(-11_i128)); // -1.1 with scale 1 + } + + #[test] + fn test_rescale_decimal_downscale_large_delta_returns_zero() { + let result = rescale_decimal::(12_345_i32, 9, 9, 9, 4); + assert_eq!(result, Some(0_i32)); + } + + #[test] + fn test_rescale_decimal_upscale_overflow_returns_none() { + let result = rescale_decimal::(9_999_i32, 4, 0, 5, 2); + assert_eq!(result, None); + } + + #[test] + fn test_rescale_decimal_invalid_input_precision_scale_returns_none() { + let result = rescale_decimal::(123_i128, 39, 39, 38, 38); + assert_eq!(result, None); + } + + #[test] + fn test_rescale_decimal_invalid_output_precision_scale_returns_none() { + let result = rescale_decimal::(123_i128, 38, 38, 39, 39); + assert_eq!(result, None); + } } diff --git a/arrow-cast/src/cast/dictionary.rs b/arrow-cast/src/cast/dictionary.rs index eae2f2167b39..64d77236ccd5 100644 --- a/arrow-cast/src/cast/dictionary.rs +++ b/arrow-cast/src/cast/dictionary.rs @@ -28,111 +28,92 @@ pub(crate) fn dictionary_cast( ) -> Result { use DataType::*; - match to_type { - Dictionary(to_index_type, to_value_type) => { - let dict_array = array - .as_any() - .downcast_ref::>() - .ok_or_else(|| { - ArrowError::ComputeError( - "Internal Error: Cannot cast dictionary to DictionaryArray of expected type".to_string(), - ) - })?; - - let keys_array: ArrayRef = - Arc::new(PrimitiveArray::::from(dict_array.keys().to_data())); - let values_array = dict_array.values(); - let cast_keys = cast_with_options(&keys_array, to_index_type, cast_options)?; - let cast_values = cast_with_options(values_array, to_value_type, cast_options)?; + let array = array.as_dictionary::(); + let from_child_type = array.values().data_type(); + match (from_child_type, to_type) { + (_, Dictionary(to_index_type, to_value_type)) => { + dictionary_to_dictionary_cast(array, to_index_type, to_value_type, cast_options) + } + // `unpack_dictionary` can handle Utf8View/BinaryView types, but incurs unnecessary data + // copy of the value buffer. Fast path which avoids copying underlying values buffer. + // TODO: handle LargeUtf8/LargeBinary -> View (need to check offsets can fit) + // TODO: handle cross types (String -> BinaryView, Binary -> StringView) + // (need to validate utf8?) + (Utf8, Utf8View) => view_from_dict_values::( + array.keys(), + array.values().as_string::(), + ), + (Binary, BinaryView) => view_from_dict_values::( + array.keys(), + array.values().as_binary::(), + ), + _ => unpack_dictionary(array, to_type, cast_options), + } +} - // Failure to cast keys (because they don't fit in the - // target type) results in NULL values; - if cast_keys.null_count() > keys_array.null_count() { - return Err(ArrowError::ComputeError(format!( - "Could not convert {} dictionary indexes from {:?} to {:?}", - cast_keys.null_count() - keys_array.null_count(), - keys_array.data_type(), - to_index_type - ))); - } +fn dictionary_to_dictionary_cast( + array: &DictionaryArray, + to_index_type: &DataType, + to_value_type: &DataType, + cast_options: &CastOptions, +) -> Result { + use DataType::*; - let data = cast_keys.into_data(); - let builder = data - .into_builder() - .data_type(to_type.clone()) - .child_data(vec![cast_values.into_data()]); + let keys_array: ArrayRef = Arc::new(PrimitiveArray::::from(array.keys().to_data())); + let values_array = array.values(); + let cast_keys = cast_with_options(&keys_array, to_index_type, cast_options)?; + let cast_values = cast_with_options(values_array, to_value_type, cast_options)?; - // Safety - // Cast keys are still valid - let data = unsafe { builder.build_unchecked() }; + // Failure to cast keys (because they don't fit in the + // target type) results in NULL values; + if cast_keys.null_count() > keys_array.null_count() { + return Err(ArrowError::ComputeError(format!( + "Could not convert {} dictionary indexes from {:?} to {:?}", + cast_keys.null_count() - keys_array.null_count(), + keys_array.data_type(), + to_index_type + ))); + } - // create the appropriate array type - let new_array: ArrayRef = match **to_index_type { - Int8 => Arc::new(DictionaryArray::::from(data)), - Int16 => Arc::new(DictionaryArray::::from(data)), - Int32 => Arc::new(DictionaryArray::::from(data)), - Int64 => Arc::new(DictionaryArray::::from(data)), - UInt8 => Arc::new(DictionaryArray::::from(data)), - UInt16 => Arc::new(DictionaryArray::::from(data)), - UInt32 => Arc::new(DictionaryArray::::from(data)), - UInt64 => Arc::new(DictionaryArray::::from(data)), - _ => { - return Err(ArrowError::CastError(format!( - "Unsupported type {to_index_type:?} for dictionary index" - ))); - } - }; + let data = cast_keys.into_data(); + let builder = data + .into_builder() + .data_type(Dictionary( + Box::new(to_index_type.clone()), + Box::new(to_value_type.clone()), + )) + .child_data(vec![cast_values.into_data()]); - Ok(new_array) - } - Utf8View => { - // `unpack_dictionary` can handle Utf8View/BinaryView types, but incurs unnecessary data copy of the value buffer. - // we handle it here to avoid the copy. - let dict_array = array - .as_dictionary::() - .downcast_dict::() - .ok_or_else(|| { - ArrowError::ComputeError( - "Internal Error: Cannot cast Utf8View to StringArray of expected type" - .to_string(), - ) - })?; + // Safety + // Cast keys are still valid + let data = unsafe { builder.build_unchecked() }; - let string_view = view_from_dict_values::>( - dict_array.values(), - dict_array.keys(), - )?; - Ok(Arc::new(string_view)) + // create the appropriate array type + let new_array: ArrayRef = match to_index_type { + Int8 => Arc::new(DictionaryArray::::from(data)), + Int16 => Arc::new(DictionaryArray::::from(data)), + Int32 => Arc::new(DictionaryArray::::from(data)), + Int64 => Arc::new(DictionaryArray::::from(data)), + UInt8 => Arc::new(DictionaryArray::::from(data)), + UInt16 => Arc::new(DictionaryArray::::from(data)), + UInt32 => Arc::new(DictionaryArray::::from(data)), + UInt64 => Arc::new(DictionaryArray::::from(data)), + _ => { + return Err(ArrowError::CastError(format!( + "Unsupported type {to_index_type} for dictionary index" + ))); } - BinaryView => { - // `unpack_dictionary` can handle Utf8View/BinaryView types, but incurs unnecessary data copy of the value buffer. - // we handle it here to avoid the copy. - let dict_array = array - .as_dictionary::() - .downcast_dict::() - .ok_or_else(|| { - ArrowError::ComputeError( - "Internal Error: Cannot cast BinaryView to BinaryArray of expected type" - .to_string(), - ) - })?; + }; - let binary_view = view_from_dict_values::( - dict_array.values(), - dict_array.keys(), - )?; - Ok(Arc::new(binary_view)) - } - _ => unpack_dictionary::(array, to_type, cast_options), - } + Ok(new_array) } -fn view_from_dict_values( - array: &GenericByteArray, +fn view_from_dict_values( keys: &PrimitiveArray, -) -> Result, ArrowError> { - let value_buffer = array.values(); - let value_offsets = array.value_offsets(); + values: &GenericByteArray, +) -> Result { + let value_buffer = values.values(); + let value_offsets = values.value_offsets(); let mut builder = GenericByteViewBuilder::::with_capacity(keys.len()); builder.append_block(value_buffer.clone()); for i in keys.iter() { @@ -157,21 +138,17 @@ fn view_from_dict_values into a flattened array of type to_type -pub(crate) fn unpack_dictionary( - array: &dyn Array, +// Unpack a dictionary into a flattened array of type to_type +pub(crate) fn unpack_dictionary( + array: &DictionaryArray, to_type: &DataType, cast_options: &CastOptions, -) -> Result -where - K: ArrowDictionaryKeyType, -{ - let dict_array = array.as_dictionary::(); - let cast_dict_values = cast_with_options(dict_array.values(), to_type, cast_options)?; - take(cast_dict_values.as_ref(), dict_array.keys(), None) +) -> Result { + let cast_dict_values = cast_with_options(array.values(), to_type, cast_options)?; + take(cast_dict_values.as_ref(), array.keys(), None) } /// Pack a data type into a dictionary array passing the values through a primitive array @@ -214,6 +191,20 @@ pub(crate) fn cast_to_dictionary( UInt16 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), UInt32 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), UInt64 => pack_numeric_to_dictionary::(array, dict_value_type, cast_options), + Decimal32(p, s) => pack_decimal_to_dictionary::( + array, + dict_value_type, + p, + s, + cast_options, + ), + Decimal64(p, s) => pack_decimal_to_dictionary::( + array, + dict_value_type, + p, + s, + cast_options, + ), Decimal128(p, s) => pack_decimal_to_dictionary::( array, dict_value_type, @@ -299,7 +290,7 @@ pub(crate) fn cast_to_dictionary( pack_byte_to_fixed_size_dictionary::(array, cast_options, byte_size) } _ => Err(ArrowError::CastError(format!( - "Unsupported output type for dictionary packing: {dict_value_type:?}" + "Unsupported output type for dictionary packing: {dict_value_type}" ))), } } diff --git a/arrow-cast/src/cast/list.rs b/arrow-cast/src/cast/list.rs index ddcbca361bf0..f6c8d2465c86 100644 --- a/arrow-cast/src/cast/list.rs +++ b/arrow-cast/src/cast/list.rs @@ -24,8 +24,8 @@ pub(crate) fn cast_values_to_list( cast_options: &CastOptions, ) -> Result { let values = cast_with_options(array, to.data_type(), cast_options)?; - let offsets = OffsetBuffer::from_lengths(std::iter::repeat(1).take(values.len())); - let list = GenericListArray::::new(to.clone(), offsets, values, None); + let offsets = OffsetBuffer::from_repeated_length(1, values.len()); + let list = GenericListArray::::try_new(to.clone(), offsets, values, None)?; Ok(Arc::new(list)) } @@ -37,7 +37,7 @@ pub(crate) fn cast_values_to_fixed_size_list( cast_options: &CastOptions, ) -> Result { let values = cast_with_options(array, to.data_type(), cast_options)?; - let list = FixedSizeListArray::new(to.clone(), size, values, None); + let list = FixedSizeListArray::try_new(to.clone(), size, values, None)?; Ok(Arc::new(list)) } @@ -140,7 +140,7 @@ where // Construct the FixedSizeListArray let nulls = nulls.map(|mut x| x.finish().into()); - let array = FixedSizeListArray::new(field.clone(), size, values, nulls); + let array = FixedSizeListArray::try_new(field.clone(), size, values, nulls)?; Ok(Arc::new(array)) } @@ -152,12 +152,12 @@ pub(crate) fn cast_list_values( ) -> Result { let list = array.as_list::(); let values = cast_with_options(list.values(), to.data_type(), cast_options)?; - Ok(Arc::new(GenericListArray::::new( + Ok(Arc::new(GenericListArray::::try_new( to.clone(), list.offsets().clone(), values, list.nulls().cloned(), - ))) + )?)) } /// Cast the container type of List/Largelist array along with the inner datatype @@ -184,10 +184,10 @@ pub(crate) fn cast_list( // Safety: valid offsets and checked for overflow let offsets = unsafe { OffsetBuffer::new_unchecked(offsets.into()) }; - Ok(Arc::new(GenericListArray::::new( + Ok(Arc::new(GenericListArray::::try_new( field.clone(), offsets, values, nulls, - ))) + )?)) } diff --git a/arrow-cast/src/cast/list_view.rs b/arrow-cast/src/cast/list_view.rs new file mode 100644 index 000000000000..0fdab8c6247d --- /dev/null +++ b/arrow-cast/src/cast/list_view.rs @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::cast::*; + +/// Helper function to cast a list view to a list +pub(crate) fn cast_list_view_to_list( + array: &dyn Array, + to: &FieldRef, + cast_options: &CastOptions, +) -> Result { + let list_view = array.as_list_view::(); + let list_view_offsets = list_view.offsets(); + let sizes = list_view.sizes(); + let source_values = list_view.values(); + + // Construct the indices and offsets for the new list array by iterating over the list view subarrays + let mut indices = Vec::with_capacity(list_view.values().len()); + let mut offsets = Vec::with_capacity(list_view.len() + 1); + // Add the offset for the first subarray + offsets.push(O::usize_as(0)); + for i in 0..list_view.len() { + // For each subarray, add the indices of the values to take + let offset = list_view_offsets[i].as_usize(); + let size = sizes[i].as_usize(); + let end = offset + size; + for j in offset..end { + indices.push(j as i32); + } + // Add the offset for the next subarray + offsets.push(O::usize_as(indices.len())); + } + + // Take the values from the source values using the indices, creating a new array + let values = arrow_select::take::take(source_values, &Int32Array::from(indices), None)?; + + // Cast the values to the target data type + let values = cast_with_options(&values, to.data_type(), cast_options)?; + + Ok(Arc::new(GenericListArray::::try_new( + to.clone(), + OffsetBuffer::new(offsets.into()), + values, + list_view.nulls().cloned(), + )?)) +} + +pub(crate) fn cast_list_view( + array: &dyn Array, + to_field: &FieldRef, + cast_options: &CastOptions, +) -> Result { + let list_view = array.as_list_view::(); + let (_field, offsets, sizes, values, nulls) = list_view.clone().into_parts(); + + // Recursively cast values + let values = cast_with_options(&values, to_field.data_type(), cast_options)?; + + let new_offsets: Vec<_> = offsets.iter().map(|x| O::usize_as(x.as_usize())).collect(); + let new_sizes: Vec<_> = sizes.iter().map(|x| O::usize_as(x.as_usize())).collect(); + Ok(Arc::new(GenericListViewArray::::try_new( + to_field.clone(), + new_offsets.into(), + new_sizes.into(), + values, + nulls, + )?)) +} + +pub(crate) fn cast_list_to_list_view(array: &dyn Array) -> Result +where + OffsetSize: OffsetSizeTrait, +{ + let list = array.as_list::(); + let list_view: GenericListViewArray = list.clone().into(); + Ok(Arc::new(list_view)) +} diff --git a/arrow-cast/src/cast/map.rs b/arrow-cast/src/cast/map.rs index d62a9519b7b3..e7a9b7495edb 100644 --- a/arrow-cast/src/cast/map.rs +++ b/arrow-cast/src/cast/map.rs @@ -42,17 +42,17 @@ pub(crate) fn cast_map_values( let key_array = cast_with_options(from.keys(), key_field.data_type(), cast_options)?; let value_array = cast_with_options(from.values(), value_field.data_type(), cast_options)?; - Ok(Arc::new(MapArray::new( + Ok(Arc::new(MapArray::try_new( entries_field.clone(), from.offsets().clone(), - StructArray::new( + StructArray::try_new( Fields::from(vec![key_field, value_field]), vec![key_array, value_array], from.entries().nulls().cloned(), - ), + )?, from.nulls().cloned(), to_ordered, - ))) + )?)) } /// Gets the key field from the entries of a map. For all other types returns None. diff --git a/arrow-cast/src/cast/mod.rs b/arrow-cast/src/cast/mod.rs index b317dabd5dda..fb77993a3028 100644 --- a/arrow-cast/src/cast/mod.rs +++ b/arrow-cast/src/cast/mod.rs @@ -40,12 +40,16 @@ mod decimal; mod dictionary; mod list; +mod list_view; mod map; +mod run_array; mod string; + use crate::cast::decimal::*; use crate::cast::dictionary::*; use crate::cast::list::*; use crate::cast::map::*; +use crate::cast::run_array::*; use crate::cast::string::*; use arrow_buffer::IntervalMonthDayNano; @@ -56,17 +60,19 @@ use std::sync::Arc; use crate::display::{ArrayFormatter, FormatOptions}; use crate::parse::{ - parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month, - string_to_datetime, Parser, + Parser, parse_interval_day_time, parse_interval_month_day_nano, parse_interval_year_month, + string_to_datetime, }; use arrow_array::{builder::*, cast::*, temporal_conversions::*, timezone::Tz, types::*, *}; -use arrow_buffer::{i256, ArrowNativeType, OffsetBuffer}; -use arrow_data::transform::MutableArrayData; +use arrow_buffer::{ArrowNativeType, OffsetBuffer, i256}; use arrow_data::ArrayData; +use arrow_data::transform::MutableArrayData; use arrow_schema::*; use arrow_select::take::take; -use num::cast::AsPrimitive; -use num::{NumCast, ToPrimitive}; +use num_traits::{NumCast, ToPrimitive, cast::AsPrimitive}; + +use crate::cast::list_view::{cast_list_to_list_view, cast_list_view, cast_list_view_to_list}; +pub use decimal::{DecimalCast, rescale_decimal}; /// CastOptions provides a way to override the default cast behaviors #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -98,45 +104,14 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { } match (from_type, to_type) { - ( - Null, - Boolean - | Int8 - | UInt8 - | Int16 - | UInt16 - | Int32 - | UInt32 - | Float32 - | Date32 - | Time32(_) - | Int64 - | UInt64 - | Float64 - | Date64 - | Timestamp(_, _) - | Time64(_) - | Duration(_) - | Interval(_) - | FixedSizeBinary(_) - | Binary - | Utf8 - | LargeBinary - | LargeUtf8 - | BinaryView - | Utf8View - | List(_) - | LargeList(_) - | FixedSizeList(_, _) - | Struct(_) - | Map(_, _) - | Dictionary(_, _), - ) => true, + (Null, _) => true, // Dictionary/List conditions should be put in front of others (Dictionary(_, from_value_type), Dictionary(_, to_value_type)) => { can_cast_types(from_value_type, to_value_type) } (Dictionary(_, value_type), _) => can_cast_types(value_type, to_type), + (RunEndEncoded(_, value_type), _) => can_cast_types(value_type.data_type(), to_type), + (_, RunEndEncoded(_, value_type)) => can_cast_types(from_type, value_type.data_type()), (_, Dictionary(_, value_type)) => can_cast_types(from_type, value_type), (List(list_from) | LargeList(list_from), List(list_to) | LargeList(list_to)) => { can_cast_types(list_from.data_type(), list_to.data_type()) @@ -147,9 +122,21 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (List(list_from) | LargeList(list_from), FixedSizeList(list_to, _)) => { can_cast_types(list_from.data_type(), list_to.data_type()) } + (List(list_from) | LargeList(list_from), ListView(list_to) | LargeListView(list_to)) => { + can_cast_types(list_from.data_type(), list_to.data_type()) + } (List(_), _) => false, - (FixedSizeList(list_from,_), List(list_to)) | - (FixedSizeList(list_from,_), LargeList(list_to)) => { + (ListView(list_from) | LargeListView(list_from), List(list_to) | LargeList(list_to)) => { + can_cast_types(list_from.data_type(), list_to.data_type()) + } + (ListView(list_from), LargeListView(list_to)) => { + can_cast_types(list_from.data_type(), list_to.data_type()) + } + (LargeListView(list_from), ListView(list_to)) => { + can_cast_types(list_from.data_type(), list_to.data_type()) + } + (FixedSizeList(list_from, _), List(list_to)) + | (FixedSizeList(list_from, _), LargeList(list_to)) => { can_cast_types(list_from.data_type(), list_to.data_type()) } (FixedSizeList(inner, size), FixedSizeList(inner_to, size_to)) if size == size_to => { @@ -157,42 +144,100 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { } (_, List(list_to)) => can_cast_types(from_type, list_to.data_type()), (_, LargeList(list_to)) => can_cast_types(from_type, list_to.data_type()), - (_, FixedSizeList(list_to,size)) if *size == 1 => { - can_cast_types(from_type, list_to.data_type())}, - (FixedSizeList(list_from,size), _) if *size == 1 => { - can_cast_types(list_from.data_type(), to_type)}, - (Map(from_entries,ordered_from), Map(to_entries, ordered_to)) if ordered_from == ordered_to => - match (key_field(from_entries), key_field(to_entries), value_field(from_entries), value_field(to_entries)) { - (Some(from_key), Some(to_key), Some(from_value), Some(to_value)) => - can_cast_types(from_key.data_type(), to_key.data_type()) && can_cast_types(from_value.data_type(), to_value.data_type()), - _ => false - }, + (_, FixedSizeList(list_to, size)) if *size == 1 => { + can_cast_types(from_type, list_to.data_type()) + } + (FixedSizeList(list_from, size), _) if *size == 1 => { + can_cast_types(list_from.data_type(), to_type) + } + (Map(from_entries, ordered_from), Map(to_entries, ordered_to)) + if ordered_from == ordered_to => + { + match ( + key_field(from_entries), + key_field(to_entries), + value_field(from_entries), + value_field(to_entries), + ) { + (Some(from_key), Some(to_key), Some(from_value), Some(to_value)) => { + can_cast_types(from_key.data_type(), to_key.data_type()) + && can_cast_types(from_value.data_type(), to_value.data_type()) + } + _ => false, + } + } // cast one decimal type to another decimal type - (Decimal128(_, _), Decimal128(_, _)) => true, - (Decimal256(_, _), Decimal256(_, _)) => true, - (Decimal128(_, _), Decimal256(_, _)) => true, - (Decimal256(_, _), Decimal128(_, _)) => true, + ( + Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), + Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), + ) => true, // unsigned integer to decimal - (UInt8 | UInt16 | UInt32 | UInt64, Decimal128(_, _)) | - (UInt8 | UInt16 | UInt32 | UInt64, Decimal256(_, _)) | + ( + UInt8 | UInt16 | UInt32 | UInt64, + Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), + ) => true, // signed numeric to decimal - (Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal128(_, _)) | - (Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, Decimal256(_, _)) | + ( + Int8 | Int16 | Int32 | Int64 | Float32 | Float64, + Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), + ) => true, // decimal to unsigned numeric - (Decimal128(_, _) | Decimal256(_, _), UInt8 | UInt16 | UInt32 | UInt64) | + ( + Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), + UInt8 | UInt16 | UInt32 | UInt64, + ) => true, // decimal to signed numeric - (Decimal128(_, _) | Decimal256(_, _), Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64) => true, + ( + Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), + Null | Int8 | Int16 | Int32 | Int64 | Float32 | Float64, + ) => true, // decimal to string - (Decimal128(_, _) | Decimal256(_, _), Utf8View | Utf8 | LargeUtf8) => true, + ( + Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), + Utf8View | Utf8 | LargeUtf8, + ) => true, // string to decimal - (Utf8View | Utf8 | LargeUtf8, Decimal128(_, _) | Decimal256(_, _)) => true, + ( + Utf8View | Utf8 | LargeUtf8, + Decimal32(_, _) | Decimal64(_, _) | Decimal128(_, _) | Decimal256(_, _), + ) => true, (Struct(from_fields), Struct(to_fields)) => { - from_fields.len() == to_fields.len() && - from_fields.iter().zip(to_fields.iter()).all(|(f1, f2)| { + if from_fields.len() != to_fields.len() { + return false; + } + + // fast path, all field names are in the same order and same number of fields + if from_fields + .iter() + .zip(to_fields.iter()) + .all(|(f1, f2)| f1.name() == f2.name()) + { + return from_fields.iter().zip(to_fields.iter()).all(|(f1, f2)| { // Assume that nullability between two structs are compatible, if not, // cast kernel will return error. can_cast_types(f1.data_type(), f2.data_type()) - }) + }); + } + + // slow path, we match the fields by name + if to_fields.iter().all(|to_field| { + from_fields + .iter() + .find(|from_field| from_field.name() == to_field.name()) + .is_some_and(|from_field| { + // Assume that nullability between two structs are compatible, if not, + // cast kernel will return error. + can_cast_types(from_field.data_type(), to_field.data_type()) + }) + }) { + return true; + } + + // if we couldn't match by name, we try to see if they can be matched by position + from_fields + .iter() + .zip(to_fields.iter()) + .all(|(f1, f2)| can_cast_types(f1.data_type(), f2.data_type())) } (Struct(_), _) => false, (_, Struct(_)) => false, @@ -211,8 +256,12 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { || to_type == &LargeUtf8 } - (Binary, LargeBinary | Utf8 | LargeUtf8 | FixedSizeBinary(_) | BinaryView | Utf8View ) => true, - (LargeBinary, Binary | Utf8 | LargeUtf8 | FixedSizeBinary(_) | BinaryView | Utf8View ) => true, + (Binary, LargeBinary | Utf8 | LargeUtf8 | FixedSizeBinary(_) | BinaryView | Utf8View) => { + true + } + (LargeBinary, Binary | Utf8 | LargeUtf8 | FixedSizeBinary(_) | BinaryView | Utf8View) => { + true + } (FixedSizeBinary(_), Binary | LargeBinary | BinaryView) => true, ( Utf8 | LargeUtf8 | Utf8View, @@ -236,22 +285,23 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool { (Utf8 | LargeUtf8, Utf8View) => true, (BinaryView, Binary | LargeBinary | Utf8 | LargeUtf8 | Utf8View) => true, (Utf8View | Utf8 | LargeUtf8, _) => to_type.is_numeric() && to_type != &Float16, - (_, Utf8 | LargeUtf8) => from_type.is_primitive(), - (_, Utf8View) => from_type.is_numeric(), + (_, Utf8 | Utf8View | LargeUtf8) => from_type.is_primitive(), (_, Binary | LargeBinary) => from_type.is_integer(), // start numeric casts ( - UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float16 | Float32 | Float64, - UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float16 | Float32 | Float64, + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float16 | Float32 + | Float64, + UInt8 | UInt16 | UInt32 | UInt64 | Int8 | Int16 | Int32 | Int64 | Float16 | Float32 + | Float64, ) => true, // end numeric casts // temporal casts (Int32, Date32 | Date64 | Time32(_)) => true, (Date32, Int32 | Int64) => true, - (Time32(_), Int32) => true, + (Time32(_), Int32 | Int64) => true, (Int64, Date64 | Date32 | Time64(_)) => true, (Date64, Int64 | Int32) => true, (Time64(_), Int64) => true, @@ -342,7 +392,7 @@ where false => array.try_unary::<_, D, _>(|v| { v.as_() .div_checked(scale_factor) - .and_then(|v| D::validate_decimal_precision(v, precision).map(|_| v)) + .and_then(|v| D::validate_decimal_precision(v, precision, scale).map(|_| v)) })?, } } else { @@ -356,7 +406,7 @@ where false => array.try_unary::<_, D, _>(|v| { v.as_() .mul_checked(scale_factor) - .and_then(|v| D::validate_decimal_precision(v, precision).map(|_| v)) + .and_then(|v| D::validate_decimal_precision(v, precision, scale).map(|_| v)) })?, } }; @@ -603,12 +653,28 @@ fn timestamp_to_date32( /// * Temporal to/from backing Primitive: zero-copy with data type change /// * `Float32/Float64` to `Decimal(precision, scale)` rounds to the `scale` decimals /// (i.e. casting `6.4999` to `Decimal(10, 1)` becomes `6.5`). +/// * `Decimal` to `Float32/Float64` is lossy and values outside the representable +/// range become `INFINITY` or `-INFINITY` without error. /// /// Unsupported Casts (check with `can_cast_types` before calling): /// * To or from `StructArray` /// * `List` to `Primitive` /// * `Interval` and `Duration` /// +/// # Durations and Intervals +/// +/// Casting integer types directly to interval types such as +/// [`IntervalMonthDayNano`] is not supported because the meaning of the integer +/// is ambiguous. For example, the integer could represent either nanoseconds +/// or months. +/// +/// To cast an integer type to an interval type, first convert to a Duration +/// type, and then cast that to the desired interval type. +/// +/// For example, to convert an `Int64` representing nanoseconds to an +/// `IntervalMonthDayNano` you would first convert the `Int64` to a +/// `DurationNanoseconds`, and then cast that to `IntervalMonthDayNano`. +/// /// # Timestamps and Timezones /// /// Timestamps are stored with an optional timezone in Arrow. @@ -705,40 +771,38 @@ pub fn cast_with_options( return Ok(make_array(array.to_data())); } match (from_type, to_type) { - ( - Null, - Boolean - | Int8 - | UInt8 - | Int16 - | UInt16 - | Int32 - | UInt32 - | Float32 - | Date32 - | Time32(_) - | Int64 - | UInt64 - | Float64 - | Date64 - | Timestamp(_, _) - | Time64(_) - | Duration(_) - | Interval(_) - | FixedSizeBinary(_) - | Binary - | Utf8 - | LargeBinary - | LargeUtf8 - | BinaryView - | Utf8View - | List(_) - | LargeList(_) - | FixedSizeList(_, _) - | Struct(_) - | Map(_, _) - | Dictionary(_, _), - ) => Ok(new_null_array(to_type, array.len())), + (Null, _) => Ok(new_null_array(to_type, array.len())), + (RunEndEncoded(index_type, _), _) => match index_type.data_type() { + Int16 => run_end_encoded_cast::(array, to_type, cast_options), + Int32 => run_end_encoded_cast::(array, to_type, cast_options), + Int64 => run_end_encoded_cast::(array, to_type, cast_options), + _ => Err(ArrowError::CastError(format!( + "Casting from run end encoded type {from_type:?} to {to_type:?} not supported", + ))), + }, + (_, RunEndEncoded(index_type, value_type)) => { + let array_ref = make_array(array.to_data()); + match index_type.data_type() { + Int16 => cast_to_run_end_encoded::( + &array_ref, + value_type.data_type(), + cast_options, + ), + Int32 => cast_to_run_end_encoded::( + &array_ref, + value_type.data_type(), + cast_options, + ), + Int64 => cast_to_run_end_encoded::( + &array_ref, + value_type.data_type(), + cast_options, + ), + _ => Err(ArrowError::CastError(format!( + "Casting from type {from_type:?} to run end encoded type {to_type:?} not supported", + ))), + } + } (Dictionary(index_type, _), _) => match **index_type { Int8 => dictionary_cast::(array, to_type, cast_options), Int16 => dictionary_cast::(array, to_type, cast_options), @@ -749,7 +813,7 @@ pub fn cast_with_options( UInt32 => dictionary_cast::(array, to_type, cast_options), UInt64 => dictionary_cast::(array, to_type, cast_options), _ => Err(ArrowError::CastError(format!( - "Casting from dictionary type {from_type:?} to {to_type:?} not supported", + "Casting from dictionary type {from_type} to {to_type} not supported", ))), }, (_, Dictionary(index_type, value_type)) => match **index_type { @@ -762,7 +826,7 @@ pub fn cast_with_options( UInt32 => cast_to_dictionary::(array, value_type, cast_options), UInt64 => cast_to_dictionary::(array, value_type, cast_options), _ => Err(ArrowError::CastError(format!( - "Casting from type {from_type:?} to dictionary type {to_type:?} not supported", + "Casting from type {from_type} to dictionary type {to_type} not supported", ))), }, (List(_), List(to)) => cast_list_values::(array, to, cast_options), @@ -777,6 +841,18 @@ pub fn cast_with_options( let array = array.as_list::(); cast_list_to_fixed_size_list::(array, field, *size, cast_options) } + (ListView(_), List(list_to)) => cast_list_view_to_list::(array, list_to, cast_options), + (LargeListView(_), LargeList(list_to)) => { + cast_list_view_to_list::(array, list_to, cast_options) + } + (ListView(_), LargeListView(list_to)) => { + cast_list_view::(array, list_to, cast_options) + } + (LargeListView(_), ListView(list_to)) => { + cast_list_view::(array, list_to, cast_options) + } + (List(_), ListView(_)) => cast_list_to_list_view::(array), + (LargeList(_), LargeListView(_)) => cast_list_to_list_view::(array), (List(_) | LargeList(_), _) => match to_type { Utf8 => value_to_string::(array, cast_options), LargeUtf8 => value_to_string::(array, cast_options), @@ -819,9 +895,9 @@ pub fn cast_with_options( array.nulls().cloned(), )?)) } - (_, List(ref to)) => cast_values_to_list::(array, to, cast_options), - (_, LargeList(ref to)) => cast_values_to_list::(array, to, cast_options), - (_, FixedSizeList(ref to, size)) if *size == 1 => { + (_, List(to)) => cast_values_to_list::(array, to, cast_options), + (_, LargeList(to)) => cast_values_to_list::(array, to, cast_options), + (_, FixedSizeList(to, size)) if *size == 1 => { cast_values_to_fixed_size_list(array, to, *size, cast_options) } (FixedSizeList(_, size), _) if *size == 1 => { @@ -831,6 +907,26 @@ pub fn cast_with_options( cast_map_values(array.as_map(), to_type, cast_options, ordered1.to_owned()) } // Decimal to decimal, same width + (Decimal32(p1, s1), Decimal32(p2, s2)) => { + cast_decimal_to_decimal_same_type::( + array.as_primitive(), + *p1, + *s1, + *p2, + *s2, + cast_options, + ) + } + (Decimal64(p1, s1), Decimal64(p2, s2)) => { + cast_decimal_to_decimal_same_type::( + array.as_primitive(), + *p1, + *s1, + *p2, + *s2, + cast_options, + ) + } (Decimal128(p1, s1), Decimal128(p2, s2)) => { cast_decimal_to_decimal_same_type::( array.as_primitive(), @@ -852,6 +948,86 @@ pub fn cast_with_options( ) } // Decimal to decimal, different width + (Decimal32(p1, s1), Decimal64(p2, s2)) => { + cast_decimal_to_decimal::( + array.as_primitive(), + *p1, + *s1, + *p2, + *s2, + cast_options, + ) + } + (Decimal32(p1, s1), Decimal128(p2, s2)) => { + cast_decimal_to_decimal::( + array.as_primitive(), + *p1, + *s1, + *p2, + *s2, + cast_options, + ) + } + (Decimal32(p1, s1), Decimal256(p2, s2)) => { + cast_decimal_to_decimal::( + array.as_primitive(), + *p1, + *s1, + *p2, + *s2, + cast_options, + ) + } + (Decimal64(p1, s1), Decimal32(p2, s2)) => { + cast_decimal_to_decimal::( + array.as_primitive(), + *p1, + *s1, + *p2, + *s2, + cast_options, + ) + } + (Decimal64(p1, s1), Decimal128(p2, s2)) => { + cast_decimal_to_decimal::( + array.as_primitive(), + *p1, + *s1, + *p2, + *s2, + cast_options, + ) + } + (Decimal64(p1, s1), Decimal256(p2, s2)) => { + cast_decimal_to_decimal::( + array.as_primitive(), + *p1, + *s1, + *p2, + *s2, + cast_options, + ) + } + (Decimal128(p1, s1), Decimal32(p2, s2)) => { + cast_decimal_to_decimal::( + array.as_primitive(), + *p1, + *s1, + *p2, + *s2, + cast_options, + ) + } + (Decimal128(p1, s1), Decimal64(p2, s2)) => { + cast_decimal_to_decimal::( + array.as_primitive(), + *p1, + *s1, + *p2, + *s2, + cast_options, + ) + } (Decimal128(p1, s1), Decimal256(p2, s2)) => { cast_decimal_to_decimal::( array.as_primitive(), @@ -862,6 +1038,26 @@ pub fn cast_with_options( cast_options, ) } + (Decimal256(p1, s1), Decimal32(p2, s2)) => { + cast_decimal_to_decimal::( + array.as_primitive(), + *p1, + *s1, + *p2, + *s2, + cast_options, + ) + } + (Decimal256(p1, s1), Decimal64(p2, s2)) => { + cast_decimal_to_decimal::( + array.as_primitive(), + *p1, + *s1, + *p2, + *s2, + cast_options, + ) + } (Decimal256(p1, s1), Decimal128(p2, s2)) => { cast_decimal_to_decimal::( array.as_primitive(), @@ -873,6 +1069,28 @@ pub fn cast_with_options( ) } // Decimal to non-decimal + (Decimal32(_, scale), _) if !to_type.is_temporal() => { + cast_from_decimal::( + array, + 10_i32, + scale, + from_type, + to_type, + |x: i32| x as f64, + cast_options, + ) + } + (Decimal64(_, scale), _) if !to_type.is_temporal() => { + cast_from_decimal::( + array, + 10_i64, + scale, + from_type, + to_type, + |x: i64| x as f64, + cast_options, + ) + } (Decimal128(_, scale), _) if !to_type.is_temporal() => { cast_from_decimal::( array, @@ -891,11 +1109,33 @@ pub fn cast_with_options( scale, from_type, to_type, - |x: i256| x.to_f64().unwrap(), + |x: i256| x.to_f64().expect("All i256 values fit in f64"), cast_options, ) } // Non-decimal to decimal + (_, Decimal32(precision, scale)) if !from_type.is_temporal() => { + cast_to_decimal::( + array, + 10_i32, + precision, + scale, + from_type, + to_type, + cast_options, + ) + } + (_, Decimal64(precision, scale)) if !from_type.is_temporal() => { + cast_to_decimal::( + array, + 10_i64, + precision, + scale, + from_type, + to_type, + cast_options, + ) + } (_, Decimal128(precision, scale)) if !from_type.is_temporal() => { cast_to_decimal::( array, @@ -918,22 +1158,17 @@ pub fn cast_with_options( cast_options, ) } - (Struct(_), Struct(to_fields)) => { - let array = array.as_struct(); - let fields = array - .columns() - .iter() - .zip(to_fields.iter()) - .map(|(l, field)| cast_with_options(l, field.data_type(), cast_options)) - .collect::, ArrowError>>()?; - let array = StructArray::try_new(to_fields.clone(), fields, array.nulls().cloned())?; - Ok(Arc::new(array) as ArrayRef) - } + (Struct(from_fields), Struct(to_fields)) => cast_struct_to_struct( + array.as_struct(), + from_fields.clone(), + to_fields.clone(), + cast_options, + ), (Struct(_), _) => Err(ArrowError::CastError(format!( - "Casting from {from_type:?} to {to_type:?} not supported" + "Casting from {from_type} to {to_type} not supported" ))), (_, Struct(_)) => Err(ArrowError::CastError(format!( - "Casting from {from_type:?} to {to_type:?} not supported" + "Casting from {from_type} to {to_type} not supported" ))), (_, Boolean) => match from_type { UInt8 => cast_numeric_to_bool::(array), @@ -951,7 +1186,7 @@ pub fn cast_with_options( Utf8 => cast_utf8_to_boolean::(array, cast_options), LargeUtf8 => cast_utf8_to_boolean::(array, cast_options), _ => Err(ArrowError::CastError(format!( - "Casting from {from_type:?} to {to_type:?} not supported", + "Casting from {from_type} to {to_type} not supported", ))), }, (Boolean, _) => match to_type { @@ -970,7 +1205,7 @@ pub fn cast_with_options( Utf8 => value_to_string::(array, cast_options), LargeUtf8 => value_to_string::(array, cast_options), _ => Err(ArrowError::CastError(format!( - "Casting from {from_type:?} to {to_type:?} not supported", + "Casting from {from_type} to {to_type} not supported", ))), }, (Utf8, _) => match to_type { @@ -1032,7 +1267,7 @@ pub fn cast_with_options( cast_string_to_month_day_nano_interval::(array, cast_options) } _ => Err(ArrowError::CastError(format!( - "Casting from {from_type:?} to {to_type:?} not supported", + "Casting from {from_type} to {to_type} not supported", ))), }, (Utf8View, _) => match to_type { @@ -1083,7 +1318,7 @@ pub fn cast_with_options( cast_view_to_month_day_nano_interval(array, cast_options) } _ => Err(ArrowError::CastError(format!( - "Casting from {from_type:?} to {to_type:?} not supported", + "Casting from {from_type} to {to_type} not supported", ))), }, (LargeUtf8, _) => match to_type { @@ -1149,7 +1384,7 @@ pub fn cast_with_options( cast_string_to_month_day_nano_interval::(array, cast_options) } _ => Err(ArrowError::CastError(format!( - "Casting from {from_type:?} to {to_type:?} not supported", + "Casting from {from_type} to {to_type} not supported", ))), }, (Binary, _) => match to_type { @@ -1167,7 +1402,7 @@ pub fn cast_with_options( cast_binary_to_string::(array, cast_options)?.as_string::(), ))), _ => Err(ArrowError::CastError(format!( - "Casting from {from_type:?} to {to_type:?} not supported", + "Casting from {from_type} to {to_type} not supported", ))), }, (LargeBinary, _) => match to_type { @@ -1186,7 +1421,7 @@ pub fn cast_with_options( Ok(Arc::new(StringViewArray::from(array.as_string::()))) } _ => Err(ArrowError::CastError(format!( - "Casting from {from_type:?} to {to_type:?} not supported", + "Casting from {from_type} to {to_type} not supported", ))), }, (FixedSizeBinary(size), _) => match to_type { @@ -1194,7 +1429,7 @@ pub fn cast_with_options( LargeBinary => cast_fixed_size_binary_to_binary::(array, *size), BinaryView => cast_fixed_size_binary_to_binary_view(array, *size), _ => Err(ArrowError::CastError(format!( - "Casting from {from_type:?} to {to_type:?} not supported", + "Casting from {from_type} to {to_type} not supported", ))), }, (BinaryView, Binary) => cast_view_to_byte::>(array), @@ -1209,11 +1444,9 @@ pub fn cast_with_options( let binary_arr = cast_view_to_byte::>(array)?; cast_binary_to_string::(&binary_arr, cast_options) } - (BinaryView, Utf8View) => { - Ok(Arc::new(array.as_binary_view().clone().to_string_view()?) as ArrayRef) - } + (BinaryView, Utf8View) => cast_binary_view_to_string_view(array, cast_options), (BinaryView, _) => Err(ArrowError::CastError(format!( - "Casting from {from_type:?} to {to_type:?} not supported", + "Casting from {from_type} to {to_type} not supported", ))), (from_type, Utf8View) if from_type.is_primitive() => { value_to_string_view(array, cast_options) @@ -1395,6 +1628,16 @@ pub fn cast_with_options( (Time32(TimeUnit::Millisecond), Int32) => { cast_reinterpret_arrays::(array) } + (Time32(TimeUnit::Second), Int64) => cast_with_options( + &cast_with_options(array, &Int32, cast_options)?, + &Int64, + cast_options, + ), + (Time32(TimeUnit::Millisecond), Int64) => cast_with_options( + &cast_with_options(array, &Int32, cast_options)?, + &Int64, + cast_options, + ), (Int64, Date64) => cast_reinterpret_arrays::(array), (Int64, Date32) => cast_with_options( &cast_with_options(array, &Int32, cast_options)?, @@ -1947,11 +2190,79 @@ pub fn cast_with_options( cast_reinterpret_arrays::(array) } (_, _) => Err(ArrowError::CastError(format!( - "Casting from {from_type:?} to {to_type:?} not supported", + "Casting from {from_type} to {to_type} not supported", ))), } } +fn cast_struct_to_struct( + array: &StructArray, + from_fields: Fields, + to_fields: Fields, + cast_options: &CastOptions, +) -> Result { + // Fast path: if field names are in the same order, we can just zip and cast + let fields_match_order = from_fields.len() == to_fields.len() + && from_fields + .iter() + .zip(to_fields.iter()) + .all(|(f1, f2)| f1.name() == f2.name()); + + let fields = if fields_match_order { + // Fast path: cast columns in order if their names match + cast_struct_fields_in_order(array, to_fields.clone(), cast_options)? + } else { + let all_fields_match_by_name = to_fields.iter().all(|to_field| { + from_fields + .iter() + .any(|from_field| from_field.name() == to_field.name()) + }); + + if all_fields_match_by_name { + // Slow path: match fields by name and reorder + cast_struct_fields_by_name(array, from_fields.clone(), to_fields.clone(), cast_options)? + } else { + // Fallback: cast field by field in order + cast_struct_fields_in_order(array, to_fields.clone(), cast_options)? + } + }; + + let array = StructArray::try_new(to_fields.clone(), fields, array.nulls().cloned())?; + Ok(Arc::new(array) as ArrayRef) +} + +fn cast_struct_fields_by_name( + array: &StructArray, + from_fields: Fields, + to_fields: Fields, + cast_options: &CastOptions, +) -> Result, ArrowError> { + to_fields + .iter() + .map(|to_field| { + let from_field_idx = from_fields + .iter() + .position(|from_field| from_field.name() == to_field.name()) + .unwrap(); // safe because we checked above + let column = array.column(from_field_idx); + cast_with_options(column, to_field.data_type(), cast_options) + }) + .collect::, ArrowError>>() +} + +fn cast_struct_fields_in_order( + array: &StructArray, + to_fields: Fields, + cast_options: &CastOptions, +) -> Result, ArrowError> { + array + .columns() + .iter() + .zip(to_fields.iter()) + .map(|(l, field)| cast_with_options(l, field.data_type(), cast_options)) + .collect::, ArrowError>>() +} + fn cast_from_decimal( array: &dyn Array, base: D::Native, @@ -1988,7 +2299,7 @@ where LargeUtf8 => value_to_string::(array, cast_options), Null => Ok(new_null_array(to_type, array.len())), _ => Err(ArrowError::CastError(format!( - "Casting from {from_type:?} to {to_type:?} not supported" + "Casting from {from_type} to {to_type} not supported" ))), } } @@ -2005,14 +2316,14 @@ fn cast_to_decimal( where D: DecimalType + ArrowPrimitiveType, M: ArrowNativeTypeOp + DecimalCast, - u8: num::traits::AsPrimitive, - u16: num::traits::AsPrimitive, - u32: num::traits::AsPrimitive, - u64: num::traits::AsPrimitive, - i8: num::traits::AsPrimitive, - i16: num::traits::AsPrimitive, - i32: num::traits::AsPrimitive, - i64: num::traits::AsPrimitive, + u8: num_traits::AsPrimitive, + u16: num_traits::AsPrimitive, + u32: num_traits::AsPrimitive, + u64: num_traits::AsPrimitive, + i8: num_traits::AsPrimitive, + i16: num_traits::AsPrimitive, + i32: num_traits::AsPrimitive, + i64: num_traits::AsPrimitive, { use DataType::*; // cast data to decimal @@ -2091,7 +2402,7 @@ where LargeUtf8 => cast_string_to_decimal::(array, *precision, *scale, cast_options), Null => Ok(new_null_array(to_type, array.len())), _ => Err(ArrowError::CastError(format!( - "Casting from {from_type:?} to {to_type:?} not supported" + "Casting from {from_type} to {to_type} not supported" ))), } } @@ -2140,7 +2451,7 @@ where R::Native: NumCast, { from.try_unary(|value| { - num::cast::cast::(value).ok_or_else(|| { + num_traits::cast::cast::(value).ok_or_else(|| { ArrowError::CastError(format!( "Can't cast value {:?} to type {}", value, @@ -2159,7 +2470,7 @@ where T::Native: NumCast, R::Native: NumCast, { - from.unary_opt::<_, R>(num::cast::cast::) + from.unary_opt::<_, R>(num_traits::cast::cast::) } fn cast_numeric_to_binary( @@ -2167,12 +2478,12 @@ fn cast_numeric_to_binary( ) -> Result { let array = array.as_primitive::(); let size = std::mem::size_of::(); - let offsets = OffsetBuffer::from_lengths(std::iter::repeat(size).take(array.len())); - Ok(Arc::new(GenericBinaryArray::::new( + let offsets = OffsetBuffer::from_repeated_length(size, array.len()); + Ok(Arc::new(GenericBinaryArray::::try_new( offsets, array.values().inner().clone(), array.nulls().cloned(), - ))) + )?)) } fn adjust_timestamp_to_timezone( @@ -2235,7 +2546,7 @@ fn cast_bool_to_numeric( ) -> Result where TO: ArrowPrimitiveType, - TO::Native: num::cast::NumCast, + TO::Native: num_traits::cast::NumCast, { Ok(Arc::new(bool_to_numeric_cast::( from.as_any().downcast_ref::().unwrap(), @@ -2246,14 +2557,14 @@ where fn bool_to_numeric_cast(from: &BooleanArray, _cast_options: &CastOptions) -> PrimitiveArray where T: ArrowPrimitiveType, - T::Native: num::NumCast, + T::Native: num_traits::NumCast, { let iter = (0..from.len()).map(|i| { if from.is_null(i) { None } else if from.value(i) { // a workaround to cast a primitive to T::Native, infallible - num::cast::cast(1) + num_traits::cast::cast(1) } else { Some(T::default_value()) } @@ -2426,9 +2737,14 @@ where #[cfg(test)] mod tests { use super::*; + use DataType::*; + use arrow_array::{Int64Array, RunArray, StringArray}; use arrow_buffer::{Buffer, IntervalDayTime, NullBuffer}; + use arrow_buffer::{ScalarBuffer, i256}; + use arrow_schema::{DataType, Field}; use chrono::NaiveDate; use half::f16; + use std::sync::Arc; #[derive(Clone)] struct DecimalCastTestConfig { @@ -2507,33 +2823,55 @@ mod tests { } } - fn create_decimal128_array( - array: Vec>, + fn create_decimal32_array( + array: Vec>, precision: u8, scale: i8, - ) -> Result { + ) -> Result { array .into_iter() - .collect::() + .collect::() .with_precision_and_scale(precision, scale) } - fn create_decimal256_array( - array: Vec>, + fn create_decimal64_array( + array: Vec>, precision: u8, scale: i8, - ) -> Result { + ) -> Result { array .into_iter() - .collect::() + .collect::() .with_precision_and_scale(precision, scale) } - #[test] - #[cfg(not(feature = "force_validate"))] - #[should_panic( - expected = "Cannot cast to Decimal128(20, 3). Overflowing on 57896044618658097711785492504343953926634992332820282019728792003956564819967" - )] + fn create_decimal128_array( + array: Vec>, + precision: u8, + scale: i8, + ) -> Result { + array + .into_iter() + .collect::() + .with_precision_and_scale(precision, scale) + } + + fn create_decimal256_array( + array: Vec>, + precision: u8, + scale: i8, + ) -> Result { + array + .into_iter() + .collect::() + .with_precision_and_scale(precision, scale) + } + + #[test] + #[cfg(not(feature = "force_validate"))] + #[should_panic( + expected = "Cannot cast to Decimal128(20, 3). Overflowing on 57896044618658097711785492504343953926634992332820282019728792003956564819967" + )] fn test_cast_decimal_to_decimal_round_with_error() { // decimal256 to decimal128 overflow let array = vec![ @@ -2655,8 +2993,81 @@ mod tests { ); } + #[test] + fn test_cast_decimal32_to_decimal32() { + // test changing precision + let input_type = DataType::Decimal32(9, 3); + let output_type = DataType::Decimal32(9, 4); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; + let array = create_decimal32_array(array, 9, 3).unwrap(); + generate_cast_test_case!( + &array, + Decimal32Array, + &output_type, + vec![ + Some(11234560_i32), + Some(21234560_i32), + Some(31234560_i32), + None + ] + ); + // negative test + let array = vec![Some(123456), None]; + let array = create_decimal32_array(array, 9, 0).unwrap(); + let result_safe = cast(&array, &DataType::Decimal32(2, 2)); + assert!(result_safe.is_ok()); + let options = CastOptions { + safe: false, + ..Default::default() + }; + + let result_unsafe = cast_with_options(&array, &DataType::Decimal32(2, 2), &options); + assert_eq!( + "Invalid argument error: 123456.00 is too large to store in a Decimal32 of precision 2. Max is 0.99", + result_unsafe.unwrap_err().to_string() + ); + } + + #[test] + fn test_cast_decimal64_to_decimal64() { + // test changing precision + let input_type = DataType::Decimal64(17, 3); + let output_type = DataType::Decimal64(17, 4); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; + let array = create_decimal64_array(array, 17, 3).unwrap(); + generate_cast_test_case!( + &array, + Decimal64Array, + &output_type, + vec![ + Some(11234560_i64), + Some(21234560_i64), + Some(31234560_i64), + None + ] + ); + // negative test + let array = vec![Some(123456), None]; + let array = create_decimal64_array(array, 9, 0).unwrap(); + let result_safe = cast(&array, &DataType::Decimal64(2, 2)); + assert!(result_safe.is_ok()); + let options = CastOptions { + safe: false, + ..Default::default() + }; + + let result_unsafe = cast_with_options(&array, &DataType::Decimal64(2, 2), &options); + assert_eq!( + "Invalid argument error: 123456.00 is too large to store in a Decimal64 of precision 2. Max is 0.99", + result_unsafe.unwrap_err().to_string() + ); + } + #[test] fn test_cast_decimal128_to_decimal128() { + // test changing precision let input_type = DataType::Decimal128(20, 3); let output_type = DataType::Decimal128(20, 4); assert!(can_cast_types(&input_type, &output_type)); @@ -2684,8 +3095,42 @@ mod tests { }; let result_unsafe = cast_with_options(&array, &DataType::Decimal128(2, 2), &options); - assert_eq!("Invalid argument error: 12345600 is too large to store in a Decimal128 of precision 2. Max is 99", - result_unsafe.unwrap_err().to_string()); + assert_eq!( + "Invalid argument error: 123456.00 is too large to store in a Decimal128 of precision 2. Max is 0.99", + result_unsafe.unwrap_err().to_string() + ); + } + + #[test] + fn test_cast_decimal32_to_decimal32_dict() { + let p = 9; + let s = 3; + let input_type = DataType::Decimal32(p, s); + let output_type = DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Decimal32(p, s)), + ); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; + let array = create_decimal32_array(array, p, s).unwrap(); + let cast_array = cast_with_options(&array, &output_type, &CastOptions::default()).unwrap(); + assert_eq!(cast_array.data_type(), &output_type); + } + + #[test] + fn test_cast_decimal64_to_decimal64_dict() { + let p = 15; + let s = 3; + let input_type = DataType::Decimal64(p, s); + let output_type = DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(DataType::Decimal64(p, s)), + ); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; + let array = create_decimal64_array(array, p, s).unwrap(); + let cast_array = cast_with_options(&array, &output_type, &CastOptions::default()).unwrap(); + assert_eq!(cast_array.data_type(), &output_type); } #[test] @@ -2720,6 +3165,136 @@ mod tests { assert_eq!(cast_array.data_type(), &output_type); } + #[test] + fn test_cast_decimal32_to_decimal32_overflow() { + let input_type = DataType::Decimal32(9, 3); + let output_type = DataType::Decimal32(9, 9); + assert!(can_cast_types(&input_type, &output_type)); + + let array = vec![Some(i32::MAX)]; + let array = create_decimal32_array(array, 9, 3).unwrap(); + let result = cast_with_options( + &array, + &output_type, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!( + "Cast error: Cannot cast to Decimal32(9, 9). Overflowing on 2147483647", + result.unwrap_err().to_string() + ); + } + + #[test] + fn test_cast_decimal32_to_decimal32_large_scale_reduction() { + let array = vec![Some(-999999999), Some(0), Some(999999999), None]; + let array = create_decimal32_array(array, 9, 3).unwrap(); + + // Divide out all digits of precision -- rounding could still produce +/- 1 + let output_type = DataType::Decimal32(9, -6); + assert!(can_cast_types(array.data_type(), &output_type)); + generate_cast_test_case!( + &array, + Decimal32Array, + &output_type, + vec![Some(-1), Some(0), Some(1), None] + ); + + // Divide out more digits than we have precision -- all-zero result + let output_type = DataType::Decimal32(9, -7); + assert!(can_cast_types(array.data_type(), &output_type)); + generate_cast_test_case!( + &array, + Decimal32Array, + &output_type, + vec![Some(0), Some(0), Some(0), None] + ); + } + + #[test] + fn test_cast_decimal64_to_decimal64_overflow() { + let input_type = DataType::Decimal64(18, 3); + let output_type = DataType::Decimal64(18, 18); + assert!(can_cast_types(&input_type, &output_type)); + + let array = vec![Some(i64::MAX)]; + let array = create_decimal64_array(array, 18, 3).unwrap(); + let result = cast_with_options( + &array, + &output_type, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert_eq!( + "Cast error: Cannot cast to Decimal64(18, 18). Overflowing on 9223372036854775807", + result.unwrap_err().to_string() + ); + } + + #[test] + fn test_cast_decimal64_to_decimal64_large_scale_reduction() { + let array = vec![ + Some(-999999999999999999), + Some(0), + Some(999999999999999999), + None, + ]; + let array = create_decimal64_array(array, 18, 3).unwrap(); + + // Divide out all digits of precision -- rounding could still produce +/- 1 + let output_type = DataType::Decimal64(18, -15); + assert!(can_cast_types(array.data_type(), &output_type)); + generate_cast_test_case!( + &array, + Decimal64Array, + &output_type, + vec![Some(-1), Some(0), Some(1), None] + ); + + // Divide out more digits than we have precision -- all-zero result + let output_type = DataType::Decimal64(18, -16); + assert!(can_cast_types(array.data_type(), &output_type)); + generate_cast_test_case!( + &array, + Decimal64Array, + &output_type, + vec![Some(0), Some(0), Some(0), None] + ); + } + + #[test] + fn test_cast_floating_to_decimals() { + for output_type in [ + DataType::Decimal32(9, 3), + DataType::Decimal64(9, 3), + DataType::Decimal128(9, 3), + DataType::Decimal256(9, 3), + ] { + let input_type = DataType::Float64; + assert!(can_cast_types(&input_type, &output_type)); + + let array = vec![Some(1.1_f64)]; + let array = PrimitiveArray::::from_iter(array); + let result = cast_with_options( + &array, + &output_type, + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ); + assert!( + result.is_ok(), + "Failed to cast to {output_type} with: {}", + result.unwrap_err() + ); + } + } + #[test] fn test_cast_decimal128_to_decimal128_overflow() { let input_type = DataType::Decimal128(38, 3); @@ -2736,8 +3311,10 @@ mod tests { format_options: FormatOptions::default(), }, ); - assert_eq!("Cast error: Cannot cast to Decimal128(38, 38). Overflowing on 170141183460469231731687303715884105727", - result.unwrap_err().to_string()); + assert_eq!( + "Cast error: Cannot cast to Decimal128(38, 38). Overflowing on 170141183460469231731687303715884105727", + result.unwrap_err().to_string() + ); } #[test] @@ -2756,10 +3333,50 @@ mod tests { format_options: FormatOptions::default(), }, ); - assert_eq!("Cast error: Cannot cast to Decimal256(76, 76). Overflowing on 170141183460469231731687303715884105727", - result.unwrap_err().to_string()); + assert_eq!( + "Cast error: Cannot cast to Decimal256(76, 76). Overflowing on 170141183460469231731687303715884105727", + result.unwrap_err().to_string() + ); } + #[test] + fn test_cast_decimal32_to_decimal256() { + let input_type = DataType::Decimal32(8, 3); + let output_type = DataType::Decimal256(20, 4); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; + let array = create_decimal32_array(array, 8, 3).unwrap(); + generate_cast_test_case!( + &array, + Decimal256Array, + &output_type, + vec![ + Some(i256::from_i128(11234560_i128)), + Some(i256::from_i128(21234560_i128)), + Some(i256::from_i128(31234560_i128)), + None + ] + ); + } + #[test] + fn test_cast_decimal64_to_decimal256() { + let input_type = DataType::Decimal64(12, 3); + let output_type = DataType::Decimal256(20, 4); + assert!(can_cast_types(&input_type, &output_type)); + let array = vec![Some(1123456), Some(2123456), Some(3123456), None]; + let array = create_decimal64_array(array, 12, 3).unwrap(); + generate_cast_test_case!( + &array, + Decimal256Array, + &output_type, + vec![ + Some(i256::from_i128(11234560_i128)), + Some(i256::from_i128(21234560_i128)), + Some(i256::from_i128(31234560_i128)), + None + ] + ); + } #[test] fn test_cast_decimal128_to_decimal256() { let input_type = DataType::Decimal128(20, 3); @@ -2795,8 +3412,10 @@ mod tests { format_options: FormatOptions::default(), }, ); - assert_eq!("Cast error: Cannot cast to Decimal128(38, 7). Overflowing on 170141183460469231731687303715884105727", - result.unwrap_err().to_string()); + assert_eq!( + "Cast error: Cannot cast to Decimal128(38, 7). Overflowing on 170141183460469231731687303715884105727", + result.unwrap_err().to_string() + ); } #[test] @@ -2814,8 +3433,10 @@ mod tests { format_options: FormatOptions::default(), }, ); - assert_eq!("Cast error: Cannot cast to Decimal256(76, 55). Overflowing on 170141183460469231731687303715884105727", - result.unwrap_err().to_string()); + assert_eq!( + "Cast error: Cannot cast to Decimal256(76, 55). Overflowing on 170141183460469231731687303715884105727", + result.unwrap_err().to_string() + ); } #[test] @@ -2956,6 +3577,22 @@ mod tests { ); } + #[test] + fn test_cast_decimal32_to_numeric() { + let value_array: Vec> = vec![Some(125), Some(225), Some(325), None, Some(525)]; + let array = create_decimal32_array(value_array, 8, 2).unwrap(); + + generate_decimal_to_numeric_cast_test_case(&array); + } + + #[test] + fn test_cast_decimal64_to_numeric() { + let value_array: Vec> = vec![Some(125), Some(225), Some(325), None, Some(525)]; + let array = create_decimal64_array(value_array, 8, 2).unwrap(); + + generate_decimal_to_numeric_cast_test_case(&array); + } + #[test] fn test_cast_decimal128_to_numeric() { let value_array: Vec> = vec![Some(125), Some(225), Some(325), None, Some(525)]; @@ -3861,9 +4498,11 @@ mod tests { match casted { Ok(_) => panic!("expected error"), Err(e) => { - assert!(e - .to_string() - .contains("Cast error: Cannot cast value 'invalid' to value of Boolean type")) + assert!( + e.to_string().contains( + "Cast error: Cannot cast value 'invalid' to value of Boolean type" + ) + ) } } } @@ -4075,26 +4714,16 @@ mod tests { #[test] fn test_cast_list_i32_to_list_u16() { - let value_data = Int32Array::from(vec![0, 0, 0, -1, -2, -1, 2, 100000000]).into_data(); - - let value_offsets = Buffer::from_slice_ref([0, 3, 6, 8]); - - // Construct a list array from the above two - // [[0,0,0], [-1, -2, -1], [2, 100000000]] - let list_data_type = DataType::List(Arc::new(Field::new_list_field(DataType::Int32, true))); - let list_data = ArrayData::builder(list_data_type) - .len(3) - .add_buffer(value_offsets) - .add_child_data(value_data) - .build() - .unwrap(); - let list_array = ListArray::from(list_data); + let values = vec![ + Some(vec![Some(0), Some(0), Some(0)]), + Some(vec![Some(-1), Some(-2), Some(-1)]), + Some(vec![Some(2), Some(100000000)]), + ]; + let list_array = ListArray::from_iter_primitive::(values); - let cast_array = cast( - &list_array, - &DataType::List(Arc::new(Field::new_list_field(DataType::UInt16, true))), - ) - .unwrap(); + let target_type = DataType::List(Arc::new(Field::new("item", DataType::UInt16, true))); + assert!(can_cast_types(list_array.data_type(), &target_type)); + let cast_array = cast(&list_array, &target_type).unwrap(); // For the ListArray itself, there are no null values (as there were no nulls when they went in) // @@ -4441,7 +5070,10 @@ mod tests { format_options: FormatOptions::default(), }; let err = cast_with_options(array, &to_type, &options).unwrap_err(); - assert_eq!(err.to_string(), "Cast error: Cannot cast string '08:08:61.091323414' to value of Time32(Second) type"); + assert_eq!( + err.to_string(), + "Cast error: Cannot cast string '08:08:61.091323414' to value of Time32(s) type" + ); } } @@ -4483,7 +5115,10 @@ mod tests { format_options: FormatOptions::default(), }; let err = cast_with_options(array, &to_type, &options).unwrap_err(); - assert_eq!(err.to_string(), "Cast error: Cannot cast string '08:08:61.091323414' to value of Time32(Millisecond) type"); + assert_eq!( + err.to_string(), + "Cast error: Cannot cast string '08:08:61.091323414' to value of Time32(ms) type" + ); } } @@ -4517,7 +5152,10 @@ mod tests { format_options: FormatOptions::default(), }; let err = cast_with_options(array, &to_type, &options).unwrap_err(); - assert_eq!(err.to_string(), "Cast error: Cannot cast string 'Not a valid time' to value of Time64(Microsecond) type"); + assert_eq!( + err.to_string(), + "Cast error: Cannot cast string 'Not a valid time' to value of Time64(µs) type" + ); } } @@ -4551,7 +5189,10 @@ mod tests { format_options: FormatOptions::default(), }; let err = cast_with_options(array, &to_type, &options).unwrap_err(); - assert_eq!(err.to_string(), "Cast error: Cannot cast string 'Not a valid time' to value of Time64(Nanosecond) type"); + assert_eq!( + err.to_string(), + "Cast error: Cannot cast string 'Not a valid time' to value of Time64(ns) type" + ); } } @@ -5339,28 +5980,9 @@ mod tests { assert!(c.is_null(2)); } - #[test] - fn test_cast_date32_to_string() { - let array = Date32Array::from(vec![10000, 17890]); - let b = cast(&array, &DataType::Utf8).unwrap(); - let c = b.as_any().downcast_ref::().unwrap(); - assert_eq!(&DataType::Utf8, c.data_type()); - assert_eq!("1997-05-19", c.value(0)); - assert_eq!("2018-12-25", c.value(1)); - } - - #[test] - fn test_cast_date64_to_string() { - let array = Date64Array::from(vec![10000 * 86400000, 17890 * 86400000]); - let b = cast(&array, &DataType::Utf8).unwrap(); - let c = b.as_any().downcast_ref::().unwrap(); - assert_eq!(&DataType::Utf8, c.data_type()); - assert_eq!("1997-05-19T00:00:00", c.value(0)); - assert_eq!("2018-12-25T00:00:00", c.value(1)); - } - - macro_rules! assert_cast_timestamp_to_string { + macro_rules! assert_cast { ($array:expr, $datatype:expr, $output_array_type: ty, $expected:expr) => {{ + assert!(can_cast_types($array.data_type(), &$datatype)); let out = cast(&$array, &$datatype).unwrap(); let actual = out .as_any() @@ -5371,6 +5993,7 @@ mod tests { assert_eq!(actual, $expected); }}; ($array:expr, $datatype:expr, $output_array_type: ty, $options:expr, $expected:expr) => {{ + assert!(can_cast_types($array.data_type(), &$datatype)); let out = cast_with_options(&$array, &$datatype, &$options).unwrap(); let actual = out .as_any() @@ -5382,6 +6005,44 @@ mod tests { }}; } + #[test] + fn test_cast_date32_to_string() { + let array = Date32Array::from(vec![Some(0), Some(10000), Some(13036), Some(17890), None]); + let expected = vec![ + Some("1970-01-01"), + Some("1997-05-19"), + Some("2005-09-10"), + Some("2018-12-25"), + None, + ]; + + assert_cast!(array, DataType::Utf8View, StringViewArray, expected); + assert_cast!(array, DataType::Utf8, StringArray, expected); + assert_cast!(array, DataType::LargeUtf8, LargeStringArray, expected); + } + + #[test] + fn test_cast_date64_to_string() { + let array = Date64Array::from(vec![ + Some(0), + Some(10000 * 86400000), + Some(13036 * 86400000), + Some(17890 * 86400000), + None, + ]); + let expected = vec![ + Some("1970-01-01T00:00:00"), + Some("1997-05-19T00:00:00"), + Some("2005-09-10T00:00:00"), + Some("2018-12-25T00:00:00"), + None, + ]; + + assert_cast!(array, DataType::Utf8View, StringViewArray, expected); + assert_cast!(array, DataType::Utf8, StringArray, expected); + assert_cast!(array, DataType::LargeUtf8, LargeStringArray, expected); + } + #[test] fn test_cast_date32_to_timestamp_and_timestamp_with_timezone() { let tz = "+0545"; // UTC + 0545 is Asia/Kathmandu @@ -5584,9 +6245,9 @@ mod tests { None, ]; - assert_cast_timestamp_to_string!(array, DataType::Utf8View, StringViewArray, expected); - assert_cast_timestamp_to_string!(array, DataType::Utf8, StringArray, expected); - assert_cast_timestamp_to_string!(array, DataType::LargeUtf8, LargeStringArray, expected); + assert_cast!(array, DataType::Utf8View, StringViewArray, expected); + assert_cast!(array, DataType::Utf8, StringArray, expected); + assert_cast!(array, DataType::LargeUtf8, LargeStringArray, expected); } #[test] @@ -5608,21 +6269,21 @@ mod tests { Some("2018-12-25 00:00:02.001000"), None, ]; - assert_cast_timestamp_to_string!( + assert_cast!( array_without_tz, DataType::Utf8View, StringViewArray, cast_options, expected ); - assert_cast_timestamp_to_string!( + assert_cast!( array_without_tz, DataType::Utf8, StringArray, cast_options, expected ); - assert_cast_timestamp_to_string!( + assert_cast!( array_without_tz, DataType::LargeUtf8, LargeStringArray, @@ -5638,21 +6299,21 @@ mod tests { Some("2018-12-25 05:45:02.001000"), None, ]; - assert_cast_timestamp_to_string!( + assert_cast!( array_with_tz, DataType::Utf8View, StringViewArray, cast_options, expected ); - assert_cast_timestamp_to_string!( + assert_cast!( array_with_tz, DataType::Utf8, StringArray, cast_options, expected ); - assert_cast_timestamp_to_string!( + assert_cast!( array_with_tz, DataType::LargeUtf8, LargeStringArray, @@ -5892,6 +6553,38 @@ mod tests { assert_eq!(string_view_array.as_ref(), &expect_string_view_array); } + #[test] + fn test_binary_view_to_string_view_with_invalid_utf8() { + let binary_view_array = BinaryViewArray::from_iter(vec![ + Some("valid".as_bytes()), + Some(&[0xff]), + Some("utf8".as_bytes()), + None, + ]); + + let strict_options = CastOptions { + safe: false, + ..Default::default() + }; + + assert!( + cast_with_options(&binary_view_array, &DataType::Utf8View, &strict_options).is_err() + ); + + let safe_options = CastOptions { + safe: true, + ..Default::default() + }; + + let string_view_array = + cast_with_options(&binary_view_array, &DataType::Utf8View, &safe_options).unwrap(); + assert_eq!(string_view_array.data_type(), &DataType::Utf8View); + + let values: Vec<_> = string_view_array.as_string_view().iter().collect(); + + assert_eq!(values, vec![Some("valid"), None, Some("utf8"), None]); + } + #[test] fn test_string_to_view() { _test_string_to_view::(); @@ -7192,8 +7885,6 @@ mod tests { #[test] fn test_cast_utf8_dict() { // FROM a dictionary with of Utf8 values - use DataType::*; - let mut builder = StringDictionaryBuilder::::new(); builder.append("one").unwrap(); builder.append_null(); @@ -7248,7 +7939,6 @@ mod tests { #[test] fn test_cast_dict_to_dict_bad_index_value_primitive() { - use DataType::*; // test converting from an array that has indexes of a type // that are out of bounds for a particular other kind of // index. @@ -7276,7 +7966,6 @@ mod tests { #[test] fn test_cast_dict_to_dict_bad_index_value_utf8() { - use DataType::*; // Same test as test_cast_dict_to_dict_bad_index_value but use // string values (and encode the expected behavior here); @@ -7305,8 +7994,6 @@ mod tests { #[test] fn test_cast_primitive_dict() { // FROM a dictionary with of INT32 values - use DataType::*; - let mut builder = PrimitiveDictionaryBuilder::::new(); builder.append(1).unwrap(); builder.append_null(); @@ -7327,8 +8014,6 @@ mod tests { #[test] fn test_cast_primitive_array_to_dict() { - use DataType::*; - let mut builder = PrimitiveBuilder::::new(); builder.append_value(1); builder.append_null(); @@ -7438,6 +8123,7 @@ mod tests { typed_test!(UInt32Array, UInt32, UInt32Type); typed_test!(UInt64Array, UInt64, UInt64Type); + typed_test!(Float16Array, Float16, Float16Type); typed_test!(Float32Array, Float32, Float32Type); typed_test!(Float64Array, Float64, Float64Type); @@ -7445,19 +8131,29 @@ mod tests { typed_test!(Date64Array, Date64, Date64Type); } - fn cast_from_null_to_other(data_type: &DataType) { + fn cast_from_null_to_other_base(data_type: &DataType, is_complex: bool) { // Cast from null to data_type - { - let array = new_null_array(&DataType::Null, 4); - assert_eq!(array.data_type(), &DataType::Null); - let cast_array = cast(&array, data_type).expect("cast failed"); - assert_eq!(cast_array.data_type(), data_type); - for i in 0..4 { + let array = new_null_array(&DataType::Null, 4); + assert_eq!(array.data_type(), &DataType::Null); + let cast_array = cast(&array, data_type).expect("cast failed"); + assert_eq!(cast_array.data_type(), data_type); + for i in 0..4 { + if is_complex { + assert!(cast_array.logical_nulls().unwrap().is_null(i)); + } else { assert!(cast_array.is_null(i)); } } } + fn cast_from_null_to_other(data_type: &DataType) { + cast_from_null_to_other_base(data_type, false); + } + + fn cast_from_null_to_other_complex(data_type: &DataType) { + cast_from_null_to_other_base(data_type, true); + } + #[test] fn test_cast_null_from_and_to_variable_sized() { cast_from_null_to_other(&DataType::Utf8); @@ -7501,6 +8197,23 @@ mod tests { // Cast null from and to struct let data_type = DataType::Struct(vec![Field::new("data", DataType::Int64, false)].into()); cast_from_null_to_other(&data_type); + + let target_type = DataType::ListView(Arc::new(Field::new("item", DataType::Int32, true))); + cast_from_null_to_other(&target_type); + + let target_type = + DataType::LargeListView(Arc::new(Field::new("item", DataType::Int32, true))); + cast_from_null_to_other(&target_type); + + let fields = UnionFields::from_fields(vec![Field::new("a", DataType::Int64, false)]); + let target_type = DataType::Union(fields, UnionMode::Sparse); + cast_from_null_to_other_complex(&target_type); + + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("item", DataType::Int32, true)), + Arc::new(Field::new("item", DataType::Int32, true)), + ); + cast_from_null_to_other_complex(&target_type); } /// Print the `DictionaryArray` `array` as a vector of strings @@ -7695,13 +8408,11 @@ mod tests { ); let list_array = cast(&array, expected.data_type()) - .unwrap_or_else(|_| panic!("Failed to cast {:?} to {:?}", array, expected)); + .unwrap_or_else(|_| panic!("Failed to cast {array:?} to {expected:?}")); assert_eq!( list_array.as_ref(), &expected, - "Incorrect result from casting {:?} to {:?}", - array, - expected + "Incorrect result from casting {array:?} to {expected:?}", ); } } @@ -7935,8 +8646,10 @@ mod tests { }, ); assert!(res.is_err()); - assert!(format!("{:?}", res) - .contains("Cannot cast to FixedSizeList(3): value at index 1 has length 2")); + assert!( + format!("{res:?}") + .contains("Cannot cast to FixedSizeList(3): value at index 1 has length 2") + ); // When safe=true (default), the cast will fill nulls for lists that are // too short and truncate lists that are too long. @@ -8026,7 +8739,7 @@ mod tests { }, ); assert!(res.is_err()); - assert!(format!("{:?}", res).contains("Can't cast value 2147483647 to type Int16")); + assert!(format!("{res:?}").contains("Can't cast value 2147483647 to type Int16")); } #[test] @@ -8166,8 +8879,12 @@ mod tests { let new_array_result = cast(&array, &new_type.clone()); assert!(!can_cast_types(array.data_type(), &new_type)); - assert!( - matches!(new_array_result, Err(ArrowError::CastError(t)) if t == r#"Casting from Map(Field { name: "entries", data_type: Struct([Field { name: "key", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "value", data_type: Utf8, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false) to Map(Field { name: "entries", data_type: Struct([Field { name: "key", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "value", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, true) not supported"#) + let Err(ArrowError::CastError(t)) = new_array_result else { + panic!(); + }; + assert_eq!( + t, + r#"Casting from Map("entries": non-null Struct("key": non-null Utf8, "value": Utf8), unsorted) to Map("entries": non-null Struct("key": non-null Utf8, "value": non-null Utf8), sorted) not supported"# ); } @@ -8213,8 +8930,12 @@ mod tests { let new_array_result = cast(&array, &new_type.clone()); assert!(!can_cast_types(array.data_type(), &new_type)); - assert!( - matches!(new_array_result, Err(ArrowError::CastError(t)) if t == r#"Casting from Map(Field { name: "entries", data_type: Struct([Field { name: "key", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "value", data_type: Interval(DayTime), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, false) to Map(Field { name: "entries", data_type: Struct([Field { name: "key", data_type: Utf8, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, Field { name: "value", data_type: Duration(Second), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }]), nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }, true) not supported"#) + let Err(ArrowError::CastError(t)) = new_array_result else { + panic!(); + }; + assert_eq!( + t, + r#"Casting from Map("entries": non-null Struct("key": non-null Utf8, "value": Interval(DayTime)), unsorted) to Map("entries": non-null Struct("key": non-null Utf8, "value": non-null Duration(s)), sorted) not supported"# ); } @@ -8563,7 +9284,7 @@ mod tests { }, ); let err = casted_array.unwrap_err().to_string(); - let expected_error = "Invalid argument error: 110 is too large to store in a Decimal128 of precision 2. Max is 99"; + let expected_error = "Invalid argument error: 1.10 is too large to store in a Decimal128 of precision 2. Max is 0.99"; assert!( err.contains(expected_error), "did not find expected error '{expected_error}' in actual error '{err}'" @@ -8594,11 +9315,8 @@ mod tests { }, ); let err = casted_array.unwrap_err().to_string(); - let expected_error = "Invalid argument error: 110 is too large to store in a Decimal256 of precision 2. Max is 99"; - assert!( - err.contains(expected_error), - "did not find expected error '{expected_error}' in actual error '{err}'" - ); + let expected_error = "Invalid argument error: 1.10 is too large to store in a Decimal256 of precision 2. Max is 0.99"; + assert_eq!(err, expected_error); } #[test] @@ -8662,6 +9380,28 @@ mod tests { "did not find expected error '{expected_error}' in actual error '{err}'" ); } + #[test] + fn test_cast_decimal256_to_f64_no_overflow() { + // Test casting i256::MAX: should produce a large finite positive value + let array = vec![Some(i256::MAX)]; + let array = create_decimal256_array(array, 76, 2).unwrap(); + let array = Arc::new(array) as ArrayRef; + + let result = cast(&array, &DataType::Float64).unwrap(); + let result = result.as_primitive::(); + assert!(result.value(0).is_finite()); + assert!(result.value(0) > 0.0); // Positive result + + // Test casting i256::MIN: should produce a large finite negative value + let array = vec![Some(i256::MIN)]; + let array = create_decimal256_array(array, 76, 2).unwrap(); + let array = Arc::new(array) as ArrayRef; + + let result = cast(&array, &DataType::Float64).unwrap(); + let result = result.as_primitive::(); + assert!(result.value(0).is_finite()); + assert!(result.value(0) < 0.0); // Negative result + } #[test] fn test_cast_decimal128_to_decimal128_negative_scale() { @@ -8691,6 +9431,15 @@ mod tests { assert_eq!("3123460", decimal_arr.value_as_string(2)); } + #[test] + fn decimal128_min_max_to_f64() { + // Ensure Decimal128 i128::MIN/MAX round-trip cast + let min128 = i128::MIN; + let max128 = i128::MAX; + assert_eq!(min128 as f64, min128 as f64); + assert_eq!(max128 as f64, max128 as f64); + } + #[test] fn test_cast_numeric_to_decimal128_negative() { let decimal_type = DataType::Decimal128(38, -1); @@ -9090,7 +9839,7 @@ mod tests { Some(array.value_as_string(i)) }; let actual = actual.as_ref().map(|s| s.as_ref()); - assert_eq!(*expected, actual, "Expected at position {}", i); + assert_eq!(*expected, actual, "Expected at position {i}"); } } @@ -9119,16 +9868,20 @@ mod tests { format_options: FormatOptions::default(), }; let casted_err = cast_with_options(&array, &output_type, &option).unwrap_err(); - assert!(casted_err - .to_string() - .contains("Cannot cast string '4.4.5' to value of Decimal128(38, 10) type")); + assert!( + casted_err + .to_string() + .contains("Cannot cast string '4.4.5' to value of Decimal128(38, 10) type") + ); let str_array = StringArray::from(vec![". 0.123"]); let array = Arc::new(str_array) as ArrayRef; let casted_err = cast_with_options(&array, &output_type, &option).unwrap_err(); - assert!(casted_err - .to_string() - .contains("Cannot cast string '. 0.123' to value of Decimal128(38, 10) type")); + assert!( + casted_err + .to_string() + .contains("Cannot cast string '. 0.123' to value of Decimal128(38, 10) type") + ); } fn test_cast_string_to_decimal128_overflow(overflow_array: ArrayRef) { @@ -9172,7 +9925,10 @@ mod tests { format_options: FormatOptions::default(), }, ); - assert_eq!("Invalid argument error: 100000000000 is too large to store in a Decimal128 of precision 10. Max is 9999999999", err.unwrap_err().to_string()); + assert_eq!( + "Invalid argument error: 1000.00000000 is too large to store in a Decimal128 of precision 10. Max is 99.99999999", + err.unwrap_err().to_string() + ); } #[test] @@ -9255,7 +10011,10 @@ mod tests { format_options: FormatOptions::default(), }, ); - assert_eq!("Invalid argument error: 100000000000 is too large to store in a Decimal256 of precision 10. Max is 9999999999", err.unwrap_err().to_string()); + assert_eq!( + "Invalid argument error: 1000.00000000 is too large to store in a Decimal256 of precision 10. Max is 99.99999999", + err.unwrap_err().to_string() + ); } #[test] @@ -9513,6 +10272,14 @@ mod tests { #[test] fn test_cast_decimal_to_string() { + assert!(can_cast_types( + &DataType::Decimal32(9, 4), + &DataType::Utf8View + )); + assert!(can_cast_types( + &DataType::Decimal64(16, 4), + &DataType::Utf8View + )); assert!(can_cast_types( &DataType::Decimal128(10, 4), &DataType::Utf8View @@ -9557,7 +10324,7 @@ mod tests { } } - let array128: Vec> = vec![ + let array32: Vec> = vec![ Some(1123454), Some(2123456), Some(-3123453), @@ -9568,11 +10335,40 @@ mod tests { Some(-123456789), None, ]; + let array64: Vec> = array32.iter().map(|num| num.map(|x| x as i64)).collect(); + let array128: Vec> = + array64.iter().map(|num| num.map(|x| x as i128)).collect(); let array256: Vec> = array128 .iter() .map(|num| num.map(i256::from_i128)) .collect(); + test_decimal_to_string::( + DataType::Utf8View, + create_decimal32_array(array32.clone(), 7, 3).unwrap(), + ); + test_decimal_to_string::( + DataType::Utf8, + create_decimal32_array(array32.clone(), 7, 3).unwrap(), + ); + test_decimal_to_string::( + DataType::LargeUtf8, + create_decimal32_array(array32, 7, 3).unwrap(), + ); + + test_decimal_to_string::( + DataType::Utf8View, + create_decimal64_array(array64.clone(), 7, 3).unwrap(), + ); + test_decimal_to_string::( + DataType::Utf8, + create_decimal64_array(array64.clone(), 7, 3).unwrap(), + ); + test_decimal_to_string::( + DataType::LargeUtf8, + create_decimal64_array(array64, 7, 3).unwrap(), + ); + test_decimal_to_string::( DataType::Utf8View, create_decimal128_array(array128.clone(), 7, 3).unwrap(), @@ -9623,7 +10419,10 @@ mod tests { format_options: FormatOptions::default(), }, ); - assert_eq!("Invalid argument error: 1234567000 is too large to store in a Decimal128 of precision 7. Max is 9999999", err.unwrap_err().to_string()); + assert_eq!( + "Invalid argument error: 1234567.000 is too large to store in a Decimal128 of precision 7. Max is 9999.999", + err.unwrap_err().to_string() + ); } #[test] @@ -9649,7 +10448,10 @@ mod tests { format_options: FormatOptions::default(), }, ); - assert_eq!("Invalid argument error: 1234567000 is too large to store in a Decimal256 of precision 7. Max is 9999999", err.unwrap_err().to_string()); + assert_eq!( + "Invalid argument error: 1234567.000 is too large to store in a Decimal256 of precision 7. Max is 9999.999", + err.unwrap_err().to_string() + ); } /// helper function to test casting from duration to interval @@ -10238,7 +11040,7 @@ mod tests { let to_type = DataType::Utf8; let result = cast(&struct_array, &to_type); assert_eq!( - r#"Cast error: Casting from Struct([Field { name: "a", data_type: Boolean, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }]) to Utf8 not supported"#, + r#"Cast error: Casting from Struct("a": non-null Boolean) to Utf8 not supported"#, result.unwrap_err().to_string() ); } @@ -10249,11 +11051,170 @@ mod tests { let to_type = DataType::Struct(vec![Field::new("a", DataType::Boolean, false)].into()); let result = cast(&array, &to_type); assert_eq!( - r#"Cast error: Casting from Utf8 to Struct([Field { name: "a", data_type: Boolean, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }]) not supported"#, + r#"Cast error: Casting from Utf8 to Struct("a": non-null Boolean) not supported"#, result.unwrap_err().to_string() ); } + #[test] + fn test_cast_struct_with_different_field_order() { + // Test slow path: fields are in different order + let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true])); + let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31])); + let string = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "qux"])); + + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("a", DataType::Boolean, false)), + boolean.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("b", DataType::Int32, false)), + int.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("c", DataType::Utf8, false)), + string.clone() as ArrayRef, + ), + ]); + + // Target has fields in different order: c, a, b instead of a, b, c + let to_type = DataType::Struct( + vec![ + Field::new("c", DataType::Utf8, false), + Field::new("a", DataType::Utf8, false), // Boolean to Utf8 + Field::new("b", DataType::Utf8, false), // Int32 to Utf8 + ] + .into(), + ); + + let result = cast(&struct_array, &to_type).unwrap(); + let result_struct = result.as_struct(); + + assert_eq!(result_struct.data_type(), &to_type); + assert_eq!(result_struct.num_columns(), 3); + + // Verify field "c" (originally position 2, now position 0) remains Utf8 + let c_column = result_struct.column(0).as_string::(); + assert_eq!( + c_column.into_iter().flatten().collect::>(), + vec!["foo", "bar", "baz", "qux"] + ); + + // Verify field "a" (originally position 0, now position 1) was cast from Boolean to Utf8 + let a_column = result_struct.column(1).as_string::(); + assert_eq!( + a_column.into_iter().flatten().collect::>(), + vec!["false", "false", "true", "true"] + ); + + // Verify field "b" (originally position 1, now position 2) was cast from Int32 to Utf8 + let b_column = result_struct.column(2).as_string::(); + assert_eq!( + b_column.into_iter().flatten().collect::>(), + vec!["42", "28", "19", "31"] + ); + } + + #[test] + fn test_cast_struct_with_missing_field() { + // Test that casting fails when target has a field not present in source + let boolean = Arc::new(BooleanArray::from(vec![false, true])); + let struct_array = StructArray::from(vec![( + Arc::new(Field::new("a", DataType::Boolean, false)), + boolean.clone() as ArrayRef, + )]); + + let to_type = DataType::Struct( + vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", DataType::Int32, false), // Field "b" doesn't exist in source + ] + .into(), + ); + + let result = cast(&struct_array, &to_type); + assert!(result.is_err()); + assert_eq!( + result.unwrap_err().to_string(), + "Invalid argument error: Incorrect number of arrays for StructArray fields, expected 2 got 1" + ); + } + + #[test] + fn test_cast_struct_with_subset_of_fields() { + // Test casting to a struct with fewer fields (selecting a subset) + let boolean = Arc::new(BooleanArray::from(vec![false, false, true, true])); + let int = Arc::new(Int32Array::from(vec![42, 28, 19, 31])); + let string = Arc::new(StringArray::from(vec!["foo", "bar", "baz", "qux"])); + + let struct_array = StructArray::from(vec![ + ( + Arc::new(Field::new("a", DataType::Boolean, false)), + boolean.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("b", DataType::Int32, false)), + int.clone() as ArrayRef, + ), + ( + Arc::new(Field::new("c", DataType::Utf8, false)), + string.clone() as ArrayRef, + ), + ]); + + // Target has only fields "c" and "a", omitting "b" + let to_type = DataType::Struct( + vec![ + Field::new("c", DataType::Utf8, false), + Field::new("a", DataType::Utf8, false), + ] + .into(), + ); + + let result = cast(&struct_array, &to_type).unwrap(); + let result_struct = result.as_struct(); + + assert_eq!(result_struct.data_type(), &to_type); + assert_eq!(result_struct.num_columns(), 2); + + // Verify field "c" remains Utf8 + let c_column = result_struct.column(0).as_string::(); + assert_eq!( + c_column.into_iter().flatten().collect::>(), + vec!["foo", "bar", "baz", "qux"] + ); + + // Verify field "a" was cast from Boolean to Utf8 + let a_column = result_struct.column(1).as_string::(); + assert_eq!( + a_column.into_iter().flatten().collect::>(), + vec!["false", "false", "true", "true"] + ); + } + + #[test] + fn test_can_cast_struct_rename_field() { + // Test that can_cast_types returns false when target has a field not in source + let from_type = DataType::Struct( + vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Utf8, false), + ] + .into(), + ); + + let to_type = DataType::Struct( + vec![ + Field::new("a", DataType::Int64, false), + Field::new("c", DataType::Boolean, false), // Field "c" not in source + ] + .into(), + ); + + assert!(can_cast_types(&from_type, &to_type)); + } + fn run_decimal_cast_test_case_between_multiple_types(t: DecimalCastTestConfig) { run_decimal_cast_test_case::(t.clone()); run_decimal_cast_test_case::(t.clone()); @@ -10289,7 +11250,7 @@ mod tests { input_repr: 99999, // 9999.9 output_prec: 7, output_scale: 6, - expected_output_repr: Err("Invalid argument error: 9999900000 is too large to store in a {} of precision 7. Max is 9999999".to_string()) // max is 9.999999 + expected_output_repr: Err("Invalid argument error: 9999.900000 is too large to store in a {} of precision 7. Max is 9.999999".to_string()) // max is 9.999999 }, // increase precision, decrease scale, always infallible DecimalCastTestConfig { @@ -10334,7 +11295,7 @@ mod tests { input_repr: 9999999, // 99.99999 output_prec: 8, output_scale: 7, - expected_output_repr: Err("Invalid argument error: 999999900 is too large to store in a {} of precision 8. Max is 99999999".to_string()) // max is 9.9999999 + expected_output_repr: Err("Invalid argument error: 99.9999900 is too large to store in a {} of precision 8. Max is 9.9999999".to_string()) // max is 9.9999999 }, // decrease precision, decrease scale, safe, infallible DecimalCastTestConfig { @@ -10361,7 +11322,7 @@ mod tests { input_repr: 9999999, // 99.99999 output_prec: 4, output_scale: 3, - expected_output_repr: Err("Invalid argument error: 100000 is too large to store in a {} of precision 4. Max is 9999".to_string()) // max is 9.999 + expected_output_repr: Err("Invalid argument error: 100.000 is too large to store in a {} of precision 4. Max is 9.999".to_string()) // max is 9.999 }, // decrease precision, same scale, safe DecimalCastTestConfig { @@ -10379,7 +11340,7 @@ mod tests { input_repr: 9999999, // 99.99999 output_prec: 6, output_scale: 5, - expected_output_repr: Err("Invalid argument error: 9999999 is too large to store in a {} of precision 6. Max is 999999".to_string()) // max is 9.99999 + expected_output_repr: Err("Invalid argument error: 99.99999 is too large to store in a {} of precision 6. Max is 9.99999".to_string()) // max is 9.99999 }, // same precision, increase scale, safe DecimalCastTestConfig { @@ -10397,7 +11358,7 @@ mod tests { input_repr: 123456, // 12.3456 output_prec: 7, output_scale: 6, - expected_output_repr: Err("Invalid argument error: 12345600 is too large to store in a {} of precision 7. Max is 9999999".to_string()) // max is 9.99999 + expected_output_repr: Err("Invalid argument error: 12.345600 is too large to store in a {} of precision 7. Max is 9.999999".to_string()) // max is 9.99999 }, // same precision, decrease scale, infallible DecimalCastTestConfig { @@ -10492,7 +11453,7 @@ mod tests { input_repr: -12345, output_prec: 6, output_scale: 5, - expected_output_repr: Err("Invalid argument error: -1234500 is too small to store in a {} of precision 6. Min is -999999".to_string()) + expected_output_repr: Err("Invalid argument error: -12.34500 is too small to store in a {} of precision 6. Min is -9.99999".to_string()) }, ]; @@ -10543,7 +11504,7 @@ mod tests { output_prec: 6, output_scale: 3, expected_output_repr: - Err("Invalid argument error: 1000000 is too large to store in a {} of precision 6. Max is 999999".to_string()), + Err("Invalid argument error: 1000.000 is too large to store in a {} of precision 6. Max is 999.999".to_string()), }, ]; for t in test_cases { @@ -10564,8 +11525,10 @@ mod tests { ..Default::default() }; let result = cast_with_options(&array, &output_type, &options); - assert_eq!(result.unwrap_err().to_string(), - "Invalid argument error: 123456789 is too large to store in a Decimal128 of precision 6. Max is 999999"); + assert_eq!( + result.unwrap_err().to_string(), + "Invalid argument error: 1234567.89 is too large to store in a Decimal128 of precision 6. Max is 9999.99" + ); } #[test] @@ -10610,8 +11573,10 @@ mod tests { ..Default::default() }; let result = cast_with_options(&array, &output_type, &options); - assert_eq!(result.unwrap_err().to_string(), - "Invalid argument error: 1234568 is too large to store in a Decimal128 of precision 6. Max is 999999"); + assert_eq!( + result.unwrap_err().to_string(), + "Invalid argument error: 12345.68 is too large to store in a Decimal128 of precision 6. Max is 9999.99" + ); } #[test] @@ -10627,8 +11592,10 @@ mod tests { ..Default::default() }; let result = cast_with_options(&array, &output_type, &options); - assert_eq!(result.unwrap_err().to_string(), - "Invalid argument error: 1234567890 is too large to store in a Decimal128 of precision 6. Max is 999999"); + assert_eq!( + result.unwrap_err().to_string(), + "Invalid argument error: 1234567.890 is too large to store in a Decimal128 of precision 6. Max is 999.999" + ); } #[test] @@ -10643,9 +11610,11 @@ mod tests { safe: false, ..Default::default() }; - let result = cast_with_options(&array, &output_type, &options); - assert_eq!(result.unwrap_err().to_string(), - "Invalid argument error: 123456789 is too large to store in a Decimal256 of precision 6. Max is 999999"); + let result = cast_with_options(&array, &output_type, &options).unwrap_err(); + assert_eq!( + result.to_string(), + "Invalid argument error: 1234567.89 is too large to store in a Decimal256 of precision 6. Max is 9999.99" + ); } #[test] @@ -10684,4 +11653,802 @@ mod tests { )) as ArrayRef; assert_eq!(*fixed_array, *r); } + + #[test] + fn test_cast_decimal_error_output() { + let array = Int64Array::from(vec![1]); + let error = cast_with_options( + &array, + &DataType::Decimal32(1, 1), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ) + .unwrap_err(); + assert_eq!( + error.to_string(), + "Invalid argument error: 1.0 is too large to store in a Decimal32 of precision 1. Max is 0.9" + ); + + let array = Int64Array::from(vec![-1]); + let error = cast_with_options( + &array, + &DataType::Decimal32(1, 1), + &CastOptions { + safe: false, + format_options: FormatOptions::default(), + }, + ) + .unwrap_err(); + assert_eq!( + error.to_string(), + "Invalid argument error: -1.0 is too small to store in a Decimal32 of precision 1. Min is -0.9" + ); + } + + #[test] + fn test_run_end_encoded_to_primitive() { + // Create a RunEndEncoded array: [1, 1, 2, 2, 2, 3] + let run_ends = Int32Array::from(vec![2, 5, 6]); + let values = Int32Array::from(vec![1, 2, 3]); + let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); + let array_ref = Arc::new(run_array) as ArrayRef; + // Cast to Int64 + let cast_result = cast(&array_ref, &DataType::Int64).unwrap(); + // Verify the result is a RunArray with Int64 values + let result_run_array = cast_result.as_any().downcast_ref::().unwrap(); + assert_eq!( + result_run_array.values(), + &[1i64, 1i64, 2i64, 2i64, 2i64, 3i64] + ); + } + + #[test] + fn test_sliced_run_end_encoded_to_primitive() { + let run_ends = Int32Array::from(vec![2, 5, 6]); + let values = Int32Array::from(vec![1, 2, 3]); + // [1, 1, 2, 2, 2, 3] + let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); + let run_array = run_array.slice(3, 3); // [2, 2, 3] + let array_ref = Arc::new(run_array) as ArrayRef; + + let cast_result = cast(&array_ref, &DataType::Int64).unwrap(); + let result_run_array = cast_result.as_primitive::(); + assert_eq!(result_run_array.values(), &[2, 2, 3]); + } + + #[test] + fn test_run_end_encoded_to_string() { + let run_ends = Int32Array::from(vec![2, 3, 5]); + let values = Int32Array::from(vec![10, 20, 30]); + let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); + let array_ref = Arc::new(run_array) as ArrayRef; + + // Cast to String + let cast_result = cast(&array_ref, &DataType::Utf8).unwrap(); + + // Verify the result is a RunArray with String values + let result_array = cast_result.as_any().downcast_ref::().unwrap(); + // Check that values are correct + assert_eq!(result_array.value(0), "10"); + assert_eq!(result_array.value(1), "10"); + assert_eq!(result_array.value(2), "20"); + } + + #[test] + fn test_primitive_to_run_end_encoded() { + // Create an Int32 array with repeated values: [1, 1, 2, 2, 2, 3] + let source_array = Int32Array::from(vec![1, 1, 2, 2, 2, 3]); + let array_ref = Arc::new(source_array) as ArrayRef; + + // Cast to RunEndEncoded + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int32, false)), + Arc::new(Field::new("values", DataType::Int32, true)), + ); + let cast_result = cast(&array_ref, &target_type).unwrap(); + + // Verify the result is a RunArray + let result_run_array = cast_result + .as_any() + .downcast_ref::>() + .unwrap(); + + // Check run structure: runs should end at positions [2, 5, 6] + assert_eq!(result_run_array.run_ends().values(), &[2, 5, 6]); + + // Check values: should be [1, 2, 3] + let values_array = result_run_array.values().as_primitive::(); + assert_eq!(values_array.values(), &[1, 2, 3]); + } + + #[test] + fn test_primitive_to_run_end_encoded_with_nulls() { + let source_array = Int32Array::from(vec![ + Some(1), + Some(1), + None, + None, + Some(2), + Some(2), + Some(3), + Some(3), + None, + None, + Some(4), + Some(4), + Some(5), + Some(5), + None, + None, + ]); + let array_ref = Arc::new(source_array) as ArrayRef; + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int32, false)), + Arc::new(Field::new("values", DataType::Int32, true)), + ); + let cast_result = cast(&array_ref, &target_type).unwrap(); + let result_run_array = cast_result + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!( + result_run_array.run_ends().values(), + &[2, 4, 6, 8, 10, 12, 14, 16] + ); + assert_eq!( + result_run_array + .values() + .as_primitive::() + .values(), + &[1, 0, 2, 3, 0, 4, 5, 0] + ); + assert_eq!(result_run_array.values().null_count(), 3); + } + + #[test] + fn test_primitive_to_run_end_encoded_with_nulls_consecutive() { + let source_array = Int64Array::from(vec![ + Some(1), + Some(1), + None, + None, + None, + None, + None, + None, + None, + None, + Some(4), + Some(20), + Some(500), + Some(500), + None, + None, + ]); + let array_ref = Arc::new(source_array) as ArrayRef; + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int16, false)), + Arc::new(Field::new("values", DataType::Int64, true)), + ); + let cast_result = cast(&array_ref, &target_type).unwrap(); + let result_run_array = cast_result + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!( + result_run_array.run_ends().values(), + &[2, 10, 11, 12, 14, 16] + ); + assert_eq!( + result_run_array + .values() + .as_primitive::() + .values(), + &[1, 0, 4, 20, 500, 0] + ); + assert_eq!(result_run_array.values().null_count(), 2); + } + + #[test] + fn test_string_to_run_end_encoded() { + // Create a String array with repeated values: ["a", "a", "b", "c", "c"] + let source_array = StringArray::from(vec!["a", "a", "b", "c", "c"]); + let array_ref = Arc::new(source_array) as ArrayRef; + + // Cast to RunEndEncoded + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int32, false)), + Arc::new(Field::new("values", DataType::Utf8, true)), + ); + let cast_result = cast(&array_ref, &target_type).unwrap(); + + // Verify the result is a RunArray + let result_run_array = cast_result + .as_any() + .downcast_ref::>() + .unwrap(); + + // Check run structure: runs should end at positions [2, 3, 5] + assert_eq!(result_run_array.run_ends().values(), &[2, 3, 5]); + + // Check values: should be ["a", "b", "c"] + let values_array = result_run_array.values().as_string::(); + assert_eq!(values_array.value(0), "a"); + assert_eq!(values_array.value(1), "b"); + assert_eq!(values_array.value(2), "c"); + } + + #[test] + fn test_empty_array_to_run_end_encoded() { + // Create an empty Int32 array + let source_array = Int32Array::from(Vec::::new()); + let array_ref = Arc::new(source_array) as ArrayRef; + + // Cast to RunEndEncoded + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int32, false)), + Arc::new(Field::new("values", DataType::Int32, true)), + ); + let cast_result = cast(&array_ref, &target_type).unwrap(); + + // Verify the result is an empty RunArray + let result_run_array = cast_result + .as_any() + .downcast_ref::>() + .unwrap(); + + // Check that both run_ends and values are empty + assert_eq!(result_run_array.run_ends().len(), 0); + assert_eq!(result_run_array.values().len(), 0); + } + + #[test] + fn test_run_end_encoded_with_nulls() { + // Create a RunEndEncoded array with nulls: [1, 1, null, 2, 2] + let run_ends = Int32Array::from(vec![2, 3, 5]); + let values = Int32Array::from(vec![Some(1), None, Some(2)]); + let run_array = RunArray::::try_new(&run_ends, &values).unwrap(); + let array_ref = Arc::new(run_array) as ArrayRef; + + // Cast to String + let cast_result = cast(&array_ref, &DataType::Utf8).unwrap(); + + // Verify the result preserves nulls + let result_run_array = cast_result.as_any().downcast_ref::().unwrap(); + assert_eq!(result_run_array.value(0), "1"); + assert!(result_run_array.is_null(2)); + assert_eq!(result_run_array.value(4), "2"); + } + + #[test] + fn test_different_index_types() { + // Test with Int16 index type + let source_array = Int32Array::from(vec![1, 1, 2, 3, 3]); + let array_ref = Arc::new(source_array) as ArrayRef; + + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int16, false)), + Arc::new(Field::new("values", DataType::Int32, true)), + ); + let cast_result = cast(&array_ref, &target_type).unwrap(); + assert_eq!(cast_result.data_type(), &target_type); + + // Verify the cast worked correctly: values are [1, 2, 3] + // and run-ends are [2, 3, 5] + let run_array = cast_result + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(run_array.values().as_primitive::().value(0), 1); + assert_eq!(run_array.values().as_primitive::().value(1), 2); + assert_eq!(run_array.values().as_primitive::().value(2), 3); + assert_eq!(run_array.run_ends().values(), &[2i16, 3i16, 5i16]); + + // Test again with Int64 index type + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int64, false)), + Arc::new(Field::new("values", DataType::Int32, true)), + ); + let cast_result = cast(&array_ref, &target_type).unwrap(); + assert_eq!(cast_result.data_type(), &target_type); + + // Verify the cast worked correctly: values are [1, 2, 3] + // and run-ends are [2, 3, 5] + let run_array = cast_result + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(run_array.values().as_primitive::().value(0), 1); + assert_eq!(run_array.values().as_primitive::().value(1), 2); + assert_eq!(run_array.values().as_primitive::().value(2), 3); + assert_eq!(run_array.run_ends().values(), &[2i64, 3i64, 5i64]); + } + + #[test] + fn test_unsupported_cast_to_run_end_encoded() { + // Create a Struct array - complex nested type that might not be supported + let field = Field::new("item", DataType::Int32, false); + let struct_array = StructArray::from(vec![( + Arc::new(field), + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + )]); + let array_ref = Arc::new(struct_array) as ArrayRef; + + // This should fail because: + // 1. The target type is not RunEndEncoded + // 2. The target type is not supported for casting from StructArray + let cast_result = cast(&array_ref, &DataType::FixedSizeBinary(10)); + + // Expect this to fail + assert!(cast_result.is_err()); + } + + /// Test casting RunEndEncoded to RunEndEncoded should fail + #[test] + fn test_cast_run_end_encoded_int64_to_int16_should_fail() { + // Construct a valid REE array with Int64 run-ends + let run_ends = Int64Array::from(vec![100_000, 400_000, 700_000]); // values too large for Int16 + let values = StringArray::from(vec!["a", "b", "c"]); + + let ree_array = RunArray::::try_new(&run_ends, &values).unwrap(); + let array_ref = Arc::new(ree_array) as ArrayRef; + + // Attempt to cast to RunEndEncoded + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int16, false)), + Arc::new(Field::new("values", DataType::Utf8, true)), + ); + let cast_options = CastOptions { + safe: false, // This should make it fail instead of returning nulls + format_options: FormatOptions::default(), + }; + + // This should fail due to run-end overflow + let result: Result, ArrowError> = + cast_with_options(&array_ref, &target_type, &cast_options); + + let e = result.expect_err("Cast should have failed but succeeded"); + assert!( + e.to_string() + .contains("Cast error: Can't cast value 100000 to type Int16") + ); + } + + #[test] + fn test_cast_run_end_encoded_int64_to_int16_with_safe_should_fail_with_null_invalid_error() { + // Construct a valid REE array with Int64 run-ends + let run_ends = Int64Array::from(vec![100_000, 400_000, 700_000]); // values too large for Int16 + let values = StringArray::from(vec!["a", "b", "c"]); + + let ree_array = RunArray::::try_new(&run_ends, &values).unwrap(); + let array_ref = Arc::new(ree_array) as ArrayRef; + + // Attempt to cast to RunEndEncoded + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int16, false)), + Arc::new(Field::new("values", DataType::Utf8, true)), + ); + let cast_options = CastOptions { + safe: true, + format_options: FormatOptions::default(), + }; + + // This fails even though safe is true because the run_ends array has null values + let result: Result, ArrowError> = + cast_with_options(&array_ref, &target_type, &cast_options); + let e = result.expect_err("Cast should have failed but succeeded"); + assert!( + e.to_string() + .contains("Invalid argument error: Found null values in run_ends array. The run_ends array should not have null values.") + ); + } + + /// Test casting RunEndEncoded to RunEndEncoded should succeed + #[test] + fn test_cast_run_end_encoded_int16_to_int64_should_succeed() { + // Construct a valid REE array with Int16 run-ends + let run_ends = Int16Array::from(vec![2, 5, 8]); // values that fit in Int16 + let values = StringArray::from(vec!["a", "b", "c"]); + + let ree_array = RunArray::::try_new(&run_ends, &values).unwrap(); + let array_ref = Arc::new(ree_array) as ArrayRef; + + // Attempt to cast to RunEndEncoded (upcast should succeed) + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int64, false)), + Arc::new(Field::new("values", DataType::Utf8, true)), + ); + let cast_options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + + // This should succeed due to valid upcast + let result: Result, ArrowError> = + cast_with_options(&array_ref, &target_type, &cast_options); + + let array_ref = result.expect("Cast should have succeeded but failed"); + // Downcast to RunArray + let run_array = array_ref + .as_any() + .downcast_ref::>() + .unwrap(); + + // Verify the cast worked correctly + // Assert the values were cast correctly + assert_eq!(run_array.run_ends().values(), &[2i64, 5i64, 8i64]); + assert_eq!(run_array.values().as_string::().value(0), "a"); + assert_eq!(run_array.values().as_string::().value(1), "b"); + assert_eq!(run_array.values().as_string::().value(2), "c"); + } + + #[test] + fn test_cast_run_end_encoded_dictionary_to_run_end_encoded() { + // Construct a valid dictionary encoded array + let values = StringArray::from_iter([Some("a"), Some("b"), Some("c")]); + let keys = UInt64Array::from_iter(vec![1, 1, 1, 0, 0, 0, 2, 2, 2]); + let array_ref = Arc::new(DictionaryArray::new(keys, Arc::new(values))) as ArrayRef; + + // Attempt to cast to RunEndEncoded + let target_type = DataType::RunEndEncoded( + Arc::new(Field::new("run_ends", DataType::Int64, false)), + Arc::new(Field::new("values", DataType::Utf8, true)), + ); + let cast_options = CastOptions { + safe: false, + format_options: FormatOptions::default(), + }; + + // This should succeed + let result = cast_with_options(&array_ref, &target_type, &cast_options) + .expect("Cast should have succeeded but failed"); + + // Verify the cast worked correctly + // Assert the values were cast correctly + let run_array = result + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(run_array.values().as_string::().value(0), "b"); + assert_eq!(run_array.values().as_string::().value(1), "a"); + assert_eq!(run_array.values().as_string::().value(2), "c"); + + // Verify the run-ends were cast correctly (run ends at 3, 6, 9) + assert_eq!(run_array.run_ends().values(), &[3i64, 6i64, 9i64]); + } + + fn int32_list_values() -> Vec>>> { + vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5), Some(6)]), + None, + Some(vec![Some(7), Some(8), Some(9)]), + Some(vec![None, Some(10)]), + ] + } + + #[test] + fn test_cast_list_view_to_list() { + let list_view = ListViewArray::from_iter_primitive::(int32_list_values()); + let target_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + assert!(can_cast_types(list_view.data_type(), &target_type)); + let cast_result = cast(&list_view, &target_type).unwrap(); + let got_list = cast_result.as_any().downcast_ref::().unwrap(); + let expected_list = ListArray::from_iter_primitive::(int32_list_values()); + assert_eq!(got_list, &expected_list); + } + + #[test] + fn test_cast_list_to_list_view() { + let list = ListArray::from_iter_primitive::(int32_list_values()); + let target_type = DataType::ListView(Arc::new(Field::new("item", DataType::Int32, true))); + assert!(can_cast_types(list.data_type(), &target_type)); + let cast_result = cast(&list, &target_type).unwrap(); + + let got_list_view = cast_result + .as_any() + .downcast_ref::() + .unwrap(); + let expected_list_view = + ListViewArray::from_iter_primitive::(int32_list_values()); + assert_eq!(got_list_view, &expected_list_view); + } + + #[test] + fn test_cast_large_list_view_to_large_list() { + let list_view = + LargeListViewArray::from_iter_primitive::(int32_list_values()); + let target_type = DataType::LargeList(Arc::new(Field::new("item", DataType::Int32, true))); + assert!(can_cast_types(list_view.data_type(), &target_type)); + let cast_result = cast(&list_view, &target_type).unwrap(); + let got_list = cast_result + .as_any() + .downcast_ref::() + .unwrap(); + + let expected_list = + LargeListArray::from_iter_primitive::(int32_list_values()); + assert_eq!(got_list, &expected_list); + } + + #[test] + fn test_cast_large_list_to_large_list_view() { + let list = LargeListArray::from_iter_primitive::(int32_list_values()); + let target_type = + DataType::LargeListView(Arc::new(Field::new("item", DataType::Int32, true))); + assert!(can_cast_types(list.data_type(), &target_type)); + let cast_result = cast(&list, &target_type).unwrap(); + + let got_list_view = cast_result + .as_any() + .downcast_ref::() + .unwrap(); + let expected_list_view = + LargeListViewArray::from_iter_primitive::(int32_list_values()); + assert_eq!(got_list_view, &expected_list_view); + } + + #[test] + fn test_cast_list_view_to_list_out_of_order() { + let list_view = ListViewArray::new( + Arc::new(Field::new("item", DataType::Int32, true)), + ScalarBuffer::from(vec![0, 6, 3]), + ScalarBuffer::from(vec![3, 3, 3]), + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9])), + None, + ); + let target_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + assert!(can_cast_types(list_view.data_type(), &target_type)); + let cast_result = cast(&list_view, &target_type).unwrap(); + let got_list = cast_result.as_any().downcast_ref::().unwrap(); + let expected_list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(7), Some(8), Some(9)]), + Some(vec![Some(4), Some(5), Some(6)]), + ]); + assert_eq!(got_list, &expected_list); + } + + #[test] + fn test_cast_list_view_to_list_overlapping() { + let list_view = ListViewArray::new( + Arc::new(Field::new("item", DataType::Int32, true)), + ScalarBuffer::from(vec![0, 0]), + ScalarBuffer::from(vec![1, 2]), + Arc::new(Int32Array::from(vec![1, 2])), + None, + ); + let target_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + assert!(can_cast_types(list_view.data_type(), &target_type)); + let cast_result = cast(&list_view, &target_type).unwrap(); + let got_list = cast_result.as_any().downcast_ref::().unwrap(); + let expected_list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1)]), + Some(vec![Some(1), Some(2)]), + ]); + assert_eq!(got_list, &expected_list); + } + + #[test] + fn test_cast_list_view_to_list_empty() { + let values: Vec>>> = vec![]; + let list_view = ListViewArray::from_iter_primitive::(values.clone()); + let target_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + assert!(can_cast_types(list_view.data_type(), &target_type)); + let cast_result = cast(&list_view, &target_type).unwrap(); + let got_list = cast_result.as_any().downcast_ref::().unwrap(); + let expected_list = ListArray::from_iter_primitive::(values); + assert_eq!(got_list, &expected_list); + } + + #[test] + fn test_cast_list_view_to_list_different_inner_type() { + let values = int32_list_values(); + let list_view = ListViewArray::from_iter_primitive::(values.clone()); + let target_type = DataType::List(Arc::new(Field::new("item", DataType::Int64, true))); + assert!(can_cast_types(list_view.data_type(), &target_type)); + let cast_result = cast(&list_view, &target_type).unwrap(); + let got_list = cast_result.as_any().downcast_ref::().unwrap(); + + let expected_list = + ListArray::from_iter_primitive::(values.into_iter().map(|list| { + list.map(|list| { + list.into_iter() + .map(|v| v.map(|v| v as i64)) + .collect::>() + }) + })); + assert_eq!(got_list, &expected_list); + } + + #[test] + fn test_cast_list_view_to_list_out_of_order_with_nulls() { + let list_view = ListViewArray::new( + Arc::new(Field::new("item", DataType::Int32, true)), + ScalarBuffer::from(vec![0, 6, 3]), + ScalarBuffer::from(vec![3, 3, 3]), + Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9])), + Some(NullBuffer::from(vec![false, true, false])), + ); + let target_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true))); + assert!(can_cast_types(list_view.data_type(), &target_type)); + let cast_result = cast(&list_view, &target_type).unwrap(); + let got_list = cast_result.as_any().downcast_ref::().unwrap(); + let expected_list = ListArray::new( + Arc::new(Field::new("item", DataType::Int32, true)), + OffsetBuffer::from_lengths([3, 3, 3]), + Arc::new(Int32Array::from(vec![1, 2, 3, 7, 8, 9, 4, 5, 6])), + Some(NullBuffer::from(vec![false, true, false])), + ); + assert_eq!(got_list, &expected_list); + } + + #[test] + fn test_cast_list_view_to_large_list_view() { + let list_view = ListViewArray::from_iter_primitive::(int32_list_values()); + let target_type = + DataType::LargeListView(Arc::new(Field::new("item", DataType::Int32, true))); + assert!(can_cast_types(list_view.data_type(), &target_type)); + let cast_result = cast(&list_view, &target_type).unwrap(); + let got = cast_result + .as_any() + .downcast_ref::() + .unwrap(); + + let expected = + LargeListViewArray::from_iter_primitive::(int32_list_values()); + assert_eq!(got, &expected); + } + + #[test] + fn test_cast_large_list_view_to_list_view() { + let list_view = + LargeListViewArray::from_iter_primitive::(int32_list_values()); + let target_type = DataType::ListView(Arc::new(Field::new("item", DataType::Int32, true))); + assert!(can_cast_types(list_view.data_type(), &target_type)); + let cast_result = cast(&list_view, &target_type).unwrap(); + let got = cast_result + .as_any() + .downcast_ref::() + .unwrap(); + + let expected = ListViewArray::from_iter_primitive::(int32_list_values()); + assert_eq!(got, &expected); + } + + #[test] + fn test_cast_time32_second_to_int64() { + let array = Time32SecondArray::from(vec![1000, 2000, 3000]); + let array = Arc::new(array) as Arc; + let to_type = DataType::Int64; + let cast_options = CastOptions::default(); + + assert!(can_cast_types(array.data_type(), &to_type)); + + let result = cast_with_options(&array, &to_type, &cast_options); + assert!( + result.is_ok(), + "Failed to cast Time32(Second) to Int64: {:?}", + result.err() + ); + + let cast_array = result.unwrap(); + let cast_array = cast_array.as_any().downcast_ref::().unwrap(); + + assert_eq!(cast_array.value(0), 1000); + assert_eq!(cast_array.value(1), 2000); + assert_eq!(cast_array.value(2), 3000); + } + + #[test] + fn test_cast_time32_millisecond_to_int64() { + let array = Time32MillisecondArray::from(vec![1000, 2000, 3000]); + let array = Arc::new(array) as Arc; + let to_type = DataType::Int64; + let cast_options = CastOptions::default(); + + assert!(can_cast_types(array.data_type(), &to_type)); + + let result = cast_with_options(&array, &to_type, &cast_options); + assert!( + result.is_ok(), + "Failed to cast Time32(Millisecond) to Int64: {:?}", + result.err() + ); + + let cast_array = result.unwrap(); + let cast_array = cast_array.as_any().downcast_ref::().unwrap(); + + assert_eq!(cast_array.value(0), 1000); + assert_eq!(cast_array.value(1), 2000); + assert_eq!(cast_array.value(2), 3000); + } + + #[test] + fn test_cast_string_to_time32_second_to_int64() { + // Mimic: select arrow_cast('03:12:44'::time, 'Time32(Second)')::bigint; + // raised in https://github.com/apache/datafusion/issues/19036 + let array = StringArray::from(vec!["03:12:44"]); + let array = Arc::new(array) as Arc; + let cast_options = CastOptions::default(); + + // 1. Cast String to Time32(Second) + let time32_type = DataType::Time32(TimeUnit::Second); + let time32_array = cast_with_options(&array, &time32_type, &cast_options).unwrap(); + + // 2. Cast Time32(Second) to Int64 + let int64_type = DataType::Int64; + assert!(can_cast_types(time32_array.data_type(), &int64_type)); + + let result = cast_with_options(&time32_array, &int64_type, &cast_options); + + assert!( + result.is_ok(), + "Failed to cast Time32(Second) to Int64: {:?}", + result.err() + ); + + let cast_array = result.unwrap(); + let cast_array = cast_array.as_any().downcast_ref::().unwrap(); + + // 03:12:44 = 3*3600 + 12*60 + 44 = 10800 + 720 + 44 = 11564 + assert_eq!(cast_array.value(0), 11564); + } + #[test] + fn test_string_dicts_to_binary_view() { + let expected = BinaryViewArray::from_iter(vec![ + VIEW_TEST_DATA[1], + VIEW_TEST_DATA[0], + None, + VIEW_TEST_DATA[3], + None, + VIEW_TEST_DATA[1], + VIEW_TEST_DATA[4], + ]); + + let values_arrays: [ArrayRef; _] = [ + Arc::new(StringArray::from_iter(VIEW_TEST_DATA)), + Arc::new(StringViewArray::from_iter(VIEW_TEST_DATA)), + Arc::new(LargeStringArray::from_iter(VIEW_TEST_DATA)), + ]; + for values in values_arrays { + let keys = + Int8Array::from_iter([Some(1), Some(0), None, Some(3), None, Some(1), Some(4)]); + let string_dict_array = + DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + + let casted = cast(&string_dict_array, &DataType::BinaryView).unwrap(); + assert_eq!(casted.as_ref(), &expected); + } + } + + #[test] + fn test_binary_dicts_to_string_view() { + let expected = StringViewArray::from_iter(vec![ + VIEW_TEST_DATA[1], + VIEW_TEST_DATA[0], + None, + VIEW_TEST_DATA[3], + None, + VIEW_TEST_DATA[1], + VIEW_TEST_DATA[4], + ]); + + let values_arrays: [ArrayRef; _] = [ + Arc::new(BinaryArray::from_iter(VIEW_TEST_DATA)), + Arc::new(BinaryViewArray::from_iter(VIEW_TEST_DATA)), + Arc::new(LargeBinaryArray::from_iter(VIEW_TEST_DATA)), + ]; + for values in values_arrays { + let keys = + Int8Array::from_iter([Some(1), Some(0), None, Some(3), None, Some(1), Some(4)]); + let string_dict_array = + DictionaryArray::::try_new(keys, Arc::new(values)).unwrap(); + + let casted = cast(&string_dict_array, &DataType::Utf8View).unwrap(); + assert_eq!(casted.as_ref(), &expected); + } + } } diff --git a/arrow-cast/src/cast/run_array.rs b/arrow-cast/src/cast/run_array.rs new file mode 100644 index 000000000000..3e14804dc824 --- /dev/null +++ b/arrow-cast/src/cast/run_array.rs @@ -0,0 +1,169 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::cast::*; +use arrow_ord::partition::partition; + +/// Attempts to cast a `RunArray` with index type K into +/// `to_type` for supported types. +pub(crate) fn run_end_encoded_cast( + array: &dyn Array, + to_type: &DataType, + cast_options: &CastOptions, +) -> Result { + match array.data_type() { + DataType::RunEndEncoded(_, _) => { + let run_array = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| ArrowError::CastError("Expected RunArray".to_string()))?; + + let values = run_array.values(); + + match to_type { + // Stay as RunEndEncoded, cast only the values + DataType::RunEndEncoded(target_index_field, target_value_field) => { + let cast_values = + cast_with_options(values, target_value_field.data_type(), cast_options)?; + + let run_ends_array = PrimitiveArray::::from_iter_values( + run_array.run_ends().values().iter().copied(), + ); + let cast_run_ends = cast_with_options( + &run_ends_array, + target_index_field.data_type(), + cast_options, + )?; + let new_run_array: ArrayRef = match target_index_field.data_type() { + DataType::Int16 => { + let re = cast_run_ends.as_primitive::(); + Arc::new(RunArray::::try_new(re, cast_values.as_ref())?) + } + DataType::Int32 => { + let re = cast_run_ends.as_primitive::(); + Arc::new(RunArray::::try_new(re, cast_values.as_ref())?) + } + DataType::Int64 => { + let re = cast_run_ends.as_primitive::(); + Arc::new(RunArray::::try_new(re, cast_values.as_ref())?) + } + _ => { + return Err(ArrowError::CastError( + "Run-end type must be i16, i32, or i64".to_string(), + )); + } + }; + Ok(Arc::new(new_run_array)) + } + + // Expand to logical form + _ => { + let len = run_array.len(); + let offset = run_array.offset(); + let run_ends = run_array.run_ends().values(); + + let mut indices = Vec::with_capacity(len); + let mut physical_idx = run_array.get_start_physical_index(); + + for logical_idx in offset..offset + len { + if logical_idx == run_ends[physical_idx].as_usize() { + // If the logical index is equal to the (next) run end, increment the physical index, + // since we are at the end of a run. + physical_idx += 1; + } + indices.push(physical_idx as i32); + } + + let taken = take(&values, &Int32Array::from_iter_values(indices), None)?; + if taken.data_type() != to_type { + cast_with_options(taken.as_ref(), to_type, cast_options) + } else { + Ok(taken) + } + } + } + } + + _ => Err(ArrowError::CastError(format!( + "Cannot cast array of type {:?} to RunEndEncodedArray", + array.data_type() + ))), + } +} + +/// Attempts to encode an array into a `RunArray` with index type K +/// and value type `value_type` +pub(crate) fn cast_to_run_end_encoded( + array: &ArrayRef, + value_type: &DataType, + cast_options: &CastOptions, +) -> Result { + let mut run_ends_builder = PrimitiveBuilder::::new(); + + // Cast the input array to the target value type if necessary + let cast_array = if array.data_type() == value_type { + array + } else { + &cast_with_options(array, value_type, cast_options)? + }; + + // Return early if the array to cast is empty + if cast_array.is_empty() { + let empty_run_ends = run_ends_builder.finish(); + let empty_values = make_array(ArrayData::new_empty(value_type)); + return Ok(Arc::new(RunArray::::try_new( + &empty_run_ends, + empty_values.as_ref(), + )?)); + } + + // REE arrays are handled by run_end_encoded_cast + if let DataType::RunEndEncoded(_, _) = array.data_type() { + return Err(ArrowError::CastError( + "Source array is already a RunEndEncoded array, should have been handled by run_end_encoded_cast".to_string() + )); + } + + // Partition the array to identify runs of consecutive equal values + let partitions = partition(&[Arc::clone(cast_array)])?; + let size = partitions.len(); + let mut run_ends = Vec::with_capacity(size); + let mut values_indexes = Vec::with_capacity(size); + let mut last_partition_end = 0; + for partition in partitions.ranges() { + values_indexes.push(last_partition_end); + run_ends.push(partition.end); + last_partition_end = partition.end; + } + + // Build the run_ends array + for run_end in run_ends { + run_ends_builder.append_value(K::Native::from_usize(run_end).ok_or_else(|| { + ArrowError::CastError(format!("Run end index out of range: {}", run_end)) + })?); + } + let run_ends_array = run_ends_builder.finish(); + // Build the values array by taking elements at the run start positions + let indices = PrimitiveArray::::from_iter_values( + values_indexes.iter().map(|&idx| idx as u32), + ); + let values_array = take(&cast_array, &indices, None)?; + + // Create and return the RunArray + let run_array = RunArray::::try_new(&run_ends_array, values_array.as_ref())?; + Ok(Arc::new(run_array)) +} diff --git a/arrow-cast/src/cast/string.rs b/arrow-cast/src/cast/string.rs index 7f22c4fd64de..77696ae0d8cc 100644 --- a/arrow-cast/src/cast/string.rs +++ b/arrow-cast/src/cast/string.rs @@ -107,15 +107,14 @@ fn parse_string_iter< .map(|x| match x { Some(v) => P::parse(v).ok_or_else(|| { ArrowError::CastError(format!( - "Cannot cast string '{}' to value of {:?} type", - v, + "Cannot cast string '{v}' to value of {} type", P::DATA_TYPE )) }), None => Ok(P::Native::default()), }) .collect::, ArrowError>>()?; - PrimitiveArray::new(v.into(), nulls()) + PrimitiveArray::try_new(v.into(), nulls())? }; Ok(Arc::new(array) as ArrayRef) @@ -339,6 +338,14 @@ where /// A specified helper to cast from `GenericBinaryArray` to `GenericStringArray` when they have same /// offset size so re-encoding offset is unnecessary. +fn extend_valid_utf8<'a, B, I>(builder: &mut B, iter: I) +where + B: Extend>, + I: Iterator>, +{ + builder.extend(iter.map(|value| value.and_then(|bytes| std::str::from_utf8(bytes).ok()))); +} + pub(crate) fn cast_binary_to_string( array: &dyn Array, cast_options: &CastOptions, @@ -356,11 +363,7 @@ pub(crate) fn cast_binary_to_string( let mut builder = GenericStringBuilder::::with_capacity(array.len(), array.value_data().len()); - let iter = array - .iter() - .map(|v| v.and_then(|v| std::str::from_utf8(v).ok())); - - builder.extend(iter); + extend_valid_utf8(&mut builder, array.iter()); Ok(Arc::new(builder.finish())) } false => Err(e), @@ -368,6 +371,25 @@ pub(crate) fn cast_binary_to_string( } } +pub(crate) fn cast_binary_view_to_string_view( + array: &dyn Array, + cast_options: &CastOptions, +) -> Result { + let array = array.as_binary_view(); + + match array.clone().to_string_view() { + Ok(result) => Ok(Arc::new(result)), + Err(error) => match cast_options.safe { + true => { + let mut builder = StringViewBuilder::with_capacity(array.len()); + extend_valid_utf8(&mut builder, array.iter()); + Ok(Arc::new(builder.finish())) + } + false => Err(error), + }, + } +} + /// Casts string to boolean fn cast_string_to_boolean<'a, StrArray>( array: &StrArray, diff --git a/arrow-cast/src/display.rs b/arrow-cast/src/display.rs index 6761ac22fa1d..bfd0f06dbef5 100644 --- a/arrow-cast/src/display.rs +++ b/arrow-cast/src/display.rs @@ -23,7 +23,8 @@ //! record batch pretty printing. //! //! [`pretty`]: crate::pretty -use std::fmt::{Display, Formatter, Write}; +use std::fmt::{Debug, Display, Formatter, Write}; +use std::hash::{Hash, Hasher}; use std::ops::Range; use arrow_array::cast::*; @@ -53,7 +54,12 @@ pub enum DurationFormat { /// By default nulls are formatted as `""` and temporal types formatted /// according to RFC3339 /// -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +/// # Equality +/// +/// Most fields in [`FormatOptions`] are compared by value, except `formatter_factory`. As the trait +/// does not require an [`Eq`] and [`Hash`] implementation, this struct only compares the pointer of +/// the factories. +#[derive(Debug, Clone)] pub struct FormatOptions<'a> { /// If set to `true` any formatting errors will be written to the output /// instead of being converted into a [`std::fmt::Error`] @@ -74,6 +80,9 @@ pub struct FormatOptions<'a> { duration_format: DurationFormat, /// Show types in visual representation batches types_info: bool, + /// Formatter factory used to instantiate custom [`ArrayFormatter`]s. This allows users to + /// provide custom formatters. + formatter_factory: Option<&'a dyn ArrayFormatterFactory>, } impl Default for FormatOptions<'_> { @@ -82,6 +91,44 @@ impl Default for FormatOptions<'_> { } } +impl PartialEq for FormatOptions<'_> { + fn eq(&self, other: &Self) -> bool { + self.safe == other.safe + && self.null == other.null + && self.date_format == other.date_format + && self.datetime_format == other.datetime_format + && self.timestamp_format == other.timestamp_format + && self.timestamp_tz_format == other.timestamp_tz_format + && self.time_format == other.time_format + && self.duration_format == other.duration_format + && self.types_info == other.types_info + && match (self.formatter_factory, other.formatter_factory) { + (Some(f1), Some(f2)) => std::ptr::eq(f1, f2), + (None, None) => true, + _ => false, + } + } +} + +impl Eq for FormatOptions<'_> {} + +impl Hash for FormatOptions<'_> { + fn hash(&self, state: &mut H) { + self.safe.hash(state); + self.null.hash(state); + self.date_format.hash(state); + self.datetime_format.hash(state); + self.timestamp_format.hash(state); + self.timestamp_tz_format.hash(state); + self.time_format.hash(state); + self.duration_format.hash(state); + self.types_info.hash(state); + self.formatter_factory + .map(|f| f as *const dyn ArrayFormatterFactory) + .hash(state); + } +} + impl<'a> FormatOptions<'a> { /// Creates a new set of format options pub const fn new() -> Self { @@ -95,6 +142,7 @@ impl<'a> FormatOptions<'a> { time_format: None, duration_format: DurationFormat::ISO8601, types_info: false, + formatter_factory: None, } } @@ -169,10 +217,172 @@ impl<'a> FormatOptions<'a> { Self { types_info, ..self } } - /// Returns true if type info should be included in visual representation of batches + /// Overrides the [`ArrayFormatterFactory`] used to instantiate custom [`ArrayFormatter`]s. + /// + /// Using [`None`] causes pretty-printers to use the default [`ArrayFormatter`]s. + pub const fn with_formatter_factory( + self, + formatter_factory: Option<&'a dyn ArrayFormatterFactory>, + ) -> Self { + Self { + formatter_factory, + ..self + } + } + + /// Returns whether formatting errors should be written to the output instead of being converted + /// into a [`std::fmt::Error`]. + pub const fn safe(&self) -> bool { + self.safe + } + + /// Returns the string used for displaying nulls. + pub const fn null(&self) -> &'a str { + self.null + } + + /// Returns the format used for [`DataType::Date32`] columns. + pub const fn date_format(&self) -> TimeFormat<'a> { + self.date_format + } + + /// Returns the format used for [`DataType::Date64`] columns. + pub const fn datetime_format(&self) -> TimeFormat<'a> { + self.datetime_format + } + + /// Returns the format used for [`DataType::Timestamp`] columns without a timezone. + pub const fn timestamp_format(&self) -> TimeFormat<'a> { + self.timestamp_format + } + + /// Returns the format used for [`DataType::Timestamp`] columns with a timezone. + pub const fn timestamp_tz_format(&self) -> TimeFormat<'a> { + self.timestamp_tz_format + } + + /// Returns the format used for [`DataType::Time32`] and [`DataType::Time64`] columns. + pub const fn time_format(&self) -> TimeFormat<'a> { + self.time_format + } + + /// Returns the [`DurationFormat`] used for duration columns. + pub const fn duration_format(&self) -> DurationFormat { + self.duration_format + } + + /// Returns true if type info should be included in a visual representation of batches. pub const fn types_info(&self) -> bool { self.types_info } + + /// Returns the [`ArrayFormatterFactory`] used to instantiate custom [`ArrayFormatter`]s. + pub const fn formatter_factory(&self) -> Option<&'a dyn ArrayFormatterFactory> { + self.formatter_factory + } +} + +/// Allows creating a new [`ArrayFormatter`] for a given [`Array`] and an optional [`Field`]. +/// +/// # Example +/// +/// The example below shows how to create a custom formatter for a custom type `my_money`. Note that +/// this example requires the `prettyprint` feature. +/// +/// ```rust +/// # #[cfg(feature = "prettyprint")]{ +/// use std::fmt::Write; +/// use arrow_array::{cast::AsArray, Array, Int32Array}; +/// use arrow_cast::display::{ArrayFormatter, ArrayFormatterFactory, DisplayIndex, FormatOptions, FormatResult}; +/// use arrow_cast::pretty::pretty_format_batches_with_options; +/// use arrow_schema::{ArrowError, Field}; +/// +/// /// A custom formatter factory that can create a formatter for the special type `my_money`. +/// /// +/// /// This struct could have access to some kind of extension type registry that can lookup the +/// /// correct formatter for an extension type on-demand. +/// #[derive(Debug)] +/// struct MyFormatters {} +/// +/// impl ArrayFormatterFactory for MyFormatters { +/// fn create_array_formatter<'formatter>( +/// &self, +/// array: &'formatter dyn Array, +/// options: &FormatOptions<'formatter>, +/// field: Option<&'formatter Field>, +/// ) -> Result>, ArrowError> { +/// // check if this is the money type +/// if field +/// .map(|f| f.extension_type_name() == Some("my_money")) +/// .unwrap_or(false) +/// { +/// // We assume that my_money always is an Int32. +/// let array = array.as_primitive(); +/// let display_index = Box::new(MyMoneyFormatter { array, options: options.clone() }); +/// return Ok(Some(ArrayFormatter::new(display_index, options.safe()))); +/// } +/// +/// Ok(None) // None indicates that the default formatter should be used. +/// } +/// } +/// +/// /// A formatter for the type `my_money` that wraps a specific array and has access to the +/// /// formatting options. +/// struct MyMoneyFormatter<'a> { +/// array: &'a Int32Array, +/// options: FormatOptions<'a>, +/// } +/// +/// impl<'a> DisplayIndex for MyMoneyFormatter<'a> { +/// fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { +/// match self.array.is_valid(idx) { +/// true => write!(f, "{} €", self.array.value(idx))?, +/// false => write!(f, "{}", self.options.null())?, +/// } +/// +/// Ok(()) +/// } +/// } +/// +/// // Usually, here you would provide your record batches. +/// let my_batches = vec![]; +/// +/// // Call the pretty printer with the custom formatter factory. +/// pretty_format_batches_with_options( +/// &my_batches, +/// &FormatOptions::new().with_formatter_factory(Some(&MyFormatters {})) +/// ); +/// # } +/// ``` +pub trait ArrayFormatterFactory: Debug + Send + Sync { + /// Creates a new [`ArrayFormatter`] for the given [`Array`] and an optional [`Field`]. If the + /// default implementation should be used, return [`None`]. + /// + /// The field shall be used to look up metadata about the `array` while `options` provide + /// information on formatting, for example, dates and times which should be considered by an + /// implementor. + fn create_array_formatter<'formatter>( + &self, + array: &'formatter dyn Array, + options: &FormatOptions<'formatter>, + field: Option<&'formatter Field>, + ) -> Result>, ArrowError>; +} + +/// Used to create a new [`ArrayFormatter`] from the given `array`, while also checking whether +/// there is an override available in the [`ArrayFormatterFactory`]. +pub(crate) fn make_array_formatter<'a>( + array: &'a dyn Array, + options: &FormatOptions<'a>, + field: Option<&'a Field>, +) -> Result, ArrowError> { + match options.formatter_factory() { + None => ArrayFormatter::try_new(array, options), + Some(formatters) => formatters + .create_array_formatter(array, options, field) + .transpose() + .unwrap_or_else(|| ArrayFormatter::try_new(array, options)), + } } /// Implements [`Display`] for a specific array value @@ -272,14 +482,19 @@ pub struct ArrayFormatter<'a> { } impl<'a> ArrayFormatter<'a> { + /// Returns an [`ArrayFormatter`] using the provided formatter. + pub fn new(format: Box, safe: bool) -> Self { + Self { format, safe } + } + /// Returns an [`ArrayFormatter`] that can be used to format `array` /// /// This returns an error if an array of the given data type cannot be formatted pub fn try_new(array: &'a dyn Array, options: &FormatOptions<'a>) -> Result { - Ok(Self { - format: make_formatter(array, options)?, - safe: options.safe, - }) + Ok(Self::new( + make_default_display_index(array, options)?, + options.safe, + )) } /// Returns a [`ValueFormatter`] that implements [`Display`] for @@ -292,7 +507,7 @@ impl<'a> ArrayFormatter<'a> { } } -fn make_formatter<'a>( +fn make_default_display_index<'a>( array: &'a dyn Array, options: &FormatOptions<'a>, ) -> Result, ArrowError> { @@ -332,12 +547,15 @@ fn make_formatter<'a>( } /// Either an [`ArrowError`] or [`std::fmt::Error`] -enum FormatError { +pub enum FormatError { + /// An error occurred while formatting the array Format(std::fmt::Error), + /// An Arrow error occurred while formatting the array. Arrow(ArrowError), } -type FormatResult = Result<(), FormatError>; +/// The result of formatting an array element via [`DisplayIndex::write`]. +pub type FormatResult = Result<(), FormatError>; impl From for FormatError { fn from(value: std::fmt::Error) -> Self { @@ -352,7 +570,8 @@ impl From for FormatError { } /// [`Display`] but accepting an index -trait DisplayIndex { +pub trait DisplayIndex { + /// Write the value of the underlying array at `idx` to `f`. fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult; } @@ -489,7 +708,7 @@ macro_rules! decimal_display { }; } -decimal_display!(Decimal128Type, Decimal256Type); +decimal_display!(Decimal32Type, Decimal64Type, Decimal128Type, Decimal256Type); fn write_timestamp( f: &mut dyn Write, @@ -710,6 +929,12 @@ impl DisplayIndex for &PrimitiveArray { impl DisplayIndex for &PrimitiveArray { fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { let value = self.value(idx); + + if value.is_zero() { + write!(f, "0 secs")?; + return Ok(()); + } + let mut prefix = ""; if value.days != 0 { @@ -733,6 +958,12 @@ impl DisplayIndex for &PrimitiveArray { impl DisplayIndex for &PrimitiveArray { fn write(&self, idx: usize, f: &mut dyn Write) -> FormatResult { let value = self.value(idx); + + if value.is_zero() { + write!(f, "0 secs")?; + return Ok(()); + } + let mut prefix = ""; if value.months != 0 { @@ -776,12 +1007,12 @@ impl Display for NanosecondsFormatter<'_> { let nanoseconds = self.nanoseconds % 1_000_000_000; if hours != 0 { - write!(f, "{prefix}{} hours", hours)?; + write!(f, "{prefix}{hours} hours")?; prefix = " "; } if mins != 0 { - write!(f, "{prefix}{} mins", mins)?; + write!(f, "{prefix}{mins} mins")?; prefix = " "; } @@ -819,12 +1050,12 @@ impl Display for MillisecondsFormatter<'_> { let milliseconds = self.milliseconds % 1_000; if hours != 0 { - write!(f, "{prefix}{} hours", hours,)?; + write!(f, "{prefix}{hours} hours")?; prefix = " "; } if mins != 0 { - write!(f, "{prefix}{} mins", mins,)?; + write!(f, "{prefix}{mins} mins")?; prefix = " "; } @@ -896,7 +1127,7 @@ impl<'a, K: ArrowDictionaryKeyType> DisplayIndexState<'a> for &'a DictionaryArra type State = Box; fn prepare(&self, options: &FormatOptions<'a>) -> Result { - make_formatter(self.values().as_ref(), options) + make_default_display_index(self.values().as_ref(), options) } fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { @@ -906,68 +1137,82 @@ impl<'a, K: ArrowDictionaryKeyType> DisplayIndexState<'a> for &'a DictionaryArra } impl<'a, K: RunEndIndexType> DisplayIndexState<'a> for &'a RunArray { - type State = Box; + type State = ArrayFormatter<'a>; fn prepare(&self, options: &FormatOptions<'a>) -> Result { - make_formatter(self.values().as_ref(), options) + let field = match (*self).data_type() { + DataType::RunEndEncoded(_, values_field) => values_field, + _ => unreachable!(), + }; + make_array_formatter(self.values().as_ref(), options, Some(field)) } fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { let value_idx = self.get_physical_index(idx); - s.as_ref().write(value_idx, f) + write!(f, "{}", s.value(value_idx))?; + Ok(()) } } fn write_list( f: &mut dyn Write, mut range: Range, - values: &dyn DisplayIndex, + values: &ArrayFormatter<'_>, ) -> FormatResult { f.write_char('[')?; if let Some(idx) = range.next() { - values.write(idx, f)?; + write!(f, "{}", values.value(idx))?; } for idx in range { - write!(f, ", ")?; - values.write(idx, f)?; + write!(f, ", {}", values.value(idx))?; } f.write_char(']')?; Ok(()) } impl<'a, O: OffsetSizeTrait> DisplayIndexState<'a> for &'a GenericListArray { - type State = Box; + type State = ArrayFormatter<'a>; fn prepare(&self, options: &FormatOptions<'a>) -> Result { - make_formatter(self.values().as_ref(), options) + let field = match (*self).data_type() { + DataType::List(f) => f, + DataType::LargeList(f) => f, + _ => unreachable!(), + }; + make_array_formatter(self.values().as_ref(), options, Some(field.as_ref())) } fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { let offsets = self.value_offsets(); let end = offsets[idx + 1].as_usize(); let start = offsets[idx].as_usize(); - write_list(f, start..end, s.as_ref()) + write_list(f, start..end, s) } } impl<'a> DisplayIndexState<'a> for &'a FixedSizeListArray { - type State = (usize, Box); + type State = (usize, ArrayFormatter<'a>); fn prepare(&self, options: &FormatOptions<'a>) -> Result { - let values = make_formatter(self.values().as_ref(), options)?; + let field = match (*self).data_type() { + DataType::FixedSizeList(f, _) => f, + _ => unreachable!(), + }; + let formatter = + make_array_formatter(self.values().as_ref(), options, Some(field.as_ref()))?; let length = self.value_length(); - Ok((length as usize, values)) + Ok((length as usize, formatter)) } fn write(&self, s: &Self::State, idx: usize, f: &mut dyn Write) -> FormatResult { let start = idx * s.0; let end = start + s.0; - write_list(f, start..end, s.1.as_ref()) + write_list(f, start..end, &s.1) } } -/// Pairs a boxed [`DisplayIndex`] with its field name -type FieldDisplay<'a> = (&'a str, Box); +/// Pairs an [`ArrayFormatter`] with its field name +type FieldDisplay<'a> = (&'a str, ArrayFormatter<'a>); impl<'a> DisplayIndexState<'a> for &'a StructArray { type State = Vec>; @@ -982,7 +1227,7 @@ impl<'a> DisplayIndexState<'a> for &'a StructArray { .iter() .zip(fields) .map(|(a, f)| { - let format = make_formatter(a.as_ref(), options)?; + let format = make_array_formatter(a.as_ref(), options, Some(f))?; Ok((f.name().as_str(), format)) }) .collect() @@ -992,12 +1237,10 @@ impl<'a> DisplayIndexState<'a> for &'a StructArray { let mut iter = s.iter(); f.write_char('{')?; if let Some((name, display)) = iter.next() { - write!(f, "{name}: ")?; - display.as_ref().write(idx, f)?; + write!(f, "{name}: {}", display.value(idx))?; } for (name, display) in iter { - write!(f, ", {name}: ")?; - display.as_ref().write(idx, f)?; + write!(f, ", {name}: {}", display.value(idx))?; } f.write_char('}')?; Ok(()) @@ -1005,11 +1248,13 @@ impl<'a> DisplayIndexState<'a> for &'a StructArray { } impl<'a> DisplayIndexState<'a> for &'a MapArray { - type State = (Box, Box); + type State = (ArrayFormatter<'a>, ArrayFormatter<'a>); fn prepare(&self, options: &FormatOptions<'a>) -> Result { - let keys = make_formatter(self.keys().as_ref(), options)?; - let values = make_formatter(self.values().as_ref(), options)?; + let (key_field, value_field) = (*self).entries_fields(); + + let keys = make_array_formatter(self.keys().as_ref(), options, Some(key_field))?; + let values = make_array_formatter(self.values().as_ref(), options, Some(value_field))?; Ok((keys, values)) } @@ -1021,16 +1266,12 @@ impl<'a> DisplayIndexState<'a> for &'a MapArray { f.write_char('{')?; if let Some(idx) = iter.next() { - s.0.write(idx, f)?; - write!(f, ": ")?; - s.1.write(idx, f)?; + write!(f, "{}: {}", s.0.value(idx), s.1.value(idx))?; } for idx in iter { - write!(f, ", ")?; - s.0.write(idx, f)?; - write!(f, ": ")?; - s.1.write(idx, f)?; + write!(f, ", {}", s.0.value(idx))?; + write!(f, ": {}", s.1.value(idx))?; } f.write_char('}')?; @@ -1039,10 +1280,7 @@ impl<'a> DisplayIndexState<'a> for &'a MapArray { } impl<'a> DisplayIndexState<'a> for &'a UnionArray { - type State = ( - Vec)>>, - UnionMode, - ); + type State = (Vec>>, UnionMode); fn prepare(&self, options: &FormatOptions<'a>) -> Result { let (fields, mode) = match (*self).data_type() { @@ -1053,7 +1291,7 @@ impl<'a> DisplayIndexState<'a> for &'a UnionArray { let max_id = fields.iter().map(|(id, _)| id).max().unwrap_or_default() as usize; let mut out: Vec> = (0..max_id + 1).map(|_| None).collect(); for (i, field) in fields.iter() { - let formatter = make_formatter(self.child(i).as_ref(), options)?; + let formatter = make_array_formatter(self.child(i).as_ref(), options, Some(field))?; out[i as usize] = Some((field.name().as_str(), formatter)) } Ok((out, *mode)) @@ -1067,9 +1305,7 @@ impl<'a> DisplayIndexState<'a> for &'a UnionArray { }; let (name, field) = s.0[id as usize].as_ref().unwrap(); - write!(f, "{{{name}=")?; - field.write(idx, f)?; - f.write_char('}')?; + write!(f, "{{{name}={}}}", field.value(idx))?; Ok(()) } } @@ -1118,6 +1354,19 @@ mod tests { assert_eq!(TEST_CONST_OPTIONS.date_format, Some("foo")); } + /// See https://github.com/apache/arrow-rs/issues/8875 + #[test] + fn test_options_send_sync() { + fn assert_send_sync() + where + T: Send + Sync, + { + // nothing – the compiler does the work + } + + assert_send_sync::>(); + } + #[test] fn test_map_array_to_string() { let keys = vec!["a", "b", "c", "d", "e", "f", "g", "h"]; diff --git a/arrow-cast/src/lib.rs b/arrow-cast/src/lib.rs index b042a7338519..3412616c5caf 100644 --- a/arrow-cast/src/lib.rs +++ b/arrow-cast/src/lib.rs @@ -21,7 +21,7 @@ html_logo_url = "https://arrow.apache.org/img/arrow-logo_chevrons_black-txt_white-bg.svg", html_favicon_url = "https://arrow.apache.org/img/arrow-logo_chevrons_black-txt_transparent-bg.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] #![warn(missing_docs)] pub mod cast; pub use cast::*; diff --git a/arrow-cast/src/parse.rs b/arrow-cast/src/parse.rs index 28d36db89af0..b266cc4aa360 100644 --- a/arrow-cast/src/parse.rs +++ b/arrow-cast/src/parse.rs @@ -18,9 +18,9 @@ //! [`Parser`] implementations for converting strings to Arrow types //! //! Used by the CSV and JSON readers to convert strings to Arrow types +use arrow_array::ArrowNativeTypeOp; use arrow_array::timezone::Tz; use arrow_array::types::*; -use arrow_array::ArrowNativeTypeOp; use arrow_buffer::ArrowNativeType; use arrow_schema::ArrowError; use chrono::prelude::*; @@ -794,7 +794,7 @@ fn parse_e_notation( None => { return Err(ArrowError::ParseError(format!( "can't parse the string value {s} to decimal" - ))) + ))); } }; @@ -1235,8 +1235,7 @@ impl Interval { match (self.months, self.days, self.nanos) { (months, days, nanos) if days == 0 && nanos == 0 => Ok(months), _ => Err(ArrowError::InvalidArgumentError(format!( - "Unable to represent interval with days and nanos as year-months: {:?}", - self + "Unable to represent interval with days and nanos as year-months: {self:?}" ))), } } @@ -2690,26 +2689,10 @@ mod tests { 0i128, 15, ), - ( - "1.016744e-320", - 0i128, - 15, - ), - ( - "-1e3", - -1000000000i128, - 6, - ), - ( - "+1e3", - 1000000000i128, - 6, - ), - ( - "-1e31", - -10000000000000000000000000000000000000i128, - 6, - ), + ("1.016744e-320", 0i128, 15), + ("-1e3", -1000000000i128, 6), + ("+1e3", 1000000000i128, 6), + ("-1e31", -10000000000000000000000000000000000000i128, 6), ]; for (s, i, scale) in edge_tests_128 { let result_128 = parse_decimal::(s, 38, scale); diff --git a/arrow-cast/src/pretty.rs b/arrow-cast/src/pretty.rs index c3fc00e4b911..e7c199dbed97 100644 --- a/arrow-cast/src/pretty.rs +++ b/arrow-cast/src/pretty.rs @@ -22,14 +22,12 @@ //! [`RecordBatch`]: arrow_array::RecordBatch //! [`Array`]: arrow_array::Array -use std::fmt::Display; - -use comfy_table::{Cell, Table}; - use arrow_array::{Array, ArrayRef, RecordBatch}; use arrow_schema::{ArrowError, SchemaRef}; +use comfy_table::{Cell, Table}; +use std::fmt::Display; -use crate::display::{ArrayFormatter, FormatOptions}; +use crate::display::{ArrayFormatter, FormatOptions, make_array_formatter}; /// Create a visual representation of [`RecordBatch`]es /// @@ -60,7 +58,7 @@ use crate::display::{ArrayFormatter, FormatOptions}; /// | 5 | e | /// +---+---+"#); /// ``` -pub fn pretty_format_batches(results: &[RecordBatch]) -> Result { +pub fn pretty_format_batches(results: &[RecordBatch]) -> Result, ArrowError> { let options = FormatOptions::default().with_display_error(true); pretty_format_batches_with_options(results, &options) } @@ -92,7 +90,7 @@ pub fn pretty_format_batches(results: &[RecordBatch]) -> Result Result { +) -> Result, ArrowError> { let options = FormatOptions::default().with_display_error(true); create_table(Some(schema), results, &options) } @@ -130,7 +128,7 @@ pub fn pretty_format_batches_with_schema( pub fn pretty_format_batches_with_options( results: &[RecordBatch], options: &FormatOptions, -) -> Result { +) -> Result, ArrowError> { create_table(None, results, options) } @@ -142,7 +140,7 @@ pub fn pretty_format_batches_with_options( pub fn pretty_format_columns( col_name: &str, results: &[ArrayRef], -) -> Result { +) -> Result, ArrowError> { let options = FormatOptions::default().with_display_error(true); pretty_format_columns_with_options(col_name, results, &options) } @@ -154,7 +152,7 @@ pub fn pretty_format_columns_with_options( col_name: &str, results: &[ArrayRef], options: &FormatOptions, -) -> Result { +) -> Result, ArrowError> { create_column(col_name, results, options) } @@ -187,7 +185,7 @@ fn create_table( } }); - if let Some(schema) = schema_opt { + if let Some(schema) = &schema_opt { let mut header = Vec::new(); for field in schema.fields() { if options.types_info() { @@ -208,10 +206,22 @@ fn create_table( } for batch in results { + let schema = schema_opt.as_ref().unwrap_or(batch.schema_ref()); + + // Could be a custom schema that was provided. + if batch.columns().len() != schema.fields().len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Expected the same number of columns in a record batch ({}) as the number of fields ({}) in the schema", + batch.columns().len(), + schema.fields.len() + ))); + } + let formatters = batch .columns() .iter() - .map(|c| ArrayFormatter::try_new(c.as_ref(), options)) + .zip(schema.fields().iter()) + .map(|(c, field)| make_array_formatter(c, options, Some(field))) .collect::, ArrowError>>()?; for row in 0..batch.num_rows() { @@ -242,7 +252,13 @@ fn create_column( table.set_header(header); for col in columns { - let formatter = ArrayFormatter::try_new(col.as_ref(), options)?; + let formatter = match options.formatter_factory() { + None => ArrayFormatter::try_new(col.as_ref(), options)?, + Some(formatters) => formatters + .create_array_formatter(col.as_ref(), options, None) + .transpose() + .unwrap_or_else(|| ArrayFormatter::try_new(col.as_ref(), options))?, + }; for row in 0..col.len() { let cells = vec![Cell::new(formatter.value(row))]; table.add_row(cells); @@ -254,18 +270,21 @@ fn create_column( #[cfg(test)] mod tests { + use std::collections::HashMap; use std::fmt::Write; use std::sync::Arc; - use half::f16; - use arrow_array::builder::*; + use arrow_array::cast::AsArray; use arrow_array::types::*; use arrow_array::*; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano, ScalarBuffer}; use arrow_schema::*; + use half::f16; - use crate::display::{array_value_to_string, DurationFormat}; + use crate::display::{ + ArrayFormatterFactory, DisplayIndex, DurationFormat, array_value_to_string, + }; use super::*; @@ -1089,6 +1108,7 @@ mod tests { Some(IntervalDayTime::new(0, 1)), Some(IntervalDayTime::new(0, 10)), Some(IntervalDayTime::new(0, 100)), + Some(IntervalDayTime::new(0, 0)), ])); let schema = Arc::new(Schema::new(vec![Field::new( @@ -1111,6 +1131,7 @@ mod tests { "| 0.001 secs |", "| 0.010 secs |", "| 0.100 secs |", + "| 0 secs |", "+------------------+", ]; @@ -1135,6 +1156,7 @@ mod tests { Some(IntervalMonthDayNano::new(0, 0, 10_000_000)), Some(IntervalMonthDayNano::new(0, 0, 100_000_000)), Some(IntervalMonthDayNano::new(0, 0, 1_000_000_000)), + Some(IntervalMonthDayNano::new(0, 0, 0)), ])); let schema = Arc::new(Schema::new(vec![Field::new( @@ -1164,6 +1186,7 @@ mod tests { "| 0.010000000 secs |", "| 0.100000000 secs |", "| 1.000000000 secs |", + "| 0 secs |", "+--------------------------+", ]; @@ -1240,9 +1263,10 @@ mod tests { // Pretty formatting let opts = FormatOptions::default().with_null("null"); let opts = opts.with_duration_format(DurationFormat::Pretty); - let pretty = pretty_format_columns_with_options("pretty", &[array.clone()], &opts) - .unwrap() - .to_string(); + let pretty = + pretty_format_columns_with_options("pretty", std::slice::from_ref(&array), &opts) + .unwrap() + .to_string(); // Expected output let expected_pretty = vec![ @@ -1282,4 +1306,474 @@ mod tests { let actual: Vec<&str> = iso.lines().collect(); assert_eq!(expected_iso, actual, "Actual result:\n{iso}"); } + + // + // Custom Formatting + // + + /// The factory that will create the [`ArrayFormatter`]s. + #[derive(Debug)] + struct TestFormatters {} + + impl ArrayFormatterFactory for TestFormatters { + fn create_array_formatter<'formatter>( + &self, + array: &'formatter dyn Array, + options: &FormatOptions<'formatter>, + field: Option<&'formatter Field>, + ) -> Result>, ArrowError> { + if field + .map(|f| f.extension_type_name() == Some("my_money")) + .unwrap_or(false) + { + // We assume that my_money always is an Int32. + let array = array.as_primitive(); + let display_index = Box::new(MyMoneyFormatter { + array, + options: options.clone(), + }); + return Ok(Some(ArrayFormatter::new(display_index, options.safe()))); + } + + if array.data_type() == &DataType::Int32 { + let array = array.as_primitive(); + let display_index = Box::new(MyInt32Formatter { + array, + options: options.clone(), + }); + return Ok(Some(ArrayFormatter::new(display_index, options.safe()))); + } + + Ok(None) + } + } + + /// A format that will append a "€" sign to the end of the Int32 values. + struct MyMoneyFormatter<'a> { + array: &'a Int32Array, + options: FormatOptions<'a>, + } + + impl<'a> DisplayIndex for MyMoneyFormatter<'a> { + fn write(&self, idx: usize, f: &mut dyn Write) -> crate::display::FormatResult { + match self.array.is_valid(idx) { + true => write!(f, "{} €", self.array.value(idx))?, + false => write!(f, "{}", self.options.null())?, + } + + Ok(()) + } + } + + /// The actual formatter + struct MyInt32Formatter<'a> { + array: &'a Int32Array, + options: FormatOptions<'a>, + } + + impl<'a> DisplayIndex for MyInt32Formatter<'a> { + fn write(&self, idx: usize, f: &mut dyn Write) -> crate::display::FormatResult { + match self.array.is_valid(idx) { + true => write!(f, "{} (32-Bit)", self.array.value(idx))?, + false => write!(f, "{}", self.options.null())?, + } + + Ok(()) + } + } + + #[test] + fn test_format_batches_with_custom_formatters() { + // define a schema. + let options = FormatOptions::new() + .with_null("") + .with_formatter_factory(Some(&TestFormatters {})); + let money_metadata = HashMap::from([( + extension::EXTENSION_TYPE_NAME_KEY.to_owned(), + "my_money".to_owned(), + )]); + let schema = Arc::new(Schema::new(vec![ + Field::new("income", DataType::Int32, true).with_metadata(money_metadata.clone()), + ])); + + // define data. + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(array::Int32Array::from(vec![ + Some(1), + None, + Some(10), + Some(100), + ]))], + ) + .unwrap(); + + let mut buf = String::new(); + write!( + &mut buf, + "{}", + pretty_format_batches_with_options(&[batch], &options).unwrap() + ) + .unwrap(); + + let s = [ + "+--------+", + "| income |", + "+--------+", + "| 1 € |", + "| |", + "| 10 € |", + "| 100 € |", + "+--------+", + ]; + let expected = s.join("\n"); + assert_eq!(expected, buf); + } + + #[test] + fn test_format_batches_with_custom_formatters_multi_nested_list() { + // define a schema. + let options = FormatOptions::new() + .with_null("") + .with_formatter_factory(Some(&TestFormatters {})); + let money_metadata = HashMap::from([( + extension::EXTENSION_TYPE_NAME_KEY.to_owned(), + "my_money".to_owned(), + )]); + let nested_field = Arc::new( + Field::new_list_field(DataType::Int32, true).with_metadata(money_metadata.clone()), + ); + + // Create nested data + let inner_list = ListBuilder::new(Int32Builder::new()).with_field(nested_field); + let mut outer_list = FixedSizeListBuilder::new(inner_list, 2); + outer_list.values().append_value([Some(1)]); + outer_list.values().append_null(); + outer_list.append(true); + outer_list.values().append_value([Some(2), Some(8)]); + outer_list + .values() + .append_value([Some(50), Some(25), Some(25)]); + outer_list.append(true); + let outer_list = outer_list.finish(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "income", + outer_list.data_type().clone(), + true, + )])); + + // define data. + let batch = RecordBatch::try_new(schema, vec![Arc::new(outer_list)]).unwrap(); + + let mut buf = String::new(); + write!( + &mut buf, + "{}", + pretty_format_batches_with_options(&[batch], &options).unwrap() + ) + .unwrap(); + + let s = [ + "+----------------------------------+", + "| income |", + "+----------------------------------+", + "| [[1 €], ] |", + "| [[2 €, 8 €], [50 €, 25 €, 25 €]] |", + "+----------------------------------+", + ]; + let expected = s.join("\n"); + assert_eq!(expected, buf); + } + + #[test] + fn test_format_batches_with_custom_formatters_nested_struct() { + // define a schema. + let options = FormatOptions::new() + .with_null("") + .with_formatter_factory(Some(&TestFormatters {})); + let money_metadata = HashMap::from([( + extension::EXTENSION_TYPE_NAME_KEY.to_owned(), + "my_money".to_owned(), + )]); + let fields = Fields::from(vec![ + Field::new("name", DataType::Utf8, true), + Field::new("income", DataType::Int32, true).with_metadata(money_metadata.clone()), + ]); + + let schema = Arc::new(Schema::new(vec![Field::new( + "income", + DataType::Struct(fields.clone()), + true, + )])); + + // Create nested data + let mut nested_data = StructBuilder::new( + fields, + vec![ + Box::new(StringBuilder::new()), + Box::new(Int32Builder::new()), + ], + ); + nested_data + .field_builder::(0) + .unwrap() + .extend([Some("Gimli"), Some("Legolas"), Some("Aragorn")]); + nested_data + .field_builder::(1) + .unwrap() + .extend([Some(10), None, Some(30)]); + nested_data.append(true); + nested_data.append(true); + nested_data.append(true); + + // define data. + let batch = RecordBatch::try_new(schema, vec![Arc::new(nested_data.finish())]).unwrap(); + + let mut buf = String::new(); + write!( + &mut buf, + "{}", + pretty_format_batches_with_options(&[batch], &options).unwrap() + ) + .unwrap(); + + let s = [ + "+---------------------------------+", + "| income |", + "+---------------------------------+", + "| {name: Gimli, income: 10 €} |", + "| {name: Legolas, income: } |", + "| {name: Aragorn, income: 30 €} |", + "+---------------------------------+", + ]; + let expected = s.join("\n"); + assert_eq!(expected, buf); + } + + #[test] + fn test_format_batches_with_custom_formatters_nested_map() { + // define a schema. + let options = FormatOptions::new() + .with_null("") + .with_formatter_factory(Some(&TestFormatters {})); + let money_metadata = HashMap::from([( + extension::EXTENSION_TYPE_NAME_KEY.to_owned(), + "my_money".to_owned(), + )]); + + let mut array = MapBuilder::::new( + None, + StringBuilder::new(), + Int32Builder::new(), + ) + .with_values_field( + Field::new("values", DataType::Int32, true).with_metadata(money_metadata.clone()), + ); + array + .keys() + .extend([Some("Gimli"), Some("Legolas"), Some("Aragorn")]); + array.values().extend([Some(10), None, Some(30)]); + array.append(true).unwrap(); + let array = array.finish(); + + // define data. + let schema = Arc::new(Schema::new(vec![Field::new( + "income", + array.data_type().clone(), + true, + )])); + let batch = RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap(); + + let mut buf = String::new(); + write!( + &mut buf, + "{}", + pretty_format_batches_with_options(&[batch], &options).unwrap() + ) + .unwrap(); + + let s = [ + "+-----------------------------------------------+", + "| income |", + "+-----------------------------------------------+", + "| {Gimli: 10 €, Legolas: , Aragorn: 30 €} |", + "+-----------------------------------------------+", + ]; + let expected = s.join("\n"); + assert_eq!(expected, buf); + } + + #[test] + fn test_format_batches_with_custom_formatters_nested_union() { + // define a schema. + let options = FormatOptions::new() + .with_null("") + .with_formatter_factory(Some(&TestFormatters {})); + let money_metadata = HashMap::from([( + extension::EXTENSION_TYPE_NAME_KEY.to_owned(), + "my_money".to_owned(), + )]); + let fields = UnionFields::try_new( + vec![0], + vec![Field::new("income", DataType::Int32, true).with_metadata(money_metadata.clone())], + ) + .unwrap(); + + // Create nested data and construct it with the correct metadata + let mut array_builder = UnionBuilder::new_dense(); + array_builder.append::("income", 1).unwrap(); + let (_, type_ids, offsets, children) = array_builder.build().unwrap().into_parts(); + let array = UnionArray::try_new(fields, type_ids, offsets, children).unwrap(); + + let schema = Arc::new(Schema::new(vec![Field::new( + "income", + array.data_type().clone(), + true, + )])); + + // define data. + let batch = RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap(); + + let mut buf = String::new(); + write!( + &mut buf, + "{}", + pretty_format_batches_with_options(&[batch], &options).unwrap() + ) + .unwrap(); + + let s = [ + "+--------------+", + "| income |", + "+--------------+", + "| {income=1 €} |", + "+--------------+", + ]; + let expected = s.join("\n"); + assert_eq!(expected, buf); + } + + #[test] + fn test_format_batches_with_custom_formatters_custom_schema_overrules_batch_schema() { + // define a schema. + let options = FormatOptions::new().with_formatter_factory(Some(&TestFormatters {})); + let money_metadata = HashMap::from([( + extension::EXTENSION_TYPE_NAME_KEY.to_owned(), + "my_money".to_owned(), + )]); + let schema = Arc::new(Schema::new(vec![ + Field::new("income", DataType::Int32, true).with_metadata(money_metadata.clone()), + ])); + + // define data. + let batch = RecordBatch::try_new( + schema, + vec![Arc::new(array::Int32Array::from(vec![ + Some(1), + None, + Some(10), + Some(100), + ]))], + ) + .unwrap(); + + let mut buf = String::new(); + write!( + &mut buf, + "{}", + create_table( + // No metadata compared to test_format_batches_with_custom_formatters + Some(Arc::new(Schema::new(vec![Field::new( + "income", + DataType::Int32, + true + ),]))), + &[batch], + &options, + ) + .unwrap() + ) + .unwrap(); + + // No € formatting as in test_format_batches_with_custom_formatters + let s = [ + "+--------------+", + "| income |", + "+--------------+", + "| 1 (32-Bit) |", + "| |", + "| 10 (32-Bit) |", + "| 100 (32-Bit) |", + "+--------------+", + ]; + let expected = s.join("\n"); + assert_eq!(expected, buf); + } + + #[test] + fn test_format_column_with_custom_formatters() { + // define data. + let array = Arc::new(array::Int32Array::from(vec![ + Some(1), + None, + Some(10), + Some(100), + ])); + + let mut buf = String::new(); + write!( + &mut buf, + "{}", + pretty_format_columns_with_options( + "income", + &[array], + &FormatOptions::default().with_formatter_factory(Some(&TestFormatters {})) + ) + .unwrap() + ) + .unwrap(); + + let s = [ + "+--------------+", + "| income |", + "+--------------+", + "| 1 (32-Bit) |", + "| |", + "| 10 (32-Bit) |", + "| 100 (32-Bit) |", + "+--------------+", + ]; + let expected = s.join("\n"); + assert_eq!(expected, buf); + } + + #[test] + fn test_pretty_format_batches_with_schema_with_wrong_number_of_fields() { + let schema_a = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Utf8, true), + ])); + let schema_b = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + + // define data. + let batch = RecordBatch::try_new( + schema_b, + vec![Arc::new(array::Int32Array::from(vec![ + Some(1), + None, + Some(10), + Some(100), + ]))], + ) + .unwrap(); + + let error = pretty_format_batches_with_schema(schema_a, &[batch]) + .err() + .unwrap(); + assert_eq!( + &error.to_string(), + "Invalid argument error: Expected the same number of columns in a record batch (1) as the number of fields (2) in the schema" + ); + } } diff --git a/arrow-csv/examples/README.md b/arrow-csv/examples/README.md deleted file mode 100644 index 340413e76d94..000000000000 --- a/arrow-csv/examples/README.md +++ /dev/null @@ -1,21 +0,0 @@ - - -# Examples -- [`csv_calculation.rs`](csv_calculation.rs): performs a simple calculation using the CSV reader \ No newline at end of file diff --git a/arrow-csv/examples/csv_calculation.rs b/arrow-csv/examples/csv_calculation.rs deleted file mode 100644 index 6ce963e2b012..000000000000 --- a/arrow-csv/examples/csv_calculation.rs +++ /dev/null @@ -1,56 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use arrow_array::cast::AsArray; -use arrow_array::types::Int16Type; -use arrow_csv::ReaderBuilder; - -use arrow_schema::{DataType, Field, Schema}; -use std::fs::File; -use std::sync::Arc; - -fn main() { - // read csv from file - let file = File::open("arrow-csv/test/data/example.csv").unwrap(); - let csv_schema = Schema::new(vec![ - Field::new("c1", DataType::Int16, true), - Field::new("c2", DataType::Float32, true), - Field::new("c3", DataType::Utf8, true), - Field::new("c4", DataType::Boolean, true), - ]); - let mut reader = ReaderBuilder::new(Arc::new(csv_schema)) - .with_header(true) - .build(file) - .unwrap(); - - match reader.next() { - Some(r) => match r { - Ok(r) => { - // get the column(0) max value - let col = r.column(0).as_primitive::(); - let max = col.iter().max().flatten(); - println!("max value column(0): {max:?}") - } - Err(e) => { - println!("{e:?}"); - } - }, - None => { - println!("csv is empty"); - } - } -} diff --git a/arrow-csv/src/lib.rs b/arrow-csv/src/lib.rs index 8532cf59a218..4c4b04098175 100644 --- a/arrow-csv/src/lib.rs +++ b/arrow-csv/src/lib.rs @@ -15,21 +15,24 @@ // specific language governing permissions and limitations // under the License. -//! Transfer data between the Arrow memory format and CSV (comma-separated values). +//! Transfer data between the [Apache Arrow] memory format and CSV (comma-separated values). +//! +//! [Apache Arrow]: https://arrow.apache.org/ #![doc( html_logo_url = "https://arrow.apache.org/img/arrow-logo_chevrons_black-txt_white-bg.svg", html_favicon_url = "https://arrow.apache.org/img/arrow-logo_chevrons_black-txt_transparent-bg.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] #![warn(missing_docs)] pub mod reader; pub mod writer; -pub use self::reader::infer_schema_from_files; pub use self::reader::Reader; pub use self::reader::ReaderBuilder; +pub use self::reader::infer_schema_from_files; +pub use self::writer::QuoteStyle; pub use self::writer::Writer; pub use self::writer::WriterBuilder; use arrow_schema::ArrowError; @@ -51,8 +54,8 @@ fn map_csv_error(error: csv::Error) -> ArrowError { } => ArrowError::CsvError(format!( "Encountered unequal lengths between records on CSV file. Expected {} \ records, found {} records{}", - len, expected_len, + len, pos.as_ref() .map(|pos| format!(" at line {}", pos.line())) .unwrap_or_default(), diff --git a/arrow-csv/src/reader/mod.rs b/arrow-csv/src/reader/mod.rs index e9f612557e0a..e26072fea917 100644 --- a/arrow-csv/src/reader/mod.rs +++ b/arrow-csv/src/reader/mod.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! CSV Reader +//! CSV Reading: [`Reader`] and [`ReaderBuilder`] //! //! # Basic Usage //! @@ -42,6 +42,46 @@ //! let batch = csv.next().unwrap().unwrap(); //! ``` //! +//! # Example: Numeric calculations on CSV +//! This code finds the maximum value in column 0 of a CSV file containing +//! ```csv +//! c1,c2,c3,c4 +//! 1,1.1,"hong kong",true +//! 3,323.12,"XiAn",false +//! 10,131323.12,"cheng du",false +//! ``` +//! +//! ``` +//! # use arrow_array::cast::AsArray; +//! # use arrow_array::types::Int16Type; +//! # use arrow_csv::ReaderBuilder; +//! # use arrow_schema::{DataType, Field, Schema}; +//! # use std::fs::File; +//! # use std::sync::Arc; +//! // Open the example file +//! let file = File::open("test/data/example.csv").unwrap(); +//! let csv_schema = Schema::new(vec![ +//! Field::new("c1", DataType::Int16, true), +//! Field::new("c2", DataType::Float32, true), +//! Field::new("c3", DataType::Utf8, true), +//! Field::new("c4", DataType::Boolean, true), +//! ]); +//! let mut reader = ReaderBuilder::new(Arc::new(csv_schema)) +//! .with_header(true) +//! .build(file) +//! .unwrap(); +//! // find the maximum value in column 0 across all batches +//! let mut max_c0 = 0; +//! while let Some(r) = reader.next() { +//! let r = r.unwrap(); // handle error +//! // get the max value in column(0) for this batch +//! let col = r.column(0).as_primitive::(); +//! let batch_max = col.iter().max().flatten().unwrap_or_default(); +//! max_c0 = max_c0.max(batch_max); +//! } +//! assert_eq!(max_c0, 10); +//!``` +//! //! # Async Usage //! //! The lower-level [`Decoder`] can be integrated with various forms of async data streams, @@ -128,7 +168,7 @@ mod records; use arrow_array::builder::{NullBuilder, PrimitiveBuilder}; use arrow_array::types::*; use arrow_array::*; -use arrow_cast::parse::{parse_decimal, string_to_datetime, Parser}; +use arrow_cast::parse::{Parser, parse_decimal, string_to_datetime}; use arrow_schema::*; use chrono::{TimeZone, Utc}; use csv::StringRecord; @@ -441,13 +481,18 @@ pub fn infer_schema_from_files( type Bounds = Option<(usize, usize)>; /// CSV file reader using [`std::io::BufReader`] +/// +/// See [`ReaderBuilder`] to construct a CSV reader with options and the +/// [module-level documentation](crate::reader) for more details and examples pub type Reader = BufReader>; -/// CSV file reader +/// CSV file reader implementation. See [`Reader`] for usage +/// +/// Despite having the same name as [`std::io::BufReader`, this structure does +/// not buffer reads itself pub struct BufReader { /// File reader reader: R, - /// The decoder decoder: Decoder, } @@ -654,6 +699,22 @@ fn parse( let field = &fields[i]; match field.data_type() { DataType::Boolean => build_boolean_array(line_number, rows, i, null_regex), + DataType::Decimal32(precision, scale) => build_decimal_array::( + line_number, + rows, + i, + *precision, + *scale, + null_regex, + ), + DataType::Decimal64(precision, scale) => build_decimal_array::( + line_number, + rows, + i, + *precision, + *scale, + null_regex, + ), DataType::Decimal128(precision, scale) => build_decimal_array::( line_number, rows, @@ -844,7 +905,7 @@ fn parse( .collect::>(), ) as ArrayRef), _ => Err(ArrowError::ParseError(format!( - "Unsupported dictionary key type {key_type:?}" + "Unsupported dictionary key type {key_type}" ))), } } @@ -1037,7 +1098,7 @@ fn build_boolean_array( .map(|e| Arc::new(e) as ArrayRef) } -/// CSV file reader builder +/// Builder for CSV [`Reader`]s #[derive(Debug)] pub struct ReaderBuilder { /// Schema of the CSV file @@ -1055,9 +1116,10 @@ pub struct ReaderBuilder { } impl ReaderBuilder { - /// Create a new builder for configuring CSV parsing options. + /// Create a new builder for configuring [`Reader`] CSV parsing options. /// - /// To convert a builder into a reader, call `ReaderBuilder::build` + /// To convert a builder into a reader, call [`ReaderBuilder::build`]. See + /// the [module-level documentation](crate::reader) for more details and examples. /// /// # Example /// @@ -1315,6 +1377,54 @@ mod tests { assert_eq!("0.290472", lng.value_as_string(9)); } + #[test] + fn test_csv_reader_with_decimal_3264() { + let schema = Arc::new(Schema::new(vec![ + Field::new("city", DataType::Utf8, false), + Field::new("lat", DataType::Decimal32(9, 6), false), + Field::new("lng", DataType::Decimal64(16, 6), false), + ])); + + let file = File::open("test/data/decimal_test.csv").unwrap(); + + let mut csv = ReaderBuilder::new(schema).build(file).unwrap(); + let batch = csv.next().unwrap().unwrap(); + // access data from a primitive array + let lat = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!("57.653484", lat.value_as_string(0)); + assert_eq!("53.002666", lat.value_as_string(1)); + assert_eq!("52.412811", lat.value_as_string(2)); + assert_eq!("51.481583", lat.value_as_string(3)); + assert_eq!("12.123456", lat.value_as_string(4)); + assert_eq!("50.760000", lat.value_as_string(5)); + assert_eq!("0.123000", lat.value_as_string(6)); + assert_eq!("123.000000", lat.value_as_string(7)); + assert_eq!("123.000000", lat.value_as_string(8)); + assert_eq!("-50.760000", lat.value_as_string(9)); + + let lng = batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + + assert_eq!("-3.335724", lng.value_as_string(0)); + assert_eq!("-2.179404", lng.value_as_string(1)); + assert_eq!("-1.778197", lng.value_as_string(2)); + assert_eq!("-3.179090", lng.value_as_string(3)); + assert_eq!("-3.179090", lng.value_as_string(4)); + assert_eq!("0.290472", lng.value_as_string(5)); + assert_eq!("0.290472", lng.value_as_string(6)); + assert_eq!("0.290472", lng.value_as_string(7)); + assert_eq!("0.290472", lng.value_as_string(8)); + assert_eq!("0.290472", lng.value_as_string(9)); + } + #[test] fn test_csv_from_buf_reader() { let schema = Schema::new(vec![ @@ -1789,7 +1899,10 @@ mod tests { let file_name = "test/data/various_invalid_types/invalid_float.csv"; let error = invalid_csv_helper(file_name); - assert_eq!("Parser error: Error while parsing value '4.x4' as type 'Float32' for column 1 at line 4. Row data: '[4,4.x4,,false]'", error); + assert_eq!( + "Parser error: Error while parsing value '4.x4' as type 'Float32' for column 1 at line 4. Row data: '[4,4.x4,,false]'", + error + ); } #[test] @@ -1797,7 +1910,10 @@ mod tests { let file_name = "test/data/various_invalid_types/invalid_int.csv"; let error = invalid_csv_helper(file_name); - assert_eq!("Parser error: Error while parsing value '2.3' as type 'UInt64' for column 0 at line 2. Row data: '[2.3,2.2,2.22,false]'", error); + assert_eq!( + "Parser error: Error while parsing value '2.3' as type 'UInt64' for column 0 at line 2. Row data: '[2.3,2.2,2.22,false]'", + error + ); } #[test] @@ -1805,7 +1921,10 @@ mod tests { let file_name = "test/data/various_invalid_types/invalid_bool.csv"; let error = invalid_csv_helper(file_name); - assert_eq!("Parser error: Error while parsing value 'none' as type 'Boolean' for column 3 at line 2. Row data: '[2,2.2,2.22,none]'", error); + assert_eq!( + "Parser error: Error while parsing value 'none' as type 'Boolean' for column 3 at line 2. Row data: '[2,2.2,2.22,none]'", + error + ); } /// Infer the data type of a record @@ -2633,7 +2752,10 @@ mod tests { .infer_schema(&mut read, None); assert!(result.is_err()); // Include line number in the error message to help locate and fix the issue - assert_eq!(result.err().unwrap().to_string(), "Csv error: Encountered unequal lengths between records on CSV file. Expected 2 records, found 3 records at line 3"); + assert_eq!( + result.err().unwrap().to_string(), + "Csv error: Encountered unequal lengths between records on CSV file. Expected 3 records, found 2 records at line 3" + ); } #[test] diff --git a/arrow-csv/src/writer.rs b/arrow-csv/src/writer.rs index c5a0a0b76d59..c38d1cdec337 100644 --- a/arrow-csv/src/writer.rs +++ b/arrow-csv/src/writer.rs @@ -15,13 +15,12 @@ // specific language governing permissions and limitations // under the License. -//! CSV Writer +//! CSV Writing: [`Writer`] and [`WriterBuilder`] //! //! This CSV writer allows Arrow data (in record batches) to be written as CSV files. //! The writer does not support writing `ListArray` and `StructArray`. //! -//! Example: -//! +//! # Example //! ``` //! # use arrow_array::*; //! # use arrow_array::types::*; @@ -62,6 +61,117 @@ //! writer.write(batch).unwrap(); //! } //! ``` +//! +//! # Whitespace Handling +//! +//! The writer supports trimming leading and trailing whitespace from string values, +//! compatible with Apache Spark's CSV options `ignoreLeadingWhiteSpace` and +//! `ignoreTrailingWhiteSpace`. This is useful when working with data that may have +//! unwanted padding. +//! +//! Whitespace trimming is applied to all string data types: +//! - `DataType::Utf8` +//! - `DataType::LargeUtf8` +//! - `DataType::Utf8View` +//! +//! ## Example: Use [`WriterBuilder`] to control whitespace handling +//! +//! ``` +//! # use arrow_array::*; +//! # use arrow_csv::WriterBuilder; +//! # use arrow_schema::*; +//! # use std::sync::Arc; +//! let schema = Schema::new(vec![ +//! Field::new("name", DataType::Utf8, false), +//! Field::new("comment", DataType::Utf8, false), +//! ]); +//! +//! let name = StringArray::from(vec![ +//! " Alice ", // Leading and trailing spaces +//! "Bob", // No spaces +//! " Charlie", // Leading spaces only +//! ]); +//! let comment = StringArray::from(vec![ +//! " Great job! ", +//! "Well done", +//! "Excellent ", +//! ]); +//! +//! let batch = RecordBatch::try_new( +//! Arc::new(schema), +//! vec![Arc::new(name), Arc::new(comment)], +//! ) +//! .unwrap(); +//! +//! // Trim both leading and trailing whitespace +//! let mut output = Vec::new(); +//! WriterBuilder::new() +//! .with_ignore_leading_whitespace(true) +//! .with_ignore_trailing_whitespace(true) +//! .build(&mut output) +//! .write(&batch) +//! .unwrap(); +//! assert_eq!( +//! String::from_utf8(output).unwrap(), +//! "\ +//! name,comment\n\ +//! Alice,Great job!\n\ +//! Bob,Well done\n\ +//! Charlie,Excellent\n" +//! ); +//! ``` +//! +//! # Quoting Styles +//! +//! The writer supports different quoting styles for fields, compatible with Apache Spark's +//! CSV options like `quoteAll`. You can control when fields are quoted using the +//! [`QuoteStyle`] enum. +//! +//! ## Example +//! +//! ``` +//! # use arrow_array::*; +//! # use arrow_csv::{WriterBuilder, QuoteStyle}; +//! # use arrow_schema::*; +//! # use std::sync::Arc; +//! +//! let schema = Schema::new(vec![ +//! Field::new("product", DataType::Utf8, false), +//! Field::new("price", DataType::Float64, false), +//! ]); +//! +//! let product = StringArray::from(vec!["apple", "banana,organic", "cherry"]); +//! let price = Float64Array::from(vec![1.50, 2.25, 3.00]); +//! +//! let batch = RecordBatch::try_new( +//! Arc::new(schema), +//! vec![Arc::new(product), Arc::new(price)], +//! ) +//! .unwrap(); +//! +//! // Default behavior (QuoteStyle::Necessary) +//! let mut output = Vec::new(); +//! WriterBuilder::new() +//! .build(&mut output) +//! .write(&batch) +//! .unwrap(); +//! assert_eq!( +//! String::from_utf8(output).unwrap(), +//! "product,price\napple,1.5\n\"banana,organic\",2.25\ncherry,3.0\n" +//! ); +//! +//! // Quote all fields (Spark's quoteAll=true) +//! let mut output = Vec::new(); +//! WriterBuilder::new() +//! .with_quote_style(QuoteStyle::Always) +//! .build(&mut output) +//! .write(&batch) +//! .unwrap(); +//! assert_eq!( +//! String::from_utf8(output).unwrap(), +//! "\"product\",\"price\"\n\"apple\",\"1.5\"\n\"banana,organic\",\"2.25\"\n\"cherry\",\"3.0\"\n" +//! ); +//! ``` use arrow_array::*; use arrow_cast::display::*; @@ -72,7 +182,25 @@ use std::io::Write; use crate::map_csv_error; const DEFAULT_NULL_VALUE: &str = ""; +/// The quoting style to use when writing CSV files. +/// +/// This type is re-exported from the `csv` crate and supports different +/// strategies for quoting fields. It is compatible with Apache Spark's +/// CSV options like `quoteAll`. +/// +/// # Example +/// +/// ``` +/// use arrow_csv::{WriterBuilder, QuoteStyle}; +/// +/// let builder = WriterBuilder::new() +/// .with_quote_style(QuoteStyle::Always); // Equivalent to Spark's quoteAll=true +/// ``` +pub use csv::QuoteStyle; + /// A CSV writer +/// +/// See the [module documentation](crate::writer) for examples. #[derive(Debug)] pub struct Writer { /// The object to write to @@ -93,16 +221,23 @@ pub struct Writer { beginning: bool, /// The value to represent null entries, defaults to [`DEFAULT_NULL_VALUE`] null_value: Option, + /// Whether to ignore leading whitespace in string values + ignore_leading_whitespace: bool, + /// Whether to ignore trailing whitespace in string values + ignore_trailing_whitespace: bool, } impl Writer { /// Create a new CsvWriter from a writable object, with default options + /// + /// See [`WriterBuilder`] for configure options, and the [module + /// documentation](crate::writer) for examples. pub fn new(writer: W) -> Self { let delimiter = b','; WriterBuilder::new().with_delimiter(delimiter).build(writer) } - /// Write a vector of record batches to a writable object + /// Write a RecordBatch to the underlying writer pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { let num_columns = batch.num_columns(); if self.beginning { @@ -157,7 +292,10 @@ impl Writer { col_idx + 1 )) })?; - byte_record.push_field(buffer.as_bytes()); + + let field_bytes = + self.get_trimmed_field_bytes(&buffer, batch.column(col_idx).data_type()); + byte_record.push_field(field_bytes); } self.writer @@ -169,6 +307,29 @@ impl Writer { Ok(()) } + /// Returns the bytes for a field, applying whitespace trimming if configured and applicable + fn get_trimmed_field_bytes<'a>(&self, buffer: &'a str, data_type: &DataType) -> &'a [u8] { + // Only trim string types when trimming is enabled + let should_trim = (self.ignore_leading_whitespace || self.ignore_trailing_whitespace) + && matches!( + data_type, + DataType::Utf8 | DataType::LargeUtf8 | DataType::Utf8View + ); + + if !should_trim { + return buffer.as_bytes(); + } + + let mut trimmed = buffer; + if self.ignore_leading_whitespace { + trimmed = trimmed.trim_start(); + } + if self.ignore_trailing_whitespace { + trimmed = trimmed.trim_end(); + } + trimmed.as_bytes() + } + /// Unwraps this `Writer`, returning the underlying writer. pub fn into_inner(self) -> W { // Safe to call `unwrap` since `write` always flushes the writer. @@ -211,6 +372,12 @@ pub struct WriterBuilder { time_format: Option, /// Optional value to represent null null_value: Option, + /// Whether to ignore leading whitespace in string values. Defaults to `false` + ignore_leading_whitespace: bool, + /// Whether to ignore trailing whitespace in string values. Defaults to `false` + ignore_trailing_whitespace: bool, + /// The quoting style to use. Defaults to `QuoteStyle::Necessary` + quote_style: QuoteStyle, } impl Default for WriterBuilder { @@ -227,14 +394,18 @@ impl Default for WriterBuilder { timestamp_tz_format: None, time_format: None, null_value: None, + ignore_leading_whitespace: false, + ignore_trailing_whitespace: false, + quote_style: QuoteStyle::default(), } } } impl WriterBuilder { - /// Create a new builder for configuring CSV writing options. + /// Create a new builder for configuring CSV [`Writer`] options. /// - /// To convert a builder into a writer, call `WriterBuilder::build` + /// To convert a builder into a writer, call [`WriterBuilder::build`]. See + /// the [module documentation](crate::writer) for more examples. /// /// # Example /// @@ -389,12 +560,62 @@ impl WriterBuilder { self.null_value.as_deref().unwrap_or(DEFAULT_NULL_VALUE) } + /// Set whether to ignore leading whitespace in string values + /// For example, a string value such as " foo" will be written as "foo" + pub fn with_ignore_leading_whitespace(mut self, ignore: bool) -> Self { + self.ignore_leading_whitespace = ignore; + self + } + + /// Get whether to ignore leading whitespace in string values + pub fn ignore_leading_whitespace(&self) -> bool { + self.ignore_leading_whitespace + } + + /// Set whether to ignore trailing whitespace in string values + /// For example, a string value such as "foo " will be written as "foo" + pub fn with_ignore_trailing_whitespace(mut self, ignore: bool) -> Self { + self.ignore_trailing_whitespace = ignore; + self + } + + /// Get whether to ignore trailing whitespace in string values + pub fn ignore_trailing_whitespace(&self) -> bool { + self.ignore_trailing_whitespace + } + + /// Set the quoting style for writing CSV files + /// + /// # Example + /// + /// ``` + /// use arrow_csv::{WriterBuilder, QuoteStyle}; + /// + /// // Quote all fields (equivalent to Spark's quoteAll=true) + /// let builder = WriterBuilder::new() + /// .with_quote_style(QuoteStyle::Always); + /// + /// // Only quote when necessary (default) + /// let builder = WriterBuilder::new() + /// .with_quote_style(QuoteStyle::Necessary); + /// ``` + pub fn with_quote_style(mut self, quote_style: QuoteStyle) -> Self { + self.quote_style = quote_style; + self + } + + /// Get the configured quoting style + pub fn quote_style(&self) -> QuoteStyle { + self.quote_style + } + /// Create a new `Writer` pub fn build(self, writer: W) -> Writer { let mut builder = csv::WriterBuilder::new(); let writer = builder .delimiter(self.delimiter) .quote(self.quote) + .quote_style(self.quote_style) .double_quote(self.double_quote) .escape(self.escape) .from_writer(writer); @@ -408,6 +629,8 @@ impl WriterBuilder { timestamp_format: self.timestamp_format, timestamp_tz_format: self.timestamp_tz_format, null_value: self.null_value, + ignore_leading_whitespace: self.ignore_leading_whitespace, + ignore_trailing_whitespace: self.ignore_trailing_whitespace, } } } @@ -418,8 +641,8 @@ mod tests { use crate::ReaderBuilder; use arrow_array::builder::{ - BinaryBuilder, Decimal128Builder, Decimal256Builder, FixedSizeBinaryBuilder, - LargeBinaryBuilder, + BinaryBuilder, Decimal32Builder, Decimal64Builder, Decimal128Builder, Decimal256Builder, + FixedSizeBinaryBuilder, LargeBinaryBuilder, }; use arrow_array::types::*; use arrow_buffer::i256; @@ -496,25 +719,38 @@ sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555,23:46:03,foo #[test] fn test_write_csv_decimal() { let schema = Schema::new(vec![ - Field::new("c1", DataType::Decimal128(38, 6), true), - Field::new("c2", DataType::Decimal256(76, 6), true), + Field::new("c1", DataType::Decimal32(9, 6), true), + Field::new("c2", DataType::Decimal64(17, 6), true), + Field::new("c3", DataType::Decimal128(38, 6), true), + Field::new("c4", DataType::Decimal256(76, 6), true), ]); - let mut c1_builder = Decimal128Builder::new().with_data_type(DataType::Decimal128(38, 6)); + let mut c1_builder = Decimal32Builder::new().with_data_type(DataType::Decimal32(9, 6)); c1_builder.extend(vec![Some(-3335724), Some(2179404), None, Some(290472)]); let c1 = c1_builder.finish(); - let mut c2_builder = Decimal256Builder::new().with_data_type(DataType::Decimal256(76, 6)); - c2_builder.extend(vec![ + let mut c2_builder = Decimal64Builder::new().with_data_type(DataType::Decimal64(17, 6)); + c2_builder.extend(vec![Some(-3335724), Some(2179404), None, Some(290472)]); + let c2 = c2_builder.finish(); + + let mut c3_builder = Decimal128Builder::new().with_data_type(DataType::Decimal128(38, 6)); + c3_builder.extend(vec![Some(-3335724), Some(2179404), None, Some(290472)]); + let c3 = c3_builder.finish(); + + let mut c4_builder = Decimal256Builder::new().with_data_type(DataType::Decimal256(76, 6)); + c4_builder.extend(vec![ Some(i256::from_i128(-3335724)), Some(i256::from_i128(2179404)), None, Some(i256::from_i128(290472)), ]); - let c2 = c2_builder.finish(); + let c4 = c4_builder.finish(); - let batch = - RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c1), Arc::new(c2)]).unwrap(); + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(c1), Arc::new(c2), Arc::new(c3), Arc::new(c4)], + ) + .unwrap(); let mut file = tempfile::tempfile().unwrap(); @@ -530,15 +766,15 @@ sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555,23:46:03,foo let mut buffer: Vec = vec![]; file.read_to_end(&mut buffer).unwrap(); - let expected = r#"c1,c2 --3.335724,-3.335724 -2.179404,2.179404 -, -0.290472,0.290472 --3.335724,-3.335724 -2.179404,2.179404 -, -0.290472,0.290472 + let expected = r#"c1,c2,c3,c4 +-3.335724,-3.335724,-3.335724,-3.335724 +2.179404,2.179404,2.179404,2.179404 +,,, +0.290472,0.290472,0.290472,0.290472 +-3.335724,-3.335724,-3.335724,-3.335724 +2.179404,2.179404,2.179404,2.179404 +,,, +0.290472,0.290472,0.290472,0.290472 "#; assert_eq!(expected, str::from_utf8(&buffer).unwrap()); } @@ -704,7 +940,10 @@ sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555,23:46:03,foo for batch in batches { let err = writer.write(batch).unwrap_err().to_string(); - assert_eq!(err, "Csv error: Error processing row 2, col 2: Cast error: Failed to convert 1926632005177685347 to temporal for Date64") + assert_eq!( + err, + "Csv error: Error processing row 2, col 2: Cast error: Failed to convert 1926632005177685347 to temporal for Date64" + ) } drop(writer); } @@ -844,4 +1083,279 @@ sed do eiusmod tempor,-556132.25,1,,2019-04-18T02:45:55.555,23:46:03,foo String::from_utf8(buf).unwrap() ); } + + #[test] + fn test_write_csv_whitespace_handling() { + let schema = Schema::new(vec![ + Field::new("c1", DataType::Utf8, false), + Field::new("c2", DataType::Float64, true), + Field::new("c3", DataType::Utf8, true), + ]); + + let c1 = StringArray::from(vec![ + " leading space", + "trailing space ", + " both spaces ", + "no spaces", + ]); + let c2 = PrimitiveArray::::from(vec![ + Some(123.45), + Some(678.90), + None, + Some(111.22), + ]); + let c3 = StringArray::from(vec![ + Some(" test "), + Some("value "), + None, + Some(" another"), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(c1), Arc::new(c2), Arc::new(c3)], + ) + .unwrap(); + + // Test with no whitespace handling (default) + let mut buf = Vec::new(); + let builder = WriterBuilder::new(); + let mut writer = builder.build(&mut buf); + writer.write(&batch).unwrap(); + drop(writer); + assert_eq!( + "c1,c2,c3\n leading space,123.45, test \ntrailing space ,678.9,value \n both spaces ,,\nno spaces,111.22, another\n", + String::from_utf8(buf).unwrap() + ); + + // Test with ignore leading whitespace only + let mut buf = Vec::new(); + let builder = WriterBuilder::new().with_ignore_leading_whitespace(true); + let mut writer = builder.build(&mut buf); + writer.write(&batch).unwrap(); + drop(writer); + assert_eq!( + "c1,c2,c3\nleading space,123.45,test \ntrailing space ,678.9,value \nboth spaces ,,\nno spaces,111.22,another\n", + String::from_utf8(buf).unwrap() + ); + + // Test with ignore trailing whitespace only + let mut buf = Vec::new(); + let builder = WriterBuilder::new().with_ignore_trailing_whitespace(true); + let mut writer = builder.build(&mut buf); + writer.write(&batch).unwrap(); + drop(writer); + assert_eq!( + "c1,c2,c3\n leading space,123.45, test\ntrailing space,678.9,value\n both spaces,,\nno spaces,111.22, another\n", + String::from_utf8(buf).unwrap() + ); + + // Test with both ignore leading and trailing whitespace + let mut buf = Vec::new(); + let builder = WriterBuilder::new() + .with_ignore_leading_whitespace(true) + .with_ignore_trailing_whitespace(true); + let mut writer = builder.build(&mut buf); + writer.write(&batch).unwrap(); + drop(writer); + assert_eq!( + "c1,c2,c3\nleading space,123.45,test\ntrailing space,678.9,value\nboth spaces,,\nno spaces,111.22,another\n", + String::from_utf8(buf).unwrap() + ); + } + + #[test] + fn test_write_csv_whitespace_with_special_chars() { + let schema = Schema::new(vec![Field::new("c1", DataType::Utf8, false)]); + + let c1 = StringArray::from(vec![ + " quoted \"value\" ", + " new\nline ", + " comma,value ", + "\ttab\tvalue\t", + ]); + + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(c1)]).unwrap(); + + // Test with both ignore leading and trailing whitespace + let mut buf = Vec::new(); + let builder = WriterBuilder::new() + .with_ignore_leading_whitespace(true) + .with_ignore_trailing_whitespace(true); + let mut writer = builder.build(&mut buf); + writer.write(&batch).unwrap(); + drop(writer); + + // Note: tabs are trimmed as they are whitespace characters + assert_eq!( + "c1\n\"quoted \"\"value\"\"\"\n\"new\nline\"\n\"comma,value\"\ntab\tvalue\n", + String::from_utf8(buf).unwrap() + ); + } + + #[test] + fn test_write_csv_whitespace_all_string_types() { + use arrow_array::{LargeStringArray, StringViewArray}; + + let schema = Schema::new(vec![ + Field::new("utf8", DataType::Utf8, false), + Field::new("large_utf8", DataType::LargeUtf8, false), + Field::new("utf8_view", DataType::Utf8View, false), + ]); + + let utf8 = StringArray::from(vec![" leading", "trailing ", " both ", "no_spaces"]); + + let large_utf8 = + LargeStringArray::from(vec![" leading", "trailing ", " both ", "no_spaces"]); + + let utf8_view = + StringViewArray::from(vec![" leading", "trailing ", " both ", "no_spaces"]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(utf8), Arc::new(large_utf8), Arc::new(utf8_view)], + ) + .unwrap(); + + // Test with no whitespace handling (default) + let mut buf = Vec::new(); + let builder = WriterBuilder::new(); + let mut writer = builder.build(&mut buf); + writer.write(&batch).unwrap(); + drop(writer); + assert_eq!( + "utf8,large_utf8,utf8_view\n leading, leading, leading\ntrailing ,trailing ,trailing \n both , both , both \nno_spaces,no_spaces,no_spaces\n", + String::from_utf8(buf).unwrap() + ); + + // Test with both ignore leading and trailing whitespace + let mut buf = Vec::new(); + let builder = WriterBuilder::new() + .with_ignore_leading_whitespace(true) + .with_ignore_trailing_whitespace(true); + let mut writer = builder.build(&mut buf); + writer.write(&batch).unwrap(); + drop(writer); + assert_eq!( + "utf8,large_utf8,utf8_view\nleading,leading,leading\ntrailing,trailing,trailing\nboth,both,both\nno_spaces,no_spaces,no_spaces\n", + String::from_utf8(buf).unwrap() + ); + + // Test with only leading whitespace trimming + let mut buf = Vec::new(); + let builder = WriterBuilder::new().with_ignore_leading_whitespace(true); + let mut writer = builder.build(&mut buf); + writer.write(&batch).unwrap(); + drop(writer); + assert_eq!( + "utf8,large_utf8,utf8_view\nleading,leading,leading\ntrailing ,trailing ,trailing \nboth ,both ,both \nno_spaces,no_spaces,no_spaces\n", + String::from_utf8(buf).unwrap() + ); + + // Test with only trailing whitespace trimming + let mut buf = Vec::new(); + let builder = WriterBuilder::new().with_ignore_trailing_whitespace(true); + let mut writer = builder.build(&mut buf); + writer.write(&batch).unwrap(); + drop(writer); + assert_eq!( + "utf8,large_utf8,utf8_view\n leading, leading, leading\ntrailing,trailing,trailing\n both, both, both\nno_spaces,no_spaces,no_spaces\n", + String::from_utf8(buf).unwrap() + ); + } + + fn write_quote_style(batch: &RecordBatch, quote_style: QuoteStyle) -> String { + let mut buf = Vec::new(); + let mut writer = WriterBuilder::new() + .with_quote_style(quote_style) + .build(&mut buf); + writer.write(batch).unwrap(); + drop(writer); + String::from_utf8(buf).unwrap() + } + + fn write_quote_style_with_null( + batch: &RecordBatch, + quote_style: QuoteStyle, + null_value: &str, + ) -> String { + let mut buf = Vec::new(); + let mut writer = WriterBuilder::new() + .with_quote_style(quote_style) + .with_null(null_value.to_string()) + .build(&mut buf); + writer.write(batch).unwrap(); + drop(writer); + String::from_utf8(buf).unwrap() + } + + #[test] + fn test_write_csv_quote_style() { + let schema = Schema::new(vec![ + Field::new("text", DataType::Utf8, false), + Field::new("number", DataType::Int32, false), + Field::new("float", DataType::Float64, false), + ]); + + let text = StringArray::from(vec!["hello", "world", "comma,value", "quote\"test"]); + let number = Int32Array::from(vec![1, 2, 3, 4]); + let float = Float64Array::from(vec![1.1, 2.2, 3.3, 4.4]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(text), Arc::new(number), Arc::new(float)], + ) + .unwrap(); + + // Test with QuoteStyle::Necessary (default) + assert_eq!( + "text,number,float\nhello,1,1.1\nworld,2,2.2\n\"comma,value\",3,3.3\n\"quote\"\"test\",4,4.4\n", + write_quote_style(&batch, QuoteStyle::Necessary) + ); + + // Test with QuoteStyle::Always (equivalent to Spark's quoteAll=true) + assert_eq!( + "\"text\",\"number\",\"float\"\n\"hello\",\"1\",\"1.1\"\n\"world\",\"2\",\"2.2\"\n\"comma,value\",\"3\",\"3.3\"\n\"quote\"\"test\",\"4\",\"4.4\"\n", + write_quote_style(&batch, QuoteStyle::Always) + ); + + // Test with QuoteStyle::NonNumeric + assert_eq!( + "\"text\",\"number\",\"float\"\n\"hello\",1,1.1\n\"world\",2,2.2\n\"comma,value\",3,3.3\n\"quote\"\"test\",4,4.4\n", + write_quote_style(&batch, QuoteStyle::NonNumeric) + ); + + // Test with QuoteStyle::Never (warning: can produce invalid CSV) + // Note: This produces invalid CSV for fields with commas or quotes + assert_eq!( + "text,number,float\nhello,1,1.1\nworld,2,2.2\ncomma,value,3,3.3\nquote\"test,4,4.4\n", + write_quote_style(&batch, QuoteStyle::Never) + ); + } + + #[test] + fn test_write_csv_quote_style_with_nulls() { + let schema = Schema::new(vec![ + Field::new("text", DataType::Utf8, true), + Field::new("number", DataType::Int32, true), + ]); + + let text = StringArray::from(vec![Some("hello"), None, Some("world")]); + let number = Int32Array::from(vec![Some(1), Some(2), None]); + + let batch = + RecordBatch::try_new(Arc::new(schema), vec![Arc::new(text), Arc::new(number)]).unwrap(); + + // Test with QuoteStyle::Always + assert_eq!( + "\"text\",\"number\"\n\"hello\",\"1\"\n\"\",\"2\"\n\"world\",\"\"\n", + write_quote_style(&batch, QuoteStyle::Always) + ); + + // Test with QuoteStyle::Always and custom null value + assert_eq!( + "\"text\",\"number\"\n\"hello\",\"1\"\n\"NULL\",\"2\"\n\"world\",\"NULL\"\n", + write_quote_style_with_null(&batch, QuoteStyle::Always, "NULL") + ); + } } diff --git a/arrow-data/Cargo.toml b/arrow-data/Cargo.toml index fbed24fea1fa..9c7a5206b2f4 100644 --- a/arrow-data/Cargo.toml +++ b/arrow-data/Cargo.toml @@ -48,7 +48,8 @@ all-features = true arrow-buffer = { workspace = true } arrow-schema = { workspace = true } -num = { version = "0.4", default-features = false, features = ["std"] } +num-integer = { version = "0.1.46", default-features = false, features = ["std"] } +num-traits = { version = "0.2.19", default-features = false, features = ["std"] } half = { version = "2.1", default-features = false } [dev-dependencies] diff --git a/arrow-data/src/byte_view.rs b/arrow-data/src/byte_view.rs index 3b3ec6246066..270f4f9948ac 100644 --- a/arrow-data/src/byte_view.rs +++ b/arrow-data/src/byte_view.rs @@ -18,6 +18,14 @@ use arrow_buffer::Buffer; use arrow_schema::ArrowError; +/// The maximum number of bytes that can be stored inline in a byte view. +/// +/// See [`ByteView`] and [`GenericByteViewArray`] for more information on the +/// layout of the views. +/// +/// [`GenericByteViewArray`]: https://docs.rs/arrow/latest/arrow/array/struct.GenericByteViewArray.html +pub const MAX_INLINE_VIEW_LEN: u32 = 12; + /// Helper to access views of [`GenericByteViewArray`] (`StringViewArray` and /// `BinaryViewArray`) where the length is greater than 12 bytes. /// @@ -76,15 +84,15 @@ impl ByteView { /// See example on [`ByteView`] docs /// /// Notes: - /// * the length should always be greater than 12 (Data less than 12 - /// bytes is stored as an inline view) + /// * the length should always be greater than [`MAX_INLINE_VIEW_LEN`] + /// (Data less than 12 bytes is stored as an inline view) /// * buffer and offset are set to `0` /// /// # Panics /// If the prefix is not exactly 4 bytes #[inline] pub fn new(length: u32, prefix: &[u8]) -> Self { - debug_assert!(length > 12); + debug_assert!(length > MAX_INLINE_VIEW_LEN); Self { length, prefix: u32::from_le_bytes(prefix.try_into().unwrap()), @@ -159,8 +167,8 @@ where { for (idx, v) in views.iter().enumerate() { let len = *v as u32; - if len <= 12 { - if len < 12 && (v >> (32 + len * 8)) != 0 { + if len <= MAX_INLINE_VIEW_LEN { + if len < MAX_INLINE_VIEW_LEN && (v >> (32 + len * 8)) != 0 { return Err(ArrowError::InvalidArgumentError(format!( "View at index {idx} contained non-zero padding for string of length {len}", ))); diff --git a/arrow-data/src/data.rs b/arrow-data/src/data.rs index 4c117184de79..4917691e23f8 100644 --- a/arrow-data/src/data.rs +++ b/arrow-data/src/data.rs @@ -21,7 +21,7 @@ use crate::bit_iterator::BitSliceIterator; use arrow_buffer::buffer::{BooleanBuffer, NullBuffer}; use arrow_buffer::{ - bit_util, i256, ArrowNativeType, Buffer, IntervalDayTime, IntervalMonthDayNano, MutableBuffer, + ArrowNativeType, Buffer, IntervalDayTime, IntervalMonthDayNano, MutableBuffer, bit_util, i256, }; use arrow_schema::{ArrowError, DataType, UnionMode}; use std::mem; @@ -83,6 +83,8 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff | DataType::Float16 | DataType::Float32 | DataType::Float64 + | DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) | DataType::Decimal128(_, _) | DataType::Decimal256(_, _) | DataType::Date32 @@ -279,7 +281,7 @@ impl ArrayData { ) -> Self { let mut skip_validation = UnsafeFlag::new(); // SAFETY: caller responsible for ensuring data is valid - skip_validation.set(true); + unsafe { skip_validation.set(true) }; ArrayDataBuilder { data_type, @@ -307,6 +309,9 @@ impl ArrayData { /// /// Note: This is a low level API and most users of the arrow crate should create /// arrays using the builders found in [arrow_array](https://docs.rs/arrow-array) + /// or [`ArrayDataBuilder`]. + /// + /// See also [`Self::into_parts`] to recover the fields pub fn try_new( data_type: DataType, len: usize, @@ -349,6 +354,33 @@ impl ArrayData { Ok(new_self) } + /// Return the constituent parts of this ArrayData + /// + /// This is the inverse of [`ArrayData::try_new`]. + /// + /// Returns `(data_type, len, nulls, offset, buffers, child_data)` + pub fn into_parts( + self, + ) -> ( + DataType, + usize, + Option, + usize, + Vec, + Vec, + ) { + let Self { + data_type, + len, + nulls, + offset, + buffers, + child_data, + } = self; + + (data_type, len, nulls, offset, buffers, child_data) + } + /// Returns a builder to construct a [`ArrayData`] instance of the same [`DataType`] #[inline] pub const fn builder(data_type: DataType) -> ArrayDataBuilder { @@ -474,21 +506,20 @@ impl ArrayData { result += buffer_size; } BufferSpec::VariableWidth => { - let buffer_len: usize; - match self.data_type { + let buffer_len = match self.data_type { DataType::Utf8 | DataType::Binary => { let offsets = self.typed_offsets::()?; - buffer_len = (offsets[self.len] - offsets[0] ) as usize; + (offsets[self.len] - offsets[0]) as usize } DataType::LargeUtf8 | DataType::LargeBinary => { let offsets = self.typed_offsets::()?; - buffer_len = (offsets[self.len] - offsets[0]) as usize; + (offsets[self.len] - offsets[0]) as usize } _ => { return Err(ArrowError::NotYetImplemented(format!( - "Invalid data type for VariableWidth buffer. Expected Utf8, LargeUtf8, Binary or LargeBinary. Got {}", - self.data_type - ))) + "Invalid data type for VariableWidth buffer. Expected Utf8, LargeUtf8, Binary or LargeBinary. Got {}", + self.data_type + ))); } }; result += buffer_len; @@ -552,7 +583,7 @@ impl ArrayData { if let DataType::Struct(_) = self.data_type() { // Slice into children let new_offset = self.offset + offset; - let new_data = ArrayData { + ArrayData { data_type: self.data_type().clone(), len: length, offset: new_offset, @@ -564,9 +595,7 @@ impl ArrayData { .map(|data| data.slice(offset, length)) .collect(), nulls: self.nulls.as_ref().map(|x| x.slice(offset, length)), - }; - - new_data + } } else { let mut new_data = self.clone(); @@ -616,6 +645,16 @@ impl ArrayData { vec![ArrayData::new_empty(f.data_type())], true, ), + DataType::ListView(f) => ( + vec![zeroed(len * 4), zeroed(len * 4)], + vec![ArrayData::new_empty(f.data_type())], + true, + ), + DataType::LargeListView(f) => ( + vec![zeroed(len * 8), zeroed(len * 8)], + vec![ArrayData::new_empty(f.data_type())], + true, + ), DataType::FixedSizeList(f, list_len) => ( vec![], vec![ArrayData::new_null(f.data_type(), *list_len as usize * len)], @@ -636,7 +675,7 @@ impl ArrayData { ), DataType::Union(f, mode) => { let (id, _) = f.iter().next().unwrap(); - let ids = Buffer::from_iter(std::iter::repeat(id).take(len)); + let ids = Buffer::from_iter(std::iter::repeat_n(id, len)); let buffers = match mode { UnionMode::Sparse => vec![ids], UnionMode::Dense => { @@ -689,7 +728,29 @@ impl ArrayData { false, ) } - d => unreachable!("{d}"), + // Handled by Some(width) branch above + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Timestamp(_, _) + | DataType::Date32 + | DataType::Date64 + | DataType::Time32(_) + | DataType::Time64(_) + | DataType::Duration(_) + | DataType::Interval(_) + | DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) => unreachable!("{data_type}"), }, }; @@ -782,7 +843,10 @@ impl ArrayData { if buffer.len() < min_buffer_size { return Err(ArrowError::InvalidArgumentError(format!( "Need at least {} bytes in buffers[{}] in array of type {:?}, but got {}", - min_buffer_size, i, self.data_type, buffer.len() + min_buffer_size, + i, + self.data_type, + buffer.len() ))); } @@ -790,7 +854,8 @@ impl ArrayData { if align_offset != 0 { return Err(ArrowError::InvalidArgumentError(format!( "Misaligned buffers[{i}] in array of type {:?}, offset from expected alignment of {alignment} by {}", - self.data_type, align_offset.min(alignment - align_offset) + self.data_type, + align_offset.min(alignment - align_offset) ))); } } @@ -804,7 +869,10 @@ impl ArrayData { if buffer.len() < min_buffer_size { return Err(ArrowError::InvalidArgumentError(format!( "Need at least {} bytes for bitmap in buffers[{}] in array of type {:?}, but got {}", - min_buffer_size, i, self.data_type, buffer.len() + min_buffer_size, + i, + self.data_type, + buffer.len() ))); } } @@ -884,7 +952,7 @@ impl ArrayData { /// entries. /// /// For an empty array, the `buffer` can also be empty. - fn typed_offsets(&self) -> Result<&[T], ArrowError> { + fn typed_offsets(&self) -> Result<&[T], ArrowError> { // An empty list-like array can have 0 offsets if self.len == 0 && self.buffers[0].is_empty() { return Ok(&[]); @@ -894,7 +962,7 @@ impl ArrayData { } /// Returns a reference to the data in `buffers[idx]` as a typed slice after validating - fn typed_buffer( + fn typed_buffer( &self, idx: usize, len: usize, @@ -918,7 +986,7 @@ impl ArrayData { /// Does a cheap sanity check that the `self.len` values in `buffer` are valid /// offsets (of type T) into some other buffer of `values_length` bytes long - fn validate_offsets( + fn validate_offsets( &self, values_length: usize, ) -> Result<(), ArrowError> { @@ -968,13 +1036,21 @@ impl ArrayData { /// Does a cheap sanity check that the `self.len` values in `buffer` are valid /// offsets and sizes (of type T) into some other buffer of `values_length` bytes long - fn validate_offsets_and_sizes( + fn validate_offsets_and_sizes( &self, values_length: usize, ) -> Result<(), ArrowError> { let offsets: &[T] = self.typed_buffer(0, self.len)?; let sizes: &[T] = self.typed_buffer(1, self.len)?; - for i in 0..values_length { + if offsets.len() != sizes.len() { + return Err(ArrowError::ComputeError(format!( + "ListView offsets len {} does not match sizes len {}", + offsets.len(), + sizes.len() + ))); + } + + for i in 0..sizes.len() { let size = sizes[i].to_usize().ok_or_else(|| { ArrowError::InvalidArgumentError(format!( "Error converting size[{}] ({}) to usize for {}", @@ -1056,7 +1132,11 @@ impl ArrayData { if field_data.len < self.len { return Err(ArrowError::InvalidArgumentError(format!( "{} child array #{} for field {} has length smaller than expected for struct array ({} < {})", - self.data_type, i, field.name(), field_data.len, self.len + self.data_type, + i, + field.name(), + field_data.len, + self.len ))); } } @@ -1088,7 +1168,9 @@ impl ArrayData { if mode == &UnionMode::Sparse && field_data.len < (self.len + self.offset) { return Err(ArrowError::InvalidArgumentError(format!( "Sparse union child array #{} has length smaller than expected for union array ({} < {})", - i, field_data.len, self.len + self.offset + i, + field_data.len, + self.len + self.offset ))); } } @@ -1280,7 +1362,7 @@ impl ArrayData { "non-nullable child of type {} contains nulls not present in parent {}", child.data_type, self.data_type ))), - } + }; } }; @@ -1371,7 +1453,7 @@ impl ArrayData { /// function would call `validate([1,2])`, and `validate([2,4])` fn validate_each_offset(&self, offset_limit: usize, validate: V) -> Result<(), ArrowError> where - T: ArrowNativeType + TryInto + num::Num + std::fmt::Display, + T: ArrowNativeType + TryInto + num_traits::Num + std::fmt::Display, V: Fn(usize, Range) -> Result<(), ArrowError>, { self.typed_offsets::()? @@ -1418,7 +1500,7 @@ impl ArrayData { /// into `buffers[1]` are valid utf8 sequences fn validate_utf8(&self) -> Result<(), ArrowError> where - T: ArrowNativeType + TryInto + num::Num + std::fmt::Display, + T: ArrowNativeType + TryInto + num_traits::Num + std::fmt::Display, { let values_buffer = &self.buffers[1].as_slice(); if let Ok(values_str) = std::str::from_utf8(values_buffer) { @@ -1450,7 +1532,7 @@ impl ArrayData { /// between `0` and `offset_limit` fn validate_offsets_full(&self, offset_limit: usize) -> Result<(), ArrowError> where - T: ArrowNativeType + TryInto + num::Num + std::fmt::Display, + T: ArrowNativeType + TryInto + num_traits::Num + std::fmt::Display, { self.validate_each_offset::(offset_limit, |_string_index, _range| { // No validation applied to each value, but the iteration @@ -1463,7 +1545,7 @@ impl ArrayData { /// is within the range [0, max_value], inclusive fn check_bounds(&self, max_value: i64) -> Result<(), ArrowError> where - T: ArrowNativeType + TryInto + num::Num + std::fmt::Display, + T: ArrowNativeType + TryInto + num_traits::Num + std::fmt::Display, { let required_len = self.len + self.offset; let buffer = &self.buffers[0]; @@ -1498,7 +1580,7 @@ impl ArrayData { /// Validates that each value in run_ends array is positive and strictly increasing. fn check_run_ends(&self) -> Result<(), ArrowError> where - T: ArrowNativeType + TryInto + num::Num + std::fmt::Display, + T: ArrowNativeType + TryInto + num_traits::Num + std::fmt::Display, { let values = self.typed_buffer::(0, self.len)?; let mut prev_value: i64 = 0_i64; @@ -1612,6 +1694,8 @@ pub fn layout(data_type: &DataType) -> DataTypeLayout { DataTypeLayout::new_fixed_width::() } DataType::Duration(_) => DataTypeLayout::new_fixed_width::(), + DataType::Decimal32(_, _) => DataTypeLayout::new_fixed_width::(), + DataType::Decimal64(_, _) => DataTypeLayout::new_fixed_width::(), DataType::Decimal128(_, _) => DataTypeLayout::new_fixed_width::(), DataType::Decimal256(_, _) => DataTypeLayout::new_fixed_width::(), DataType::FixedSizeBinary(size) => { @@ -1761,7 +1845,7 @@ impl DataTypeLayout { }, ], can_contain_null_mask: true, - variadic: true, + variadic: false, } } } @@ -1984,6 +2068,7 @@ impl ArrayDataBuilder { /// /// Note: This is shorthand for /// ```rust + /// # #[expect(unsafe_op_in_unsafe_fn)] /// # let mut builder = arrow_data::ArrayDataBuilder::new(arrow_schema::DataType::Null); /// # let _ = unsafe { /// builder.skip_validation(true).build().unwrap() @@ -1995,7 +2080,7 @@ impl ArrayDataBuilder { /// The same caveats as [`ArrayData::new_unchecked`] /// apply. pub unsafe fn build_unchecked(self) -> ArrayData { - self.skip_validation(true).build().unwrap() + unsafe { self.skip_validation(true) }.build().unwrap() } /// Creates an `ArrayData`, consuming `self` @@ -2094,7 +2179,9 @@ impl ArrayDataBuilder { /// If validation is skipped, the buffers must form a valid Arrow array, /// otherwise undefined behavior will result pub unsafe fn skip_validation(mut self, skip_validation: bool) -> Self { - self.skip_validation.set(skip_validation); + unsafe { + self.skip_validation.set(skip_validation); + } self } } @@ -2447,5 +2534,23 @@ mod tests { for i in 0..array.len() { assert!(array.is_null(i)); } + + let array = ArrayData::new_null( + &DataType::ListView(Arc::new(Field::new_list_field(DataType::Int32, true))), + array_len, + ); + assert_eq!(array.len(), array_len); + for i in 0..array.len() { + assert!(array.is_null(i)); + } + + let array = ArrayData::new_null( + &DataType::LargeListView(Arc::new(Field::new_list_field(DataType::Int32, true))), + array_len, + ); + assert_eq!(array.len(), array_len); + for i in 0..array.len() { + assert!(array.is_null(i)); + } } } diff --git a/arrow-data/src/decimal.rs b/arrow-data/src/decimal.rs index e84461f2ec3a..2c997753bd5f 100644 --- a/arrow-data/src/decimal.rs +++ b/arrow-data/src/decimal.rs @@ -15,19 +15,22 @@ // specific language governing permissions and limitations // under the License. -//! Maximum and minimum values for [`Decimal256`] and [`Decimal128`]. +//! Maximum and minimum values for [`Decimal256`], [`Decimal128`], [`Decimal64`] and [`Decimal32`]. //! //! Also provides functions to validate if a given decimal value is within //! the valid range of the decimal type. //! +//! [`Decimal32`]: arrow_schema::DataType::Decimal32 +//! [`Decimal64`]: arrow_schema::DataType::Decimal64 //! [`Decimal128`]: arrow_schema::DataType::Decimal128 //! [`Decimal256`]: arrow_schema::DataType::Decimal256 use arrow_buffer::i256; use arrow_schema::ArrowError; pub use arrow_schema::{ + DECIMAL_DEFAULT_SCALE, DECIMAL32_DEFAULT_SCALE, DECIMAL32_MAX_PRECISION, DECIMAL32_MAX_SCALE, + DECIMAL64_DEFAULT_SCALE, DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, - DECIMAL_DEFAULT_SCALE, }; /// `MAX_DECIMAL256_FOR_EACH_PRECISION[p]` holds the maximum [`i256`] value that can @@ -899,26 +902,264 @@ pub const MIN_DECIMAL128_FOR_EACH_PRECISION: [i128; 39] = [ -99999999999999999999999999999999999999, ]; +/// `MAX_DECIMAL64_FOR_EACH_PRECISION[p]` holds the maximum `i64` value that can +/// be stored in [`Decimal64`] value of precision `p`. +/// +/// # Notes +/// +/// The first element is unused and is inserted so that we can look up using +/// precision as the index without the need to subtract 1 first. +/// +/// # Example +/// ``` +/// # use arrow_data::decimal::MAX_DECIMAL64_FOR_EACH_PRECISION; +/// assert_eq!(MAX_DECIMAL64_FOR_EACH_PRECISION[3], 999); +/// ``` +/// +/// [`Decimal64`]: arrow_schema::DataType::Decimal64 +pub const MAX_DECIMAL64_FOR_EACH_PRECISION: [i64; 19] = [ + 0, // unused first element + 9, + 99, + 999, + 9999, + 99999, + 999999, + 9999999, + 99999999, + 999999999, + 9999999999, + 99999999999, + 999999999999, + 9999999999999, + 99999999999999, + 999999999999999, + 9999999999999999, + 99999999999999999, + 999999999999999999, +]; + +/// `MIN_DECIMAL64_FOR_EACH_PRECISION[p]` holds the minimum `i64` value that can +/// be stored in a [`Decimal64`] value of precision `p`. +/// +/// # Notes +/// +/// The first element is unused and is inserted so that we can look up using +/// precision as the index without the need to subtract 1 first. +/// +/// # Example +/// ``` +/// # use arrow_data::decimal::MIN_DECIMAL64_FOR_EACH_PRECISION; +/// assert_eq!(MIN_DECIMAL64_FOR_EACH_PRECISION[3], -999); +/// ``` +/// +/// [`Decimal64`]: arrow_schema::DataType::Decimal64 +pub const MIN_DECIMAL64_FOR_EACH_PRECISION: [i64; 19] = [ + 0, // unused first element + -9, + -99, + -999, + -9999, + -99999, + -999999, + -9999999, + -99999999, + -999999999, + -9999999999, + -99999999999, + -999999999999, + -9999999999999, + -99999999999999, + -999999999999999, + -9999999999999999, + -99999999999999999, + -999999999999999999, +]; + +/// `MAX_DECIMAL32_FOR_EACH_PRECISION[p]` holds the maximum `i32` value that can +/// be stored in [`Decimal32`] value of precision `p`. +/// +/// # Notes +/// +/// The first element is unused and is inserted so that we can look up using +/// precision as the index without the need to subtract 1 first. +/// +/// # Example +/// ``` +/// # use arrow_data::decimal::MAX_DECIMAL32_FOR_EACH_PRECISION; +/// assert_eq!(MAX_DECIMAL32_FOR_EACH_PRECISION[3], 999); +/// ``` +/// +/// [`Decimal32`]: arrow_schema::DataType::Decimal32 +pub const MAX_DECIMAL32_FOR_EACH_PRECISION: [i32; 10] = [ + 0, // unused first element + 9, 99, 999, 9999, 99999, 999999, 9999999, 99999999, 999999999, +]; + +/// `MIN_DECIMAL32_FOR_EACH_PRECISION[p]` holds the minimum `ialue that can +/// be stored in a [`Decimal32`] value of precision `p`. +/// +/// # Notes +/// +/// The first element is unused and is inserted so that we can look up using +/// precision as the index without the need to subtract 1 first. +/// +/// # Example +/// ``` +/// # use arrow_data::decimal::MIN_DECIMAL32_FOR_EACH_PRECISION; +/// assert_eq!(MIN_DECIMAL32_FOR_EACH_PRECISION[3], -999); +/// ``` +/// +/// [`Decimal32`]: arrow_schema::DataType::Decimal32 +pub const MIN_DECIMAL32_FOR_EACH_PRECISION: [i32; 10] = [ + 0, // unused first element + -9, -99, -999, -9999, -99999, -999999, -9999999, -99999999, -999999999, +]; + +/// Validates that the specified `i32` value can be properly +/// interpreted as a [`Decimal32`] number with precision `precision` +/// +/// [`Decimal32`]: arrow_schema::DataType::Decimal32 +#[inline] +pub fn validate_decimal32_precision( + value: i32, + precision: u8, + scale: i8, +) -> Result<(), ArrowError> { + if precision > DECIMAL32_MAX_PRECISION { + return Err(ArrowError::InvalidArgumentError(format!( + "Max precision of a Decimal32 is {DECIMAL32_MAX_PRECISION}, but got {precision}", + ))); + } + if value > MAX_DECIMAL32_FOR_EACH_PRECISION[precision as usize] { + let unscaled_value = + format_decimal_str_internal(&value.to_string(), precision.into(), scale, false); + let unscale_max_value = format_decimal_str( + &MAX_DECIMAL32_FOR_EACH_PRECISION[precision as usize].to_string(), + precision.into(), + scale, + ); + Err(ArrowError::InvalidArgumentError(format!( + "{unscaled_value} is too large to store in a Decimal32 of precision {precision}. Max is {}", + unscale_max_value + ))) + } else if value < MIN_DECIMAL32_FOR_EACH_PRECISION[precision as usize] { + let unscaled_value = + format_decimal_str_internal(&value.to_string(), precision.into(), scale, false); + let unscale_min_value = format_decimal_str( + &MIN_DECIMAL32_FOR_EACH_PRECISION[precision as usize].to_string(), + precision.into(), + scale, + ); + Err(ArrowError::InvalidArgumentError(format!( + "{unscaled_value} is too small to store in a Decimal32 of precision {precision}. Min is {}", + unscale_min_value + ))) + } else { + Ok(()) + } +} + +/// Returns true if the specified `i32` value can be properly +/// interpreted as a [`Decimal32`] number with precision `precision` +/// +/// [`Decimal32`]: arrow_schema::DataType::Decimal32 +#[inline] +pub fn is_validate_decimal32_precision(value: i32, precision: u8) -> bool { + precision <= DECIMAL32_MAX_PRECISION + && value >= MIN_DECIMAL32_FOR_EACH_PRECISION[precision as usize] + && value <= MAX_DECIMAL32_FOR_EACH_PRECISION[precision as usize] +} + +/// Validates that the specified `i64` value can be properly +/// interpreted as a [`Decimal64`] number with precision `precision` +/// +/// [`Decimal64`]: arrow_schema::DataType::Decimal64 +#[inline] +pub fn validate_decimal64_precision( + value: i64, + precision: u8, + scale: i8, +) -> Result<(), ArrowError> { + if precision > DECIMAL64_MAX_PRECISION { + return Err(ArrowError::InvalidArgumentError(format!( + "Max precision of a Decimal64 is {DECIMAL64_MAX_PRECISION}, but got {precision}", + ))); + } + if value > MAX_DECIMAL64_FOR_EACH_PRECISION[precision as usize] { + let unscaled_value = + format_decimal_str_internal(&value.to_string(), precision.into(), scale, false); + let unscaled_max_value = format_decimal_str( + &MAX_DECIMAL64_FOR_EACH_PRECISION[precision as usize].to_string(), + precision.into(), + scale, + ); + Err(ArrowError::InvalidArgumentError(format!( + "{unscaled_value} is too large to store in a Decimal64 of precision {precision}. Max is {}", + unscaled_max_value + ))) + } else if value < MIN_DECIMAL64_FOR_EACH_PRECISION[precision as usize] { + let unscaled_value = + format_decimal_str_internal(&value.to_string(), precision.into(), scale, false); + let unscaled_min_value = format_decimal_str( + &MIN_DECIMAL64_FOR_EACH_PRECISION[precision as usize].to_string(), + precision.into(), + scale, + ); + Err(ArrowError::InvalidArgumentError(format!( + "{unscaled_value} is too small to store in a Decimal64 of precision {precision}. Min is {}", + unscaled_min_value + ))) + } else { + Ok(()) + } +} + +/// Returns true if the specified `i64` value can be properly +/// interpreted as a [`Decimal64`] number with precision `precision` +/// +/// [`Decimal64`]: arrow_schema::DataType::Decimal64 +#[inline] +pub fn is_validate_decimal64_precision(value: i64, precision: u8) -> bool { + precision <= DECIMAL64_MAX_PRECISION + && value >= MIN_DECIMAL64_FOR_EACH_PRECISION[precision as usize] + && value <= MAX_DECIMAL64_FOR_EACH_PRECISION[precision as usize] +} + /// Validates that the specified `i128` value can be properly /// interpreted as a [`Decimal128`] number with precision `precision` /// /// [`Decimal128`]: arrow_schema::DataType::Decimal128 #[inline] -pub fn validate_decimal_precision(value: i128, precision: u8) -> Result<(), ArrowError> { +pub fn validate_decimal_precision(value: i128, precision: u8, scale: i8) -> Result<(), ArrowError> { if precision > DECIMAL128_MAX_PRECISION { return Err(ArrowError::InvalidArgumentError(format!( "Max precision of a Decimal128 is {DECIMAL128_MAX_PRECISION}, but got {precision}", ))); } if value > MAX_DECIMAL128_FOR_EACH_PRECISION[precision as usize] { + let unscaled_value = + format_decimal_str_internal(&value.to_string(), precision.into(), scale, false); + let unscaled_max_value = format_decimal_str( + &MAX_DECIMAL128_FOR_EACH_PRECISION[precision as usize].to_string(), + precision.into(), + scale, + ); Err(ArrowError::InvalidArgumentError(format!( - "{value} is too large to store in a Decimal128 of precision {precision}. Max is {}", - MAX_DECIMAL128_FOR_EACH_PRECISION[precision as usize] + "{unscaled_value} is too large to store in a Decimal128 of precision {precision}. Max is {}", + unscaled_max_value ))) } else if value < MIN_DECIMAL128_FOR_EACH_PRECISION[precision as usize] { + let unscaled_value = + format_decimal_str_internal(&value.to_string(), precision.into(), scale, false); + let unscaled_min_value = format_decimal_str( + &MIN_DECIMAL128_FOR_EACH_PRECISION[precision as usize].to_string(), + precision.into(), + scale, + ); Err(ArrowError::InvalidArgumentError(format!( - "{value} is too small to store in a Decimal128 of precision {precision}. Min is {}", - MIN_DECIMAL128_FOR_EACH_PRECISION[precision as usize] + "{unscaled_value} is too small to store in a Decimal128 of precision {precision}. Min is {}", + unscaled_min_value ))) } else { Ok(()) @@ -941,21 +1182,40 @@ pub fn is_validate_decimal_precision(value: i128, precision: u8) -> bool { /// /// [`Decimal256`]: arrow_schema::DataType::Decimal256 #[inline] -pub fn validate_decimal256_precision(value: i256, precision: u8) -> Result<(), ArrowError> { +pub fn validate_decimal256_precision( + value: i256, + precision: u8, + scale: i8, +) -> Result<(), ArrowError> { if precision > DECIMAL256_MAX_PRECISION { return Err(ArrowError::InvalidArgumentError(format!( "Max precision of a Decimal256 is {DECIMAL256_MAX_PRECISION}, but got {precision}", ))); } + if value > MAX_DECIMAL256_FOR_EACH_PRECISION[precision as usize] { + let unscaled_value = + format_decimal_str_internal(&value.to_string(), precision.into(), scale, false); + let unscaled_max_value = format_decimal_str( + &MAX_DECIMAL256_FOR_EACH_PRECISION[precision as usize].to_string(), + precision.into(), + scale, + ); Err(ArrowError::InvalidArgumentError(format!( - "{value:?} is too large to store in a Decimal256 of precision {precision}. Max is {:?}", - MAX_DECIMAL256_FOR_EACH_PRECISION[precision as usize] + "{unscaled_value} is too large to store in a Decimal256 of precision {precision}. Max is {}", + unscaled_max_value ))) } else if value < MIN_DECIMAL256_FOR_EACH_PRECISION[precision as usize] { + let unscaled_value = + format_decimal_str_internal(&value.to_string(), precision.into(), scale, false); + let unscaled_min_value = format_decimal_str( + &MIN_DECIMAL256_FOR_EACH_PRECISION[precision as usize].to_string(), + precision.into(), + scale, + ); Err(ArrowError::InvalidArgumentError(format!( - "{value:?} is too small to store in a Decimal256 of precision {precision}. Min is {:?}", - MIN_DECIMAL256_FOR_EACH_PRECISION[precision as usize] + "{unscaled_value} is too small to store in a Decimal256 of precision {precision}. Min is {}", + unscaled_min_value ))) } else { Ok(()) @@ -972,3 +1232,44 @@ pub fn is_validate_decimal256_precision(value: i256, precision: u8) -> bool { && value >= MIN_DECIMAL256_FOR_EACH_PRECISION[precision as usize] && value <= MAX_DECIMAL256_FOR_EACH_PRECISION[precision as usize] } + +#[inline] +/// Formats a decimal string given the precision and scale. +pub fn format_decimal_str(value_str: &str, precision: usize, scale: i8) -> String { + format_decimal_str_internal(value_str, precision, scale, true) +} + +// Format a decimal string given the precision and scale. +// If `safe_decimal` is true, the function will ensure that the output string +// does not exceed the specified precision. +fn format_decimal_str_internal( + value_str: &str, + precision: usize, + scale: i8, + safe_decimal: bool, +) -> String { + let (sign, rest) = match value_str.strip_prefix('-') { + Some(stripped) => ("-", stripped), + None => ("", value_str), + }; + let bound = if safe_decimal { + precision.min(rest.len()) + sign.len() + } else { + value_str.len() + }; + let value_str = &value_str[0..bound]; + + if scale == 0 { + value_str.to_string() + } else if scale < 0 { + let padding = value_str.len() + scale.unsigned_abs() as usize; + format!("{value_str:0 scale as usize { + // Decimal separator is in the middle of the string + let (whole, decimal) = value_str.split_at(value_str.len() - scale as usize); + format!("{whole}.{decimal}") + } else { + // String has to be padded + format!("{}0.{:0>width$}", sign, rest, width = scale as usize) + } +} diff --git a/arrow-data/src/equal/boolean.rs b/arrow-data/src/equal/boolean.rs index addae936f118..64b7125e3688 100644 --- a/arrow-data/src/equal/boolean.rs +++ b/arrow-data/src/equal/boolean.rs @@ -16,7 +16,7 @@ // under the License. use crate::bit_iterator::BitIndexIterator; -use crate::data::{contains_nulls, ArrayData}; +use crate::data::{ArrayData, contains_nulls}; use arrow_buffer::bit_util::get_bit; use super::utils::{equal_bits, equal_len}; diff --git a/arrow-data/src/equal/dictionary.rs b/arrow-data/src/equal/dictionary.rs index 1d9c4b8d964f..a906ec030580 100644 --- a/arrow-data/src/equal/dictionary.rs +++ b/arrow-data/src/equal/dictionary.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::data::{contains_nulls, ArrayData}; +use crate::data::{ArrayData, contains_nulls}; use arrow_buffer::ArrowNativeType; use super::equal_range; diff --git a/arrow-data/src/equal/fixed_list.rs b/arrow-data/src/equal/fixed_list.rs index 4b79e5c33fab..9a5d64d217ad 100644 --- a/arrow-data/src/equal/fixed_list.rs +++ b/arrow-data/src/equal/fixed_list.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::data::{contains_nulls, ArrayData}; +use crate::data::{ArrayData, contains_nulls}; use arrow_schema::DataType; use super::equal_range; diff --git a/arrow-data/src/equal/list.rs b/arrow-data/src/equal/list.rs index cc4ba3cacf9f..ba5e5a8c93c1 100644 --- a/arrow-data/src/equal/list.rs +++ b/arrow-data/src/equal/list.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -use crate::data::{count_nulls, ArrayData}; +use crate::data::{ArrayData, count_nulls}; use arrow_buffer::ArrowNativeType; -use num::Integer; +use num_integer::Integer; use super::equal_range; diff --git a/arrow-data/src/equal/list_view.rs b/arrow-data/src/equal/list_view.rs new file mode 100644 index 000000000000..c7cb31db9099 --- /dev/null +++ b/arrow-data/src/equal/list_view.rs @@ -0,0 +1,129 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::ArrayData; +use crate::data::count_nulls; +use crate::equal::equal_values; +use arrow_buffer::ArrowNativeType; +use num_integer::Integer; + +pub(super) fn list_view_equal( + lhs: &ArrayData, + rhs: &ArrayData, + lhs_start: usize, + rhs_start: usize, + len: usize, +) -> bool { + let lhs_offsets = lhs.buffer::(0); + let lhs_sizes = lhs.buffer::(1); + + let rhs_offsets = rhs.buffer::(0); + let rhs_sizes = rhs.buffer::(1); + + let lhs_data = &lhs.child_data()[0]; + let rhs_data = &rhs.child_data()[0]; + + let lhs_null_count = count_nulls(lhs.nulls(), lhs_start, len); + let rhs_null_count = count_nulls(rhs.nulls(), rhs_start, len); + + if lhs_null_count != rhs_null_count { + return false; + } + + if lhs_null_count == 0 { + // non-null pathway: all sizes must be equal, and all values must be equal + let lhs_range_sizes = &lhs_sizes[lhs_start..lhs_start + len]; + let rhs_range_sizes = &rhs_sizes[rhs_start..rhs_start + len]; + + if lhs_range_sizes.len() != rhs_range_sizes.len() { + return false; + } + + if lhs_range_sizes != rhs_range_sizes { + return false; + } + + // Check values for equality + let lhs_range_offsets = &lhs_offsets[lhs_start..lhs_start + len]; + let rhs_range_offsets = &rhs_offsets[rhs_start..rhs_start + len]; + + if lhs_range_offsets.len() != rhs_range_offsets.len() { + return false; + } + + for ((&lhs_offset, &rhs_offset), &size) in lhs_range_offsets + .iter() + .zip(rhs_range_offsets) + .zip(lhs_range_sizes) + { + let lhs_offset = lhs_offset.to_usize().unwrap(); + let rhs_offset = rhs_offset.to_usize().unwrap(); + let size = size.to_usize().unwrap(); + + // Check if offsets are valid for the given range + if !equal_values(lhs_data, rhs_data, lhs_offset, rhs_offset, size) { + return false; + } + } + } else { + // Need to integrate validity check in the inner loop. + // non-null pathway: all sizes must be equal, and all values must be equal + let lhs_range_sizes = &lhs_sizes[lhs_start..lhs_start + len]; + let rhs_range_sizes = &rhs_sizes[rhs_start..rhs_start + len]; + + let lhs_nulls = lhs.nulls().unwrap().slice(lhs_start, len); + let rhs_nulls = rhs.nulls().unwrap().slice(rhs_start, len); + + // Sizes can differ if values are null + if lhs_range_sizes.len() != rhs_range_sizes.len() { + return false; + } + + // Check values for equality, with null checking + let lhs_range_offsets = &lhs_offsets[lhs_start..lhs_start + len]; + let rhs_range_offsets = &rhs_offsets[rhs_start..rhs_start + len]; + + if lhs_range_offsets.len() != rhs_range_offsets.len() { + return false; + } + + for (index, ((&lhs_offset, &rhs_offset), &size)) in lhs_range_offsets + .iter() + .zip(rhs_range_offsets) + .zip(lhs_range_sizes) + .enumerate() + { + let lhs_is_null = lhs_nulls.is_null(index); + let rhs_is_null = rhs_nulls.is_null(index); + + if lhs_is_null != rhs_is_null { + return false; + } + + let lhs_offset = lhs_offset.to_usize().unwrap(); + let rhs_offset = rhs_offset.to_usize().unwrap(); + let size = size.to_usize().unwrap(); + + // Check if values match in the range + if !lhs_is_null && !equal_values(lhs_data, rhs_data, lhs_offset, rhs_offset, size) { + return false; + } + } + } + + true +} diff --git a/arrow-data/src/equal/mod.rs b/arrow-data/src/equal/mod.rs index f24179b61700..7a310b1240df 100644 --- a/arrow-data/src/equal/mod.rs +++ b/arrow-data/src/equal/mod.rs @@ -30,6 +30,7 @@ mod dictionary; mod fixed_binary; mod fixed_list; mod list; +mod list_view; mod null; mod primitive; mod run; @@ -41,6 +42,8 @@ mod variable_size; // these methods assume the same type, len and null count. // For this reason, they are not exposed and are instead used // to build the generic functions below (`equal_range` and `equal`). +use self::run::run_equal; +use crate::equal::list_view::list_view_equal; use boolean::boolean_equal; use byte_view::byte_view_equal; use dictionary::dictionary_equal; @@ -53,8 +56,6 @@ use structure::struct_equal; use union::union_equal; use variable_size::variable_sized_equal; -use self::run::run_equal; - /// Compares the values of two [ArrayData] starting at `lhs_start` and `rhs_start` respectively /// for `len` slots. #[inline] @@ -78,6 +79,8 @@ fn equal_values( DataType::Int64 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::Float32 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::Float64 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Decimal32(_, _) => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::Decimal64(_, _) => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::Decimal128(_, _) => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::Decimal256(_, _) => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::Date32 | DataType::Time32(_) | DataType::Interval(IntervalUnit::YearMonth) => { @@ -102,10 +105,9 @@ fn equal_values( byte_view_equal(lhs, rhs, lhs_start, rhs_start, len) } DataType::List(_) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), - DataType::ListView(_) | DataType::LargeListView(_) => { - unimplemented!("ListView/LargeListView not yet implemented") - } DataType::LargeList(_) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::ListView(_) => list_view_equal::(lhs, rhs, lhs_start, rhs_start, len), + DataType::LargeListView(_) => list_view_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::FixedSizeList(_, _) => fixed_list_equal(lhs, rhs, lhs_start, rhs_start, len), DataType::Struct(_) => struct_equal(lhs, rhs, lhs_start, rhs_start, len), DataType::Union(_, _) => union_equal(lhs, rhs, lhs_start, rhs_start, len), diff --git a/arrow-data/src/equal/structure.rs b/arrow-data/src/equal/structure.rs index e4751c26f489..d6efaff9e4a8 100644 --- a/arrow-data/src/equal/structure.rs +++ b/arrow-data/src/equal/structure.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::data::{contains_nulls, ArrayData}; +use crate::data::{ArrayData, contains_nulls}; use super::equal_range; diff --git a/arrow-data/src/equal/utils.rs b/arrow-data/src/equal/utils.rs index f1f4be44730e..464907c78b21 100644 --- a/arrow-data/src/equal/utils.rs +++ b/arrow-data/src/equal/utils.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use crate::data::{contains_nulls, ArrayData}; +use crate::data::{ArrayData, contains_nulls}; use arrow_buffer::bit_chunk_iterator::BitChunks; use arrow_schema::DataType; diff --git a/arrow-data/src/equal/variable_size.rs b/arrow-data/src/equal/variable_size.rs index d6e8e6a95481..c83a39ebd808 100644 --- a/arrow-data/src/equal/variable_size.rs +++ b/arrow-data/src/equal/variable_size.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -use crate::data::{contains_nulls, ArrayData}; +use crate::data::{ArrayData, contains_nulls}; use arrow_buffer::ArrowNativeType; -use num::Integer; +use num_integer::Integer; use super::utils::equal_len; diff --git a/arrow-data/src/ffi.rs b/arrow-data/src/ffi.rs index 3b446ef255fe..408dfbaac909 100644 --- a/arrow-data/src/ffi.rs +++ b/arrow-data/src/ffi.rs @@ -18,7 +18,7 @@ //! Contains declarations to bind to the [C Data Interface](https://arrow.apache.org/docs/format/CDataInterface.html). use crate::bit_mask::set_bits; -use crate::{layout, ArrayData}; +use crate::{ArrayData, layout}; use arrow_buffer::buffer::NullBuffer; use arrow_buffer::{Buffer, MutableBuffer, ScalarBuffer}; use arrow_schema::DataType; @@ -71,15 +71,15 @@ unsafe extern "C" fn release_array(array: *mut FFI_ArrowArray) { if array.is_null() { return; } - let array = &mut *array; + let array = unsafe { &mut *array }; // take ownership of `private_data`, therefore dropping it` - let private = Box::from_raw(array.private_data as *mut ArrayPrivateData); + let private = unsafe { Box::from_raw(array.private_data as *mut ArrayPrivateData) }; for child in private.children.iter() { - let _ = Box::from_raw(*child); + let _ = unsafe { Box::from_raw(*child) }; } if !private.dictionary.is_null() { - let _ = Box::from_raw(private.dictionary); + let _ = unsafe { Box::from_raw(private.dictionary) }; } array.release = None; @@ -222,7 +222,7 @@ impl FFI_ArrowArray { /// [move]: https://arrow.apache.org/docs/format/CDataInterface.html#moving-an-array /// [valid]: https://doc.rust-lang.org/std/ptr/index.html#safety pub unsafe fn from_raw(array: *mut FFI_ArrowArray) -> Self { - std::ptr::replace(array, Self::empty()) + unsafe { std::ptr::replace(array, Self::empty()) } } /// create an empty `FFI_ArrowArray`, which can be used to import data into diff --git a/arrow-data/src/lib.rs b/arrow-data/src/lib.rs index a023b1d98cb6..07e7553b2b43 100644 --- a/arrow-data/src/lib.rs +++ b/arrow-data/src/lib.rs @@ -23,7 +23,7 @@ html_logo_url = "https://arrow.apache.org/img/arrow-logo_chevrons_black-txt_white-bg.svg", html_favicon_url = "https://arrow.apache.org/img/arrow-logo_chevrons_black-txt_transparent-bg.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] #![warn(missing_docs)] mod data; pub use data::*; diff --git a/arrow-data/src/transform/boolean.rs b/arrow-data/src/transform/boolean.rs index d93fa15a4e0f..1f3bd8f885c0 100644 --- a/arrow-data/src/transform/boolean.rs +++ b/arrow-data/src/transform/boolean.rs @@ -15,11 +15,11 @@ // specific language governing permissions and limitations // under the License. -use super::{Extend, _MutableArrayData, utils::resize_for_bits}; -use crate::bit_mask::set_bits; +use super::{_MutableArrayData, Extend, utils::resize_for_bits}; use crate::ArrayData; +use crate::bit_mask::set_bits; -pub(super) fn build_extend(array: &ArrayData) -> Extend { +pub(super) fn build_extend(array: &ArrayData) -> Extend<'_> { let values = array.buffers()[0].as_slice(); Box::new( move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| { diff --git a/arrow-data/src/transform/fixed_binary.rs b/arrow-data/src/transform/fixed_binary.rs index 44c6f46ebf7e..626ecbee0261 100644 --- a/arrow-data/src/transform/fixed_binary.rs +++ b/arrow-data/src/transform/fixed_binary.rs @@ -15,11 +15,11 @@ // specific language governing permissions and limitations // under the License. -use super::{Extend, _MutableArrayData}; +use super::{_MutableArrayData, Extend}; use crate::ArrayData; use arrow_schema::DataType; -pub(super) fn build_extend(array: &ArrayData) -> Extend { +pub(super) fn build_extend(array: &ArrayData) -> Extend<'_> { let size = match array.data_type() { DataType::FixedSizeBinary(i) => *i as usize, _ => unreachable!(), diff --git a/arrow-data/src/transform/fixed_size_list.rs b/arrow-data/src/transform/fixed_size_list.rs index 8eef7bce9bb3..ada1a2f763c4 100644 --- a/arrow-data/src/transform/fixed_size_list.rs +++ b/arrow-data/src/transform/fixed_size_list.rs @@ -18,9 +18,9 @@ use crate::ArrayData; use arrow_schema::DataType; -use super::{Extend, _MutableArrayData}; +use super::{_MutableArrayData, Extend}; -pub(super) fn build_extend(array: &ArrayData) -> Extend { +pub(super) fn build_extend(array: &ArrayData) -> Extend<'_> { let size = match array.data_type() { DataType::FixedSizeList(_, i) => *i as usize, _ => unreachable!(), diff --git a/arrow-data/src/transform/list.rs b/arrow-data/src/transform/list.rs index d9a1c62a8e8e..b7a9ab6da0ed 100644 --- a/arrow-data/src/transform/list.rs +++ b/arrow-data/src/transform/list.rs @@ -16,14 +16,17 @@ // under the License. use super::{ - Extend, _MutableArrayData, + _MutableArrayData, Extend, utils::{extend_offsets, get_last_offset}, }; use crate::ArrayData; use arrow_buffer::ArrowNativeType; -use num::{CheckedAdd, Integer}; +use num_integer::Integer; +use num_traits::CheckedAdd; -pub(super) fn build_extend(array: &ArrayData) -> Extend { +pub(super) fn build_extend( + array: &ArrayData, +) -> Extend<'_> { let offsets = array.buffer::(0); Box::new( move |mutable: &mut _MutableArrayData, index: usize, start: usize, len: usize| { diff --git a/arrow-data/src/transform/list_view.rs b/arrow-data/src/transform/list_view.rs new file mode 100644 index 000000000000..9b66a6a6abb1 --- /dev/null +++ b/arrow-data/src/transform/list_view.rs @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::ArrayData; +use crate::transform::_MutableArrayData; +use arrow_buffer::ArrowNativeType; +use num_integer::Integer; +use num_traits::CheckedAdd; + +pub(super) fn build_extend( + array: &ArrayData, +) -> crate::transform::Extend<'_> { + let offsets = array.buffer::(0); + let sizes = array.buffer::(1); + Box::new( + move |mutable: &mut _MutableArrayData, _index: usize, start: usize, len: usize| { + let offset_buffer = &mut mutable.buffer1; + let sizes_buffer = &mut mutable.buffer2; + + for &offset in &offsets[start..start + len] { + offset_buffer.push(offset); + } + + // sizes + for &size in &sizes[start..start + len] { + sizes_buffer.push(size); + } + + // the beauty of views is that we don't need to copy child_data, we just splat + // the offsets and sizes. + }, + ) +} + +pub(super) fn extend_nulls(mutable: &mut _MutableArrayData, len: usize) { + let offset_buffer = &mut mutable.buffer1; + let sizes_buffer = &mut mutable.buffer2; + + // We push 0 as a placeholder for NULL values in both the offsets and sizes + (0..len).for_each(|_| offset_buffer.push(T::default())); + (0..len).for_each(|_| sizes_buffer.push(T::default())); +} diff --git a/arrow-data/src/transform/mod.rs b/arrow-data/src/transform/mod.rs index af0e1c104f6a..c6052817bfb6 100644 --- a/arrow-data/src/transform/mod.rs +++ b/arrow-data/src/transform/mod.rs @@ -20,19 +20,20 @@ //! Provides utilities for creating, manipulating, and converting Arrow arrays //! made of primitive types, strings, and nested types. -use super::{data::new_buffers, ArrayData, ArrayDataBuilder, ByteView}; +use super::{ArrayData, ArrayDataBuilder, ByteView, data::new_buffers}; use crate::bit_mask::set_bits; use arrow_buffer::buffer::{BooleanBuffer, NullBuffer}; -use arrow_buffer::{bit_util, i256, ArrowNativeType, Buffer, MutableBuffer}; +use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer, bit_util, i256}; use arrow_schema::{ArrowError, DataType, IntervalUnit, UnionMode}; use half::f16; -use num::Integer; +use num_integer::Integer; use std::mem; mod boolean; mod fixed_binary; mod fixed_size_list; mod list; +mod list_view; mod null; mod primitive; mod run; @@ -73,7 +74,7 @@ impl _MutableArrayData<'_> { } } -fn build_extend_null_bits(array: &ArrayData, use_nulls: bool) -> ExtendNullBits { +fn build_extend_null_bits(array: &ArrayData, use_nulls: bool) -> ExtendNullBits<'_> { if let Some(nulls) = array.nulls() { let bytes = nulls.validity(); Box::new(move |mutable, start, len| { @@ -190,7 +191,7 @@ impl std::fmt::Debug for MutableArrayData<'_> { /// Builds an extend that adds `offset` to the source primitive /// Additionally validates that `max` fits into the /// the underlying primitive returning None if not -fn build_extend_dictionary(array: &ArrayData, offset: usize, max: usize) -> Option { +fn build_extend_dictionary(array: &ArrayData, offset: usize, max: usize) -> Option> { macro_rules! validate_and_build { ($dt: ty) => {{ let _: $dt = max.try_into().ok()?; @@ -215,7 +216,7 @@ fn build_extend_dictionary(array: &ArrayData, offset: usize, max: usize) -> Opti } /// Builds an extend that adds `buffer_offset` to any buffer indices encountered -fn build_extend_view(array: &ArrayData, buffer_offset: u32) -> Extend { +fn build_extend_view(array: &ArrayData, buffer_offset: u32) -> Extend<'_> { let views = array.buffer::(0); Box::new( move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| { @@ -234,7 +235,7 @@ fn build_extend_view(array: &ArrayData, buffer_offset: u32) -> Extend { ) } -fn build_extend(array: &ArrayData) -> Extend { +fn build_extend(array: &ArrayData) -> Extend<'_> { match array.data_type() { DataType::Null => null::build_extend(array), DataType::Boolean => boolean::build_extend(array), @@ -257,16 +258,17 @@ fn build_extend(array: &ArrayData) -> Extend { | DataType::Duration(_) | DataType::Interval(IntervalUnit::DayTime) => primitive::build_extend::(array), DataType::Interval(IntervalUnit::MonthDayNano) => primitive::build_extend::(array), + DataType::Decimal32(_, _) => primitive::build_extend::(array), + DataType::Decimal64(_, _) => primitive::build_extend::(array), DataType::Decimal128(_, _) => primitive::build_extend::(array), DataType::Decimal256(_, _) => primitive::build_extend::(array), DataType::Utf8 | DataType::Binary => variable_size::build_extend::(array), DataType::LargeUtf8 | DataType::LargeBinary => variable_size::build_extend::(array), DataType::BinaryView | DataType::Utf8View => unreachable!("should use build_extend_view"), DataType::Map(_, _) | DataType::List(_) => list::build_extend::(array), - DataType::ListView(_) | DataType::LargeListView(_) => { - unimplemented!("ListView/LargeListView not implemented") - } DataType::LargeList(_) => list::build_extend::(array), + DataType::ListView(_) => list_view::build_extend::(array), + DataType::LargeListView(_) => list_view::build_extend::(array), DataType::Dictionary(_, _) => unreachable!("should use build_extend_dictionary"), DataType::Struct(_) => structure::build_extend(array), DataType::FixedSizeBinary(_) => fixed_binary::build_extend(array), @@ -303,16 +305,17 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls { | DataType::Duration(_) | DataType::Interval(IntervalUnit::DayTime) => primitive::extend_nulls::, DataType::Interval(IntervalUnit::MonthDayNano) => primitive::extend_nulls::, + DataType::Decimal32(_, _) => primitive::extend_nulls::, + DataType::Decimal64(_, _) => primitive::extend_nulls::, DataType::Decimal128(_, _) => primitive::extend_nulls::, DataType::Decimal256(_, _) => primitive::extend_nulls::, DataType::Utf8 | DataType::Binary => variable_size::extend_nulls::, DataType::LargeUtf8 | DataType::LargeBinary => variable_size::extend_nulls::, DataType::BinaryView | DataType::Utf8View => primitive::extend_nulls::, DataType::Map(_, _) | DataType::List(_) => list::extend_nulls::, - DataType::ListView(_) | DataType::LargeListView(_) => { - unimplemented!("ListView/LargeListView not implemented") - } DataType::LargeList(_) => list::extend_nulls::, + DataType::ListView(_) => list_view::extend_nulls::, + DataType::LargeListView(_) => list_view::extend_nulls::, DataType::Dictionary(child_data_type, _) => match child_data_type.as_ref() { DataType::UInt8 => primitive::extend_nulls::, DataType::UInt16 => primitive::extend_nulls::, @@ -446,7 +449,11 @@ impl<'a> MutableArrayData<'a> { new_buffers(data_type, *capacity) } ( - DataType::List(_) | DataType::LargeList(_) | DataType::FixedSizeList(_, _), + DataType::List(_) + | DataType::LargeList(_) + | DataType::ListView(_) + | DataType::LargeListView(_) + | DataType::FixedSizeList(_, _), Capacities::List(capacity, _), ) => { array_capacity = *capacity; @@ -456,7 +463,9 @@ impl<'a> MutableArrayData<'a> { }; let child_data = match &data_type { - DataType::Decimal128(_, _) + DataType::Decimal32(_, _) + | DataType::Decimal64(_, _) + | DataType::Decimal128(_, _) | DataType::Decimal256(_, _) | DataType::Null | DataType::Boolean @@ -485,10 +494,11 @@ impl<'a> MutableArrayData<'a> { | DataType::Utf8View | DataType::Interval(_) | DataType::FixedSizeBinary(_) => vec![], - DataType::ListView(_) | DataType::LargeListView(_) => { - unimplemented!("ListView/LargeListView not implemented") - } - DataType::Map(_, _) | DataType::List(_) | DataType::LargeList(_) => { + DataType::Map(_, _) + | DataType::List(_) + | DataType::LargeList(_) + | DataType::ListView(_) + | DataType::LargeListView(_) => { let children = arrays .iter() .map(|array| &array.child_data()[0]) @@ -779,7 +789,12 @@ impl<'a> MutableArrayData<'a> { b.insert(0, data.buffer1.into()); b } - DataType::Utf8 | DataType::Binary | DataType::LargeUtf8 | DataType::LargeBinary => { + DataType::Utf8 + | DataType::Binary + | DataType::LargeUtf8 + | DataType::LargeBinary + | DataType::ListView(_) + | DataType::LargeListView(_) => { vec![data.buffer1.into(), data.buffer2.into()] } DataType::Union(_, mode) => { diff --git a/arrow-data/src/transform/null.rs b/arrow-data/src/transform/null.rs index 5d1535564d9e..7355a5420b8e 100644 --- a/arrow-data/src/transform/null.rs +++ b/arrow-data/src/transform/null.rs @@ -15,10 +15,10 @@ // specific language governing permissions and limitations // under the License. -use super::{Extend, _MutableArrayData}; +use super::{_MutableArrayData, Extend}; use crate::ArrayData; -pub(super) fn build_extend(_: &ArrayData) -> Extend { +pub(super) fn build_extend(_: &ArrayData) -> Extend<'_> { Box::new(move |_, _, _, _| {}) } diff --git a/arrow-data/src/transform/primitive.rs b/arrow-data/src/transform/primitive.rs index 627dc00de1df..8f9929c4305d 100644 --- a/arrow-data/src/transform/primitive.rs +++ b/arrow-data/src/transform/primitive.rs @@ -20,9 +20,9 @@ use arrow_buffer::ArrowNativeType; use std::mem::size_of; use std::ops::Add; -use super::{Extend, _MutableArrayData}; +use super::{_MutableArrayData, Extend}; -pub(super) fn build_extend(array: &ArrayData) -> Extend { +pub(super) fn build_extend(array: &ArrayData) -> Extend<'_> { let values = array.buffer::(0); Box::new( move |mutable: &mut _MutableArrayData, _, start: usize, len: usize| { @@ -33,7 +33,7 @@ pub(super) fn build_extend(array: &ArrayData) -> Extend { ) } -pub(super) fn build_extend_with_offset(array: &ArrayData, offset: T) -> Extend +pub(super) fn build_extend_with_offset(array: &ArrayData, offset: T) -> Extend<'_> where T: ArrowNativeType + Add, { diff --git a/arrow-data/src/transform/run.rs b/arrow-data/src/transform/run.rs index 0d37a8374c6d..6ae3a034f340 100644 --- a/arrow-data/src/transform/run.rs +++ b/arrow-data/src/transform/run.rs @@ -15,19 +15,17 @@ // specific language governing permissions and limitations // under the License. -use super::{ArrayData, Extend, _MutableArrayData}; +use super::{_MutableArrayData, ArrayData, Extend}; use arrow_buffer::{ArrowNativeType, Buffer, ToByteSlice}; use arrow_schema::DataType; -use num::CheckedAdd; +use num_traits::CheckedAdd; /// Generic helper to get the last run end value from a run ends array fn get_last_run_end(run_ends_data: &super::MutableArrayData) -> T { if run_ends_data.data.len == 0 { T::default() } else { - // Convert buffer to typed slice and get the last element - let buffer = Buffer::from(run_ends_data.data.buffer1.as_slice()); - let typed_slice: &[T] = buffer.typed_data(); + let typed_slice: &[T] = run_ends_data.data.buffer1.typed_data(); if typed_slice.len() >= run_ends_data.data.len { typed_slice[run_ends_data.data.len - 1] } else { @@ -75,10 +73,7 @@ pub fn extend_nulls(mutable: &mut _MutableArrayData, len: usize) { DataType::Int16 => extend_nulls_impl!(i16), DataType::Int32 => extend_nulls_impl!(i32), DataType::Int64 => extend_nulls_impl!(i64), - _ => panic!( - "Invalid run end type for RunEndEncoded array: {:?}", - run_end_type - ), + _ => panic!("Invalid run end type for RunEndEncoded array: {run_end_type}"), }; mutable.child_data[0].data.len += 1; @@ -184,7 +179,7 @@ fn process_extends_batch( /// Returns a function that extends the run encoded array. /// /// It finds the physical indices in the source array that correspond to the logical range to copy, and adjusts the runs to the logical indices of the array to extend. The values are copied from the source array to the destination array verbatim. -pub fn build_extend(array: &ArrayData) -> Extend { +pub fn build_extend(array: &ArrayData) -> Extend<'_> { Box::new( move |mutable: &mut _MutableArrayData, array_idx: usize, start: usize, len: usize| { if len == 0 { @@ -211,7 +206,7 @@ pub fn build_extend(array: &ArrayData) -> Extend { let (run_ends_bytes, values_range) = build_extend_arrays::<$run_end_type>( source_buffer, source_run_ends.len(), - start, + start + array.offset(), len, dest_last_run_end, ); @@ -228,10 +223,7 @@ pub fn build_extend(array: &ArrayData) -> Extend { DataType::Int16 => build_and_process_impl!(i16), DataType::Int32 => build_and_process_impl!(i32), DataType::Int64 => build_and_process_impl!(i64), - _ => panic!( - "Invalid run end type for RunEndEncoded array: {:?}", - dest_run_end_type - ), + _ => panic!("Invalid run end type for RunEndEncoded array: {dest_run_end_type}",), } }, ) diff --git a/arrow-data/src/transform/structure.rs b/arrow-data/src/transform/structure.rs index 7330dcaa3705..588cc00f446b 100644 --- a/arrow-data/src/transform/structure.rs +++ b/arrow-data/src/transform/structure.rs @@ -15,10 +15,10 @@ // specific language governing permissions and limitations // under the License. -use super::{Extend, _MutableArrayData}; +use super::{_MutableArrayData, Extend}; use crate::ArrayData; -pub(super) fn build_extend(_: &ArrayData) -> Extend { +pub(super) fn build_extend(_: &ArrayData) -> Extend<'_> { Box::new( move |mutable: &mut _MutableArrayData, index: usize, start: usize, len: usize| { mutable diff --git a/arrow-data/src/transform/union.rs b/arrow-data/src/transform/union.rs index d7083588d782..f6f291e3f05d 100644 --- a/arrow-data/src/transform/union.rs +++ b/arrow-data/src/transform/union.rs @@ -15,10 +15,10 @@ // specific language governing permissions and limitations // under the License. -use super::{Extend, _MutableArrayData}; +use super::{_MutableArrayData, Extend}; use crate::ArrayData; -pub(super) fn build_extend_sparse(array: &ArrayData) -> Extend { +pub(super) fn build_extend_sparse(array: &ArrayData) -> Extend<'_> { let type_ids = array.buffer::(0); Box::new( @@ -36,7 +36,7 @@ pub(super) fn build_extend_sparse(array: &ArrayData) -> Extend { ) } -pub(super) fn build_extend_dense(array: &ArrayData) -> Extend { +pub(super) fn build_extend_dense(array: &ArrayData) -> Extend<'_> { let type_ids = array.buffer::(0); let offsets = array.buffer::(1); let arrow_schema::DataType::Union(src_fields, _) = array.data_type() else { diff --git a/arrow-data/src/transform/utils.rs b/arrow-data/src/transform/utils.rs index 5407f68e0d0c..979738d057fd 100644 --- a/arrow-data/src/transform/utils.rs +++ b/arrow-data/src/transform/utils.rs @@ -15,8 +15,9 @@ // specific language governing permissions and limitations // under the License. -use arrow_buffer::{bit_util, ArrowNativeType, MutableBuffer}; -use num::{CheckedAdd, Integer}; +use arrow_buffer::{ArrowNativeType, MutableBuffer, bit_util}; +use num_integer::Integer; +use num_traits::CheckedAdd; /// extends the `buffer` to be able to hold `len` bits, setting all bits of the new size to zero. #[inline] @@ -52,9 +53,9 @@ pub(super) unsafe fn get_last_offset(offset_buffer: &Mutable // Soundness // * offset buffer is always extended in slices of T and aligned accordingly. // * Buffer[0] is initialized with one element, 0, and thus `mutable_offsets.len() - 1` is always valid. - let (prefix, offsets, suffix) = offset_buffer.as_slice().align_to::(); + let (prefix, offsets, suffix) = unsafe { offset_buffer.as_slice().align_to::() }; debug_assert!(prefix.is_empty() && suffix.is_empty()); - *offsets.get_unchecked(offsets.len() - 1) + *unsafe { offsets.get_unchecked(offsets.len() - 1) } } #[cfg(test)] diff --git a/arrow-data/src/transform/variable_size.rs b/arrow-data/src/transform/variable_size.rs index ec0174bf8cb2..ec9dcf1fd1c2 100644 --- a/arrow-data/src/transform/variable_size.rs +++ b/arrow-data/src/transform/variable_size.rs @@ -17,11 +17,11 @@ use crate::ArrayData; use arrow_buffer::{ArrowNativeType, MutableBuffer}; -use num::traits::AsPrimitive; -use num::{CheckedAdd, Integer}; +use num_integer::Integer; +use num_traits::{AsPrimitive, CheckedAdd}; use super::{ - Extend, _MutableArrayData, + _MutableArrayData, Extend, utils::{extend_offsets, get_last_offset}, }; @@ -41,7 +41,7 @@ fn extend_offset_values>( pub(super) fn build_extend>( array: &ArrayData, -) -> Extend { +) -> Extend<'_> { let offsets = array.buffer::(0); let values = array.buffers()[1].as_slice(); Box::new( diff --git a/arrow-flight/Cargo.toml b/arrow-flight/Cargo.toml index 041901e4915a..8f95e1995a67 100644 --- a/arrow-flight/Cargo.toml +++ b/arrow-flight/Cargo.toml @@ -44,11 +44,12 @@ bytes = { version = "1", default-features = false } futures = { version = "0.3", default-features = false, features = ["alloc"] } once_cell = { version = "1", optional = true } paste = { version = "1.0" , optional = true } -prost = { version = "0.13.1", default-features = false, features = ["prost-derive"] } +prost = { version = "0.14.1", default-features = false, features = ["derive"] } # For Timestamp type -prost-types = { version = "0.13.1", default-features = false } +prost-types = { version = "0.14.1", default-features = false } tokio = { version = "1.0", default-features = false, features = ["macros", "rt", "rt-multi-thread"], optional = true } -tonic = { version = "0.12.3", default-features = false, features = ["transport", "codegen", "prost"] } +tonic = { version = "0.14.1", default-features = false, features = ["transport", "codegen", "router"] } +tonic-prost = { version = "0.14.1", default-features = false } # CLI-related dependencies anyhow = { version = "1.0", optional = true } @@ -64,9 +65,13 @@ default = [] flight-sql = ["dep:arrow-arith", "dep:arrow-data", "dep:arrow-ord", "dep:arrow-row", "dep:arrow-select", "dep:arrow-string", "dep:once_cell", "dep:paste"] # TODO: Remove in the next release flight-sql-experimental = ["flight-sql"] -tls = ["tonic/tls"] +tls-aws-lc= ["tonic/tls-aws-lc"] +tls-native-roots = ["tonic/tls-native-roots"] +tls-ring = ["tonic/tls-ring"] +tls-webpki-roots = ["tonic/tls-webpki-roots"] + # Enable CLI tools -cli = ["arrow-array/chrono-tz", "arrow-cast/prettyprint", "tonic/tls-webpki-roots", "dep:anyhow", "dep:clap", "dep:tracing-log", "dep:tracing-subscriber"] +cli = ["arrow-array/chrono-tz", "arrow-cast/prettyprint", "tonic/tls-webpki-roots", "tonic/gzip", "tonic/deflate", "tonic/zstd", "dep:anyhow", "dep:clap", "dep:tracing-log", "dep:tracing-subscriber", "dep:tokio"] [dev-dependencies] arrow-cast = { workspace = true, features = ["prettyprint"] } @@ -85,18 +90,18 @@ uuid = { version = "1.10.0", features = ["v4"] } [[example]] name = "flight_sql_server" -required-features = ["flight-sql", "tls"] +required-features = ["flight-sql", "tls-ring"] [[bin]] name = "flight_sql_client" -required-features = ["cli", "flight-sql", "tls"] +required-features = ["cli", "flight-sql", "tls-ring"] [[test]] name = "flight_sql_client" path = "tests/flight_sql_client.rs" -required-features = ["flight-sql", "tls"] +required-features = ["flight-sql", "tls-ring"] [[test]] name = "flight_sql_client_cli" path = "tests/flight_sql_client_cli.rs" -required-features = ["cli", "flight-sql", "tls"] +required-features = ["cli", "flight-sql", "tls-ring"] diff --git a/arrow-flight/README.md b/arrow-flight/README.md index 381a63048b69..1cd8f5cfe21b 100644 --- a/arrow-flight/README.md +++ b/arrow-flight/README.md @@ -43,12 +43,16 @@ that demonstrate how to build a Flight server implemented with [tonic](https://d ## Feature Flags -- `flight-sql`: Enables experimental support for - [Apache Arrow FlightSQL], a protocol for interacting with SQL databases. +- `flight-sql`: Support for [Apache Arrow FlightSQL], a protocol for interacting with SQL databases. -- `flight-sql-experimental` : Deprecated feature and will be removed in next release +You can enable TLS using the following features (not enabled by default) -- `tls`: Enables `tls` on `tonic` +- `tls-aws-lc`: enables [tonic feature] `tls-aws-lc` +- `tls-native-roots`: enables [tonic feature] `tls-native-roots` +- `tls-ring`: enables [tonic feature] `tls-ring` +- `tls-webpki`: enables [tonic feature] `tls-webpki-roots` + +[tonic feature]: https://docs.rs/tonic/latest/tonic/#feature-flags ## CLI diff --git a/arrow-flight/examples/flight_sql_server.rs b/arrow-flight/examples/flight_sql_server.rs index 396b72f4cb22..ae03cac28515 100644 --- a/arrow-flight/examples/flight_sql_server.rs +++ b/arrow-flight/examples/flight_sql_server.rs @@ -15,12 +15,12 @@ // specific language governing permissions and limitations // under the License. -use arrow_flight::sql::server::PeekableFlightDataStream; use arrow_flight::sql::DoPutPreparedStatementResult; -use base64::prelude::BASE64_STANDARD; +use arrow_flight::sql::server::PeekableFlightDataStream; use base64::Engine; +use base64::prelude::BASE64_STANDARD; use core::str; -use futures::{stream, Stream, TryStreamExt}; +use futures::{Stream, TryStreamExt, stream}; use once_cell::sync::Lazy; use prost::Message; use std::collections::HashSet; @@ -39,23 +39,23 @@ use arrow_flight::sql::metadata::{ SqlInfoData, SqlInfoDataBuilder, XdbcTypeInfo, XdbcTypeInfoData, XdbcTypeInfoDataBuilder, }; use arrow_flight::sql::{ - server::FlightSqlService, ActionBeginSavepointRequest, ActionBeginSavepointResult, - ActionBeginTransactionRequest, ActionBeginTransactionResult, ActionCancelQueryRequest, - ActionCancelQueryResult, ActionClosePreparedStatementRequest, - ActionCreatePreparedStatementRequest, ActionCreatePreparedStatementResult, - ActionCreatePreparedSubstraitPlanRequest, ActionEndSavepointRequest, - ActionEndTransactionRequest, Any, CommandGetCatalogs, CommandGetCrossReference, - CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, CommandGetPrimaryKeys, - CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, CommandGetXdbcTypeInfo, - CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementIngest, - CommandStatementQuery, CommandStatementSubstraitPlan, CommandStatementUpdate, Nullable, - ProstMessageExt, Searchable, SqlInfo, TicketStatementQuery, XdbcDataType, + ActionBeginSavepointRequest, ActionBeginSavepointResult, ActionBeginTransactionRequest, + ActionBeginTransactionResult, ActionCancelQueryRequest, ActionCancelQueryResult, + ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest, + ActionCreatePreparedStatementResult, ActionCreatePreparedSubstraitPlanRequest, + ActionEndSavepointRequest, ActionEndTransactionRequest, Any, CommandGetCatalogs, + CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys, + CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables, + CommandGetXdbcTypeInfo, CommandPreparedStatementQuery, CommandPreparedStatementUpdate, + CommandStatementIngest, CommandStatementQuery, CommandStatementSubstraitPlan, + CommandStatementUpdate, Nullable, ProstMessageExt, Searchable, SqlInfo, TicketStatementQuery, + XdbcDataType, server::FlightSqlService, }; use arrow_flight::utils::batches_to_flight_data; use arrow_flight::{ - flight_service_server::FlightService, flight_service_server::FlightServiceServer, Action, - FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse, - IpcMessage, SchemaAsIpc, Ticket, + Action, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, + HandshakeResponse, IpcMessage, SchemaAsIpc, Ticket, flight_service_server::FlightService, + flight_service_server::FlightServiceServer, }; use arrow_ipc::writer::IpcWriteOptions; use arrow_schema::{ArrowError, DataType, Field, Schema}; @@ -189,7 +189,7 @@ impl FlightSqlService for FlightSqlServiceImpl { let result = Ok(result); let output = futures::stream::iter(vec![result]); - let token = format!("Bearer {}", FAKE_TOKEN); + let token = format!("Bearer {FAKE_TOKEN}"); let mut response: Response + Send>>> = Response::new(Box::pin(output)); response.metadata_mut().append( @@ -745,7 +745,7 @@ async fn main() -> Result<(), Box> { let addr_str = "0.0.0.0:50051"; let addr = addr_str.parse()?; - println!("Listening on {:?}", addr); + println!("Listening on {addr:?}"); if std::env::var("USE_TLS").ok().is_some() { let cert = std::fs::read_to_string("arrow-flight/examples/data/server.pem")?; @@ -814,7 +814,7 @@ mod tests { async fn bind_tcp() -> (TcpIncoming, SocketAddr) { let listener = TcpListener::bind("0.0.0.0:0").await.unwrap(); let addr = listener.local_addr().unwrap(); - let incoming = TcpIncoming::from_listener(listener, true, None).unwrap(); + let incoming = TcpIncoming::from(listener).with_nodelay(Some(true)); (incoming, addr) } diff --git a/arrow-flight/examples/server.rs b/arrow-flight/examples/server.rs index 8c766b075957..ca856dce28cb 100644 --- a/arrow-flight/examples/server.rs +++ b/arrow-flight/examples/server.rs @@ -20,9 +20,9 @@ use tonic::transport::Server; use tonic::{Request, Response, Status, Streaming}; use arrow_flight::{ - flight_service_server::FlightService, flight_service_server::FlightServiceServer, Action, - ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, - HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket, + Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, + HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket, + flight_service_server::FlightService, flight_service_server::FlightServiceServer, }; #[derive(Clone)] diff --git a/arrow-flight/gen/Cargo.toml b/arrow-flight/gen/Cargo.toml index 79d46cd377fa..2ce3f814d89b 100644 --- a/arrow-flight/gen/Cargo.toml +++ b/arrow-flight/gen/Cargo.toml @@ -32,5 +32,5 @@ publish = false [dependencies] # Pin specific version of the tonic-build dependencies to avoid auto-generated # (and checked in) arrow.flight.protocol.rs from changing -prost-build = { version = "=0.13.5", default-features = false } -tonic-build = { version = "=0.12.3", default-features = false, features = ["transport", "prost"] } +prost-build = { version = "0.14.1", default-features = false } +tonic-prost-build = { version = "0.14.1", default-features = false } diff --git a/arrow-flight/gen/src/main.rs b/arrow-flight/gen/src/main.rs index a69134e7acbe..6db70dc10938 100644 --- a/arrow-flight/gen/src/main.rs +++ b/arrow-flight/gen/src/main.rs @@ -25,11 +25,11 @@ fn main() -> Result<(), Box> { let proto_dir = Path::new("../format"); let proto_path = Path::new("../format/Flight.proto"); - tonic_build::configure() + tonic_prost_build::configure() // protoc in Ubuntu builder needs this option .protoc_arg("--experimental_allow_proto3_optional") .out_dir("src") - .compile_protos_with_config(prost_config(), &[proto_path], &[proto_dir])?; + .compile_with_config(prost_config(), &[proto_path], &[proto_dir])?; // read file contents to string let mut file = OpenOptions::new() @@ -48,11 +48,11 @@ fn main() -> Result<(), Box> { let proto_dir = Path::new("../format"); let proto_path = Path::new("../format/FlightSql.proto"); - tonic_build::configure() + tonic_prost_build::configure() // protoc in Ubuntu builder needs this option .protoc_arg("--experimental_allow_proto3_optional") .out_dir("src/sql") - .compile_protos_with_config(prost_config(), &[proto_path], &[proto_dir])?; + .compile_with_config(prost_config(), &[proto_path], &[proto_dir])?; // read file contents to string let mut file = OpenOptions::new() diff --git a/arrow-flight/src/arrow.flight.protocol.rs b/arrow-flight/src/arrow.flight.protocol.rs index 0cd4f6948b77..bb6370d1acec 100644 --- a/arrow-flight/src/arrow.flight.protocol.rs +++ b/arrow-flight/src/arrow.flight.protocol.rs @@ -3,7 +3,7 @@ // This file is @generated by prost-build. /// /// The request that a client provides to a server on handshake. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct HandshakeRequest { /// /// A defined protocol version @@ -14,7 +14,7 @@ pub struct HandshakeRequest { #[prost(bytes = "bytes", tag = "2")] pub payload: ::prost::bytes::Bytes, } -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct HandshakeResponse { /// /// A defined protocol version @@ -27,19 +27,19 @@ pub struct HandshakeResponse { } /// /// A message for doing simple auth. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct BasicAuth { #[prost(string, tag = "2")] pub username: ::prost::alloc::string::String, #[prost(string, tag = "3")] pub password: ::prost::alloc::string::String, } -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct Empty {} /// /// Describes an available action, including both the name used for execution /// along with a short description of the purpose of the action. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct ActionType { #[prost(string, tag = "1")] pub r#type: ::prost::alloc::string::String, @@ -49,14 +49,14 @@ pub struct ActionType { /// /// A service specific expression that can be used to return a limited set /// of available Arrow Flight streams. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Criteria { #[prost(bytes = "bytes", tag = "1")] pub expression: ::prost::bytes::Bytes, } /// /// An opaque action specific for the service. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Action { #[prost(string, tag = "1")] pub r#type: ::prost::alloc::string::String, @@ -83,7 +83,7 @@ pub struct RenewFlightEndpointRequest { } /// /// An opaque result returned after executing an action. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Result { #[prost(bytes = "bytes", tag = "1")] pub body: ::prost::bytes::Bytes, @@ -92,14 +92,14 @@ pub struct Result { /// The result of the CancelFlightInfo action. /// /// The result should be stored in Result.body. -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct CancelFlightInfoResult { #[prost(enumeration = "CancelStatus", tag = "1")] pub status: i32, } /// /// Wrap the result of a getSchema call -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct SchemaResult { /// The schema of the dataset in its IPC form: /// 4 bytes - an optional IPC_CONTINUATION_TOKEN prefix @@ -111,7 +111,7 @@ pub struct SchemaResult { /// /// The name or tag for a Flight. May be used as a way to retrieve or generate /// a flight or be used to expose a set of previously defined flights. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct FlightDescriptor { #[prost(enumeration = "flight_descriptor::DescriptorType", tag = "1")] pub r#type: i32, @@ -322,7 +322,7 @@ pub struct FlightEndpoint { /// /// A location where a Flight service will accept retrieval of a particular /// stream given a ticket. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Location { #[prost(string, tag = "1")] pub uri: ::prost::alloc::string::String, @@ -333,14 +333,14 @@ pub struct Location { /// /// Tickets are meant to be single use. It is an error/application-defined /// behavior to reuse a ticket. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct Ticket { #[prost(bytes = "bytes", tag = "1")] pub ticket: ::prost::bytes::Bytes, } /// /// A batch of Arrow data as part of a stream of batches. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct FlightData { /// /// The descriptor of the data. This is only relevant when a client is @@ -365,7 +365,7 @@ pub struct FlightData { } /// * /// The response message associated with the submission of a DoPut. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct PutResult { #[prost(bytes = "bytes", tag = "1")] pub app_metadata: ::prost::bytes::Bytes, @@ -435,20 +435,9 @@ pub mod flight_service_client { pub struct FlightServiceClient { inner: tonic::client::Grpc, } - impl FlightServiceClient { - /// Attempt to create a new client by connecting to a given endpoint. - pub async fn connect(dst: D) -> Result - where - D: TryInto, - D::Error: Into, - { - let conn = tonic::transport::Endpoint::new(dst)?.connect().await?; - Ok(Self::new(conn)) - } - } impl FlightServiceClient where - T: tonic::client::GrpcService, + T: tonic::client::GrpcService, T::Error: Into, T::ResponseBody: Body + std::marker::Send + 'static, ::Error: Into + std::marker::Send, @@ -469,13 +458,13 @@ pub mod flight_service_client { F: tonic::service::Interceptor, T::ResponseBody: Default, T: tonic::codegen::Service< - http::Request, + http::Request, Response = http::Response< - >::ResponseBody, + >::ResponseBody, >, >, , + http::Request, >>::Error: Into + std::marker::Send + std::marker::Sync, { FlightServiceClient::new(InterceptedService::new(inner, interceptor)) @@ -531,7 +520,7 @@ pub mod flight_service_client { format!("Service was not ready: {}", e.into()), ) })?; - let codec = tonic::codec::ProstCodec::default(); + let codec = tonic_prost::ProstCodec::default(); let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/Handshake", ); @@ -564,7 +553,7 @@ pub mod flight_service_client { format!("Service was not ready: {}", e.into()), ) })?; - let codec = tonic::codec::ProstCodec::default(); + let codec = tonic_prost::ProstCodec::default(); let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/ListFlights", ); @@ -598,7 +587,7 @@ pub mod flight_service_client { format!("Service was not ready: {}", e.into()), ) })?; - let codec = tonic::codec::ProstCodec::default(); + let codec = tonic_prost::ProstCodec::default(); let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/GetFlightInfo", ); @@ -647,7 +636,7 @@ pub mod flight_service_client { format!("Service was not ready: {}", e.into()), ) })?; - let codec = tonic::codec::ProstCodec::default(); + let codec = tonic_prost::ProstCodec::default(); let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/PollFlightInfo", ); @@ -678,7 +667,7 @@ pub mod flight_service_client { format!("Service was not ready: {}", e.into()), ) })?; - let codec = tonic::codec::ProstCodec::default(); + let codec = tonic_prost::ProstCodec::default(); let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/GetSchema", ); @@ -709,7 +698,7 @@ pub mod flight_service_client { format!("Service was not ready: {}", e.into()), ) })?; - let codec = tonic::codec::ProstCodec::default(); + let codec = tonic_prost::ProstCodec::default(); let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/DoGet", ); @@ -740,7 +729,7 @@ pub mod flight_service_client { format!("Service was not ready: {}", e.into()), ) })?; - let codec = tonic::codec::ProstCodec::default(); + let codec = tonic_prost::ProstCodec::default(); let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/DoPut", ); @@ -770,7 +759,7 @@ pub mod flight_service_client { format!("Service was not ready: {}", e.into()), ) })?; - let codec = tonic::codec::ProstCodec::default(); + let codec = tonic_prost::ProstCodec::default(); let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/DoExchange", ); @@ -803,7 +792,7 @@ pub mod flight_service_client { format!("Service was not ready: {}", e.into()), ) })?; - let codec = tonic::codec::ProstCodec::default(); + let codec = tonic_prost::ProstCodec::default(); let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/DoAction", ); @@ -833,7 +822,7 @@ pub mod flight_service_client { format!("Service was not ready: {}", e.into()), ) })?; - let codec = tonic::codec::ProstCodec::default(); + let codec = tonic_prost::ProstCodec::default(); let path = http::uri::PathAndQuery::from_static( "/arrow.flight.protocol.FlightService/ListActions", ); @@ -1098,7 +1087,7 @@ pub mod flight_service_server { B: Body + std::marker::Send + 'static, B::Error: Into + std::marker::Send + 'static, { - type Response = http::Response; + type Response = http::Response; type Error = std::convert::Infallible; type Future = BoxFuture; fn poll_ready( @@ -1142,7 +1131,7 @@ pub mod flight_service_server { let inner = self.inner.clone(); let fut = async move { let method = HandshakeSvc(inner); - let codec = tonic::codec::ProstCodec::default(); + let codec = tonic_prost::ProstCodec::default(); let mut grpc = tonic::server::Grpc::new(codec) .apply_compression_config( accept_compression_encodings, @@ -1188,7 +1177,7 @@ pub mod flight_service_server { let inner = self.inner.clone(); let fut = async move { let method = ListFlightsSvc(inner); - let codec = tonic::codec::ProstCodec::default(); + let codec = tonic_prost::ProstCodec::default(); let mut grpc = tonic::server::Grpc::new(codec) .apply_compression_config( accept_compression_encodings, @@ -1233,7 +1222,7 @@ pub mod flight_service_server { let inner = self.inner.clone(); let fut = async move { let method = GetFlightInfoSvc(inner); - let codec = tonic::codec::ProstCodec::default(); + let codec = tonic_prost::ProstCodec::default(); let mut grpc = tonic::server::Grpc::new(codec) .apply_compression_config( accept_compression_encodings, @@ -1279,7 +1268,7 @@ pub mod flight_service_server { let inner = self.inner.clone(); let fut = async move { let method = PollFlightInfoSvc(inner); - let codec = tonic::codec::ProstCodec::default(); + let codec = tonic_prost::ProstCodec::default(); let mut grpc = tonic::server::Grpc::new(codec) .apply_compression_config( accept_compression_encodings, @@ -1324,7 +1313,7 @@ pub mod flight_service_server { let inner = self.inner.clone(); let fut = async move { let method = GetSchemaSvc(inner); - let codec = tonic::codec::ProstCodec::default(); + let codec = tonic_prost::ProstCodec::default(); let mut grpc = tonic::server::Grpc::new(codec) .apply_compression_config( accept_compression_encodings, @@ -1370,7 +1359,7 @@ pub mod flight_service_server { let inner = self.inner.clone(); let fut = async move { let method = DoGetSvc(inner); - let codec = tonic::codec::ProstCodec::default(); + let codec = tonic_prost::ProstCodec::default(); let mut grpc = tonic::server::Grpc::new(codec) .apply_compression_config( accept_compression_encodings, @@ -1416,7 +1405,7 @@ pub mod flight_service_server { let inner = self.inner.clone(); let fut = async move { let method = DoPutSvc(inner); - let codec = tonic::codec::ProstCodec::default(); + let codec = tonic_prost::ProstCodec::default(); let mut grpc = tonic::server::Grpc::new(codec) .apply_compression_config( accept_compression_encodings, @@ -1462,7 +1451,7 @@ pub mod flight_service_server { let inner = self.inner.clone(); let fut = async move { let method = DoExchangeSvc(inner); - let codec = tonic::codec::ProstCodec::default(); + let codec = tonic_prost::ProstCodec::default(); let mut grpc = tonic::server::Grpc::new(codec) .apply_compression_config( accept_compression_encodings, @@ -1508,7 +1497,7 @@ pub mod flight_service_server { let inner = self.inner.clone(); let fut = async move { let method = DoActionSvc(inner); - let codec = tonic::codec::ProstCodec::default(); + let codec = tonic_prost::ProstCodec::default(); let mut grpc = tonic::server::Grpc::new(codec) .apply_compression_config( accept_compression_encodings, @@ -1554,7 +1543,7 @@ pub mod flight_service_server { let inner = self.inner.clone(); let fut = async move { let method = ListActionsSvc(inner); - let codec = tonic::codec::ProstCodec::default(); + let codec = tonic_prost::ProstCodec::default(); let mut grpc = tonic::server::Grpc::new(codec) .apply_compression_config( accept_compression_encodings, @@ -1571,7 +1560,9 @@ pub mod flight_service_server { } _ => { Box::pin(async move { - let mut response = http::Response::new(empty_body()); + let mut response = http::Response::new( + tonic::body::Body::default(), + ); let headers = response.headers_mut(); headers .insert( diff --git a/arrow-flight/src/bin/flight_sql_client.rs b/arrow-flight/src/bin/flight_sql_client.rs index 7b9e34898ac8..554c6339aac2 100644 --- a/arrow-flight/src/bin/flight_sql_client.rs +++ b/arrow-flight/src/bin/flight_sql_client.rs @@ -17,15 +17,16 @@ use std::{sync::Arc, time::Duration}; -use anyhow::{bail, Context, Result}; +use anyhow::{Context, Result, bail}; use arrow_array::{ArrayRef, Datum, RecordBatch, StringArray}; -use arrow_cast::{cast_with_options, pretty::pretty_format_batches, CastOptions}; +use arrow_cast::{CastOptions, cast_with_options, pretty::pretty_format_batches}; use arrow_flight::{ - sql::{client::FlightSqlServiceClient, CommandGetDbSchemas, CommandGetTables}, FlightInfo, + flight_service_client::FlightServiceClient, + sql::{CommandGetDbSchemas, CommandGetTables, client::FlightSqlServiceClient}, }; use arrow_schema::Schema; -use clap::{Parser, Subcommand}; +use clap::{Parser, Subcommand, ValueEnum}; use core::str; use futures::TryStreamExt; use tonic::{ @@ -53,6 +54,24 @@ pub struct LoggingArgs { log_verbose_count: u8, } +/// gRPC/HTTP compression algorithms. +#[derive(Clone, Copy, Debug, PartialEq, Eq, ValueEnum)] +pub enum CompressionEncoding { + Gzip, + Deflate, + Zstd, +} + +impl From for tonic::codec::CompressionEncoding { + fn from(encoding: CompressionEncoding) -> Self { + match encoding { + CompressionEncoding::Gzip => Self::Gzip, + CompressionEncoding::Deflate => Self::Deflate, + CompressionEncoding::Zstd => Self::Zstd, + } + } +} + #[derive(Debug, Parser)] struct ClientArgs { /// Additional headers. @@ -85,6 +104,14 @@ struct ClientArgs { #[clap(long)] tls: bool, + /// Dump TLS key log. + /// + /// The target file is specified by the `SSLKEYLOGFILE` environment variable. + /// + /// Requires `--tls`. + #[clap(long, requires = "tls")] + key_log: bool, + /// Server host. /// /// Required. @@ -96,6 +123,34 @@ struct ClientArgs { /// Defaults to `443` if `tls` is set, otherwise defaults to `80`. #[clap(long)] port: Option, + + /// Compression accepted by the client for responses sent by the server. + /// + /// The client will send this information to the server as part of the request. The server is free to pick an + /// algorithm from that list or use no compression (called "identity" encoding). + /// + /// You may define multiple algorithms by using a comma-separated list. + #[clap(long, value_delimiter = ',')] + accept_compression: Vec, + + /// Compression of requests sent by the client to the server. + /// + /// Since the client needs to decide on the compression before sending the request, there is no client<->server + /// negotiation. If the server does NOT support the chosen compression, it will respond with an error a la: + /// + /// ``` + /// Ipc error: Status { + /// code: Unimplemented, + /// message: "Content is compressed with `zstd` which isn't supported", + /// metadata: MetadataMap { headers: {"grpc-accept-encoding": "identity", ...} }, + /// ... + /// } + /// ``` + /// + /// Based on the algorithms listed in the `grpc-accept-encoding` header, you may make a more educated guess for + /// your next request. Note that `identity` is a synonym for "no compression". + #[clap(long)] + send_compression: Option, } #[derive(Debug, Parser)] @@ -323,7 +378,7 @@ fn construct_record_batch_from_params( } fn setup_logging(args: LoggingArgs) -> Result<()> { - use tracing_subscriber::{util::SubscriberInitExt, EnvFilter, FmtSubscriber}; + use tracing_subscriber::{EnvFilter, FmtSubscriber, util::SubscriberInitExt}; tracing_log::LogTracer::init().context("tracing log init")?; @@ -357,7 +412,11 @@ async fn setup_client(args: ClientArgs) -> Result Result bool { - self.schema().is_some() - } - /// Return schema for the stream, if it has been received pub fn schema(&self) -> Option<&SchemaRef> { self.inner.schema() diff --git a/arrow-flight/src/encode.rs b/arrow-flight/src/encode.rs index 57ac9f3173fe..187de400f6c0 100644 --- a/arrow-flight/src/encode.rs +++ b/arrow-flight/src/encode.rs @@ -17,14 +17,14 @@ use std::{collections::VecDeque, fmt::Debug, pin::Pin, sync::Arc, task::Poll}; -use crate::{error::Result, FlightData, FlightDescriptor, SchemaAsIpc}; +use crate::{FlightData, FlightDescriptor, SchemaAsIpc, error::Result}; use arrow_array::{Array, ArrayRef, RecordBatch, RecordBatchOptions, UnionArray}; -use arrow_ipc::writer::{DictionaryTracker, IpcDataGenerator, IpcWriteOptions}; +use arrow_ipc::writer::{CompressionContext, DictionaryTracker, IpcDataGenerator, IpcWriteOptions}; use arrow_schema::{DataType, Field, FieldRef, Fields, Schema, SchemaRef, UnionMode}; use bytes::Bytes; -use futures::{ready, stream::BoxStream, Stream, StreamExt}; +use futures::{Stream, StreamExt, ready, stream::BoxStream}; /// Creates a [`Stream`] of [`FlightData`]s from a /// `Stream` of [`Result`]<[`RecordBatch`], [`FlightError`]>. @@ -535,15 +535,13 @@ fn prepare_field_for_flight( ) .with_metadata(field.metadata().clone()) } else { - #[allow(deprecated)] - let dict_id = dictionary_tracker.set_dict_id(field.as_ref()); - + dictionary_tracker.next_dict_id(); #[allow(deprecated)] Field::new_dict( field.name(), field.data_type().clone(), field.is_nullable(), - dict_id, + 0, field.dict_is_ordered().unwrap_or_default(), ) .with_metadata(field.metadata().clone()) @@ -585,14 +583,13 @@ fn prepare_schema_for_flight( ) .with_metadata(field.metadata().clone()) } else { - #[allow(deprecated)] - let dict_id = dictionary_tracker.set_dict_id(field.as_ref()); + dictionary_tracker.next_dict_id(); #[allow(deprecated)] Field::new_dict( field.name(), field.data_type().clone(), field.is_nullable(), - dict_id, + 0, field.dict_is_ordered().unwrap_or_default(), ) .with_metadata(field.metadata().clone()) @@ -650,20 +647,16 @@ struct FlightIpcEncoder { options: IpcWriteOptions, data_gen: IpcDataGenerator, dictionary_tracker: DictionaryTracker, + compression_context: CompressionContext, } impl FlightIpcEncoder { fn new(options: IpcWriteOptions, error_on_replacement: bool) -> Self { - #[allow(deprecated)] - let preserve_dict_id = options.preserve_dict_id(); Self { options, data_gen: IpcDataGenerator::default(), - #[allow(deprecated)] - dictionary_tracker: DictionaryTracker::new_with_preserve_dict_id( - error_on_replacement, - preserve_dict_id, - ), + dictionary_tracker: DictionaryTracker::new(error_on_replacement), + compression_context: CompressionContext::default(), } } @@ -675,9 +668,12 @@ impl FlightIpcEncoder { /// Convert a `RecordBatch` to a Vec of `FlightData` representing /// dictionaries and a `FlightData` representing the batch fn encode_batch(&mut self, batch: &RecordBatch) -> Result<(Vec, FlightData)> { - let (encoded_dictionaries, encoded_batch) = - self.data_gen - .encoded_batch(batch, &mut self.dictionary_tracker, &self.options)?; + let (encoded_dictionaries, encoded_batch) = self.data_gen.encode( + batch, + &mut self.dictionary_tracker, + &self.options, + &mut self.compression_context, + )?; let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect(); let flight_batch = encoded_batch.into(); @@ -1547,9 +1543,8 @@ mod tests { async fn verify_flight_round_trip(mut batches: Vec) { let expected_schema = batches.first().unwrap().schema(); - #[allow(deprecated)] let encoder = FlightDataEncoderBuilder::default() - .with_options(IpcWriteOptions::default().with_preserve_dict_id(false)) + .with_options(IpcWriteOptions::default()) .with_dictionary_handling(DictionaryHandling::Resend) .build(futures::stream::iter(batches.clone().into_iter().map(Ok))); @@ -1575,8 +1570,7 @@ mod tests { HashMap::from([("some_key".to_owned(), "some_value".to_owned())]), ); - #[allow(deprecated)] - let mut dictionary_tracker = DictionaryTracker::new_with_preserve_dict_id(false, true); + let mut dictionary_tracker = DictionaryTracker::new(false); let got = prepare_schema_for_flight(&schema, &mut dictionary_tracker, false); assert!(got.metadata().contains_key("some_key")); @@ -1606,12 +1600,16 @@ mod tests { options: &IpcWriteOptions, ) -> (Vec, FlightData) { let data_gen = IpcDataGenerator::default(); - #[allow(deprecated)] - let mut dictionary_tracker = - DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id()); + let mut dictionary_tracker = DictionaryTracker::new(false); + let mut compression_context = CompressionContext::default(); let (encoded_dictionaries, encoded_batch) = data_gen - .encoded_batch(batch, &mut dictionary_tracker, options) + .encode( + batch, + &mut dictionary_tracker, + options, + &mut compression_context, + ) .expect("DictionaryTracker configured above to not error on replacement"); let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect(); @@ -1695,9 +1693,9 @@ mod tests { #[tokio::test] async fn flight_data_size_even() { - let s1 = StringArray::from_iter_values(std::iter::repeat(".10 bytes.").take(1024)); + let s1 = StringArray::from_iter_values(std::iter::repeat_n(".10 bytes.", 1024)); let i1 = Int16Array::from_iter_values(0..1024); - let s2 = StringArray::from_iter_values(std::iter::repeat("6bytes").take(1024)); + let s2 = StringArray::from_iter_values(std::iter::repeat_n("6bytes", 1024)); let i2 = Int64Array::from_iter_values(0..1024); let batch = RecordBatch::try_from_iter(vec![ diff --git a/arrow-flight/src/error.rs b/arrow-flight/src/error.rs index ac8030583299..d22c24eea6d4 100644 --- a/arrow-flight/src/error.rs +++ b/arrow-flight/src/error.rs @@ -51,12 +51,12 @@ impl FlightError { impl std::fmt::Display for FlightError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - FlightError::Arrow(source) => write!(f, "Arrow error: {}", source), - FlightError::NotYetImplemented(desc) => write!(f, "Not yet implemented: {}", desc), - FlightError::Tonic(source) => write!(f, "Tonic error: {}", source), - FlightError::ProtocolError(desc) => write!(f, "Protocol error: {}", desc), - FlightError::DecodeError(desc) => write!(f, "Decode error: {}", desc), - FlightError::ExternalError(source) => write!(f, "External error: {}", source), + FlightError::Arrow(source) => write!(f, "Arrow error: {source}"), + FlightError::NotYetImplemented(desc) => write!(f, "Not yet implemented: {desc}"), + FlightError::Tonic(source) => write!(f, "Tonic error: {source}"), + FlightError::ProtocolError(desc) => write!(f, "Protocol error: {desc}"), + FlightError::DecodeError(desc) => write!(f, "Decode error: {desc}"), + FlightError::ExternalError(source) => write!(f, "External error: {source}"), } } } @@ -78,6 +78,12 @@ impl From for FlightError { } } +impl From for FlightError { + fn from(error: prost::DecodeError) -> Self { + Self::DecodeError(error.to_string()) + } +} + impl From for FlightError { fn from(value: ArrowError) -> Self { Self::Arrow(value) diff --git a/arrow-flight/src/lib.rs b/arrow-flight/src/lib.rs index 72dd07040920..db900341560c 100644 --- a/arrow-flight/src/lib.rs +++ b/arrow-flight/src/lib.rs @@ -35,15 +35,13 @@ //! 3. Support for [Flight SQL] in [`sql`]. Requires the //! `flight-sql` feature of this crate to be activated. //! -//! 4. The feature [`flight-sql-experimental`] is deprecated and will be removed in a future release. -//! //! [Flight SQL]: https://arrow.apache.org/docs/format/FlightSql.html #![doc( html_logo_url = "https://arrow.apache.org/img/arrow-logo_chevrons_black-txt_white-bg.svg", html_favicon_url = "https://arrow.apache.org/img/arrow-logo_chevrons_black-txt_transparent-bg.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] #![allow(rustdoc::invalid_html_tags)] #![warn(missing_docs)] // The unused_crate_dependencies lint does not work well for crates defining additional examples/bin targets @@ -53,8 +51,8 @@ use arrow_ipc::{convert, writer, writer::EncodedData, writer::IpcWriteOptions}; use arrow_schema::{ArrowError, Schema}; use arrow_ipc::convert::try_schema_from_ipc_buffer; -use base64::prelude::BASE64_STANDARD; use base64::Engine; +use base64::prelude::BASE64_STANDARD; use bytes::Bytes; use prost_types::Timestamp; use std::{fmt, ops::Deref}; @@ -62,7 +60,7 @@ use std::{fmt, ops::Deref}; type ArrowResult = std::result::Result; #[allow(clippy::all)] -mod gen { +mod r#gen { // Since this file is auto-generated, we suppress all warnings #![allow(missing_docs)] include!("arrow.flight.protocol.rs"); @@ -70,22 +68,22 @@ mod gen { /// Defines a `Flight` for generation or retrieval. pub mod flight_descriptor { - use super::gen; - pub use gen::flight_descriptor::DescriptorType; + use super::r#gen; + pub use r#gen::flight_descriptor::DescriptorType; } /// Low Level [tonic] [`FlightServiceClient`](gen::flight_service_client::FlightServiceClient). pub mod flight_service_client { - use super::gen; - pub use gen::flight_service_client::FlightServiceClient; + use super::r#gen; + pub use r#gen::flight_service_client::FlightServiceClient; } /// Low Level [tonic] [`FlightServiceServer`](gen::flight_service_server::FlightServiceServer) /// and [`FlightService`](gen::flight_service_server::FlightService). pub mod flight_service_server { - use super::gen; - pub use gen::flight_service_server::FlightService; - pub use gen::flight_service_server::FlightServiceServer; + use super::r#gen; + pub use r#gen::flight_service_server::FlightService; + pub use r#gen::flight_service_server::FlightServiceServer; } /// Mid Level [`FlightClient`] @@ -103,27 +101,27 @@ pub mod encode; /// Common error types pub mod error; -pub use gen::Action; -pub use gen::ActionType; -pub use gen::BasicAuth; -pub use gen::CancelFlightInfoRequest; -pub use gen::CancelFlightInfoResult; -pub use gen::CancelStatus; -pub use gen::Criteria; -pub use gen::Empty; -pub use gen::FlightData; -pub use gen::FlightDescriptor; -pub use gen::FlightEndpoint; -pub use gen::FlightInfo; -pub use gen::HandshakeRequest; -pub use gen::HandshakeResponse; -pub use gen::Location; -pub use gen::PollInfo; -pub use gen::PutResult; -pub use gen::RenewFlightEndpointRequest; -pub use gen::Result; -pub use gen::SchemaResult; -pub use gen::Ticket; +pub use r#gen::Action; +pub use r#gen::ActionType; +pub use r#gen::BasicAuth; +pub use r#gen::CancelFlightInfoRequest; +pub use r#gen::CancelFlightInfoResult; +pub use r#gen::CancelStatus; +pub use r#gen::Criteria; +pub use r#gen::Empty; +pub use r#gen::FlightData; +pub use r#gen::FlightDescriptor; +pub use r#gen::FlightEndpoint; +pub use r#gen::FlightInfo; +pub use r#gen::HandshakeRequest; +pub use r#gen::HandshakeResponse; +pub use r#gen::Location; +pub use r#gen::PollInfo; +pub use r#gen::PutResult; +pub use r#gen::RenewFlightEndpointRequest; +pub use r#gen::Result; +pub use r#gen::SchemaResult; +pub use r#gen::Ticket; /// Helper to extract HTTP/gRPC trailers from a tonic stream. mod trailers; @@ -151,9 +149,7 @@ pub struct IpcMessage(pub Bytes); fn flight_schema_as_encoded_data(arrow_schema: &Schema, options: &IpcWriteOptions) -> EncodedData { let data_gen = writer::IpcDataGenerator::default(); - #[allow(deprecated)] - let mut dict_tracker = - writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id()); + let mut dict_tracker = writer::DictionaryTracker::new(false); data_gen.schema_to_bytes_with_dictionary_tracker(arrow_schema, &mut dict_tracker, options) } @@ -607,6 +603,12 @@ impl FlightInfo { self } + /// Add endpoints for fetching all data + pub fn with_endpoints(mut self, endpoints: Vec) -> Self { + self.endpoint = endpoints; + self + } + /// Add a [`FlightDescriptor`] describing what this data is pub fn with_descriptor(mut self, flight_descriptor: FlightDescriptor) -> Self { self.flight_descriptor = Some(flight_descriptor); diff --git a/arrow-flight/src/sql/arrow.flight.protocol.sql.rs b/arrow-flight/src/sql/arrow.flight.protocol.sql.rs index 7a37a0b28856..e7083c583edd 100644 --- a/arrow-flight/src/sql/arrow.flight.protocol.sql.rs +++ b/arrow-flight/src/sql/arrow.flight.protocol.sql.rs @@ -19,7 +19,7 @@ /// int32_to_int32_list_map: map> /// > /// where there is one row per requested piece of metadata information. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct CommandGetSqlInfo { /// /// Values are modelled after ODBC's SQLGetInfo() function. This information is intended to provide @@ -99,7 +99,7 @@ pub struct CommandGetSqlInfo { /// is only relevant to be used by ODBC). /// > /// The returned data should be ordered by data_type and then by type_name. -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct CommandGetXdbcTypeInfo { /// /// Specifies the data type to search for the info. @@ -118,7 +118,7 @@ pub struct CommandGetXdbcTypeInfo { /// catalog_name: utf8 not null /// > /// The returned data should be ordered by catalog_name. -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct CommandGetCatalogs {} /// /// Represents a request to retrieve the list of database schemas on a Flight SQL enabled backend. @@ -133,7 +133,7 @@ pub struct CommandGetCatalogs {} /// db_schema_name: utf8 not null /// > /// The returned data should be ordered by catalog_name, then db_schema_name. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct CommandGetDbSchemas { /// /// Specifies the Catalog to search for the tables. @@ -177,7 +177,7 @@ pub struct CommandGetDbSchemas { /// - ARROW:FLIGHT:SQL:IS_READ_ONLY - "1" indicates if the column is read only, "0" otherwise. /// - ARROW:FLIGHT:SQL:IS_SEARCHABLE - "1" indicates if the column is searchable via WHERE clause, "0" otherwise. /// The returned data should be ordered by catalog_name, db_schema_name, table_name, then table_type, followed by table_schema if requested. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct CommandGetTables { /// /// Specifies the Catalog to search for the tables. @@ -226,7 +226,7 @@ pub struct CommandGetTables { /// table_type: utf8 not null /// > /// The returned data should be ordered by table_type. -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct CommandGetTableTypes {} /// /// Represents a request to retrieve the primary keys of a table on a Flight SQL enabled backend. @@ -244,7 +244,7 @@ pub struct CommandGetTableTypes {} /// key_sequence: int32 not null /// > /// The returned data should be ordered by catalog_name, db_schema_name, table_name, key_name, then key_sequence. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct CommandGetPrimaryKeys { /// /// Specifies the catalog to search for the table. @@ -287,7 +287,7 @@ pub struct CommandGetPrimaryKeys { /// > /// The returned data should be ordered by fk_catalog_name, fk_db_schema_name, fk_table_name, fk_key_name, then key_sequence. /// update_rule and delete_rule returns a byte that is equivalent to actions declared on UpdateDeleteRules enum. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct CommandGetExportedKeys { /// /// Specifies the catalog to search for the foreign key table. @@ -334,7 +334,7 @@ pub struct CommandGetExportedKeys { /// - 2 = SET NULL /// - 3 = NO ACTION /// - 4 = SET DEFAULT -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct CommandGetImportedKeys { /// /// Specifies the catalog to search for the primary key table. @@ -383,7 +383,7 @@ pub struct CommandGetImportedKeys { /// - 2 = SET NULL /// - 3 = NO ACTION /// - 4 = SET DEFAULT -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct CommandGetCrossReference { /// * /// The catalog name where the parent table is. @@ -420,7 +420,7 @@ pub struct CommandGetCrossReference { } /// /// Request message for the "CreatePreparedStatement" action on a Flight SQL enabled backend. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct ActionCreatePreparedStatementRequest { /// The valid SQL string to create a prepared statement for. #[prost(string, tag = "1")] @@ -432,7 +432,7 @@ pub struct ActionCreatePreparedStatementRequest { } /// /// An embedded message describing a Substrait plan to execute. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct SubstraitPlan { /// The serialized substrait.Plan to create a prepared statement for. /// XXX(ARROW-16902): this is bytes instead of an embedded message @@ -448,7 +448,7 @@ pub struct SubstraitPlan { } /// /// Request message for the "CreatePreparedSubstraitPlan" action on a Flight SQL enabled backend. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct ActionCreatePreparedSubstraitPlanRequest { /// The serialized substrait.Plan to create a prepared statement for. #[prost(message, optional, tag = "1")] @@ -466,7 +466,7 @@ pub struct ActionCreatePreparedSubstraitPlanRequest { /// - Automatically, by a server timeout. /// /// The result should be wrapped in a google.protobuf.Any message. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct ActionCreatePreparedStatementResult { /// Opaque handle for the prepared statement on the server. #[prost(bytes = "bytes", tag = "1")] @@ -486,7 +486,7 @@ pub struct ActionCreatePreparedStatementResult { /// /// Request message for the "ClosePreparedStatement" action on a Flight SQL enabled backend. /// Closes server resources associated with the prepared statement handle. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct ActionClosePreparedStatementRequest { /// Opaque handle for the prepared statement on the server. #[prost(bytes = "bytes", tag = "1")] @@ -495,7 +495,7 @@ pub struct ActionClosePreparedStatementRequest { /// /// Request message for the "BeginTransaction" action. /// Begins a transaction. -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct ActionBeginTransactionRequest {} /// /// Request message for the "BeginSavepoint" action. @@ -503,7 +503,7 @@ pub struct ActionBeginTransactionRequest {} /// /// Only supported if FLIGHT_SQL_TRANSACTION is /// FLIGHT_SQL_TRANSACTION_SUPPORT_SAVEPOINT. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct ActionBeginSavepointRequest { /// The transaction to which a savepoint belongs. #[prost(bytes = "bytes", tag = "1")] @@ -520,7 +520,7 @@ pub struct ActionBeginSavepointRequest { /// automatically rolled back. /// /// The result should be wrapped in a google.protobuf.Any message. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct ActionBeginTransactionResult { /// Opaque handle for the transaction on the server. #[prost(bytes = "bytes", tag = "1")] @@ -534,7 +534,7 @@ pub struct ActionBeginTransactionResult { /// out, then the savepoint is also invalidated. /// /// The result should be wrapped in a google.protobuf.Any message. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct ActionBeginSavepointResult { /// Opaque handle for the savepoint on the server. #[prost(bytes = "bytes", tag = "1")] @@ -547,7 +547,7 @@ pub struct ActionBeginSavepointResult { /// /// If the action completes successfully, the transaction handle is /// invalidated, as are all associated savepoints. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct ActionEndTransactionRequest { /// Opaque handle for the transaction on the server. #[prost(bytes = "bytes", tag = "1")] @@ -609,7 +609,7 @@ pub mod action_end_transaction_request { /// Releasing a savepoint invalidates that savepoint. Rolling back to /// a savepoint does not invalidate the savepoint, but invalidates all /// savepoints created after the current savepoint. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct ActionEndSavepointRequest { /// Opaque handle for the savepoint on the server. #[prost(bytes = "bytes", tag = "1")] @@ -678,7 +678,7 @@ pub mod action_end_savepoint_request { /// - ARROW:FLIGHT:SQL:IS_READ_ONLY - "1" indicates if the column is read only, "0" otherwise. /// - ARROW:FLIGHT:SQL:IS_SEARCHABLE - "1" indicates if the column is searchable via WHERE clause, "0" otherwise. /// - GetFlightInfo: execute the query. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct CommandStatementQuery { /// The SQL syntax. #[prost(string, tag = "1")] @@ -704,7 +704,7 @@ pub struct CommandStatementQuery { /// - ARROW:FLIGHT:SQL:IS_SEARCHABLE - "1" indicates if the column is searchable via WHERE clause, "0" otherwise. /// - GetFlightInfo: execute the query. /// - DoPut: execute the query. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct CommandStatementSubstraitPlan { /// A serialized substrait.Plan #[prost(message, optional, tag = "1")] @@ -716,7 +716,7 @@ pub struct CommandStatementSubstraitPlan { /// * /// Represents a ticket resulting from GetFlightInfo with a CommandStatementQuery. /// This should be used only once and treated as an opaque value, that is, clients should not attempt to parse this. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct TicketStatementQuery { /// Unique identifier for the instance of the statement to execute. #[prost(bytes = "bytes", tag = "1")] @@ -742,7 +742,7 @@ pub struct TicketStatementQuery { /// for the parameters when determining the schema. /// - DoPut: bind parameter values. All of the bound parameter sets will be executed as a single atomic execution. /// - GetFlightInfo: execute the prepared statement instance. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct CommandPreparedStatementQuery { /// Opaque handle for the prepared statement on the server. #[prost(bytes = "bytes", tag = "1")] @@ -751,7 +751,7 @@ pub struct CommandPreparedStatementQuery { /// /// Represents a SQL update query. Used in the command member of FlightDescriptor /// for the RPC call DoPut to cause the server to execute the included SQL update. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct CommandStatementUpdate { /// The SQL syntax. #[prost(string, tag = "1")] @@ -764,7 +764,7 @@ pub struct CommandStatementUpdate { /// Represents a SQL update query. Used in the command member of FlightDescriptor /// for the RPC call DoPut to cause the server to execute the included /// prepared statement handle as an update. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct CommandPreparedStatementUpdate { /// Opaque handle for the prepared statement on the server. #[prost(bytes = "bytes", tag = "1")] @@ -810,7 +810,7 @@ pub struct CommandStatementIngest { /// Nested message and enum types in `CommandStatementIngest`. pub mod command_statement_ingest { /// Options for table definition behavior - #[derive(Clone, Copy, PartialEq, ::prost::Message)] + #[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct TableDefinitionOptions { #[prost( enumeration = "table_definition_options::TableNotExistOption", @@ -918,7 +918,7 @@ pub mod command_statement_ingest { /// Returned from the RPC call DoPut when a CommandStatementUpdate, /// CommandPreparedStatementUpdate, or CommandStatementIngest was /// in the request, containing results from the update. -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct DoPutUpdateResult { /// The number of records updated. A return value of -1 represents /// an unknown updated record count. @@ -930,7 +930,7 @@ pub struct DoPutUpdateResult { /// *Note on legacy behavior*: previous versions of the protocol did not return any result for /// this command, and that behavior should still be supported by clients. In that case, the client /// can continue as though the fields in this message were not provided or set to sensible default values. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct DoPutPreparedStatementResult { /// Represents a (potentially updated) opaque handle for the prepared statement on the server. /// Because the handle could potentially be updated, any previous handles for this prepared @@ -959,7 +959,7 @@ pub struct DoPutPreparedStatementResult { /// /// This command is deprecated since 13.0.0. Use the "CancelFlightInfo" /// action with DoAction instead. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, Eq, Hash, ::prost::Message)] pub struct ActionCancelQueryRequest { /// The result of the GetFlightInfo RPC that initiated the query. /// XXX(ARROW-16902): this must be a serialized FlightInfo, but is @@ -975,7 +975,7 @@ pub struct ActionCancelQueryRequest { /// /// This command is deprecated since 13.0.0. Use the "CancelFlightInfo" /// action with DoAction instead. -#[derive(Clone, Copy, PartialEq, ::prost::Message)] +#[derive(Clone, Copy, PartialEq, Eq, Hash, ::prost::Message)] pub struct ActionCancelQueryResult { #[prost(enumeration = "action_cancel_query_result::CancelResult", tag = "1")] pub result: i32, diff --git a/arrow-flight/src/sql/client.rs b/arrow-flight/src/sql/client.rs index 6791b68b757d..5476d4ede9a4 100644 --- a/arrow-flight/src/sql/client.rs +++ b/arrow-flight/src/sql/client.rs @@ -17,8 +17,14 @@ //! A FlightSQL Client [`FlightSqlServiceClient`] -use base64::prelude::BASE64_STANDARD; +use arrow_buffer::Buffer; +use arrow_ipc::MessageHeader; +use arrow_ipc::convert::fb_to_schema; +use arrow_ipc::reader::read_record_batch; +use arrow_ipc::root_as_message; +use arrow_schema::SchemaRef; use base64::Engine; +use base64::prelude::BASE64_STANDARD; use bytes::Bytes; use std::collections::HashMap; use std::str::FromStr; @@ -27,8 +33,9 @@ use tonic::metadata::AsciiMetadataKey; use crate::decode::FlightRecordBatchStream; use crate::encode::FlightDataEncoderBuilder; use crate::error::FlightError; +use crate::error::Result; use crate::flight_service_client::FlightServiceClient; -use crate::sql::gen::action_end_transaction_request::EndTransaction; +use crate::sql::r#gen::action_end_transaction_request::EndTransaction; use crate::sql::server::{ BEGIN_TRANSACTION, CLOSE_PREPARED_STATEMENT, CREATE_PREPARED_STATEMENT, END_TRANSACTION, }; @@ -49,19 +56,15 @@ use crate::{ IpcMessage, PutResult, Ticket, }; use arrow_array::RecordBatch; -use arrow_buffer::Buffer; -use arrow_ipc::convert::fb_to_schema; -use arrow_ipc::reader::read_record_batch; -use arrow_ipc::{root_as_message, MessageHeader}; -use arrow_schema::{ArrowError, Schema, SchemaRef}; -use futures::{stream, Stream, TryStreamExt}; +use arrow_schema::{ArrowError, Schema}; +use futures::{Stream, TryStreamExt, stream}; use prost::Message; -use tonic::transport::Channel; +use tonic::codegen::{Body, StdError}; use tonic::{IntoRequest, IntoStreamingRequest, Streaming}; /// A FlightSQLServiceClient is an endpoint for retrieving or storing Arrow data /// by FlightSQL protocol. -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct FlightSqlServiceClient { token: Option, headers: HashMap, @@ -71,14 +74,20 @@ pub struct FlightSqlServiceClient { /// A FlightSql protocol client that can run queries against FlightSql servers /// This client is in the "experimental" stage. It is not guaranteed to follow the spec in all instances. /// Github issues are welcomed. -impl FlightSqlServiceClient { +impl FlightSqlServiceClient +where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: Body + Send + 'static, + ::Error: Into + Send, +{ /// Creates a new FlightSql client that connects to a server over an arbitrary tonic `Channel` - pub fn new(channel: Channel) -> Self { + pub fn new(channel: T) -> Self { Self::new_from_inner(FlightServiceClient::new(channel)) } /// Creates a new higher level client with the provided lower level client - pub fn new_from_inner(inner: FlightServiceClient) -> Self { + pub fn new_from_inner(inner: FlightServiceClient) -> Self { Self { token: None, flight_client: inner, @@ -87,17 +96,17 @@ impl FlightSqlServiceClient { } /// Return a reference to the underlying [`FlightServiceClient`] - pub fn inner(&self) -> &FlightServiceClient { + pub fn inner(&self) -> &FlightServiceClient { &self.flight_client } /// Return a mutable reference to the underlying [`FlightServiceClient`] - pub fn inner_mut(&mut self) -> &mut FlightServiceClient { + pub fn inner_mut(&mut self) -> &mut FlightServiceClient { &mut self.flight_client } /// Consume this client and return the underlying [`FlightServiceClient`] - pub fn into_inner(self) -> FlightServiceClient { + pub fn into_inner(self) -> FlightServiceClient { self.flight_client } @@ -126,15 +135,10 @@ impl FlightSqlServiceClient { async fn get_flight_info_for_command( &mut self, cmd: M, - ) -> Result { + ) -> Result { let descriptor = FlightDescriptor::new_cmd(cmd.as_any().encode_to_vec()); let req = self.set_request_headers(descriptor.into_request())?; - let fi = self - .flight_client - .get_flight_info(req) - .await - .map_err(status_to_arrow_error)? - .into_inner(); + let fi = self.flight_client.get_flight_info(req).await?.into_inner(); Ok(fi) } @@ -143,7 +147,7 @@ impl FlightSqlServiceClient { &mut self, query: String, transaction_id: Option, - ) -> Result { + ) -> Result { let cmd = CommandStatementQuery { query, transaction_id, @@ -156,7 +160,7 @@ impl FlightSqlServiceClient { /// If the server returns an "authorization" header, it is automatically parsed and set as /// a token for future requests. Any other data returned by the server in the handshake /// response is returned as a binary blob. - pub async fn handshake(&mut self, username: &str, password: &str) -> Result { + pub async fn handshake(&mut self, username: &str, password: &str) -> Result { let cmd = HandshakeRequest { protocol_version: 0, payload: Default::default(), @@ -179,7 +183,7 @@ impl FlightSqlServiceClient { .map_err(|_| ArrowError::ParseError("Can't read auth header".to_string()))?; let bearer = "Bearer "; if !auth.starts_with(bearer) { - Err(ArrowError::ParseError("Invalid auth header!".to_string()))?; + return Err(ArrowError::ParseError("Invalid auth header!".to_string()))?; } let auth = auth[bearer.len()..].to_string(); self.token = Some(auth); @@ -204,7 +208,7 @@ impl FlightSqlServiceClient { &mut self, query: String, transaction_id: Option, - ) -> Result { + ) -> Result { let cmd = CommandStatementUpdate { query, transaction_id, @@ -217,19 +221,9 @@ impl FlightSqlServiceClient { }]) .into_request(), )?; - let mut result = self - .flight_client - .do_put(req) - .await - .map_err(status_to_arrow_error)? - .into_inner(); - let result = result - .message() - .await - .map_err(status_to_arrow_error)? - .unwrap(); - let result: DoPutUpdateResult = - Message::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?; + let mut result = self.flight_client.do_put(req).await?.into_inner(); + let result = result.message().await?.unwrap(); + let result: DoPutUpdateResult = Message::decode(&*result.app_metadata)?; Ok(result.record_count) } @@ -238,7 +232,7 @@ impl FlightSqlServiceClient { &mut self, command: CommandStatementIngest, stream: S, - ) -> Result + ) -> Result where S: Stream> + Send + 'static, { @@ -255,41 +249,28 @@ impl FlightSqlServiceClient { FallibleRequestStream::new(sender, flight_data); let req = self.set_request_headers(flight_data.into_streaming_request())?; - let mut result = self - .flight_client - .do_put(req) - .await - .map_err(status_to_arrow_error)? - .into_inner(); + let mut result = self.flight_client.do_put(req).await?.into_inner(); // check if the there were any errors in the input stream provided note // if receiver.await fails, it means the sender was dropped and there is // no message to return. if let Ok(msg) = receiver.await { - return Err(ArrowError::ExternalError(Box::new(msg))); + return Err(FlightError::ExternalError(Box::new(msg))); } - let result = result - .message() - .await - .map_err(status_to_arrow_error)? - .unwrap(); - let result: DoPutUpdateResult = - Message::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?; + let result = result.message().await?.unwrap(); + let result: DoPutUpdateResult = Message::decode(&*result.app_metadata)?; Ok(result.record_count) } /// Request a list of catalogs as tabular FlightInfo results - pub async fn get_catalogs(&mut self) -> Result { + pub async fn get_catalogs(&mut self) -> Result { self.get_flight_info_for_command(CommandGetCatalogs {}) .await } /// Request a list of database schemas as tabular FlightInfo results - pub async fn get_db_schemas( - &mut self, - request: CommandGetDbSchemas, - ) -> Result { + pub async fn get_db_schemas(&mut self, request: CommandGetDbSchemas) -> Result { self.get_flight_info_for_command(request).await } @@ -297,15 +278,10 @@ impl FlightSqlServiceClient { pub async fn do_get( &mut self, ticket: impl IntoRequest, - ) -> Result { + ) -> Result { let req = self.set_request_headers(ticket.into_request())?; - let (md, response_stream, _ext) = self - .flight_client - .do_get(req) - .await - .map_err(status_to_arrow_error)? - .into_parts(); + let (md, response_stream, _ext) = self.flight_client.do_get(req).await?.into_parts(); let (response_stream, trailers) = extract_lazy_trailers(response_stream); Ok(FlightRecordBatchStream::new_from_flight_data( @@ -319,43 +295,27 @@ impl FlightSqlServiceClient { pub async fn do_put( &mut self, request: impl tonic::IntoStreamingRequest, - ) -> Result, ArrowError> { + ) -> Result> { let req = self.set_request_headers(request.into_streaming_request())?; - Ok(self - .flight_client - .do_put(req) - .await - .map_err(status_to_arrow_error)? - .into_inner()) + Ok(self.flight_client.do_put(req).await?.into_inner()) } /// DoAction allows a flight client to do a specific action against a flight service pub async fn do_action( &mut self, request: impl IntoRequest, - ) -> Result, ArrowError> { + ) -> Result> { let req = self.set_request_headers(request.into_request())?; - Ok(self - .flight_client - .do_action(req) - .await - .map_err(status_to_arrow_error)? - .into_inner()) + Ok(self.flight_client.do_action(req).await?.into_inner()) } /// Request a list of tables. - pub async fn get_tables( - &mut self, - request: CommandGetTables, - ) -> Result { + pub async fn get_tables(&mut self, request: CommandGetTables) -> Result { self.get_flight_info_for_command(request).await } /// Request the primary keys for a table. - pub async fn get_primary_keys( - &mut self, - request: CommandGetPrimaryKeys, - ) -> Result { + pub async fn get_primary_keys(&mut self, request: CommandGetPrimaryKeys) -> Result { self.get_flight_info_for_command(request).await } @@ -364,7 +324,7 @@ impl FlightSqlServiceClient { pub async fn get_exported_keys( &mut self, request: CommandGetExportedKeys, - ) -> Result { + ) -> Result { self.get_flight_info_for_command(request).await } @@ -372,7 +332,7 @@ impl FlightSqlServiceClient { pub async fn get_imported_keys( &mut self, request: CommandGetImportedKeys, - ) -> Result { + ) -> Result { self.get_flight_info_for_command(request).await } @@ -382,21 +342,18 @@ impl FlightSqlServiceClient { pub async fn get_cross_reference( &mut self, request: CommandGetCrossReference, - ) -> Result { + ) -> Result { self.get_flight_info_for_command(request).await } /// Request a list of table types. - pub async fn get_table_types(&mut self) -> Result { + pub async fn get_table_types(&mut self) -> Result { self.get_flight_info_for_command(CommandGetTableTypes {}) .await } /// Request a list of SQL information. - pub async fn get_sql_info( - &mut self, - sql_infos: Vec, - ) -> Result { + pub async fn get_sql_info(&mut self, sql_infos: Vec) -> Result { let request = CommandGetSqlInfo { info: sql_infos.iter().map(|sql_info| *sql_info as u32).collect(), }; @@ -407,7 +364,7 @@ impl FlightSqlServiceClient { pub async fn get_xdbc_type_info( &mut self, request: CommandGetXdbcTypeInfo, - ) -> Result { + ) -> Result { self.get_flight_info_for_command(request).await } @@ -416,7 +373,10 @@ impl FlightSqlServiceClient { &mut self, query: String, transaction_id: Option, - ) -> Result, ArrowError> { + ) -> Result> + where + T: Clone, + { let cmd = ActionCreatePreparedStatementRequest { query, transaction_id, @@ -426,18 +386,9 @@ impl FlightSqlServiceClient { body: cmd.as_any().encode_to_vec().into(), }; let req = self.set_request_headers(action.into_request())?; - let mut result = self - .flight_client - .do_action(req) - .await - .map_err(status_to_arrow_error)? - .into_inner(); - let result = result - .message() - .await - .map_err(status_to_arrow_error)? - .unwrap(); - let any = Any::decode(&*result.body).map_err(decode_error_to_arrow_error)?; + let mut result = self.flight_client.do_action(req).await?.into_inner(); + let result = result.message().await?.unwrap(); + let any = Any::decode(&*result.body)?; let prepared_result: ActionCreatePreparedStatementResult = any.unpack()?.unwrap(); let dataset_schema = match prepared_result.dataset_schema.len() { 0 => Schema::empty(), @@ -456,25 +407,16 @@ impl FlightSqlServiceClient { } /// Request to begin a transaction. - pub async fn begin_transaction(&mut self) -> Result { + pub async fn begin_transaction(&mut self) -> Result { let cmd = ActionBeginTransactionRequest {}; let action = Action { r#type: BEGIN_TRANSACTION.to_string(), body: cmd.as_any().encode_to_vec().into(), }; let req = self.set_request_headers(action.into_request())?; - let mut result = self - .flight_client - .do_action(req) - .await - .map_err(status_to_arrow_error)? - .into_inner(); - let result = result - .message() - .await - .map_err(status_to_arrow_error)? - .unwrap(); - let any = Any::decode(&*result.body).map_err(decode_error_to_arrow_error)?; + let mut result = self.flight_client.do_action(req).await?.into_inner(); + let result = result.message().await?.unwrap(); + let any = Any::decode(&*result.body)?; let begin_result: ActionBeginTransactionResult = any.unpack()?.unwrap(); Ok(begin_result.transaction_id) } @@ -484,7 +426,7 @@ impl FlightSqlServiceClient { &mut self, transaction_id: Bytes, action: EndTransaction, - ) -> Result<(), ArrowError> { + ) -> Result<()> { let cmd = ActionEndTransactionRequest { transaction_id, action: action as i32, @@ -494,25 +436,17 @@ impl FlightSqlServiceClient { body: cmd.as_any().encode_to_vec().into(), }; let req = self.set_request_headers(action.into_request())?; - let _ = self - .flight_client - .do_action(req) - .await - .map_err(status_to_arrow_error)? - .into_inner(); + let _ = self.flight_client.do_action(req).await?.into_inner(); Ok(()) } /// Explicitly shut down and clean up the client. - pub async fn close(&mut self) -> Result<(), ArrowError> { + pub async fn close(&mut self) -> Result<()> { // TODO: consume self instead of &mut self to explicitly prevent reuse? Ok(()) } - fn set_request_headers( - &self, - mut req: tonic::Request, - ) -> Result, ArrowError> { + fn set_request_headers(&self, mut req: tonic::Request) -> Result> { for (k, v) in &self.headers { let k = AsciiMetadataKey::from_str(k.as_str()).map_err(|e| { ArrowError::ParseError(format!("Cannot convert header key \"{k}\": {e}")) @@ -532,6 +466,16 @@ impl FlightSqlServiceClient { } } +impl Clone for FlightSqlServiceClient { + fn clone(&self) -> Self { + Self { + headers: self.headers.clone(), + token: self.token.clone(), + flight_client: self.flight_client.clone(), + } + } +} + /// A PreparedStatement #[derive(Debug, Clone)] pub struct PreparedStatement { @@ -542,9 +486,15 @@ pub struct PreparedStatement { parameter_schema: Schema, } -impl PreparedStatement { +impl PreparedStatement +where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: Body + Send + 'static, + ::Error: Into + Send, +{ pub(crate) fn new( - flight_client: FlightSqlServiceClient, + flight_client: FlightSqlServiceClient, handle: impl Into, dataset_schema: Schema, parameter_schema: Schema, @@ -559,7 +509,7 @@ impl PreparedStatement { } /// Executes the prepared statement query on the server. - pub async fn execute(&mut self) -> Result { + pub async fn execute(&mut self) -> Result { self.write_bind_params().await?; let cmd = CommandPreparedStatementQuery { @@ -574,7 +524,7 @@ impl PreparedStatement { } /// Executes the prepared statement update query on the server. - pub async fn execute_update(&mut self) -> Result { + pub async fn execute_update(&mut self) -> Result { self.write_bind_params().await?; let cmd = CommandPreparedStatementUpdate { @@ -588,35 +538,30 @@ impl PreparedStatement { ..Default::default() }])) .await?; - let result = result - .message() - .await - .map_err(status_to_arrow_error)? - .unwrap(); - let result: DoPutUpdateResult = - Message::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?; + let result = result.message().await?.unwrap(); + let result: DoPutUpdateResult = Message::decode(&*result.app_metadata)?; Ok(result.record_count) } /// Retrieve the parameter schema from the query. - pub fn parameter_schema(&self) -> Result<&Schema, ArrowError> { + pub fn parameter_schema(&self) -> Result<&Schema> { Ok(&self.parameter_schema) } /// Retrieve the ResultSet schema from the query. - pub fn dataset_schema(&self) -> Result<&Schema, ArrowError> { + pub fn dataset_schema(&self) -> Result<&Schema> { Ok(&self.dataset_schema) } /// Set a RecordBatch that contains the parameters that will be bind. - pub fn set_parameters(&mut self, parameter_binding: RecordBatch) -> Result<(), ArrowError> { + pub fn set_parameters(&mut self, parameter_binding: RecordBatch) -> Result<()> { self.parameter_binding = Some(parameter_binding); Ok(()) } /// Submit parameters to the server, if any have been set on this prepared statement instance /// Updates our stored prepared statement handle with the handle given by the server response. - async fn write_bind_params(&mut self) -> Result<(), ArrowError> { + async fn write_bind_params(&mut self) -> Result<()> { if let Some(ref params_batch) = self.parameter_binding { let cmd = CommandPreparedStatementQuery { prepared_statement_handle: self.handle.clone(), @@ -631,8 +576,7 @@ impl PreparedStatement { self.parameter_binding.clone().map(Ok), )) .try_collect::>() - .await - .map_err(flight_error_to_arrow_error)?; + .await?; // Attempt to update the stored handle with any updated handle in the DoPut result. // Older servers do not respond with a result for DoPut, so skip this step when @@ -642,8 +586,7 @@ impl PreparedStatement { .do_put(stream::iter(flight_data)) .await? .message() - .await - .map_err(status_to_arrow_error)? + .await? { if let Some(handle) = self.unpack_prepared_statement_handle(&result)? { self.handle = handle; @@ -656,18 +599,14 @@ impl PreparedStatement { /// Decodes the app_metadata stored in a [`PutResult`] as a /// [`DoPutPreparedStatementResult`] and then returns /// the inner prepared statement handle as [`Bytes`] - fn unpack_prepared_statement_handle( - &self, - put_result: &PutResult, - ) -> Result, ArrowError> { - let result: DoPutPreparedStatementResult = - Message::decode(&*put_result.app_metadata).map_err(decode_error_to_arrow_error)?; + fn unpack_prepared_statement_handle(&self, put_result: &PutResult) -> Result> { + let result: DoPutPreparedStatementResult = Message::decode(&*put_result.app_metadata)?; Ok(result.prepared_statement_handle) } /// Close the prepared statement, so that this PreparedStatement can not used /// anymore and server can free up any resources. - pub async fn close(mut self) -> Result<(), ArrowError> { + pub async fn close(mut self) -> Result<()> { let cmd = ActionClosePreparedStatementRequest { prepared_statement_handle: self.handle.clone(), }; @@ -680,21 +619,6 @@ impl PreparedStatement { } } -fn decode_error_to_arrow_error(err: prost::DecodeError) -> ArrowError { - ArrowError::IpcError(err.to_string()) -} - -fn status_to_arrow_error(status: tonic::Status) -> ArrowError { - ArrowError::IpcError(format!("{status:?}")) -} - -fn flight_error_to_arrow_error(err: FlightError) -> ArrowError { - match err { - FlightError::Arrow(e) => e, - e => ArrowError::ExternalError(Box::new(e)), - } -} - /// A polymorphic structure to natively represent different types of data contained in `FlightData` pub enum ArrowFlightData { /// A record batch @@ -707,7 +631,7 @@ pub enum ArrowFlightData { pub fn arrow_data_from_flight_data( flight_data: FlightData, arrow_schema_ref: &SchemaRef, -) -> Result { +) -> std::result::Result { let ipc_message = root_as_message(&flight_data.data_header[..]) .map_err(|err| ArrowError::ParseError(format!("Unable to get root as message: {err:?}")))?; diff --git a/arrow-flight/src/sql/metadata/db_schemas.rs b/arrow-flight/src/sql/metadata/db_schemas.rs index 68e8b497336e..c182140e58f3 100644 --- a/arrow-flight/src/sql/metadata/db_schemas.rs +++ b/arrow-flight/src/sql/metadata/db_schemas.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use arrow_arith::boolean::and; -use arrow_array::{builder::StringBuilder, ArrayRef, RecordBatch, StringArray}; +use arrow_array::{ArrayRef, RecordBatch, StringArray, builder::StringBuilder}; use arrow_ord::cmp::eq; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use arrow_select::{filter::filter_record_batch, take::take}; diff --git a/arrow-flight/src/sql/metadata/mod.rs b/arrow-flight/src/sql/metadata/mod.rs index fd71149a3180..66c12fce9af4 100644 --- a/arrow-flight/src/sql/metadata/mod.rs +++ b/arrow-flight/src/sql/metadata/mod.rs @@ -70,8 +70,7 @@ mod tests { let actual_lines: Vec<_> = formatted.trim().lines().collect(); assert_eq!( &actual_lines, expected_lines, - "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n", - expected_lines, actual_lines + "\n\nexpected:\n\n{expected_lines:#?}\nactual:\n\n{actual_lines:#?}\n\n", ); } } diff --git a/arrow-flight/src/sql/metadata/sql_info.rs b/arrow-flight/src/sql/metadata/sql_info.rs index 58b228530942..155946ea6ce6 100644 --- a/arrow-flight/src/sql/metadata/sql_info.rs +++ b/arrow-flight/src/sql/metadata/sql_info.rs @@ -30,7 +30,7 @@ use std::sync::Arc; use arrow_arith::boolean::or; use arrow_array::array::{Array, UInt32Array, UnionArray}; use arrow_array::builder::{ - ArrayBuilder, BooleanBuilder, Int32Builder, Int64Builder, Int8Builder, ListBuilder, MapBuilder, + ArrayBuilder, BooleanBuilder, Int8Builder, Int32Builder, Int64Builder, ListBuilder, MapBuilder, StringBuilder, UInt32Builder, }; use arrow_array::{RecordBatch, Scalar}; @@ -196,10 +196,7 @@ static UNION_TYPE: Lazy = Lazy::new(|| { ), ]; - // create "type ids", one for each type, assume they go from 0 .. num_fields - let type_ids: Vec = (0..fields.len()).map(|v| v as i8).collect(); - - DataType::Union(UnionFields::new(type_ids, fields), UnionMode::Dense) + DataType::Union(UnionFields::from_fields(fields), UnionMode::Dense) }); impl SqlInfoUnionBuilder { @@ -444,7 +441,7 @@ pub struct GetSqlInfoBuilder<'a> { impl CommandGetSqlInfo { /// Create a builder suitable for constructing a response - pub fn into_builder(self, infos: &SqlInfoData) -> GetSqlInfoBuilder { + pub fn into_builder(self, infos: &SqlInfoData) -> GetSqlInfoBuilder<'_> { GetSqlInfoBuilder { info: self.info, infos, diff --git a/arrow-flight/src/sql/metadata/table_types.rs b/arrow-flight/src/sql/metadata/table_types.rs index 54cfe6fe27a7..7f525da05f90 100644 --- a/arrow-flight/src/sql/metadata/table_types.rs +++ b/arrow-flight/src/sql/metadata/table_types.rs @@ -21,7 +21,7 @@ use std::sync::Arc; -use arrow_array::{builder::StringBuilder, ArrayRef, RecordBatch}; +use arrow_array::{ArrayRef, RecordBatch, builder::StringBuilder}; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use arrow_select::take::take; use once_cell::sync::Lazy; diff --git a/arrow-flight/src/sql/metadata/xdbc_info.rs b/arrow-flight/src/sql/metadata/xdbc_info.rs index a3a18ca10888..62e2de9e5d97 100644 --- a/arrow-flight/src/sql/metadata/xdbc_info.rs +++ b/arrow-flight/src/sql/metadata/xdbc_info.rs @@ -299,7 +299,7 @@ pub struct GetXdbcTypeInfoBuilder<'a> { impl CommandGetXdbcTypeInfo { /// Create a builder suitable for constructing a response - pub fn into_builder(self, infos: &XdbcTypeInfoData) -> GetXdbcTypeInfoBuilder { + pub fn into_builder(self, infos: &XdbcTypeInfoData) -> GetXdbcTypeInfoBuilder<'_> { GetXdbcTypeInfoBuilder { data_type: self.data_type, infos, diff --git a/arrow-flight/src/sql/mod.rs b/arrow-flight/src/sql/mod.rs index 955f1904a6d6..e076f7aa0747 100644 --- a/arrow-flight/src/sql/mod.rs +++ b/arrow-flight/src/sql/mod.rs @@ -44,70 +44,70 @@ use paste::paste; use prost::Message; #[allow(clippy::all)] -mod gen { +mod r#gen { // Since this file is auto-generated, we suppress all warnings #![allow(missing_docs)] include!("arrow.flight.protocol.sql.rs"); } -pub use gen::action_end_transaction_request::EndTransaction; -pub use gen::command_statement_ingest::table_definition_options::{ +pub use r#gen::ActionBeginSavepointRequest; +pub use r#gen::ActionBeginSavepointResult; +pub use r#gen::ActionBeginTransactionRequest; +pub use r#gen::ActionBeginTransactionResult; +pub use r#gen::ActionCancelQueryRequest; +pub use r#gen::ActionCancelQueryResult; +pub use r#gen::ActionClosePreparedStatementRequest; +pub use r#gen::ActionCreatePreparedStatementRequest; +pub use r#gen::ActionCreatePreparedStatementResult; +pub use r#gen::ActionCreatePreparedSubstraitPlanRequest; +pub use r#gen::ActionEndSavepointRequest; +pub use r#gen::ActionEndTransactionRequest; +pub use r#gen::CommandGetCatalogs; +pub use r#gen::CommandGetCrossReference; +pub use r#gen::CommandGetDbSchemas; +pub use r#gen::CommandGetExportedKeys; +pub use r#gen::CommandGetImportedKeys; +pub use r#gen::CommandGetPrimaryKeys; +pub use r#gen::CommandGetSqlInfo; +pub use r#gen::CommandGetTableTypes; +pub use r#gen::CommandGetTables; +pub use r#gen::CommandGetXdbcTypeInfo; +pub use r#gen::CommandPreparedStatementQuery; +pub use r#gen::CommandPreparedStatementUpdate; +pub use r#gen::CommandStatementIngest; +pub use r#gen::CommandStatementQuery; +pub use r#gen::CommandStatementSubstraitPlan; +pub use r#gen::CommandStatementUpdate; +pub use r#gen::DoPutPreparedStatementResult; +pub use r#gen::DoPutUpdateResult; +pub use r#gen::Nullable; +pub use r#gen::Searchable; +pub use r#gen::SqlInfo; +pub use r#gen::SqlNullOrdering; +pub use r#gen::SqlOuterJoinsSupportLevel; +pub use r#gen::SqlSupportedCaseSensitivity; +pub use r#gen::SqlSupportedElementActions; +pub use r#gen::SqlSupportedGroupBy; +pub use r#gen::SqlSupportedPositionedCommands; +pub use r#gen::SqlSupportedResultSetConcurrency; +pub use r#gen::SqlSupportedResultSetType; +pub use r#gen::SqlSupportedSubqueries; +pub use r#gen::SqlSupportedTransaction; +pub use r#gen::SqlSupportedTransactions; +pub use r#gen::SqlSupportedUnions; +pub use r#gen::SqlSupportsConvert; +pub use r#gen::SqlTransactionIsolationLevel; +pub use r#gen::SubstraitPlan; +pub use r#gen::SupportedSqlGrammar; +pub use r#gen::TicketStatementQuery; +pub use r#gen::UpdateDeleteRules; +pub use r#gen::XdbcDataType; +pub use r#gen::XdbcDatetimeSubcode; +pub use r#gen::action_end_transaction_request::EndTransaction; +pub use r#gen::command_statement_ingest::TableDefinitionOptions; +pub use r#gen::command_statement_ingest::table_definition_options::{ TableExistsOption, TableNotExistOption, }; -pub use gen::command_statement_ingest::TableDefinitionOptions; -pub use gen::ActionBeginSavepointRequest; -pub use gen::ActionBeginSavepointResult; -pub use gen::ActionBeginTransactionRequest; -pub use gen::ActionBeginTransactionResult; -pub use gen::ActionCancelQueryRequest; -pub use gen::ActionCancelQueryResult; -pub use gen::ActionClosePreparedStatementRequest; -pub use gen::ActionCreatePreparedStatementRequest; -pub use gen::ActionCreatePreparedStatementResult; -pub use gen::ActionCreatePreparedSubstraitPlanRequest; -pub use gen::ActionEndSavepointRequest; -pub use gen::ActionEndTransactionRequest; -pub use gen::CommandGetCatalogs; -pub use gen::CommandGetCrossReference; -pub use gen::CommandGetDbSchemas; -pub use gen::CommandGetExportedKeys; -pub use gen::CommandGetImportedKeys; -pub use gen::CommandGetPrimaryKeys; -pub use gen::CommandGetSqlInfo; -pub use gen::CommandGetTableTypes; -pub use gen::CommandGetTables; -pub use gen::CommandGetXdbcTypeInfo; -pub use gen::CommandPreparedStatementQuery; -pub use gen::CommandPreparedStatementUpdate; -pub use gen::CommandStatementIngest; -pub use gen::CommandStatementQuery; -pub use gen::CommandStatementSubstraitPlan; -pub use gen::CommandStatementUpdate; -pub use gen::DoPutPreparedStatementResult; -pub use gen::DoPutUpdateResult; -pub use gen::Nullable; -pub use gen::Searchable; -pub use gen::SqlInfo; -pub use gen::SqlNullOrdering; -pub use gen::SqlOuterJoinsSupportLevel; -pub use gen::SqlSupportedCaseSensitivity; -pub use gen::SqlSupportedElementActions; -pub use gen::SqlSupportedGroupBy; -pub use gen::SqlSupportedPositionedCommands; -pub use gen::SqlSupportedResultSetConcurrency; -pub use gen::SqlSupportedResultSetType; -pub use gen::SqlSupportedSubqueries; -pub use gen::SqlSupportedTransaction; -pub use gen::SqlSupportedTransactions; -pub use gen::SqlSupportedUnions; -pub use gen::SqlSupportsConvert; -pub use gen::SqlTransactionIsolationLevel; -pub use gen::SubstraitPlan; -pub use gen::SupportedSqlGrammar; -pub use gen::TicketStatementQuery; -pub use gen::UpdateDeleteRules; -pub use gen::XdbcDataType; -pub use gen::XdbcDatetimeSubcode; pub mod client; pub mod metadata; diff --git a/arrow-flight/src/sql/server.rs b/arrow-flight/src/sql/server.rs index add7c8db40c2..871a67b72cd6 100644 --- a/arrow-flight/src/sql/server.rs +++ b/arrow-flight/src/sql/server.rs @@ -34,11 +34,11 @@ use super::{ SqlInfo, TicketStatementQuery, }; use crate::{ - flight_service_server::FlightService, gen::PollInfo, Action, ActionType, Criteria, Empty, - FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, PutResult, - SchemaResult, Ticket, + Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, + HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket, + flight_service_server::FlightService, r#gen::PollInfo, }; -use futures::{stream::Peekable, Stream, StreamExt}; +use futures::{Stream, StreamExt, stream::Peekable}; use prost::Message; use tonic::{Request, Response, Status, Streaming}; @@ -392,7 +392,7 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static { _request: Request, error: DoPutError, ) -> Result::DoPutStream>, Status> { - Err(Status::unimplemented(format!("Unhandled Error: {}", error))) + Err(Status::unimplemented(format!("Unhandled Error: {error}"))) } /// Execute an update SQL statement. @@ -628,7 +628,7 @@ where self.get_flight_info_catalogs(token, request).await } Command::CommandGetDbSchemas(token) => { - return self.get_flight_info_schemas(token, request).await + return self.get_flight_info_schemas(token, request).await; } Command::CommandGetTables(token) => self.get_flight_info_tables(token, request).await, Command::CommandGetTableTypes(token) => { @@ -879,7 +879,7 @@ where let stmt = self .do_action_create_prepared_statement(cmd, request) .await?; - let output = futures::stream::iter(vec![Ok(super::super::gen::Result { + let output = futures::stream::iter(vec![Ok(super::super::r#gen::Result { body: stmt.as_any().encode_to_vec().into(), })]); return Ok(Response::new(Box::pin(output))); @@ -921,7 +921,7 @@ where Status::invalid_argument("Unable to unpack ActionBeginTransactionRequest.") })?; let stmt = self.do_action_begin_transaction(cmd, request).await?; - let output = futures::stream::iter(vec![Ok(super::super::gen::Result { + let output = futures::stream::iter(vec![Ok(super::super::r#gen::Result { body: stmt.as_any().encode_to_vec().into(), })]); return Ok(Response::new(Box::pin(output))); @@ -946,7 +946,7 @@ where Status::invalid_argument("Unable to unpack ActionBeginSavepointRequest.") })?; let stmt = self.do_action_begin_savepoint(cmd, request).await?; - let output = futures::stream::iter(vec![Ok(super::super::gen::Result { + let output = futures::stream::iter(vec![Ok(super::super::r#gen::Result { body: stmt.as_any().encode_to_vec().into(), })]); return Ok(Response::new(Box::pin(output))); @@ -971,7 +971,7 @@ where Status::invalid_argument("Unable to unpack ActionCancelQueryRequest.") })?; let stmt = self.do_action_cancel_query(cmd, request).await?; - let output = futures::stream::iter(vec![Ok(super::super::gen::Result { + let output = futures::stream::iter(vec![Ok(super::super::r#gen::Result { body: stmt.as_any().encode_to_vec().into(), })]); return Ok(Response::new(Box::pin(output))); diff --git a/arrow-flight/src/streams.rs b/arrow-flight/src/streams.rs index 0cd3aa41a547..8a9d5ab30667 100644 --- a/arrow-flight/src/streams.rs +++ b/arrow-flight/src/streams.rs @@ -19,11 +19,11 @@ use crate::error::FlightError; use futures::{ - channel::oneshot::{Receiver, Sender}, FutureExt, Stream, StreamExt, + channel::oneshot::{Receiver, Sender}, }; use std::pin::Pin; -use std::task::{ready, Poll}; +use std::task::{Poll, ready}; /// Wrapper around a fallible stream (one that returns errors) that makes it infallible. /// diff --git a/arrow-flight/src/trailers.rs b/arrow-flight/src/trailers.rs index 73136379d69f..7929b53a41a0 100644 --- a/arrow-flight/src/trailers.rs +++ b/arrow-flight/src/trailers.rs @@ -21,8 +21,8 @@ use std::{ task::{Context, Poll}, }; -use futures::{ready, FutureExt, Stream, StreamExt}; -use tonic::{metadata::MetadataMap, Status, Streaming}; +use futures::{FutureExt, Stream, StreamExt, ready}; +use tonic::{Status, Streaming, metadata::MetadataMap}; /// Extract [`LazyTrailers`] from [`Streaming`] [tonic] response. /// diff --git a/arrow-flight/src/utils.rs b/arrow-flight/src/utils.rs index 428dde73ca6c..6effb5f86aaf 100644 --- a/arrow-flight/src/utils.rs +++ b/arrow-flight/src/utils.rs @@ -24,6 +24,7 @@ use std::sync::Arc; use arrow_array::{ArrayRef, RecordBatch}; use arrow_buffer::Buffer; use arrow_ipc::convert::fb_to_schema; +use arrow_ipc::writer::CompressionContext; use arrow_ipc::{reader, root_as_message, writer, writer::IpcWriteOptions}; use arrow_schema::{ArrowError, Schema, SchemaRef}; @@ -90,13 +91,16 @@ pub fn batches_to_flight_data( let mut flight_data = vec![]; let data_gen = writer::IpcDataGenerator::default(); - #[allow(deprecated)] - let mut dictionary_tracker = - writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id()); + let mut dictionary_tracker = writer::DictionaryTracker::new(false); + let mut compression_context = CompressionContext::default(); for batch in batches.iter() { - let (encoded_dictionaries, encoded_batch) = - data_gen.encoded_batch(batch, &mut dictionary_tracker, &options)?; + let (encoded_dictionaries, encoded_batch) = data_gen.encode( + batch, + &mut dictionary_tracker, + &options, + &mut compression_context, + )?; dictionaries.extend(encoded_dictionaries.into_iter().map(Into::into)); flight_data.push(encoded_batch.into()); diff --git a/arrow-flight/tests/client.rs b/arrow-flight/tests/client.rs index 25dad0e77a3e..ab566f578cbb 100644 --- a/arrow-flight/tests/client.rs +++ b/arrow-flight/tests/client.rs @@ -22,10 +22,10 @@ mod common; use crate::common::fixture::TestFixture; use arrow_array::{RecordBatch, UInt64Array}; use arrow_flight::{ - decode::FlightRecordBatchStream, encode::FlightDataEncoderBuilder, error::FlightError, Action, - ActionType, CancelFlightInfoRequest, CancelFlightInfoResult, CancelStatus, Criteria, Empty, - FlightClient, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, - HandshakeResponse, PollInfo, PutResult, RenewFlightEndpointRequest, Ticket, + Action, ActionType, CancelFlightInfoRequest, CancelFlightInfoResult, CancelStatus, Criteria, + Empty, FlightClient, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, + HandshakeRequest, HandshakeResponse, PollInfo, PutResult, RenewFlightEndpointRequest, Ticket, + decode::FlightRecordBatchStream, encode::FlightDataEncoderBuilder, error::FlightError, }; use arrow_schema::{DataType, Field, Schema}; use bytes::Bytes; diff --git a/arrow-flight/tests/common/server.rs b/arrow-flight/tests/common/server.rs index a004ccb0737e..5aa22a869627 100644 --- a/arrow-flight/tests/common/server.rs +++ b/arrow-flight/tests/common/server.rs @@ -19,14 +19,14 @@ use std::sync::{Arc, Mutex}; use arrow_array::RecordBatch; use arrow_schema::Schema; -use futures::{stream::BoxStream, StreamExt, TryStreamExt}; -use tonic::{metadata::MetadataMap, Request, Response, Status, Streaming}; +use futures::{StreamExt, TryStreamExt, stream::BoxStream}; +use tonic::{Request, Response, Status, Streaming, metadata::MetadataMap}; use arrow_flight::{ - encode::FlightDataEncoderBuilder, - flight_service_server::{FlightService, FlightServiceServer}, Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaAsIpc, SchemaResult, Ticket, + encode::FlightDataEncoderBuilder, + flight_service_server::{FlightService, FlightServiceServer}, }; #[derive(Debug, Clone)] diff --git a/arrow-flight/tests/common/utils.rs b/arrow-flight/tests/common/utils.rs index 0f70e4b31021..f36b41cba344 100644 --- a/arrow-flight/tests/common/utils.rs +++ b/arrow-flight/tests/common/utils.rs @@ -20,8 +20,8 @@ use std::sync::Arc; use arrow_array::{ - types::Int32Type, ArrayRef, BinaryViewArray, DictionaryArray, Float64Array, RecordBatch, - StringViewArray, UInt8Array, + ArrayRef, BinaryViewArray, DictionaryArray, Float64Array, RecordBatch, StringViewArray, + UInt8Array, types::Int32Type, }; use arrow_schema::{DataType, Field, Schema}; diff --git a/arrow-flight/tests/encode_decode.rs b/arrow-flight/tests/encode_decode.rs index cbfae1825845..fcd6b39ab0a1 100644 --- a/arrow-flight/tests/encode_decode.rs +++ b/arrow-flight/tests/encode_decode.rs @@ -21,8 +21,8 @@ use std::{collections::HashMap, sync::Arc}; use arrow_array::{ArrayRef, RecordBatch}; use arrow_cast::pretty::pretty_format_batches; -use arrow_flight::flight_descriptor::DescriptorType; use arrow_flight::FlightDescriptor; +use arrow_flight::flight_descriptor::DescriptorType; use arrow_flight::{ decode::{DecodedPayload, FlightDataDecoder, FlightRecordBatchStream}, encode::FlightDataEncoderBuilder, diff --git a/arrow-flight/tests/flight_sql_client.rs b/arrow-flight/tests/flight_sql_client.rs index f3b7114dbafa..97687c3dea37 100644 --- a/arrow-flight/tests/flight_sql_client.rs +++ b/arrow-flight/tests/flight_sql_client.rs @@ -64,10 +64,12 @@ pub async fn test_begin_end_transaction() { // unknown transaction id let transaction_id = "UnknownTransactionId".to_string().into(); - assert!(flight_sql_client - .end_transaction(transaction_id, EndTransaction::Commit) - .await - .is_err()); + assert!( + flight_sql_client + .end_transaction(transaction_id, EndTransaction::Commit) + .await + .is_err() + ); } #[tokio::test] @@ -139,9 +141,10 @@ pub async fn test_do_put_empty_stream() { // Execute a `do_put` and verify that the server error contains the expected message let err = flight_sql_client.do_put(request_stream).await.unwrap_err(); - assert!(err - .to_string() - .contains("Unhandled Error: Command is missing."),); + assert!( + err.to_string() + .contains("Unhandled Error: Command is missing."), + ); } #[tokio::test] @@ -172,9 +175,10 @@ pub async fn test_do_put_first_element_err() { // Execute a `do_put` and verify that the server error contains the expected message let err = flight_sql_client.do_put(request_stream).await.unwrap_err(); - assert!(err - .to_string() - .contains("Unhandled Error: Command is missing."),); + assert!( + err.to_string() + .contains("Unhandled Error: Command is missing."), + ); } #[tokio::test] @@ -196,9 +200,10 @@ pub async fn test_do_put_missing_flight_descriptor() { // Execute a `do_put` and verify that the server error contains the expected message let err = flight_sql_client.do_put(request_stream).await.unwrap_err(); - assert!(err - .to_string() - .contains("Unhandled Error: Flight descriptor is missing."),); + assert!( + err.to_string() + .contains("Unhandled Error: Flight descriptor is missing."), + ); } fn make_ingest_command() -> CommandStatementIngest { diff --git a/arrow-flight/tests/flight_sql_client_cli.rs b/arrow-flight/tests/flight_sql_client_cli.rs index c8e9190e246f..c161caae8ca4 100644 --- a/arrow-flight/tests/flight_sql_client_cli.rs +++ b/arrow-flight/tests/flight_sql_client_cli.rs @@ -22,19 +22,19 @@ use std::{pin::Pin, sync::Arc}; use crate::common::fixture::TestFixture; use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray, TimestampNanosecondArray}; use arrow_flight::{ + Action, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, + HandshakeResponse, IpcMessage, SchemaAsIpc, Ticket, decode::FlightRecordBatchStream, encode::FlightDataEncoderBuilder, flight_service_server::{FlightService, FlightServiceServer}, sql::{ - server::{FlightSqlService, PeekableFlightDataStream}, ActionCreatePreparedStatementRequest, ActionCreatePreparedStatementResult, Any, CommandGetCatalogs, CommandGetDbSchemas, CommandGetTableTypes, CommandGetTables, CommandPreparedStatementQuery, CommandStatementQuery, DoPutPreparedStatementResult, ProstMessageExt, SqlInfo, + server::{FlightSqlService, PeekableFlightDataStream}, }, utils::batches_to_flight_data, - Action, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, - HandshakeResponse, IpcMessage, SchemaAsIpc, Ticket, }; use arrow_ipc::writer::IpcWriteOptions; use arrow_schema::{ArrowError, DataType, Field, Schema, TimeUnit}; @@ -46,6 +46,11 @@ use tonic::{Request, Response, Status, Streaming}; const QUERY: &str = "SELECT * FROM table;"; +/// Return a Command instance for running the `flight_sql_client` CLI +fn flight_sql_client_cmd() -> Command { + Command::new(assert_cmd::cargo::cargo_bin!("flight_sql_client")) +} + #[tokio::test] async fn test_simple() { let test_server = FlightSqlServiceImpl::default(); @@ -53,8 +58,7 @@ async fn test_simple() { let addr = fixture.addr; let stdout = tokio::task::spawn_blocking(move || { - Command::cargo_bin("flight_sql_client") - .unwrap() + flight_sql_client_cmd() .env_clear() .env("RUST_BACKTRACE", "1") .env("RUST_LOG", "warn") @@ -94,8 +98,7 @@ async fn test_get_catalogs() { let addr = fixture.addr; let stdout = tokio::task::spawn_blocking(move || { - Command::cargo_bin("flight_sql_client") - .unwrap() + flight_sql_client_cmd() .env_clear() .env("RUST_BACKTRACE", "1") .env("RUST_LOG", "warn") @@ -133,8 +136,7 @@ async fn test_get_db_schemas() { let addr = fixture.addr; let stdout = tokio::task::spawn_blocking(move || { - Command::cargo_bin("flight_sql_client") - .unwrap() + flight_sql_client_cmd() .env_clear() .env("RUST_BACKTRACE", "1") .env("RUST_LOG", "warn") @@ -173,8 +175,7 @@ async fn test_get_tables() { let addr = fixture.addr; let stdout = tokio::task::spawn_blocking(move || { - Command::cargo_bin("flight_sql_client") - .unwrap() + flight_sql_client_cmd() .env_clear() .env("RUST_BACKTRACE", "1") .env("RUST_LOG", "warn") @@ -212,8 +213,7 @@ async fn test_get_tables_db_filter() { let addr = fixture.addr; let stdout = tokio::task::spawn_blocking(move || { - Command::cargo_bin("flight_sql_client") - .unwrap() + flight_sql_client_cmd() .env_clear() .env("RUST_BACKTRACE", "1") .env("RUST_LOG", "warn") @@ -253,8 +253,7 @@ async fn test_get_tables_types() { let addr = fixture.addr; let stdout = tokio::task::spawn_blocking(move || { - Command::cargo_bin("flight_sql_client") - .unwrap() + flight_sql_client_cmd() .env_clear() .env("RUST_BACKTRACE", "1") .env("RUST_LOG", "warn") @@ -295,8 +294,7 @@ async fn test_do_put_prepared_statement(test_server: FlightSqlServiceImpl) { let addr = fixture.addr; let stdout = tokio::task::spawn_blocking(move || { - Command::cargo_bin("flight_sql_client") - .unwrap() + flight_sql_client_cmd() .env_clear() .env("RUST_BACKTRACE", "1") .env("RUST_LOG", "warn") diff --git a/arrow-integration-test/Cargo.toml b/arrow-integration-test/Cargo.toml index d560d4fd8363..39ea3b60b1ab 100644 --- a/arrow-integration-test/Cargo.toml +++ b/arrow-integration-test/Cargo.toml @@ -39,6 +39,7 @@ all-features = true arrow = { workspace = true } arrow-buffer = { workspace = true } hex = { version = "0.4", default-features = false, features = ["std"] } +num-bigint = { version = "0.4", default-features = false } +num-traits = { version = "0.2.19", default-features = false, features = ["std"] } serde = { version = "1.0", default-features = false, features = ["rc", "derive"] } serde_json = { version = "1.0", default-features = false, features = ["std"] } -num = { version = "0.4", default-features = false, features = ["std"] } diff --git a/arrow-integration-test/src/datatype.rs b/arrow-integration-test/src/datatype.rs index 24e02c8430c7..4c17fbe76be7 100644 --- a/arrow-integration-test/src/datatype.rs +++ b/arrow-integration-test/src/datatype.rs @@ -61,6 +61,8 @@ pub fn data_type_from_json(json: &serde_json::Value) -> Result { }; match bit_width { + 32 => Ok(DataType::Decimal32(precision, scale)), + 64 => Ok(DataType::Decimal64(precision, scale)), 128 => Ok(DataType::Decimal128(precision, scale)), 256 => Ok(DataType::Decimal256(precision, scale)), _ => Err(ArrowError::ParseError( @@ -335,6 +337,12 @@ pub fn data_type_to_json(data_type: &DataType) -> serde_json::Value { TimeUnit::Nanosecond => "NANOSECOND", }}), DataType::Dictionary(_, _) => json!({ "name": "dictionary"}), + DataType::Decimal32(precision, scale) => { + json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 32}) + } + DataType::Decimal64(precision, scale) => { + json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 64}) + } DataType::Decimal128(precision, scale) => { json!({"name": "decimal", "precision": precision, "scale": scale, "bitWidth": 128}) } diff --git a/arrow-integration-test/src/field.rs b/arrow-integration-test/src/field.rs index 4b896ed391be..8b0ca264e02e 100644 --- a/arrow-integration-test/src/field.rs +++ b/arrow-integration-test/src/field.rs @@ -142,7 +142,7 @@ pub fn field_from_json(json: &serde_json::Value) -> Result { Some(_) => { return Err(ArrowError::ParseError( "Field 'children' must be an array".to_string(), - )) + )); } None => { return Err(ArrowError::ParseError( @@ -158,7 +158,7 @@ pub fn field_from_json(json: &serde_json::Value) -> Result { Some(_) => { return Err(ArrowError::ParseError( "Field 'children' must be an array".to_string(), - )) + )); } None => { return Err(ArrowError::ParseError( @@ -177,15 +177,15 @@ pub fn field_from_json(json: &serde_json::Value) -> Result { } t => { return Err(ArrowError::ParseError(format!( - "Map children should be a struct with 2 fields, found {t:?}" - ))) + "Map children should be a struct with 2 fields, found {t:?}" + ))); } } } Some(_) => { return Err(ArrowError::ParseError( "Field 'children' must be an array with 1 element".to_string(), - )) + )); } None => { return Err(ArrowError::ParseError( @@ -207,7 +207,7 @@ pub fn field_from_json(json: &serde_json::Value) -> Result { Some(_) => { return Err(ArrowError::ParseError( "Field 'children' must be an array".to_string(), - )) + )); } None => { return Err(ArrowError::ParseError( @@ -275,7 +275,7 @@ pub fn field_to_json(field: &Field) -> serde_json::Value { }; match field.data_type() { - DataType::Dictionary(ref index_type, ref value_type) => { + DataType::Dictionary(index_type, value_type) => { #[allow(deprecated)] let dict_id = field.dict_id().unwrap(); serde_json::json!({ diff --git a/arrow-integration-test/src/lib.rs b/arrow-integration-test/src/lib.rs index baa76059f9c6..0f0b4fe2ffee 100644 --- a/arrow-integration-test/src/lib.rs +++ b/arrow-integration-test/src/lib.rs @@ -15,22 +15,30 @@ // specific language governing permissions and limitations // under the License. -//! Support for the [Apache Arrow JSON test data format](https://github.com/apache/arrow/blob/master/docs/source/format/Integration.rst#json-test-data-format) +//! Partial support for the [Apache Arrow JSON test data format](https://github.com/apache/arrow/blob/master/docs/source/format/Integration.rst#json-test-data-format) //! //! These utilities define structs that read the integration JSON format for integration testing purposes. //! //! This is not a canonical format, but provides a human-readable way of verifying language implementations +//! +//!

+//! +//! This crate is **only intended for integration testing the +//! [Arrow project](https://github.com/apache/arrow-rs)**. It is not [intended for usage outside of +//! this context](https://github.com/apache/arrow-rs/issues/8684#issuecomment-3433193158). +//! +//!
#![doc( html_logo_url = "https://arrow.apache.org/img/arrow-logo_chevrons_black-txt_white-bg.svg", html_favicon_url = "https://arrow.apache.org/img/arrow-logo_chevrons_black-txt_transparent-bg.svg" )] -#![cfg_attr(docsrs, feature(doc_auto_cfg))] +#![cfg_attr(docsrs, feature(doc_cfg))] #![warn(missing_docs)] use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano, ScalarBuffer}; use hex::decode; -use num::BigInt; -use num::Signed; +use num_bigint::BigInt; +use num_traits::Signed; use serde::{Deserialize, Serialize}; use serde_json::{Map as SJMap, Value}; use std::collections::HashMap; @@ -794,13 +802,13 @@ pub fn array_from_json( DataType::Dictionary(key_type, value_type) => { #[allow(deprecated)] let dict_id = field.dict_id().ok_or_else(|| { - ArrowError::JsonError(format!("Unable to find dict_id for field {field:?}")) + ArrowError::JsonError(format!("Unable to find dict_id for field {field}")) })?; // find dictionary let dictionary = dictionaries .ok_or_else(|| { ArrowError::JsonError(format!( - "Unable to find any dictionaries for field {field:?}" + "Unable to find any dictionaries for field {field}" )) })? .get(&dict_id); @@ -814,10 +822,46 @@ pub fn array_from_json( dictionaries, ), None => Err(ArrowError::JsonError(format!( - "Unable to find dictionary for field {field:?}" + "Unable to find dictionary for field {field}" ))), } } + DataType::Decimal32(precision, scale) => { + let mut b = Decimal32Builder::with_capacity(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_str().unwrap().parse::().unwrap()), + _ => b.append_null(), + }; + } + Ok(Arc::new( + b.finish().with_precision_and_scale(*precision, *scale)?, + )) + } + DataType::Decimal64(precision, scale) => { + let mut b = Decimal64Builder::with_capacity(json_col.count); + for (is_valid, value) in json_col + .validity + .as_ref() + .unwrap() + .iter() + .zip(json_col.data.unwrap()) + { + match is_valid { + 1 => b.append_value(value.as_str().unwrap().parse::().unwrap()), + _ => b.append_null(), + }; + } + Ok(Arc::new( + b.finish().with_precision_and_scale(*precision, *scale)?, + )) + } DataType::Decimal128(precision, scale) => { let mut b = Decimal128Builder::with_capacity(json_col.count); for (is_valid, value) in json_col @@ -910,7 +954,7 @@ pub fn array_from_json( Ok(Arc::new(array)) } t => Err(ArrowError::JsonError(format!( - "data type {t:?} not supported" + "data type {t} not supported" ))), } } @@ -1007,6 +1051,16 @@ fn create_null_buf(json_col: &ArrowJsonColumn) -> Buffer { impl ArrowJsonBatch { /// Convert a [`RecordBatch`] to an [`ArrowJsonBatch`] + /// + ///
+ /// + /// This function is **deliberately incomplete**! As noted in the crate-level documentation, + /// this crate is only intended for use within the Arrow project itself. + /// + /// Right now, this function only supports `DataType::Int8` columns. Other data types will lead + /// to an empty `ArrowJsonColumn`. + /// + ///
pub fn from_batch(batch: &RecordBatch) -> ArrowJsonBatch { let mut json_batch = ArrowJsonBatch { count: batch.num_rows(), diff --git a/arrow-integration-test/src/schema.rs b/arrow-integration-test/src/schema.rs index 512f0aed8e54..7777c48c1f4b 100644 --- a/arrow-integration-test/src/schema.rs +++ b/arrow-integration-test/src/schema.rs @@ -40,7 +40,7 @@ pub fn schema_from_json(json: &serde_json::Value) -> Result { _ => { return Err(ArrowError::ParseError( "Schema fields should be an array".to_string(), - )) + )); } }; diff --git a/arrow-integration-testing/Cargo.toml b/arrow-integration-testing/Cargo.toml index 8654b4b92734..ae13d32b57a9 100644 --- a/arrow-integration-testing/Cargo.toml +++ b/arrow-integration-testing/Cargo.toml @@ -39,11 +39,10 @@ arrow-flight = { path = "../arrow-flight", default-features = false } arrow-integration-test = { path = "../arrow-integration-test", default-features = false } clap = { version = "4", default-features = false, features = ["std", "derive", "help", "error-context", "usage"] } futures = { version = "0.3", default-features = false } -prost = { version = "0.13", default-features = false } -serde = { version = "1.0", default-features = false, features = ["rc", "derive"] } +prost = { version = "0.14.1", default-features = false } serde_json = { version = "1.0", default-features = false, features = ["std"] } tokio = { version = "1.0", default-features = false, features = [ "rt-multi-thread"] } -tonic = { version = "0.12", default-features = false } +tonic = { version = "0.14.1", default-features = false } tracing-subscriber = { version = "0.3.1", default-features = false, features = ["fmt"], optional = true } flate2 = { version = "1", default-features = false, features = ["rust_backend"] } diff --git a/arrow-integration-testing/src/flight_client_scenarios/auth_basic_proto.rs b/arrow-integration-testing/src/flight_client_scenarios/auth_basic_proto.rs index 34c3c7706df5..4c12be6d6c42 100644 --- a/arrow-integration-testing/src/flight_client_scenarios/auth_basic_proto.rs +++ b/arrow-integration-testing/src/flight_client_scenarios/auth_basic_proto.rs @@ -19,10 +19,10 @@ use crate::{AUTH_PASSWORD, AUTH_USERNAME}; -use arrow_flight::{flight_service_client::FlightServiceClient, BasicAuth, HandshakeRequest}; -use futures::{stream, StreamExt}; +use arrow_flight::{BasicAuth, HandshakeRequest, flight_service_client::FlightServiceClient}; +use futures::{StreamExt, stream}; use prost::Message; -use tonic::{metadata::MetadataValue, Request, Status}; +use tonic::{Request, Status, metadata::MetadataValue, transport::Endpoint}; type Error = Box; type Result = std::result::Result; @@ -32,7 +32,9 @@ type Client = FlightServiceClient; /// Run a scenario that tests basic auth. pub async fn run_scenario(host: &str, port: u16) -> Result { let url = format!("http://{host}:{port}"); - let mut client = FlightServiceClient::connect(url).await?; + let endpoint = Endpoint::new(url)?; + let channel = endpoint.connect().await?; + let mut client = FlightServiceClient::new(channel); let action = arrow_flight::Action::default(); diff --git a/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs index 406419028d00..05ca5627ecd8 100644 --- a/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs +++ b/arrow-integration-testing/src/flight_client_scenarios/integration_test.rs @@ -24,15 +24,18 @@ use arrow::{ array::ArrayRef, buffer::Buffer, datatypes::SchemaRef, - ipc::{self, reader, writer}, + ipc::{ + self, reader, + writer::{self, CompressionContext}, + }, record_batch::RecordBatch, }; use arrow_flight::{ - flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient, - utils::flight_data_to_arrow_batch, FlightData, FlightDescriptor, IpcMessage, Location, Ticket, + FlightData, FlightDescriptor, IpcMessage, Location, Ticket, flight_descriptor::DescriptorType, + flight_service_client::FlightServiceClient, utils::flight_data_to_arrow_batch, }; -use futures::{channel::mpsc, sink::SinkExt, stream, StreamExt}; -use tonic::{Request, Streaming}; +use futures::{StreamExt, channel::mpsc, sink::SinkExt, stream}; +use tonic::{Request, Streaming, transport::Endpoint}; use arrow::datatypes::Schema; use std::sync::Arc; @@ -46,7 +49,9 @@ type Client = FlightServiceClient; pub async fn run_scenario(host: &str, port: u16, path: &str) -> Result { let url = format!("http://{host}:{port}"); - let client = FlightServiceClient::connect(url).await?; + let endpoint = Endpoint::new(url)?; + let channel = endpoint.connect().await?; + let client = FlightServiceClient::new(channel); let json_file = open_json_file(path)?; @@ -72,9 +77,7 @@ async fn upload_data( let (mut upload_tx, upload_rx) = mpsc::channel(10); let options = arrow::ipc::writer::IpcWriteOptions::default(); - #[allow(deprecated)] - let mut dict_tracker = - writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id()); + let mut dict_tracker = writer::DictionaryTracker::new(false); let data_gen = writer::IpcDataGenerator::default(); let data = IpcMessage( data_gen @@ -92,6 +95,8 @@ async fn upload_data( let mut original_data_iter = original_data.iter().enumerate(); + let mut compression_context = CompressionContext::default(); + if let Some((counter, first_batch)) = original_data_iter.next() { let metadata = counter.to_string().into_bytes(); // Preload the first batch into the channel before starting the request @@ -101,6 +106,7 @@ async fn upload_data( first_batch, &options, &mut dict_tracker, + &mut compression_context, ) .await?; @@ -123,6 +129,7 @@ async fn upload_data( batch, &options, &mut dict_tracker, + &mut compression_context, ) .await?; @@ -152,11 +159,12 @@ async fn send_batch( batch: &RecordBatch, options: &writer::IpcWriteOptions, dictionary_tracker: &mut writer::DictionaryTracker, + compression_context: &mut CompressionContext, ) -> Result { let data_gen = writer::IpcDataGenerator::default(); let (encoded_dictionaries, encoded_batch) = data_gen - .encoded_batch(batch, dictionary_tracker, options) + .encode(batch, dictionary_tracker, options, compression_context) .expect("DictionaryTracker configured above to not error on replacement"); let dictionary_flight_data: Vec = @@ -213,7 +221,9 @@ async fn consume_flight_location( // more details: https://github.com/apache/arrow-rs/issues/1398 location.uri = location.uri.replace("grpc+tcp://", "http://"); - let mut client = FlightServiceClient::connect(location.uri).await?; + let endpoint = Endpoint::new(location.uri)?; + let channel = endpoint.connect().await?; + let mut client = FlightServiceClient::new(channel); let resp = client.do_get(ticket).await?; let mut resp = resp.into_inner(); diff --git a/arrow-integration-testing/src/flight_client_scenarios/middleware.rs b/arrow-integration-testing/src/flight_client_scenarios/middleware.rs index 495825738aec..e8836c34c47d 100644 --- a/arrow-integration-testing/src/flight_client_scenarios/middleware.rs +++ b/arrow-integration-testing/src/flight_client_scenarios/middleware.rs @@ -18,7 +18,7 @@ //! Scenario for testing middleware. use arrow_flight::{ - flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient, FlightDescriptor, + FlightDescriptor, flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient, }; use prost::bytes::Bytes; use tonic::{Request, Status}; diff --git a/arrow-integration-testing/src/flight_server_scenarios/auth_basic_proto.rs b/arrow-integration-testing/src/flight_server_scenarios/auth_basic_proto.rs index 5462e5bd674b..38582e6fef68 100644 --- a/arrow-integration-testing/src/flight_server_scenarios/auth_basic_proto.rs +++ b/arrow-integration-testing/src/flight_server_scenarios/auth_basic_proto.rs @@ -21,13 +21,13 @@ use std::pin::Pin; use std::sync::Arc; use arrow_flight::{ - flight_service_server::FlightService, flight_service_server::FlightServiceServer, Action, - ActionType, BasicAuth, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, + Action, ActionType, BasicAuth, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket, + flight_service_server::FlightService, flight_service_server::FlightServiceServer, }; -use futures::{channel::mpsc, sink::SinkExt, Stream, StreamExt}; +use futures::{Stream, StreamExt, channel::mpsc, sink::SinkExt}; use tokio::sync::Mutex; -use tonic::{metadata::MetadataMap, transport::Server, Request, Response, Status, Streaming}; +use tonic::{Request, Response, Status, Streaming, metadata::MetadataMap, transport::Server}; type TonicStream = Pin + Send + Sync + 'static>>; type Error = Box; diff --git a/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs b/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs index 92989a20393e..ae316886381a 100644 --- a/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs +++ b/arrow-integration-testing/src/flight_server_scenarios/integration_test.rs @@ -31,14 +31,14 @@ use arrow::{ record_batch::RecordBatch, }; use arrow_flight::{ - flight_descriptor::DescriptorType, flight_service_server::FlightService, - flight_service_server::FlightServiceServer, Action, ActionType, Criteria, Empty, FlightData, - FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest, HandshakeResponse, IpcMessage, - PollInfo, PutResult, SchemaAsIpc, SchemaResult, Ticket, + Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, + HandshakeRequest, HandshakeResponse, IpcMessage, PollInfo, PutResult, SchemaAsIpc, + SchemaResult, Ticket, flight_descriptor::DescriptorType, flight_service_server::FlightService, + flight_service_server::FlightServiceServer, }; -use futures::{channel::mpsc, sink::SinkExt, Stream, StreamExt}; +use futures::{Stream, StreamExt, channel::mpsc, sink::SinkExt}; use tokio::sync::Mutex; -use tonic::{transport::Server, Request, Response, Status, Streaming}; +use tonic::{Request, Response, Status, Streaming, transport::Server}; type TonicStream = Pin + Send + Sync + 'static>>; @@ -119,9 +119,7 @@ impl FlightService for FlightServiceImpl { .ok_or_else(|| Status::not_found(format!("Could not find flight. {key}")))?; let options = arrow::ipc::writer::IpcWriteOptions::default(); - #[allow(deprecated)] - let mut dictionary_tracker = - writer::DictionaryTracker::new_with_preserve_dict_id(false, options.preserve_dict_id()); + let mut dictionary_tracker = writer::DictionaryTracker::new(false); let data_gen = writer::IpcDataGenerator::default(); let data = IpcMessage( data_gen @@ -146,7 +144,12 @@ impl FlightService for FlightServiceImpl { .enumerate() .flat_map(|(counter, batch)| { let (encoded_dictionaries, encoded_batch) = data_gen - .encoded_batch(batch, &mut dictionary_tracker, &options) + .encode( + batch, + &mut dictionary_tracker, + &options, + &mut Default::default(), + ) .expect("DictionaryTracker configured above to not error on replacement"); let dictionary_flight_data = encoded_dictionaries.into_iter().map(Into::into); @@ -380,7 +383,7 @@ async fn save_uploaded_chunks( ipc::MessageHeader::Schema => { return Err(Status::internal( "Not expecting a schema when messages are read", - )) + )); } ipc::MessageHeader::RecordBatch => { send_app_metadata(&mut response_tx, &data.app_metadata).await?; diff --git a/arrow-integration-testing/src/flight_server_scenarios/middleware.rs b/arrow-integration-testing/src/flight_server_scenarios/middleware.rs index 6685d45dffac..6bafb4843316 100644 --- a/arrow-integration-testing/src/flight_server_scenarios/middleware.rs +++ b/arrow-integration-testing/src/flight_server_scenarios/middleware.rs @@ -20,13 +20,13 @@ use std::pin::Pin; use arrow_flight::{ + Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, + HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket, flight_descriptor::DescriptorType, flight_service_server::FlightService, - flight_service_server::FlightServiceServer, Action, ActionType, Criteria, Empty, FlightData, - FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse, PollInfo, PutResult, - SchemaResult, Ticket, + flight_service_server::FlightServiceServer, }; use futures::Stream; -use tonic::{transport::Server, Request, Response, Status, Streaming}; +use tonic::{Request, Response, Status, Streaming, transport::Server}; type TonicStream = Pin + Send + Sync + 'static>>; diff --git a/arrow-integration-testing/src/lib.rs b/arrow-integration-testing/src/lib.rs index e669690ef4f5..cf572d769df5 100644 --- a/arrow-integration-testing/src/lib.rs +++ b/arrow-integration-testing/src/lib.rs @@ -25,12 +25,12 @@ use serde_json::Value; use arrow::array::{Array, StructArray}; use arrow::datatypes::{DataType, Field, Fields, Schema}; use arrow::error::{ArrowError, Result}; -use arrow::ffi::{from_ffi_and_data_type, FFI_ArrowArray, FFI_ArrowSchema}; +use arrow::ffi::{FFI_ArrowArray, FFI_ArrowSchema, from_ffi_and_data_type}; use arrow::record_batch::RecordBatch; use arrow::util::test_util::arrow_test_data; use arrow_integration_test::*; use std::collections::HashMap; -use std::ffi::{c_char, c_int, CStr, CString}; +use std::ffi::{CStr, CString, c_char, c_int}; use std::fs::File; use std::io::BufReader; use std::iter::zip; @@ -207,8 +207,7 @@ fn cdata_integration_import_schema_and_compare_to_json( // compare schemas if canonicalize_schema(&json_schema) != canonicalize_schema(&imported_schema) { return Err(ArrowError::ComputeError(format!( - "Schemas do not match.\n- JSON: {:?}\n- Imported: {:?}", - json_schema, imported_schema + "Schemas do not match.\n- JSON: {json_schema:?}\n- Imported: {imported_schema:?}", ))); } Ok(()) @@ -253,7 +252,7 @@ fn cdata_integration_import_batch_and_compare_to_json( fn result_to_c_error(result: &std::result::Result) -> *mut c_char { match result { Ok(_) => ptr::null_mut(), - Err(e) => CString::new(format!("{}", e)).unwrap().into_raw(), + Err(e) => CString::new(format!("{e}")).unwrap().into_raw(), } } @@ -262,7 +261,7 @@ fn result_to_c_error(result: &std::result::Result /// # Safety /// /// The pointer is assumed to have been obtained using CString::into_raw. -#[no_mangle] +#[unsafe(no_mangle)] pub unsafe extern "C" fn arrow_rs_free_error(c_error: *mut c_char) { if !c_error.is_null() { drop(unsafe { CString::from_raw(c_error) }); @@ -270,7 +269,7 @@ pub unsafe extern "C" fn arrow_rs_free_error(c_error: *mut c_char) { } /// A C-ABI for exporting an Arrow schema from a JSON file -#[no_mangle] +#[unsafe(no_mangle)] pub extern "C" fn arrow_rs_cdata_integration_export_schema_from_json( c_json_name: *const c_char, out: *mut FFI_ArrowSchema, @@ -280,7 +279,7 @@ pub extern "C" fn arrow_rs_cdata_integration_export_schema_from_json( } /// A C-ABI to compare an Arrow schema against a JSON file -#[no_mangle] +#[unsafe(no_mangle)] pub extern "C" fn arrow_rs_cdata_integration_import_schema_and_compare_to_json( c_json_name: *const c_char, c_schema: *mut FFI_ArrowSchema, @@ -290,7 +289,7 @@ pub extern "C" fn arrow_rs_cdata_integration_import_schema_and_compare_to_json( } /// A C-ABI for exporting a RecordBatch from a JSON file -#[no_mangle] +#[unsafe(no_mangle)] pub extern "C" fn arrow_rs_cdata_integration_export_batch_from_json( c_json_name: *const c_char, batch_num: c_int, @@ -301,7 +300,7 @@ pub extern "C" fn arrow_rs_cdata_integration_export_batch_from_json( } /// A C-ABI to compare a RecordBatch against a JSON file -#[no_mangle] +#[unsafe(no_mangle)] pub extern "C" fn arrow_rs_cdata_integration_import_batch_and_compare_to_json( c_json_name: *const c_char, batch_num: c_int, diff --git a/arrow-ipc/Cargo.toml b/arrow-ipc/Cargo.toml index a1f826ef7d10..943852ffdec9 100644 --- a/arrow-ipc/Cargo.toml +++ b/arrow-ipc/Cargo.toml @@ -40,8 +40,9 @@ arrow-array = { workspace = true } arrow-buffer = { workspace = true } arrow-data = { workspace = true } arrow-schema = { workspace = true } +arrow-select = { workspace = true} flatbuffers = { version = "25.2.10", default-features = false } -lz4_flex = { version = "0.11", default-features = false, features = ["std", "frame"], optional = true } +lz4_flex = { version = "0.12", default-features = false, features = ["std", "frame"], optional = true } zstd = { version = "0.13.0", default-features = false, optional = true } [features] @@ -49,7 +50,7 @@ default = [] lz4 = ["lz4_flex"] [dev-dependencies] -criterion = "0.5.1" +criterion = { workspace = true } tempfile = "3.3" tokio = "1.43.0" # used in benches diff --git a/arrow-ipc/benches/ipc_reader.rs b/arrow-ipc/benches/ipc_reader.rs index ab77449eeb7d..ef1de88d328d 100644 --- a/arrow-ipc/benches/ipc_reader.rs +++ b/arrow-ipc/benches/ipc_reader.rs @@ -16,14 +16,14 @@ // under the License. use arrow_array::builder::{Date32Builder, Decimal128Builder, Int32Builder}; -use arrow_array::{builder::StringBuilder, RecordBatch}; +use arrow_array::{RecordBatch, builder::StringBuilder}; use arrow_buffer::Buffer; use arrow_ipc::convert::fb_to_schema; -use arrow_ipc::reader::{read_footer_length, FileDecoder, FileReader, StreamReader}; +use arrow_ipc::reader::{FileDecoder, FileReader, StreamReader, read_footer_length}; use arrow_ipc::writer::{FileWriter, IpcWriteOptions, StreamWriter}; -use arrow_ipc::{root_as_footer, Block, CompressionType}; +use arrow_ipc::{Block, CompressionType, root_as_footer}; use arrow_schema::{DataType, Field, Schema}; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use std::io::{Cursor, Write}; use std::sync::Arc; use tempfile::tempdir; @@ -240,7 +240,7 @@ impl IPCBufferDecoder { } unsafe fn with_skip_validation(mut self, skip_validation: bool) -> Self { - self.decoder = self.decoder.with_skip_validation(skip_validation); + self.decoder = unsafe { self.decoder.with_skip_validation(skip_validation) }; self } diff --git a/arrow-ipc/benches/ipc_writer.rs b/arrow-ipc/benches/ipc_writer.rs index 6b4d184b4556..eda7e3c58fe0 100644 --- a/arrow-ipc/benches/ipc_writer.rs +++ b/arrow-ipc/benches/ipc_writer.rs @@ -16,11 +16,11 @@ // under the License. use arrow_array::builder::{Date32Builder, Decimal128Builder, Int32Builder}; -use arrow_array::{builder::StringBuilder, RecordBatch}; -use arrow_ipc::writer::{FileWriter, IpcWriteOptions, StreamWriter}; +use arrow_array::{RecordBatch, builder::StringBuilder}; use arrow_ipc::CompressionType; +use arrow_ipc::writer::{FileWriter, IpcWriteOptions, StreamWriter}; use arrow_schema::{DataType, Field, Schema}; -use criterion::{criterion_group, criterion_main, Criterion}; +use criterion::{Criterion, criterion_group, criterion_main}; use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { diff --git a/arrow-ipc/regen.sh b/arrow-ipc/regen.sh index b368bd1bc7cc..676ec9933c55 100755 --- a/arrow-ipc/regen.sh +++ b/arrow-ipc/regen.sh @@ -88,9 +88,9 @@ use flatbuffers::EndianScalar; HEREDOC ) -SCHEMA_IMPORT="\nuse crate::gen::Schema::*;" -SPARSE_TENSOR_IMPORT="\nuse crate::gen::SparseTensor::*;" -TENSOR_IMPORT="\nuse crate::gen::Tensor::*;" +SCHEMA_IMPORT="\nuse crate::r#gen::Schema::*;" +SPARSE_TENSOR_IMPORT="\nuse crate::r#gen::SparseTensor::*;" +TENSOR_IMPORT="\nuse crate::r#gen::Tensor::*;" # For flatbuffer(1.12.0+), remove: use crate::${name}::\*; names=("File" "Message" "Schema" "SparseTensor" "Tensor") @@ -129,7 +129,7 @@ for f in `ls *.rs`; do sed --in-place='' 's/TYPE__/TYPE_/g' $f # Some files need prefixes - if [[ $f == "File.rs" ]]; then + if [[ $f == "File.rs" ]]; then # Now prefix the file with the static contents echo -e "${PREFIX}" "${SCHEMA_IMPORT}" | cat - $f > temp && mv temp $f elif [[ $f == "Message.rs" ]]; then diff --git a/arrow-ipc/src/compression.rs b/arrow-ipc/src/compression.rs index 47ea7785cbec..9bbc6e752c12 100644 --- a/arrow-ipc/src/compression.rs +++ b/arrow-ipc/src/compression.rs @@ -22,6 +22,41 @@ use arrow_schema::ArrowError; const LENGTH_NO_COMPRESSED_DATA: i64 = -1; const LENGTH_OF_PREFIX_DATA: i64 = 8; +/// Additional context that may be needed for compression. +/// +/// In the case of zstd, this will contain the zstd context, which can be reused between subsequent +/// compression calls to avoid the performance overhead of initialising a new context for every +/// compression. +pub struct CompressionContext { + #[cfg(feature = "zstd")] + compressor: zstd::bulk::Compressor<'static>, +} + +// the reason we allow derivable_impls here is because when zstd feature is not enabled, this +// becomes derivable. however with zstd feature want to be explicit about the compression level. +#[allow(clippy::derivable_impls)] +impl Default for CompressionContext { + fn default() -> Self { + CompressionContext { + // safety: `new` here will only return error here if using an invalid compression level + #[cfg(feature = "zstd")] + compressor: zstd::bulk::Compressor::new(zstd::DEFAULT_COMPRESSION_LEVEL) + .expect("can use default compression level"), + } + } +} + +impl std::fmt::Debug for CompressionContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut ds = f.debug_struct("CompressionContext"); + + #[cfg(feature = "zstd")] + ds.field("compressor", &"zstd::bulk::Compressor"); + + ds.finish() + } +} + /// Represents compressing a ipc stream using a particular compression algorithm #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum CompressionCodec { @@ -58,6 +93,7 @@ impl CompressionCodec { &self, input: &[u8], output: &mut Vec, + context: &mut CompressionContext, ) -> Result { let uncompressed_data_len = input.len(); let original_output_len = output.len(); @@ -67,7 +103,7 @@ impl CompressionCodec { } else { // write compressed data directly into the output buffer output.extend_from_slice(&uncompressed_data_len.to_le_bytes()); - self.compress(input, output)?; + self.compress(input, output, context)?; let compression_len = output.len() - original_output_len; if compression_len > uncompressed_data_len { @@ -115,10 +151,15 @@ impl CompressionCodec { /// Compress the data in input buffer and write to output buffer /// using the specified compression - fn compress(&self, input: &[u8], output: &mut Vec) -> Result<(), ArrowError> { + fn compress( + &self, + input: &[u8], + output: &mut Vec, + context: &mut CompressionContext, + ) -> Result<(), ArrowError> { match self { CompressionCodec::Lz4Frame => compress_lz4(input, output), - CompressionCodec::Zstd => compress_zstd(input, output), + CompressionCodec::Zstd => compress_zstd(input, output, context), } } @@ -175,17 +216,23 @@ fn decompress_lz4(_input: &[u8], _decompressed_size: usize) -> Result, A } #[cfg(feature = "zstd")] -fn compress_zstd(input: &[u8], output: &mut Vec) -> Result<(), ArrowError> { - use std::io::Write; - let mut encoder = zstd::Encoder::new(output, 0)?; - encoder.write_all(input)?; - encoder.finish()?; +fn compress_zstd( + input: &[u8], + output: &mut Vec, + context: &mut CompressionContext, +) -> Result<(), ArrowError> { + let result = context.compressor.compress(input)?; + output.extend_from_slice(&result); Ok(()) } #[cfg(not(feature = "zstd"))] #[allow(clippy::ptr_arg)] -fn compress_zstd(_input: &[u8], _output: &mut Vec) -> Result<(), ArrowError> { +fn compress_zstd( + _input: &[u8], + _output: &mut Vec, + _context: &mut CompressionContext, +) -> Result<(), ArrowError> { Err(ArrowError::InvalidArgumentError( "zstd IPC compression requires the zstd feature".to_string(), )) @@ -227,7 +274,9 @@ mod tests { let input_bytes = b"hello lz4"; let codec = super::CompressionCodec::Lz4Frame; let mut output_bytes: Vec = Vec::new(); - codec.compress(input_bytes, &mut output_bytes).unwrap(); + codec + .compress(input_bytes, &mut output_bytes, &mut Default::default()) + .unwrap(); let result = codec .decompress(output_bytes.as_slice(), input_bytes.len()) .unwrap(); @@ -240,7 +289,9 @@ mod tests { let input_bytes = b"hello zstd"; let codec = super::CompressionCodec::Zstd; let mut output_bytes: Vec = Vec::new(); - codec.compress(input_bytes, &mut output_bytes).unwrap(); + codec + .compress(input_bytes, &mut output_bytes, &mut Default::default()) + .unwrap(); let result = codec .decompress(output_bytes.as_slice(), input_bytes.len()) .unwrap(); diff --git a/arrow-ipc/src/convert.rs b/arrow-ipc/src/convert.rs index 9c6c3831067c..16e61deadb0f 100644 --- a/arrow-ipc/src/convert.rs +++ b/arrow-ipc/src/convert.rs @@ -19,6 +19,7 @@ use arrow_buffer::Buffer; use arrow_schema::*; +use core::panic; use flatbuffers::{ FlatBufferBuilder, ForwardsUOffset, UnionWIPOffset, Vector, Verifiable, Verifier, VerifierOptions, WIPOffset, @@ -28,7 +29,7 @@ use std::fmt::{Debug, Formatter}; use std::sync::Arc; use crate::writer::DictionaryTracker; -use crate::{KeyValue, Message, CONTINUATION_MARKER}; +use crate::{CONTINUATION_MARKER, KeyValue, Message}; use DataType::*; /// Low level Arrow [Schema] to IPC bytes converter @@ -127,12 +128,6 @@ impl<'a> IpcSchemaEncoder<'a> { } } -/// Serialize a schema in IPC format -#[deprecated(since = "54.0.0", note = "Use `IpcSchemaConverter`.")] -pub fn schema_to_fb(schema: &Schema) -> FlatBufferBuilder<'_> { - IpcSchemaEncoder::new().schema_to_fb(schema) -} - /// Push a key-value metadata into a FlatBufferBuilder and return [WIPOffset] pub fn metadata_to_fb<'a>( fbb: &mut FlatBufferBuilder<'a>, @@ -170,7 +165,7 @@ impl From> for Field { let arrow_field = if let Some(dictionary) = field.dictionary() { #[allow(deprecated)] Field::new_dict( - field.name().unwrap(), + field.name().unwrap_or_default(), get_data_type(field, true), field.nullable(), dictionary.id(), @@ -178,7 +173,7 @@ impl From> for Field { ) } else { Field::new( - field.name().unwrap(), + field.name().unwrap_or_default(), get_data_type(field, true), field.nullable(), ) @@ -284,9 +279,9 @@ pub fn try_schema_from_ipc_buffer(buffer: &[u8]) -> Result { if buffer.len() < len as usize { let actual_len = buffer.len(); - return Err(ArrowError::ParseError( - format!("The buffer length ({actual_len}) is less than the encapsulated message's reported length ({len})") - )); + return Err(ArrowError::ParseError(format!( + "The buffer length ({actual_len}) is less than the encapsulated message's reported length ({len})" + ))); } let msg = crate::root_as_message(buffer) @@ -430,6 +425,20 @@ pub(crate) fn get_data_type(field: crate::Field, may_be_dictionary: bool) -> Dat } DataType::LargeList(Arc::new(children.get(0).into())) } + crate::Type::ListView => { + let children = field.children().unwrap(); + if children.len() != 1 { + panic!("expect a listview to have one child") + } + DataType::ListView(Arc::new(children.get(0).into())) + } + crate::Type::LargeListView => { + let children = field.children().unwrap(); + if children.len() != 1 { + panic!("expect a large listview to have one child") + } + DataType::LargeListView(Arc::new(children.get(0).into())) + } crate::Type::FixedSizeList => { let children = field.children().unwrap(); if children.len() != 1 { @@ -471,6 +480,8 @@ pub(crate) fn get_data_type(field: crate::Field, may_be_dictionary: bool) -> Dat let precision: u8 = fsb.precision().try_into().unwrap(); let scale: i8 = fsb.scale().try_into().unwrap(); match bit_width { + 32 => DataType::Decimal32(precision, scale), + 64 => DataType::Decimal64(precision, scale), 128 => DataType::Decimal128(precision, scale), 256 => DataType::Decimal256(precision, scale), _ => panic!("Unexpected decimal bit width {bit_width}"), @@ -493,8 +504,9 @@ pub(crate) fn get_data_type(field: crate::Field, may_be_dictionary: bool) -> Dat }; let fields = match union.typeIds() { - None => UnionFields::new(0_i8..fields.len() as i8, fields), - Some(ids) => UnionFields::new(ids.iter().map(|i| i as i8), fields), + None => UnionFields::from_fields(fields), + Some(ids) => UnionFields::try_new(ids.iter().map(|i| i as i8), fields) + .expect("invalid union field"), }; DataType::Union(fields, union_mode) @@ -528,24 +540,13 @@ pub(crate) fn build_field<'a>( match dictionary_tracker { Some(tracker) => Some(get_fb_dictionary( index_type, - #[allow(deprecated)] - tracker.set_dict_id(field), - field - .dict_is_ordered() - .expect("All Dictionary types have `dict_is_ordered`"), - fbb, - )), - None => Some(get_fb_dictionary( - index_type, - #[allow(deprecated)] - field - .dict_id() - .expect("Dictionary type must have a dictionary id"), + tracker.next_dict_id(), field .dict_is_ordered() .expect("All Dictionary types have `dict_is_ordered`"), fbb, )), + None => panic!("IPC must no longer be used without dictionary tracker"), } } else { None @@ -774,7 +775,7 @@ pub(crate) fn get_fb_field_type<'a>( children: Some(fbb.create_vector(&empty_fields[..])), } } - List(ref list_type) => { + List(list_type) => { let child = build_field(fbb, dictionary_tracker, list_type); FBFieldType { type_type: crate::Type::List, @@ -782,8 +783,25 @@ pub(crate) fn get_fb_field_type<'a>( children: Some(fbb.create_vector(&[child])), } } - ListView(_) | LargeListView(_) => unimplemented!("ListView/LargeListView not implemented"), - LargeList(ref list_type) => { + ListView(list_type) => { + let child = build_field(fbb, dictionary_tracker, list_type); + FBFieldType { + type_type: crate::Type::ListView, + type_: crate::ListViewBuilder::new(fbb).finish().as_union_value(), + children: Some(fbb.create_vector(&[child])), + } + } + LargeListView(list_type) => { + let child = build_field(fbb, dictionary_tracker, list_type); + FBFieldType { + type_type: crate::Type::LargeListView, + type_: crate::LargeListViewBuilder::new(fbb) + .finish() + .as_union_value(), + children: Some(fbb.create_vector(&[child])), + } + } + LargeList(list_type) => { let child = build_field(fbb, dictionary_tracker, list_type); FBFieldType { type_type: crate::Type::LargeList, @@ -791,7 +809,7 @@ pub(crate) fn get_fb_field_type<'a>( children: Some(fbb.create_vector(&[child])), } } - FixedSizeList(ref list_type, len) => { + FixedSizeList(list_type, len) => { let child = build_field(fbb, dictionary_tracker, list_type); let mut builder = crate::FixedSizeListBuilder::new(fbb); builder.add_listSize(*len); @@ -841,6 +859,28 @@ pub(crate) fn get_fb_field_type<'a>( // type in the DictionaryEncoding metadata in the parent field get_fb_field_type(value_type, dictionary_tracker, fbb) } + Decimal32(precision, scale) => { + let mut builder = crate::DecimalBuilder::new(fbb); + builder.add_precision(*precision as i32); + builder.add_scale(*scale as i32); + builder.add_bitWidth(32); + FBFieldType { + type_type: crate::Type::Decimal, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + } + } + Decimal64(precision, scale) => { + let mut builder = crate::DecimalBuilder::new(fbb); + builder.add_precision(*precision as i32); + builder.add_scale(*scale as i32); + builder.add_bitWidth(64); + FBFieldType { + type_type: crate::Type::Decimal, + type_: builder.finish().as_union_value(), + children: Some(fbb.create_vector(&empty_fields[..])), + } + } Decimal128(precision, scale) => { let mut builder = crate::DecimalBuilder::new(fbb); builder.add_precision(*precision as i32); @@ -1143,13 +1183,14 @@ mod tests { Field::new( "union", DataType::Union( - UnionFields::new( + UnionFields::try_new( vec![2, 3], // non-default type ids vec![ Field::new("int32", DataType::Int32, true), Field::new("utf8", DataType::Utf8, true), ], - ), + ) + .unwrap(), UnionMode::Dense, ), true, diff --git a/arrow-ipc/src/gen/File.rs b/arrow-ipc/src/gen/File.rs index 427cf75de096..ab2273614759 100644 --- a/arrow-ipc/src/gen/File.rs +++ b/arrow-ipc/src/gen/File.rs @@ -18,7 +18,7 @@ #![allow(dead_code)] #![allow(unused_imports)] -use crate::gen::Schema::*; +use crate::r#gen::Schema::*; use flatbuffers::EndianScalar; use std::{cmp::Ordering, mem}; // automatically generated by the FlatBuffers compiler, do not modify @@ -49,21 +49,26 @@ impl<'a> flatbuffers::Follow<'a> for Block { type Inner = &'a Block; #[inline] unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { - <&'a Block>::follow(buf, loc) + unsafe { <&'a Block>::follow(buf, loc) } } } impl<'a> flatbuffers::Follow<'a> for &'a Block { type Inner = &'a Block; #[inline] unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { - flatbuffers::follow_cast_ref::(buf, loc) + unsafe { flatbuffers::follow_cast_ref::(buf, loc) } } } impl<'b> flatbuffers::Push for Block { type Output = Block; #[inline] unsafe fn push(&self, dst: &mut [u8], _written_len: usize) { - let src = ::core::slice::from_raw_parts(self as *const Block as *const u8, Self::size()); + let src = unsafe { + ::core::slice::from_raw_parts( + self as *const Block as *const u8, + ::size(), + ) + }; dst.copy_from_slice(src); } #[inline] @@ -200,7 +205,7 @@ impl<'a> flatbuffers::Follow<'a> for Footer<'a> { #[inline] unsafe fn follow(buf: &'a [u8], loc: usize) -> Self::Inner { Self { - _tab: flatbuffers::Table::new(buf, loc), + _tab: unsafe { flatbuffers::Table::new(buf, loc) }, } } } @@ -470,14 +475,14 @@ pub fn size_prefixed_root_as_footer_with_opts<'b, 'o>( /// # Safety /// Callers must trust the given bytes do indeed contain a valid `Footer`. pub unsafe fn root_as_footer_unchecked(buf: &[u8]) -> Footer { - flatbuffers::root_unchecked::