Skip to content
Snippets Groups Projects
Commit 121118a5 authored by Jérôme Botoko Ekila's avatar Jérôme Botoko Ekila
Browse files

fix: evaluate clevr using indexing

parent fa9883e4
No related branches found
No related tags found
No related merge requests found
......@@ -67,9 +67,13 @@ def get_evaluation_dataloader(
batch_size: int,
images_dir: str,
questions_path: str,
from_idx,
to_idx,
max_samples: Optional[int] = None,
) -> DataLoader:
ds: Dataset = get_dataset(experiment, questions_path, images_dir, device)
ds: Dataset = get_dataset(
experiment, questions_path, images_dir, device, from_idx, to_idx
)
if max_samples is not None:
all_indices = list(range(0, len(ds))) # type: ignore
if shuffle:
......
......@@ -13,23 +13,23 @@ from nmn.utils.config import load_json
def get_dataset(
experiment: str, questions_path: str, images_dir: str, device: str
experiment: str, questions_path: str, images_dir: str, device: str, from_idx, to_idx
) -> Dataset:
# only CLEVR is executed with the ProgramEvaluator
if experiment == Experiment.CLEVR:
dataset = CLEVRDataset(questions_path, images_dir, device)
dataset = CLEVRDataset(questions_path, images_dir, device, from_idx, to_idx)
else:
raise UnsupportedExperimentError(experiment)
return dataset
class CLEVRDataset(Dataset):
def __init__(self, questions_path: str, images_dir: str, device: str):
def __init__(self, questions_path: str, images_dir: str, device: str, from_idx: int, to_idx: int):
super(CLEVRDataset, self).__init__()
self.images_dir: str = images_dir
self.device: str = device
data: dict = load_json(questions_path)
self.questions: list = data["questions"]
self.questions: list = data["questions"][from_idx:to_idx]
self.image2tensor: ImageToTensor = ImageToTensor(device)
def __getitem__(self, idx) -> tuple[np.ndarray, torch.Tensor, list, Optional[str]]:
......
......@@ -64,3 +64,15 @@ class EvalOptions(Options):
help="Threshold for detectron2 to remove duplicate detected instances.",
)
self.parser.add_argument(
"--from_idx",
type=int,
required=True,
)
self.parser.add_argument(
"--to_idx",
type=int,
required=True,
)
......@@ -38,6 +38,8 @@ def main() -> None:
batch_size=1,
images_dir=get_images(opt.experiment, opt.split),
questions_path=get_experiment_questions(opt.experiment, image_split),
from_idx=opt.from_idx,
to_idx=opt.to_idx,
max_samples=opt.max_samples,
)
logging.info(f"Data loader has {len(data_loader)} samples")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment