Skip to content
Snippets Groups Projects
Commit d5102815 authored by Lara Verheyen's avatar Lara Verheyen
Browse files

fix: score returned as prob instead of tensor

parent ad639fee
No related branches found
No related tags found
No related merge requests found
......@@ -30,6 +30,7 @@ class Query(Primitive):
category=self.slots["attribute"],
concepts=concepts,
)
# Specific detail relating to MNIST_Dialog
if self.args.experiment == Experiment.MNIST_DIALOG:
if self.slots["attribute"] == "number":
......@@ -47,7 +48,7 @@ class Query(Primitive):
json.dumps(
{
"bindings": [
[{"variable": "target", "score": prob, "value": concept}]
[{"variable": "target", "score": prob.item(), "value": concept}]
]
}
),
......@@ -69,7 +70,7 @@ class Query(Primitive):
label = self.args.config.categories[category][index]
if label == self.slots["target"]:
bindings.append(
[{"variable": "attribute", "score": prob, "value": category}]
[{"variable": "attribute", "score": prob.item(), "value": category}]
)
return json.dumps({"bindings": bindings}), 200
......@@ -88,7 +89,7 @@ class Query(Primitive):
bindings.append(
[
{"variable": "attribute", "score": 1.0, "value": category},
{"variable": "target", "score": prob, "value": label},
{"variable": "target", "score": prob.item(), "value": label},
]
)
......@@ -134,4 +135,3 @@ class Query(Primitive):
return self.check_concept()
else:
self.raise_missing_case_error()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment