diff --git a/nmn/datasets/experiment.py b/nmn/datasets/experiment.py index d363f7dd97cfd05081227c46594d4e6f11dac406..7b6cde22808314b91459344086dfb6eb8d31186a 100644 --- a/nmn/datasets/experiment.py +++ b/nmn/datasets/experiment.py @@ -5,6 +5,7 @@ class Experiment(str, Enum): """Types of modules.""" CLEVR = "clevr" + COGENT = "cogent" CLEVR_DIALOG = "clevr_dialog" MNIST_DIALOG = "mnist_dialog" diff --git a/nmn/datasets/generate/query_concept.py b/nmn/datasets/generate/query_concept.py index 0fff8e7cba43cdae6f003521662f2a4eec5d3465..4013dbd9b4520af18d0ad9937e91c6e66899548c 100644 --- a/nmn/datasets/generate/query_concept.py +++ b/nmn/datasets/generate/query_concept.py @@ -46,7 +46,7 @@ def get_query_concept_samples_mnist( def get_query_concept_samples( args: DotDict, scene: dict, category: str, concept: str ) -> list[dict]: - if args.experiment == Experiment.CLEVR: + if args.experiment == Experiment.CLEVR or args.experiment == Experiment.COGENT: return get_query_concept_samples_clevr(scene, category, concept) elif args.experiment == Experiment.MNIST_DIALOG: return get_query_concept_samples_mnist(scene, category, concept) diff --git a/nmn/datasets/generate/query_spatial.py b/nmn/datasets/generate/query_spatial.py index 5a161fedada0e0df99493f5b0c00aeab2b4368b7..e351f3ebd1c4ae0e086fbcdaba12d1bb23af68f8 100644 --- a/nmn/datasets/generate/query_spatial.py +++ b/nmn/datasets/generate/query_spatial.py @@ -172,7 +172,7 @@ def get_query_immediate_relate_concept_samples_mnist(args: DotDict) -> None: def generate_query_spatial_dataset(args: DotDict) -> None: - if args.experiment == Experiment.CLEVR: + if args.experiment == Experiment.CLEVR or args.experiment == Experiment.COGENT: get_query_relate_concept_samples_clevr(args) get_query_immediate_relate_concept_samples_clevr(args) get_query_extreme_relate_concept_samples_clevr(args) diff --git a/nmn/datasets/load/experiment_dataset.py b/nmn/datasets/load/experiment_dataset.py index 8356c32a0dfb6e9ea0ecf79bcf01333d3e83ee80..d70ad661c02ddb4b02be01e3ce7aba6f2458769e 100644 --- a/nmn/datasets/load/experiment_dataset.py +++ b/nmn/datasets/load/experiment_dataset.py @@ -16,7 +16,7 @@ def get_dataset( experiment: str, questions_path: str, images_dir: str, device: str, from_idx, to_idx ) -> Dataset: # only CLEVR is executed with the ProgramEvaluator - if experiment == Experiment.CLEVR: + if experiment == Experiment.CLEVR or args.experiment == Experiment.COGENT: dataset = CLEVRDataset(questions_path, images_dir, device, from_idx, to_idx) else: raise UnsupportedExperimentError(experiment) diff --git a/nmn/modules/load_modules/get_modules.py b/nmn/modules/load_modules/get_modules.py index e580cf20d7d7bbdd8ed37b66307fcc6562ac177e..bff768adf1e0cbc870d24d363d23f1a8856382d0 100644 --- a/nmn/modules/load_modules/get_modules.py +++ b/nmn/modules/load_modules/get_modules.py @@ -36,7 +36,7 @@ def get_evaluation_modules( pt["opt"] = DotDict(pt["opt"]) logging.info(f"Found {len(checkpoints)} checkpoint files") # dispatch to a separate function depending on the experiment - if experiment == Experiment.CLEVR: + if experiment == Experiment.CLEVR or args.experiment == Experiment.COGENT: return get_clevr_evaluation_modules(checkpoints, device, threshold) elif experiment == Experiment.MNIST_DIALOG: return get_mnist_evaluation_modules(checkpoints, device) diff --git a/nmn/utils/config.py b/nmn/utils/config.py index 3c5d9953365e1bd2bf42707eeffc37876cee7f64..3669f8ba2e1735e3dd8f3f569774490e6f06f908 100644 --- a/nmn/utils/config.py +++ b/nmn/utils/config.py @@ -15,6 +15,8 @@ def load_experiment_config(opt) -> None: opt.experiment == Experiment.CLEVR or opt.experiment == Experiment.CLEVR_DIALOG ): config_file = "data/clevr/config.json" + elif opt.experimet == Experiment.COGENT: + config_file = "data/cogent/config.json" else: raise UnsupportedExperimentError(opt.experiment) diff --git a/scripts/preprocess_data/generate_module_data.py b/scripts/preprocess_data/generate_module_data.py index 4874d4307baf34d06a6bd9b366780a3121d725e0..edf0319c808670d63234594743d348680f00c42f 100644 --- a/scripts/preprocess_data/generate_module_data.py +++ b/scripts/preprocess_data/generate_module_data.py @@ -28,6 +28,7 @@ def main(args): setup_logger(args) for dataset_name in ["train", "val", "test"]: + # NOTE: train, val, test -> in cogent trainA, valA, ... # read the dataset logging.info(f"Reading dataset {dataset_name}...") args.scenes_basename = dataset_name diff --git a/scripts/utils/loading.py b/scripts/utils/loading.py index a4efa549ec679d94136b80dcb2414a5c16fc40a8..6755c0b8a2b2e92bf93f11e08ed1548f819dcb4d 100644 --- a/scripts/utils/loading.py +++ b/scripts/utils/loading.py @@ -8,6 +8,8 @@ def get_images(experiment: str, split: str) -> str: return f"data/clevr/images/{split}/" elif experiment == Experiment.MNIST_DIALOG: return f"data/mnist_dialog/images/{split}/" + elif experiment == Experiment.COGENT: + return f"data/cogent/images/{split}/" else: raise UnsupportedExperimentError(experiment) @@ -21,6 +23,8 @@ def get_scenes(experiment: str, split: str, module_name: str) -> str: def get_experiment_questions(experiment: str, split: str) -> str: if experiment == Experiment.CLEVR: return f"data/clevr/questions/CLEVR_{split}_questions.json" + elif experiment == Experiment.COGENT: + return f"data/cogent/questions/CLEVR_{split}_questions.json" else: raise UnsupportedExperimentError(experiment)