import whest as we
with we.BudgetContext(flop_budget=1e20) as bc:
n = 10
with we.namespace('init'):
A = we.random.randn(n, n)
A = we.symmetrize(A, group=we.PermutationGroup.symmetric(2,axes=(0,1)))
with we.namespace('mult1'):
B = A * A
with we.namespace('mult2'):
B2 = we.multiply(A, A)
get_flops = lambda k: bc.summary_dict(by_namespace=True)['by_namespace'].get(k, {}).get('flops_used', 0)
print(get_flops('mult1')) # INCORRECT: 0, should be 55
print(get_flops('mult2')) # Correct: 55
Dunders (
__mul__,__add__, etc) onSymmetricTensoruse numpy's native operators, and thus do not have flops tracked at all.Example: