Hi @skj9865
I did a bit of a deeper dive on this than planned so I hope you don’t mind the long message
You are definitely on the right track on a lot of the things you are doing and in most cases simply running into Monty’s current limitations. So I am sorry if this post doesn’t provide a full solution yet.
I took your branch and ran the experiments you had on there. I made a few modifications, which you can see here: Mnist Tests by vkakerbeck · Pull Request #1 · skj9865/tbp.monty · GitHub
I didn’t know where you downloaded the MNIST dataset, so I got it from here: MNIST_png | Kaggle
I had to make a couple adjustments to the dataloader since I think the dataset you use is a different format (e.g. I go the pixel values in range 0,1 where I think your code expected 0, 255). I also had to install opencv in addition to the standard monty setup.
I also added some dummy visualization code that I pushed to that branch. It is clunkily integrated, but I thought it might be useful to you, so I left it there. You can just set enable_plotting=True in the EvidenceLM arguments to see a live plot of Monty’s observations and current hypotheses. This helped me debug.
From running some experiments and visualizing what is going on, here are some insights:
Sensor Processing
- It seems like some of the sensor processing is not ideal yet. For one, the principal curvature estimates seem to be noisy (although only some outliers) which can lead to Monty initializing bad pose hypotheses or getting large pose errors when matching.
You could try experimenting with a different patch size or smoothing the depth image (something I did when working with the Omniglot dataset).
If you have unreliable pose estimates, it can also help to just fix which poses you test for each object instead of inferring it from the sensory data. Especially in this dataset, where numbers are not presented in arbitrary orientations. For example, you could set
initial_possible_poses=[
[0, 0, 0],
[0, 10, 0],
[0, -10, 0],
[0, 20, 0],
[0, -20, 0],
[0, 30, 0],
[0, -30, 0],
],
to test tilts of up to 30 degrees of the numbers. Eyeballing it on a few examples, this seems to cover most cases. Below is the model of number 1 (right) and a 30-degree rotated version (left)
You could test in smaller steps or higher rotations if you want to. But after a couple of other fixes, the sensation informed poses (
initial_possible_poses='informed') worked similarly well for me, so I didn’t go down this path much further.
- The principal curvature magnitudes don’t seem to express the amount of curvature in the stroke. Instead, they seem related to how close to the border of the stroke a certain pixel is.
You may need to update your curvature calculation code to look at curvature in the 2D image. I would imagine that some very basic kernels could do a good job, or just try to fit a circle through the white pixels in the patch (like in the example below).
- We are not making use of “off-object” observations. This has nothing to do with your code. It is something on our roadmap (see Use Off-Object Observations) and if you have some idea of how to add this feature we are happy about contributions.
You can see the issue in the image below:
The left plot shows the observations in world coordinates (what the sensor is sensing). The two right plots show the observations projected into the respective model’s reference frame. So, for example, the middle plot shows the model of 7 that Monty has learned (grey dots), and then the observations translated and rotated by the most likely pose hypothesis for this object. The middle plot shows this for the current most likely object (7) and the right plot for the target object (1).
You can see that when transforming the observations by Monty’s pose hypothesis, the observations align well with both the model of the 7 and the 1.
We will not be able to get a better estimate since the sensor has already sensed all the pixels in the test image.
What we would ideally want is that when the sensor moves into an area where the model of 7 would predict the 7 to exist, and we only sense black pixels, we get negative evidence. However, this currently doesn’t happen since black pixels are labeled as ‘off object’ and we are not sending those observations to the LM.
The solution to this is complicated. Internally, we currently tabled this problem since it mostly exists in those artificial scenarios where an object exists in an empty void. In most natural settings, all observations will be on some object (although not all on the same object), and as the background varies, the object remains the same, the grid object model will start to only represent the constant features and not varying backgrounds. We have not tested this mechanism much yet since we also still run most of our experiments with one object in an empty void.
One option you could try, although not ideal, is to go back to also sending the black pixel observations to the LM. Its not great because this means the black pixels will become part of the model for each digit. But it might help with examples like the one above.
You could also make some more major modifications to the code where you send off object observations to the LM but you don’t use them when building the graph. That way your graph would only consist of the digits themselves but you could still recognize prediction errors when the graph predicted a white pixel but you got a black one.
Motor System
Generally, using the spiral policy for your use case doesn’t seem to be an issue. There is just one little bug I found with it: Once you get to the edge of the move_area the sensor keeps moving in circles like this:
(Numbers are pixel values of
new_loc with your default settings)
The issue with that is that the LM keeps getting the same observations at the end and artificially accumulates more evidence based on those. As a quick fix, I simply set
max_total_steps=21 * 21
so it times out once it has seen everything on the image once. (Remember to also set
num_exploratory_steps=0 in
monty_args because otherwise the
max_total_steps value gets overwritten with that)
There is definitely a cleaner solution where the dataloader just calls
raise StopIteration() when the boundary is reached.
The MNIST Dataset
Generally, I noticed some more peculiarities about the MNIST dataset. For example, a specific digit can have very different ways it is written. For instance, take these two versions of a ‘1’:
The left one is the one Monty was trained on, and the right one is the one it is tested on. Those two numbers have very different morphologies, and so Monty has trouble fitting them together. It actually finds that the test 1 fits better onto the model of a 2 than the one-stroke-version of the 1 it has learned:
Just reflecting on this, it seems like I would also have two separate models for those two ways of writing a 1. So, one solution could be to let Monty learn multiple models of each digit (not one model of all of the versions combined). A way to do this could be to run supervised learning, and instead of supplying the label ‘1’, you supply the label ‘1_1’, ‘1_2’,… for each representative version of the digit. Then, when you analyze the performance of the model, you would have to do a small update that looks at the
most_likely_object and
primary_target_object columns in the .csv stats file and just compares the digit before the _ (whether it recognized one of the models of the correct digit).
Then, there are also examples of generally noisy and edge cases in MNIST. Just in the 10 examples of 1’s I tested on, I found those two that usually made Monty misclassify. The left one was usually classified as a 2, and the right one had some strange artifacts in it.
When you look at Monty’s hypotheses for the 1 with the artifacts on the right, you can see that it fits it better onto the model of a 6 with those artifacts (something that off object observations might also help with)
Hyperparameters
A couple of tweaks I did to the hyperparameters:
- Set
min_eval_steps=80 (or 100 or similar) in the monty_args to ensure Monty sees a significant amount of the digit before making a classification.
max_match_distance=0.01 worked well for me, but in the plots you shared, it looked like your locations were in a different range than they are for me. Just something to double-check and adjust if your locations are more in the 6-22 range instead of 0-0.1.
- I set the tolerance on the pose_vectors a bit lower to
"pose_vectors": np.ones(3) * 5, since the noise I saw on the was either more extreme (flipped the wrong way, which can’t be solved with higher tolerances) or very small.
- set
feature_weights for pose_vectors to "pose_vectors": [0, 1, 0], so that we just look at the first PC direction. The first vector is the point normal, which in this dataset always points up and is therefor not informative. The third vector should always be 90 degrees from the second so will also not give us any extra infos.
- Like I mentioned above, you can test
initial_possible_poses as informed or specifying slight tilts around the Y axis.
Continuing Training on Pretrained Models
This is not something we have tried before, so I am sorry that our code doesn’t seem to support this out of the box. The way I got this to work is by specifying the EvidenceGraphLM in the training config. Our usual pretraining experiments use the PatchAndViewMontyConfig which uses the DisplacementGraphLM and therefore learns GraphObjectModels. But there is no issue with running supervised pre-training with the EvidenceGraphLM and learning GridObjectModels (besides making sure to set the voxel size parameters correctly, as you already noticed. Although I didn’t run into this issue, maybe because the change was already in your branch?). You can see the updated config in my branch.
Additionally, I had to remove the model.use_original_graph = True line in _initialize_model_with_graph so that the loaded graphs are again initialized as GridObjectModels and can be extended.
Potential Next Steps
While the little fixes I made in the PR should help a bit, they still don’t give great performance. To summarize, some things I would recommend you try are:
- Have Monty learn several separate models of the same digit to represent different ways a digit could be written. Then, adapt the analysis to classify recognition of any version of the digit as correct
- Look into using the black pixel observations (at least for calculating prediction error). This one is a bit open ended and I would have to think more about it myself to come up with an ideal solution. Feel free to write your thoughts and questions if you go down this road.
- Improve the feature extraction in your SM to have more robust curvature detection (both the direction and magnitude).
I hope this helps!