diff --git a/wsds/ws_dataset.py b/wsds/ws_dataset.py index 16cf20a..8b5ae45 100644 --- a/wsds/ws_dataset.py +++ b/wsds/ws_dataset.py @@ -245,13 +245,14 @@ def _parse_sql_queries_polars(self, *queries, shard_subsample=1, rng=None, shard subdirs = defaultdict(list) exprs = [] + needed_special_columns = [] needs_key = False for query in queries: expr = pl.sql_expr(query) for col in expr.meta.root_names(): - if col == "__key__": + if col == "__key__" or col == '__shard_path__' or col == '__shard_offset__': # __key__ exists in all shards - needs_key = True + needed_special_columns.append(col) continue subdir, field = self.fields[col] assert col == field, "renamed fields are not supported in SQL queries yet" @@ -259,13 +260,12 @@ def _parse_sql_queries_polars(self, *queries, shard_subsample=1, rng=None, shard exprs.append(expr) # If only __key__ is in the query, we need to load shards from at least one subdir - if needs_key: - if not subdirs: - subdirs[self.fields["__key__"][0]].append('__key__') - else: - for f in subdirs.values(): - f.append('__key__') - break + key_value = self.fields["__key__"] + key_subdir = key_value[0] + if needed_special_columns: + if subdirs: + key_subdir = list(subdirs.keys())[0] + subdirs[key_subdir] += needed_special_columns if rng is None: rng = random @@ -284,7 +284,11 @@ def _parse_sql_queries_polars(self, *queries, shard_subsample=1, rng=None, shard for subdir, fields in subdirs.items(): shard_path = self.get_shard_path(subdir, shard) if shard_ok: - df = scan_ipc(shard_path, glob=False).select(fields) + df = scan_ipc( + shard_path, glob=False, + include_file_paths="__shard_path__" if subdir == key_subdir else None, + row_index_name="__shard_offset__" if subdir == key_subdir else None, + ).select(fields) if subdir not in subdir_samples: subdir_samples[subdir] = df.clear().collect() else: