Skip to content

Bug: Incorrect state initialization in _nupi_sgd_init leads to wrong second step in nuPI #107

@juan43ramirez

Description

@juan43ramirez

Bug

The nuPI optimizer has a bug in its dense initialization function, _nupi_sgd_init, when init_type=nuPIInitType.SGD.

At the end of the first step (t=0), it performs an incorrect state-saving operation, which causes the second optimizer update to be wrong.

Given the state update rule
$$\xi_t = \nu, \xi_{t-1} + (1 - \nu), e_t,$$
the first saved state should be
$$\xi_0 = \nu, \xi_{-1} + (1 - \nu), e_0 = \nu e_0 + (1 - \nu)e_0 = e_0.$$

Thus, after the first step (t=0), _nupi_sgd_init should store
state["xi"] = e_0 (with e_0 given by detached_error).

Instead, it incorrectly saves
state["xi"] = torch.zeros_like(param),
which forces ($\xi_0 = 0$).

On the second step (t=1), the optimizer then computes the proportional term
$$K_P(1-\nu)(e_1 - \xi_0)$$,
using this incorrect ($\xi_0$), producing a disproportionately large and erroneous update.

Steps

  1. Initialize the nuPI optimizer with init_type=nuPIInitType.SGD (the default) on a dense parameter.
  2. Run optimizer.step() (t=0). The step update is correct, but the state xi_0 is incorrectly saved as 0.
  3. Run optimizer.step() (t=1).
  4. Observe that the update computed in step 3 (t=1) is wrong, as the proportional term is based on a bad xi_0 state.

Expected behavior

The _nupi_sgd_init function must save state["xi"] = detached_error.clone() at the end of the first step (t=0) to correctly set $\xi_0 = e_0$.

Context

The bug is in this block of _nupi_sgd_init:

30b158e/.../nupi_optimizer.py#L360-L362

Proposed Fix

    if uses_kp_term:
        if "xi" not in state:
            # This is step t=0. Initialize xi_0 = v_0 as per SGD init math.
            state["xi"] = detached_error.clone()
        else:
            # This is step t > 0. Update xi_t = nu*xi_{t-1} + (1-nu)*v_t
            state["xi"].mul_(ema_nu).add_(detached_error, alpha=1 - ema_nu)

Note on _sparse_nupi_sgd_init

The sparse implementation (_sparse_nupi_sgd_init) should also be checked to ensure it does not suffer from a similar state initialization error.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions