diff --git a/src/flat_retrieve.py b/src/flat_retrieve.py index bc75e7d..e0d7233 100644 --- a/src/flat_retrieve.py +++ b/src/flat_retrieve.py @@ -33,8 +33,9 @@ bank_emb.div_(bank_emb.norm(2, 1, keepdim=True).expand_as(bank_emb)) # score and rank +K = args.K scores = bank_emb.mm(query_emb.t()) # B x Q -_, indices = torch.topk(scores, params.k, dim=0) # K x Q +_, indices = torch.topk(scores, K, dim=0) # K x Q # fetch and print retrieved text txt_mmap, ref_mmap = IndexTextOpen(args.bank)