diff --git a/nmn/evaluate/program_executor.py b/nmn/evaluate/program_executor.py index 19ed9b8b0024e7923d402fb6727e7361bc28473e..6621803fb4da491e86cfe57a9c9136efef098391 100644 --- a/nmn/evaluate/program_executor.py +++ b/nmn/evaluate/program_executor.py @@ -9,11 +9,7 @@ import torch.nn as nn from nmn.modules.executors.filter import FilterExecutor from nmn.modules.executors.query import QueryExecutor -from nmn.modules.executors.relate import ( - ExtremeRelateExecutor, - ImmediateRelateExecutor, - RelateExecutor, -) +from nmn.modules.executors.relate import ExtremeRelateExecutor, ImmediateRelateExecutor, RelateExecutor from nmn.modules.executors.same import Same from nmn.modules.executors.segment import SegmentExecutor from nmn.modules.symbolic.compare import Equal, GreaterThan, LessThan @@ -70,8 +66,6 @@ class ProgramExecutor(nn.Module): super(ProgramExecutor, self).__init__() # initialise modules and set to eval mode self.modules: dict[str, nn.Module] = modules - for module in self.modules.values(): - module.eval() # configs self.config: dict = config # stack where results are saved diff --git a/scripts/evaluate/test_experiment.py b/scripts/evaluate/test_experiment.py index 700255f94d9963bcf4a9e58f4f29cf9f825420be..04da1c4dbd305abf47ab31c536530501fe4d85b5 100644 --- a/scripts/evaluate/test_experiment.py +++ b/scripts/evaluate/test_experiment.py @@ -45,6 +45,10 @@ def main() -> None: device=opt.device, threshold=opt.detectron_threshold, ) + # set modules in eval mode and send to device + for module in modules.values(): + module.eval() + module.to(opt.device) logging.info("Modules loaded") # Create the ProgramExecutor