Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 49 additions & 12 deletions graph_net/sample_pass/fusible_subgraph_ranges_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
from itertools import groupby
from dataclasses import dataclass
from collections import defaultdict


class FusibleSubgraphRangesGenerator(SamplePass, ResumableSamplePassMixin):
Expand Down Expand Up @@ -91,7 +92,6 @@ def analyze(self):
naive_proposal_fused_num_ops_lists = [
sorted(set(num_ops_list))
for _, num_ops_list in num_kernels_and_num_ops_list
if len(set(num_ops_list)) > 1
]
proposal_fused_num_ops_lists = self._merge_all_decreasing_num_ops_lists(
analysis_ctx, naive_proposal_fused_num_ops_lists
Expand All @@ -114,6 +114,7 @@ def _merge_all_decreasing_num_ops_lists(self, analysis_ctx, num_ops_lists):
break
dead_loop_detect_cnt += 1
assert dead_loop_detect_cnt < kLimit, f"{dead_loop_detect_cnt=}"
num_ops_lists = [op_list for op_list in num_ops_lists if len(op_list) > 1]
return num_ops_lists

def _merge_one_decreasing_num_ops_lists(self, analysis_ctx, num_ops_lists):
Expand All @@ -138,7 +139,6 @@ def get_next_head_num_kernels(i):
return analysis_ctx.num_kernels4num_ops(num_ops_lists[i + 1][0])

for i in range(len(num_ops_lists) - 1):
assert len(num_ops_lists[i]) > 1
if get_cur_tail_num_kernels(i) >= get_next_head_num_kernels(i):
return i
return None
Expand All @@ -152,14 +152,14 @@ def is_a_range(int_list):
assert len(int_list) > 1
return (int_list[-1] + 1) - int_list[0] == len(int_list)

def have_any_increasing(num_ops_list: list[int]):
for i, cur_num_ops in enumerate(num_ops_list):
if i == 0:
continue
cur_num_kernels = analysis_ctx.num_kernels4num_ops(cur_num_ops)
last_num_kernels = analysis_ctx.num_kernels4num_ops(num_ops_list[i - 1])
if cur_num_kernels > last_num_kernels:
def have_tail_increasing(num_ops_list: list[int]):
for i in range(len(num_ops_list) - 1, 0, -1):
cur_num_kernels = analysis_ctx.num_kernels4num_ops(num_ops_list[i])
prev_num_kernels = analysis_ctx.num_kernels4num_ops(num_ops_list[i - 1])
if cur_num_kernels > prev_num_kernels:
return True
elif cur_num_kernels < prev_num_kernels:
return False
return False

def head_eq_tail(num_ops_list: list[int]):
Expand All @@ -174,9 +174,9 @@ def head_gt_tail(num_ops_list: list[int]):

def valid_fused_ops(num_ops_list: list[int]):
if head_gt_tail(num_ops_list):
return True
return not have_tail_increasing(num_ops_list)
if head_eq_tail(num_ops_list):
return not have_any_increasing(num_ops_list)
return not have_tail_increasing(num_ops_list)
return False

proposal_fused_num_ops_lists = [
Expand Down Expand Up @@ -243,7 +243,44 @@ def get_num_ops(pair):
(num_kernels, [num_ops for _, num_ops in group])
for num_kernels, group in grouped_num_kernels_and_num_ops
]
return num_kernels_and_num_ops_list

num_kernels_to_indexes = defaultdict(list)

for i, (num_kernels, _) in enumerate(num_kernels_and_num_ops_list):
num_kernels_to_indexes[num_kernels].append(i)

num_kernels_and_num_ops_closure_list = [
(num_kernels, num_ops)
for num_kernels, indexes in num_kernels_to_indexes.items()
for i in range(min(indexes), max(indexes) + 1)
for num_ops in num_kernels_and_num_ops_list[i][1]
]
num_kernels_and_num_ops_closure_list = sorted(
num_kernels_and_num_ops_closure_list, key=lambda pair: pair[1]
)
num_ops_and_grouped_num_kernels_list = groupby(
num_kernels_and_num_ops_closure_list, key=lambda pair: pair[1]
)

min_num_kernels_and_num_ops_list = [
(
min(num_kernels for num_kernels, _, in num_kernels_and_num_ops_list),
num_ops,
)
for num_ops, num_kernels_and_num_ops_list in num_ops_and_grouped_num_kernels_list
]

min_num_kernels_and_num_ops = sorted(
min_num_kernels_and_num_ops_list, key=lambda pair: pair[1]
)
grouped_min_num_kernels_and_num_ops = groupby(
min_num_kernels_and_num_ops, key=lambda pair: pair[0]
)
min_num_kernels_and_num_ops_list = [
(num_kernels, [num_ops for _, num_ops in group])
for num_kernels, group in grouped_min_num_kernels_and_num_ops
]
return min_num_kernels_and_num_ops_list


@dataclass
Expand Down
Loading
Loading