Skip to content
Snippets Groups Projects
Commit ff13ffa5 authored by Lara Verheyen's avatar Lara Verheyen
Browse files

fix: relate executors take highest prediction over all relations

parent 5cf74df4
No related branches found
No related tags found
No related merge requests found
......@@ -31,7 +31,13 @@ class RelateExecutor(nn.Module):
relationship: str,
) -> AttentionSet:
module: Classifier = self.relate_modules[relationship]
relations = None
if relationship == "left" or relationship == "right":
relations = ["left", "right"]
elif relationship == "behind" or relationship == "front":
relations = ["behind", "front"]
elif relationship == "above" or relationship == "below":
relations = ["above", "below"]
# source dictionary must contain only a single object!
if not source.contains_one_object():
......@@ -42,18 +48,25 @@ class RelateExecutor(nn.Module):
target: AttentionSet = AttentionSet()
for attn_id, candidate_neighbour in segmented_scene.get_attentions():
if attn_id != source_object_id:
# source_object has the attribute
probs = module(
{
"image": image,
"attention1": source_object,
"attention2": candidate_neighbour,
}
)
# select binary choice with highest confidence
prob, index = torch.max(probs, dim=1)
# if candidate_neighbour is considered a neighbour
if index.item() == 1:
probs = torch.full((len(relations),), -1.0)
for idx, relation in enumerate(relations,0):
module: Classifier = self.relate_modules[relation]
# source_object has the attribute
log_probs = module(
{
"image": image,
"attention1": source_object,
"attention2": candidate_neighbour,
}
)
probs[idx] = log_probs[0][1].item()
prob, index = torch.max(probs, dim=0)
highest_relation: str = relations[index]
if highest_relation == relationship:
target.add_attention(
attention_id=attn_id,
attention=candidate_neighbour,
......@@ -87,7 +100,14 @@ class ImmediateRelateExecutor(nn.Module):
source: AttentionSet,
relationship: str,
) -> AttentionSet:
module: Classifier = self.immediate_relate_modules[relationship]
relations = None
if relationship == "left" or relationship == "right":
relations = ["left", "right"]
elif relationship == "behind" or relationship == "front":
relations = ["behind", "front"]
elif relationship == "above" or relationship == "below":
relations = ["above", "below"]
if not source.contains_one_object():
return AttentionSet()
......@@ -101,40 +121,32 @@ class ImmediateRelateExecutor(nn.Module):
target: AttentionSet = AttentionSet()
for attn_id, candidate_neighbour in segmented_scene.get_attentions():
if attn_id != source_object_id:
probs = module(
{
"image": image,
"attention1": source_object,
"attention2": candidate_neighbour,
}
)
no_prob = probs.detach().numpy()[0][0]
yes_prob = probs.detach().numpy()[0][1]
probs = torch.full((len(relations),), -1.0)
for idx, relation in enumerate(relations,0):
module: Classifier = self.immediate_relate_modules[relation]
# source_object has the attribute
log_probs = module(
{
"image": image,
"attention1": source_object,
"attention2": candidate_neighbour,
}
)
probs[idx] = log_probs[0][1].item()
# select binary choice with highest confidence
prob, index = torch.max(probs, dim=1)
prob, index = torch.max(probs, dim=0)
highest_relation: str = relations[index]
# set mask2 as best neighbour if it is a better neighbour
# or when none has been selected so far
if index.item() == 1 and (
if highest_relation == relationship and (
best_candidate_certainty == None
or prob > best_candidate_certainty
):
best_candidate_certainty = prob
best_candidate_id = attn_id
elif (
(relationship == "behind" or relationship == "front")
and np.exp(no_prob) < 0.9
and (
best_candidate_certainty == None
or prob > best_candidate_certainty
)
):
best_candidate_certainty = prob
best_candidate_id = attn_id
# if no attention can be found, return an empty set
if best_candidate_id:
attention: torch.Tensor = segmented_scene.get_attention(
......@@ -172,18 +184,35 @@ class ExtremeRelateExecutor(nn.Module):
self, image: torch.Tensor, source: AttentionSet, relationship: str
) -> AttentionSet:
# BE AWARE: this piece of code assumes single element batches
module: Classifier = self.extreme_relate_modules[relationship]
relations = None
if relationship == "left" or relationship == "right":
relations = ["left", "right"]
elif relationship == "behind" or relationship == "front":
relations = ["behind", "front"]
candidates = {}
# find attn_id that is the most likely 'extreme' neighbour
for candidate_id, candidate_attention in source.get_attentions():
# source_object has the attribute
probs = module({"image": image, "attention": candidate_attention})
# select binary choice with highest confidence
value, index = torch.max(probs, dim=1)
# if attn_id is considered the most extreme neighbour
if index.item() == 1:
candidates[value] = candidate_id
probs = torch.full((len(relations),), -1.0)
for idx, relation in enumerate(relations,0):
module: Classifier = self.extreme_relate_modules[relation]
# source_object has the attribute
log_probs = module(
{
"image": image,
"attention": candidate_attention,
}
)
probs[idx] = log_probs[0][1].item()
prob, index = torch.max(probs, dim=0)
highest_relation: str = relations[index]
if highest_relation == relationship:
candidates[prob] = candidate_id
# select candidate attention with highest confidence
if candidates:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment