From 01bf349fe6bbb4e937426e9ab378b022f5e93d02 Mon Sep 17 00:00:00 2001 From: Amr Alaa Date: Thu, 17 Dec 2020 11:49:45 -0800 Subject: [PATCH] Fix missing params --- src/flat_retrieve.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/flat_retrieve.py b/src/flat_retrieve.py index bc75e7d..e0d7233 100644 --- a/src/flat_retrieve.py +++ b/src/flat_retrieve.py @@ -33,8 +33,9 @@ bank_emb.div_(bank_emb.norm(2, 1, keepdim=True).expand_as(bank_emb)) # score and rank +K = args.K scores = bank_emb.mm(query_emb.t()) # B x Q -_, indices = torch.topk(scores, params.k, dim=0) # K x Q +_, indices = torch.topk(scores, K, dim=0) # K x Q # fetch and print retrieved text txt_mmap, ref_mmap = IndexTextOpen(args.bank)