Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallel Apache Arrow + DuckDB Solution to the One Trillion Row Challenge 26 mins 15 secs on Dell Workstation #3

Open
MurrayData opened this issue Feb 8, 2024 · 16 comments

Comments

@MurrayData
Copy link

I wanted to test my new Dell Precision 7960 workstation on this task.

The hardware spec is: Intel(R) Xeon(R) w5-3435X CPU 16 core/32 thread max speed 4.7GHz, 512GB DDR5 LRDIMMs running at 4400 MHz, 4 x Samsung Pro 990 2TB Gen 4 NVMe in a RAID 0 in a Dell UltraSpeed card in a PCIe 5.0 x 16 slot, NVIDIA RTX A6000 (Ampere) GPU 48GB.

I tried several approaches, but settled for a native Apache Arrow table group by solution using parallel workers to execute the chunks. The first stage aggregation uses Apache Arrow tables to compute min, max, sum and count of temperature for each station in a group by.

def aggregate_chunk(start, inc, files):
    last = min(len(files), start+inc)
    station_ds = pa.dataset.dataset(files[start:last])
    station_table = station_ds.to_table() 
    result_table = pa.TableGroupBy(station_table, 'station').aggregate([('measure','min'),('measure','max'),('measure','sum'),('measure','count')])
    return result_table

Following concatenation of the group by tables, a second stage aggregation is run using DuckDB to group by station name and compute min and max of the aggregate and mean by dividing the aggregate sum by the aggregate count.

%%time
cpus = mp.cpu_count()
pool_size = cpus // 4
print(f'CPU thread count: {cpus}\nParallel workers pool size: {pool_size}\nFile batch size: {inc}')

with mp.Pool(pool_size) as pool:
    results = list(tqdm.tqdm(pool.imap(f, range(0,max_value,inc)), total=max_value//inc))

Which generates the following output:

CPU thread count: 32
Parallel workers pool size: 8
File batch size: 10
100%|██████████| 10000/10000 [26:14<00:00,  6.35it/s]
CPU times: user 6min 11s, sys: 1min 55s, total: 8min 7s
Wall time: 26min 14s

Interestingly the optimal solution, found by trial and error, was to use a smaller file batch size (10 files) and 8 parallel workers.

Following concatenation of the group by tables, a second stage aggregation is run using DuckDB to group and sort by station name then compute the min and max of the aggregate chunks and the mean computed by dividing the aggregate sum by the aggregate count.

query = """
        SELECT station,
        MIN(measure_min) AS measure_min,
        MAX(measure_min) AS measure_max,
        SUM(measure_sum) / SUM(measure_count) AS measure_mean
        FROM summary_table
        GROUP BY station
        ORDER BY station
        """
print(duckdb_explain(con, query))

Explanation:

physical_plan
┌───────────────────────────┐
│          ORDER_BY         │
│   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
│          ORDERS:          │
│ summary_table.station ASC │
└─────────────┬─────────────┘                             
┌─────────────┴─────────────┐
│         PROJECTION        │
│   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
│          station          │
│        measure_min        │
│        measure_max        │
│        measure_mean       │
└─────────────┬─────────────┘                             
┌─────────────┴─────────────┐
│       HASH_GROUP_BY       │
│   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
│             #0            │
│          min(#1)          │
│          max(#2)          │
│          sum(#3)          │
│          sum(#4)          │
└─────────────┬─────────────┘                             
┌─────────────┴─────────────┐
│         PROJECTION        │
│   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
│          station          │
│        measure_min        │
│        measure_min        │
│        measure_sum        │
│       measure_count       │
└─────────────┬─────────────┘                             
┌─────────────┴─────────────┐
│        ARROW_SCAN         │
│   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
│          station          │
│        measure_min        │
│        measure_sum        │
│       measure_count       │
│   ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─   │
│           EC: 1           │
└───────────────────────────┘                             

Execute the query:

%%time
result = con.execute(query)

CPU times: user 773 ms, sys: 152 ms, total: 926 ms
Wall time: 285 ms

Generates the output:

           station  measure_min  measure_max  measure_mean
0             Abha        -43.8        -21.6     18.000158
1          Abidjan        -34.4        -13.7     25.999815
2           Abéché        -33.6        -10.2     29.400132
3            Accra        -34.1        -13.2     26.400023
4      Addis Ababa        -58.0        -22.7     16.000132
..             ...          ...          ...           ...
407       Yinchuan        -53.1        -30.5      9.000168
408         Zagreb        -52.8        -29.2     10.699659
409  Zanzibar City        -34.2        -13.7     26.000115
410         Ürümqi        -52.9        -32.0      7.399789
411          İzmir        -45.3        -21.7     17.899969

[412 rows x 4 columns]

Total elapsed time:

print(f'Time taken: {dt.timedelta(seconds=time.time()-start)}')

Time taken: 0:26:15.065602
@MurrayData
Copy link
Author

Addendum. Although not specified in the task spec, the original output looked untidy so I modified the final DuckDB query to perform a diacritic insensitive sort of the station names.

SELECT station,
MIN(measure_min) AS measure_min,
MAX(measure_min) AS measure_max,
SUM(measure_sum) / SUM(measure_count) AS measure_mean
FROM summary_table
GROUP BY station
ORDER BY strip_accents(station)

Which generated the much neater looking output:

           station  measure_min  measure_max  measure_mean
0           Abéché        -33.6        -10.2     29.400132
1             Abha        -43.8        -21.6     18.000158
2          Abidjan        -34.4        -13.7     25.999815
3            Accra        -34.1        -13.2     26.400023
4      Addis Ababa        -58.0        -22.7     16.000132
..             ...          ...          ...           ...
407    Yellowknife        -66.4        -44.1     -4.300343
408        Yerevan        -52.3        -27.6     12.400337
409       Yinchuan        -53.1        -30.5      9.000168
410         Zagreb        -52.8        -29.2     10.699659
411  Zanzibar City        -34.2        -13.7     26.000115

[412 rows x 4 columns]

The final output, in CSV format, may be found here:

1trc_results.csv

@mrocklin
Copy link
Member

mrocklin commented Feb 8, 2024

Cool. I'm curious, where was the data stored, on the local RAID?

@MurrayData
Copy link
Author

Cool. I'm curious, where was the data stored, on the local RAID?

Yes, it's generated on the NVMe RAID 0 device which provides up to 26.4 GB/s read speed.

image

@mrocklin
Copy link
Member

mrocklin commented Feb 8, 2024

That's fun. Did you achieve that read speed during the calculation? (if so, my guess is that the choice of computational tool doesn't matter that much, and we're just IO bound here)

@MurrayData
Copy link
Author

That's fun. Did you achieve that read speed during the calculation? (if so, my guess is that the choice of computational tool doesn't matter that much, and we're just IO bound here)

I was monitoring the I/O. It peaked at around 7 GB/s per worker, but also had idle spots in between. My estimate of the split is 18 minutes I/O and data transfer, 8 minutes computation time.

@mrocklin
Copy link
Member

mrocklin commented Feb 8, 2024

I'm surprirsed that they weren't overlapping more. Need more concurrency maybe?

@MurrayData
Copy link
Author

I'm surprised that they weren't overlapping more. Need more concurrency maybe?

The Python Multiprocessing library, which is what I used in this case, is quite basic and lacks many controls. There is almost certainly a better approach. I have some ideas to try.

@shughes-uk
Copy link

I'm surprised that they weren't overlapping more. Need more concurrency maybe?

The Python Multiprocessing library, which is what I used in this case, is quite basic and lacks many controls. There is almost certainly a better approach. I have some ideas to try.

I hear dask is quite a solid upgrade from the stdlib multiprocessing library 😉

@MurrayData
Copy link
Author

I'm surprised that they weren't overlapping more. Need more concurrency maybe?

The Python Multiprocessing library, which is what I used in this case, is quite basic and lacks many controls. There is almost certainly a better approach. I have some ideas to try.

I hear dask is quite a solid upgrade from the stdlib multiprocessing library 😉

I tried Dask and Dask-CUDF (see my LinkedIn posts). Both worked, but took around 4 to 5 times longer than the native Apache Arrow solution. Running Arrow batches in parallel, with optimal batch size and number of workers, while not particularly elegant was by far the fastest on this particular system.

@mrocklin
Copy link
Member

mrocklin commented Feb 9, 2024

I think that @shughes-uk is likely talking about just using raw Dask the parallel computing solution, not Dask Dataframe the big pandas implementation (sometimes people conflate the two).

This doc might be of interest: https://docs.dask.org/en/stable/futures.html

@MurrayData
Copy link
Author

I think that @shughes-uk is likely talking about just using raw Dask the parallel computing solution, not Dask Dataframe the big pandas implementation (sometimes people conflate the two).

This doc might be of interest: https://docs.dask.org/en/stable/futures.html

That's really interesting. Being totally honest, I have always thought of the Dask Dataframe. I will definitely look into this. Thank you @shughes-uk and @mrocklin

@MurrayData
Copy link
Author

I really appreciate the advice @shughes-uk @mrocklin. I just read the docs, and tried the example notebook on the workstation. Dask futures is a really powerful tool, thank you.

image

image

@mrocklin
Copy link
Member

I'm glad you like it. Was it able to boost your preference at all or did you get the same as with multiprocessing?

@MurrayData
Copy link
Author

I'm glad you like it. Was it able to boost your preference at all or did you get the same as with multiprocessing?

It's was marginally faster, but it's a neater solution to code as I could tier the final stage aggregation of the chunk summaries as a function using the results from the first stage.

Total run time came down to 24 minutes, 10 seconds.

The task is very much I/O bound but by tweaking parameters I was able to reduce idle times and speed it up slightly.

I experimented with parameters. It's a trade off between maximising occupancy and minimising file system contention on the parallel tasks. On this hardware I found:

  • The optimal number of workers was 32, the same as the CPU hyperthread count (2 x physical core). This is ok as the I/O causes latency during which the CPU is idle.
  • The optimal number of threads per worker is 1. Any more, and performance degrades. I tried 2 threads and 32 workers, but CPU occupancy dropped to around 60%.
  • The optimal number of files read per chunk is 8. Any more, or any less, and performance degraded. Too few and CPU was under occupied, too many and file system contention slowed the process down, even with a small increase to 10.
  • With a large number of future jobs like this (12,500), it is advisable to use the batch_size parameter in the client.map function. I found the optimal was to set the batch_size to the number of workers, in this case, so processing begins immediately. If the batch_size parameter is not used, then there is a 12 minute delay before processing commences with these parameters.

Initialisation:

from dask.distributed import Client, progress
cpus = mp.cpu_count() # Number of CPU hyperthreads
client = Client(n_workers=cpus)
client

Aggregate chunks function:

def aggregate_chunk(start, inc, files):
    last = min(len(files), start+inc)
    station_ds = pa.dataset.dataset(files[start:last])
    station_table = station_ds.to_table() 
    return pa.TableGroupBy(station_table, 'station').aggregate([('measure','min'),('measure','max'),('measure','sum'),('measure','count')])

Aggregate results of chunks to produce summary:

def aggregate_results(results):
    summary_table = pa.concat_tables(results)

    query = """
            SELECT station,
            MIN(measure_min) AS measure_min,
            MAX(measure_min) AS measure_max,
            SUM(measure_sum) / SUM(measure_count) AS measure_mean
            FROM summary_table
            GROUP BY station
            ORDER BY strip_accents(station)
            """
    result = duckdb.execute(query)
    return result.df()

Set the number of parameters and create the partial function:

max_value = n
inc = 8
print(f'No of tasks: {max_value // inc:,}\nNo of cores: {cpus}')
f = functools.partial(aggregate_chunk, inc=inc, files=files)

Map the aggregate chunks function:

%%time
futures = client.map(f, range(0,max_value,inc), batch_size=cpus)

Submit the summary aggregation function to work on the futures from the above:

%%time
results = client.submit(aggregate_results, futures)

Get the results from the summary aggregation:

%%time
result_df = results.result()

And here is the summary:

0           Abéché        -33.6         -9.5     29.400132
1             Abha        -43.8        -20.5     18.000158
2          Abidjan        -34.4        -13.1     25.999815
3            Accra        -34.1        -12.8     26.400023
4      Addis Ababa        -58.0        -22.5     16.000132
..             ...          ...          ...           ...
407    Yellowknife        -66.4        -43.3     -4.300343
408        Yerevan        -52.3        -26.6     12.400337
409       Yinchuan        -53.1        -29.7      9.000168
410         Zagreb        -52.8        -28.5     10.699659
411  Zanzibar City        -34.2        -12.9     26.000115

[412 rows x 4 columns]
CPU times: user 5.67 ms, sys: 433 µs, total: 6.1 ms
Wall time: 6.09 ms

image

@mrocklin
Copy link
Member

If the batch_size parameter is not used, then there is a 12 minute delay before processing commences with these parameters

This is odd. It sounds like the partial'ed function might be really hard to serialize, maybe because it contains the list of the 100,000 files? I'll bet that if you were to partition the files ahead of time (maybe with toolz.partition_all(nfiles, files) and pass that iterator of lists of files to the map call that things would compute more cleanly.

@RichardScottOZ
Copy link

That workstation could do lots of fun things with Dask

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants