Skip to content

Dunder operations on symmetric tensors don't track flops. #38

@wiwu2390

Description

@wiwu2390

Dunders (__mul__, __add__, etc) on SymmetricTensor use numpy's native operators, and thus do not have flops tracked at all.

Example:

import whest as we
 
with we.BudgetContext(flop_budget=1e20) as bc:
    n = 10
    with we.namespace('init'):
        A = we.random.randn(n, n)
        A = we.symmetrize(A, group=we.PermutationGroup.symmetric(2,axes=(0,1)))
    with we.namespace('mult1'):
        B = A * A
    with we.namespace('mult2'):
        B2 = we.multiply(A, A)

get_flops = lambda k: bc.summary_dict(by_namespace=True)['by_namespace'].get(k, {}).get('flops_used', 0)
print(get_flops('mult1'))   # INCORRECT: 0, should be 55
print(get_flops('mult2'))   # Correct: 55

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingpriority:p1Ship in the current iteration

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions