Skip to content
Snippets Groups Projects
Commit 6c7e1b95 authored by Jérôme Botoko Ekila's avatar Jérôme Botoko Ekila
Browse files

style: formatting and removeal of whitespaces

parent 3247cae4
No related branches found
Tags nmn-v3
No related merge requests found
......@@ -177,9 +177,13 @@ class ProgramExecutor(nn.Module):
module: FilterExecutor = self.modules["filter"] # type: ignore
category = self.get_category(instruction)
concept = self.get_concept(instruction)
concepts= self.config["categories"][category]
concepts = self.config["categories"][category]
return module(
image=image, source=inputs[0], category=category, concept=concept, concepts=concepts
image=image,
source=inputs[0],
category=category,
concept=concept,
concepts=concepts,
)
# relate
elif function == "relate":
......
......@@ -11,8 +11,13 @@ class FilterExecutor(nn.Module):
self.query_modules: dict[str, dict[str, Classifier]] = query_modules
def forward(
self, image: torch.Tensor, source: AttentionSet, category: str, concept: str, concepts: list[str],
) -> AttentionSet:
self,
image: torch.Tensor,
source: AttentionSet,
category: str,
concept: str,
concepts: list[str],
) -> AttentionSet:
"""Forward the image and source set to the query module specified by the given category.
......@@ -31,11 +36,11 @@ class FilterExecutor(nn.Module):
target: AttentionSet = AttentionSet()
filter_concept = concept
# For every attention in the source set,
# For every attention in the source set,
# find the concept with the highest yes probability for the category to which the input concept belongs.
for attn_id, attn in source.get_attentions():
prev_certainty = source.get_certainty(attn_id)
probs = torch.full((len(concepts),), -1.0)
probs = torch.full((len(concepts),), -1.0)
for idx, concept in enumerate(concepts, 0):
# retrieve query module (trained on the given concept)
......@@ -50,12 +55,12 @@ class FilterExecutor(nn.Module):
highest_concept: str = concepts[index]
# only include attention if the concept with the highest yes probability is equal to the concept that needs to be filtered.
if filter_concept == highest_concept:
if filter_concept == highest_concept:
target.add_attention(
attention_id=attn_id,
attention=attn,
certainty=prob.item(),
previous_certainty=prev_certainty,
)
attention_id=attn_id,
attention=attn,
certainty=prob.item(),
previous_certainty=prev_certainty,
)
return target
......@@ -32,12 +32,12 @@ class RelateExecutor(nn.Module):
) -> AttentionSet:
relations = None
if relationship == "left" or relationship == "right":
if relationship == "left" or relationship == "right":
relations = ["left", "right"]
elif relationship == "behind" or relationship == "front":
elif relationship == "behind" or relationship == "front":
relations = ["behind", "front"]
elif relationship == "above" or relationship == "below":
relations = ["above", "below"]
elif relationship == "above" or relationship == "below":
relations = ["above", "below"]
# source dictionary must contain only a single object!
if not source.contains_one_object():
......@@ -49,9 +49,9 @@ class RelateExecutor(nn.Module):
for attn_id, candidate_neighbour in segmented_scene.get_attentions():
if attn_id != source_object_id:
probs = torch.full((len(relations),), -1.0)
for idx, relation in enumerate(relations,0):
module: Classifier = self.relate_modules[relation]
probs = torch.full((len(relations),), -1.0)
for idx, relation in enumerate(relations, 0):
module: Classifier = self.relate_modules[relation]
# source_object has the attribute
log_probs = module(
......@@ -102,12 +102,12 @@ class ImmediateRelateExecutor(nn.Module):
) -> AttentionSet:
relations = None
if relationship == "left" or relationship == "right":
if relationship == "left" or relationship == "right":
relations = ["left", "right"]
elif relationship == "behind" or relationship == "front":
elif relationship == "behind" or relationship == "front":
relations = ["behind", "front"]
elif relationship == "above" or relationship == "below":
relations = ["above", "below"]
elif relationship == "above" or relationship == "below":
relations = ["above", "below"]
if not source.contains_one_object():
return AttentionSet()
......@@ -122,9 +122,9 @@ class ImmediateRelateExecutor(nn.Module):
for attn_id, candidate_neighbour in segmented_scene.get_attentions():
if attn_id != source_object_id:
probs = torch.full((len(relations),), -1.0)
for idx, relation in enumerate(relations,0):
module: Classifier = self.immediate_relate_modules[relation]
probs = torch.full((len(relations),), -1.0)
for idx, relation in enumerate(relations, 0):
module: Classifier = self.immediate_relate_modules[relation]
# source_object has the attribute
log_probs = module(
......@@ -146,7 +146,7 @@ class ImmediateRelateExecutor(nn.Module):
):
best_candidate_certainty = prob
best_candidate_id = attn_id
# if no attention can be found, return an empty set
if best_candidate_id:
attention: torch.Tensor = segmented_scene.get_attention(
......@@ -186,34 +186,29 @@ class ExtremeRelateExecutor(nn.Module):
# BE AWARE: this piece of code assumes single element batches
relations = None
if relationship == "left" or relationship == "right":
if relationship == "left" or relationship == "right":
relations = ["left", "right"]
elif relationship == "behind" or relationship == "front":
elif relationship == "behind" or relationship == "front":
relations = ["behind", "front"]
elif relationship == "middle":
elif relationship == "middle":
relations = ["middle"]
candidates = {}
# find attn_id that is the most likely 'extreme' neighbour
for candidate_id, candidate_attention in source.get_attentions():
probs = torch.full((len(relations),), -1.0)
for idx, relation in enumerate(relations,0):
module: Classifier = self.extreme_relate_modules[relation]
probs = torch.full((len(relations),), -1.0)
for idx, relation in enumerate(relations, 0):
module: Classifier = self.extreme_relate_modules[relation]
# source_object has the attribute
log_probs = module(
{
"image": image,
"attention": candidate_attention,
}
)
log_probs = module({"image": image, "attention": candidate_attention,})
probs[idx] = log_probs[0][1].item()
prob, index = torch.max(probs, dim=0)
highest_relation: str = relations[index]
if highest_relation == relationship:
if highest_relation == relationship:
candidates[prob] = candidate_id
# select candidate attention with highest confidence
......
......@@ -100,7 +100,7 @@ def load_predictor(
weights: str,
n_classes: int,
threshold: float,
mask_format: str
mask_format: str,
) -> DefaultPredictor:
"""Instantiates a Mask R-CNN instance with a set of pretrained weights.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment