Skip to content
Snippets Groups Projects
Commit a76c988d authored by Jérôme Botoko Ekila's avatar Jérôme Botoko Ekila
Browse files

fix: check when detectron threshold is none

parent 86adcfea
No related branches found
No related tags found
No related merge requests found
......@@ -9,7 +9,11 @@ import torch.nn as nn
from nmn.modules.executors.filter import FilterExecutor
from nmn.modules.executors.query import QueryExecutor
from nmn.modules.executors.relate import ExtremeRelateExecutor, ImmediateRelateExecutor, RelateExecutor
from nmn.modules.executors.relate import (
ExtremeRelateExecutor,
ImmediateRelateExecutor,
RelateExecutor,
)
from nmn.modules.executors.same import Same
from nmn.modules.executors.segment import SegmentExecutor
from nmn.modules.symbolic.compare import Equal, GreaterThan, LessThan
......
......@@ -30,28 +30,31 @@ class SegmentExecutor(nn.Module):
def filter_doubles(
self, detected_instances: list[torch.Tensor]
) -> list[torch.Tensor]:
jaccards = {}
for object_i in range(len(detected_instances)):
for object_j in range(object_i + 1, len(detected_instances)):
iou = self.jaccard(
detected_instances[object_i], detected_instances[object_j]
)[1].item()
jaccards[(object_i, object_j)] = iou
largest_obj_pair = max(jaccards, key=jaccards.get)
largest_obj_iou = jaccards[largest_obj_pair]
largest_obj = largest_obj_pair[0]
if self.threshold and largest_obj_iou < self.threshold:
if not self.threshold:
return detected_instances
else:
logging.info(f" ---> filtered: {largest_obj_pair} from {jaccards}")
return [
detected_instances[i]
for i in range(len(detected_instances))
if i != largest_obj
]
jaccards = {}
for object_i in range(len(detected_instances)):
for object_j in range(object_i + 1, len(detected_instances)):
iou = self.jaccard(
detected_instances[object_i], detected_instances[object_j]
)[1].item()
jaccards[(object_i, object_j)] = iou
largest_obj_pair = max(jaccards, key=jaccards.get)
largest_obj_iou = jaccards[largest_obj_pair]
largest_obj = largest_obj_pair[0]
if largest_obj_iou < self.threshold:
# jaccard metric does not exceed threshold, thus do not filter!
return detected_instances
else:
logging.info(f" ---> filtered: {largest_obj_pair} from {jaccards}")
return [
detected_instances[i]
for i in range(len(detected_instances))
if i != largest_obj
]
def forward(self, image) -> AttentionSet:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment