Skip to content

Why might my program hang indefinitely when I use torch_xla.sync? #9750

@oluwatimilehin

Description

@oluwatimilehin

Hello team,

I am trying to profile an inference run for a model as follows:

    device = torch_xla.device()

    torch_model = torch_model.to(device)
    torch_inputs = torch_inputs.to(device)

    print(f"Running inference on {model}")
    print(f"Warming up {model} with {num_warmup_runs} runs")
    for _ in range(num_warmup_runs):
        with torch.no_grad():
            logits = torch_model(torch_inputs).logits
            torch_xla.sync(wait=True)

    latencies_ms = []

    for i in range(args.num_iterations):
        start_time = time.time()
        with torch.no_grad():
            logits = torch_model(torch_inputs).logits
            torch_xla.sync(wait=True)

        end_time = time.time()
        latencies_ms.append((end_time - start_time) * 1000)

However, this program only terminates when I comment out the torch_xla.sync lines. I also noticed that I get similar performance numbers with and without the torch_xla.sync() call. When do I need to sync it, and why might it be causing the program to hang?

Perhaps more importantly, what is the right way to measure how long an inference takes on a TPU?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions