Skip to content
This repository was archived by the owner on Sep 15, 2021. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions benchmark/synthetic_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,16 @@
)
parser.add_argument(
"--async-sync-interval",
default=50,
default=500,
type=int,
help="Model synchronization interval(ms) for async algorithm",
)
parser.add_argument(
"--async-warmup-steps",
default=0,
type=int,
help="Warmup(allreduce) steps for async algorithm",
)
parser.add_argument(
"--amp",
action="store_true",
Expand Down Expand Up @@ -131,13 +137,16 @@
elif args.algorithm == "qadam":
from bagua.torch_api.algorithms import q_adam

optimizer = q_adam.QAdamOptimizer(model.parameters(), lr=0.01 * bagua.get_world_size(), warmup_steps=100)
optimizer = q_adam.QAdamOptimizer(
model.parameters(), lr=0.01 * bagua.get_world_size(), warmup_steps=100
)
algorithm = q_adam.QAdamAlgorithm(optimizer)
elif args.algorithm == "async":
from bagua.torch_api.algorithms import async_model_average

algorithm = async_model_average.AsyncModelAverageAlgorithm(
sync_interval_ms=args.async_sync_interval
sync_interval_ms=args.async_sync_interval,
warmup_steps=args.async_warmup_steps,
)
else:
raise NotImplementedError
Expand Down Expand Up @@ -186,6 +195,7 @@ def benchmark_step():

# Warm-up
logging.info("Running warmup...")

timeit.timeit(benchmark_step, number=args.num_warmup_batches)

# Benchmark
Expand Down
37 changes: 30 additions & 7 deletions imagenet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,18 @@

parser.add_argument(
"--async-sync-interval",
default=100,
default=500,
type=int,
help="Model synchronization interval(ms) for async algorithm",
)

parser.add_argument(
"--async-warmup-steps",
default=100,
type=int,
help="Warmup(allreduce) steps for async algorithm",
)

best_acc1 = 0


Expand Down Expand Up @@ -232,13 +239,16 @@ def main_worker(args):
elif args.algorithm == "qadam":
from bagua.torch_api.algorithms import q_adam

optimizer = q_adam.QAdamOptimizer(model.parameters(), lr=args.lr, warmup_steps=100)
optimizer = q_adam.QAdamOptimizer(
model.parameters(), lr=args.lr, warmup_steps=100
)
algorithm = q_adam.QAdamAlgorithm(optimizer)
elif args.algorithm == "async":
from bagua.torch_api.algorithms import async_model_average

algorithm = async_model_average.AsyncModelAverageAlgorithm(
sync_interval_ms=args.async_sync_interval
sync_interval_ms=args.async_sync_interval,
warmup_steps=args.async_warmup_steps,
)
else:
raise NotImplementedError
Expand Down Expand Up @@ -335,9 +345,15 @@ def main_worker(args):
if args.distributed:
train_sampler.set_epoch(epoch)

if args.algorithm == "async":
algorithm.resume(model)

# train for one epoch
train(train_loader, model, criterion, optimizer, scaler, epoch, args)

if args.algorithm == "async":
algorithm.abort(model)

# evaluate on validation set
acc1 = validate(val_loader, model, criterion, epoch, args)

Expand All @@ -357,9 +373,6 @@ def main_worker(args):
is_best,
)

if args.algorithm == "async":
algorithm.abort(model)


def train(train_loader, model, criterion, optimizer, scaler, epoch, args):
batch_time = AverageMeter("Time", ":6.3f")
Expand Down Expand Up @@ -415,10 +428,17 @@ def train(train_loader, model, criterion, optimizer, scaler, epoch, args):
top5.update(acc5[0], images.size(0))

if args.prof >= 0:
torch.cuda.nvtx.range_push("optimizer.step()")
torch.cuda.nvtx.range_push("backward")

# compute gradient and do SGD step
scaler.scale(loss).backward()

if args.prof >= 0:
torch.cuda.nvtx.range_pop()

if args.prof >= 0:
torch.cuda.nvtx.range_push("optimizer.step()")

scaler.step(optimizer)
scaler.update()

Expand All @@ -439,6 +459,9 @@ def train(train_loader, model, criterion, optimizer, scaler, epoch, args):
if args.prof >= 0 and i == args.prof + 10:
print("Profiling ended at iteration {}".format(i))
torch.cuda.cudart().cudaProfilerStop()

if args.algorithm == "async":
model.bagua_algorithm.abort(model)
quit()


Expand Down
15 changes: 10 additions & 5 deletions mnist/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,15 @@ def main():
elif args.algorithm == "qadam":
from bagua.torch_api.algorithms import q_adam

optimizer = q_adam.QAdamOptimizer(model.parameters(), lr=args.lr, warmup_steps=100)
optimizer = q_adam.QAdamOptimizer(
model.parameters(), lr=args.lr, warmup_steps=100
)
algorithm = q_adam.QAdamAlgorithm(optimizer)
elif args.algorithm == "async":
from bagua.torch_api.algorithms import async_model_average

algorithm = async_model_average.AsyncModelAverageAlgorithm(
sync_interval_ms=args.async_sync_interval
sync_interval_ms=args.async_sync_interval,
)
else:
raise NotImplementedError
Expand All @@ -250,14 +252,17 @@ def main():

scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)
for epoch in range(1, args.epochs + 1):
if args.algorithm == "async":
algorithm.resume(model)

train(args, model, train_loader, optimizer, epoch)

if args.algorithm == "async":
algorithm.abort(model)

test(model, test_loader)
scheduler.step()

if args.algorithm == "async":
algorithm.abort(model)

if args.save_model:
torch.save(model.state_dict(), "mnist_cnn.pt")

Expand Down
10 changes: 8 additions & 2 deletions squad/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,8 @@ def train(args, train_dataset, model, tokenizer):
from bagua.torch_api.algorithms import async_model_average

algorithm = async_model_average.AsyncModelAverageAlgorithm(
sync_interval_ms=args.async_sync_interval
sync_interval_ms=args.async_sync_interval,
warmup_steps=args.async_warmup_steps,
)
else:
raise NotImplementedError
Expand Down Expand Up @@ -399,7 +400,6 @@ def train(args, train_dataset, model, tokenizer):

if args.algorithm == "async":
algorithm.abort(model)
torch.cuda.synchronize()

return global_step, tr_loss / global_step

Expand Down Expand Up @@ -919,6 +919,12 @@ def main():
type=int,
help="Model synchronization interval(ms) for async algorithm",
)
parser.add_argument(
"--async-warmup-steps",
default=100,
type=int,
help="Warmup(allreduce) steps for async algorithm",
)
args = parser.parse_args()

if args.doc_stride >= args.max_seq_length - args.max_query_length:
Expand Down