This repository was archived by the owner on Feb 22, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbase_model.py
More file actions
42 lines (33 loc) · 2.13 KB
/
base_model.py
File metadata and controls
42 lines (33 loc) · 2.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from typing import Any, Dict, List
class ShopBenchBaseModel:
def __init__(self):
pass
def get_batch_size(self) -> int:
"""
Determines the batch size that is used by the evaluator when calling the `batch_predict` function.
Returns:
int: The batch size, an integer between 1 and 16. This value indicates how many
queries should be processed together in a single batch. It can be dynamic
across different batch_predict calls, or stay a static value.
"""
raise NotImplementedError("get_batch_size method not implemented")
def batch_predict(self, batch: Dict[str, Any], is_multiple_choice:bool) -> List[str]:
"""
Generates a batch of prediction based on associated prompts and task_type
For multiple choice tasks, it randomly selects a choice.
For other tasks, it returns a list of integers as a string,
representing the model's prediction in a format compatible with task-specific parsers.
Parameters:
- batch (Dict[str, Any]): A dictionary containing a batch of input prompts with the following keys
- prompt (List[str]): a list of input prompts for the model.
- is_multiple_choice bool: A boolean flag indicating if all the items in this batch belong to multiple choice tasks.
Returns:
str: A list of predictions for each of the prompts received in the batch.
Each prediction is
a string representing a single integer[0, 3] for multiple choice tasks,
or a string representing a comma separated list of integers for Ranking, Retrieval tasks,
or a string representing a comma separated list of named entities for Named Entity Recognition tasks.
or a string representing the (unconstrained) generated response for the generation tasks
Please refer to parsers.py for more details on how these responses will be parsed by the evaluator.
"""
raise NotImplementedError("predict method not implemented")