FiveLMStackedMontyConfig Error

Hi everyone,

I am having trouble getting a custom FiveLMStackedMontyConfig to run in a pretraining experiment.

I get the following error, the code I have modified so far is pasted below.

(tbp.monty) [sruiz10@node1928 benchmarks]$ python run.py -e FiveLM



Printing config below
----------------------------------------------------------------------------------------------------
{'dataset_args': {'env_init_args': {'agents': [{'agent_args': {'action_space_type': 'surface_agent',
                                                               'agent_id': 'agent_id_0',
                                                               'height': 0.0,
                                                               'position': [0.0,
                                                                            1.5,
                                                                            0.1],
                                                               'positions': [[0.0,
                                                                              0.0,
                                                                              0.0],
                                                                             [0.0,
                                                                              0.0,
                                                                              0.03]],
                                                               'resolutions': [[64,
                                                                                64],
                                                                               [64,
                                                                                64]],
                                                               'rotations': [[1.0,
                                                                              0.0,
                                                                              0.0,
                                                                              0.0],
                                                                             [1.0,
                                                                              0.0,
                                                                              0.0,
                                                                              0.0]],
                                                               'semantics': [False,
                                                                             False],
                                                               'sensor_ids': ['patch',
                                                                              'view_finder'],
                                                               'zooms': [10.0,
                                                                         1.0]},
                                                'agent_type': <class 'tbp.monty.simulators.habitat.agents.MultiSensorAgent'>}],
                                    'data_path': '/users/sruiz10/data/sruiz10/tbp/data/habitat/objects/ycb',
                                    'objects': [{'enable_physics': False,
                                                 'name': 'coneSolid',
                                                 'object_to_avoid': False,
                                                 'position': (0.0, 1.5, -0.1),
                                                 'primary_target_bb': None,
                                                 'rotation': (1.0,
                                                              0.0,
                                                              0.0,
                                                              0.0),
                                                 'scale': (1.0, 1.0, 1.0),
                                                 'semantic_id': None}],
                                    'scene_id': None,
                                    'seed': 42},
                  'env_init_func': <class 'tbp.monty.simulators.habitat.environment.HabitatEnvironment'>,
                  'rng': None,
                  'transform': [<tbp.monty.frameworks.environment_utils.transforms.MissingToMaxDepth object at 0x7fb38b3bccd0>,
                                <tbp.monty.frameworks.environment_utils.transforms.DepthTo3DLocations object at 0x7fb38b3c73a0>]},
 'dataset_class': <class 'tbp.monty.frameworks.environments.embodied_data.EnvironmentDataset'>,
 'eval_dataloader_args': {'object_init_sampler': PredefinedObjectInitializer with params: 
	 positions: [[0.0, 1.5, 0.0]]
	 rotations: [array([0, 0, 0]), array([ 0, 90,  0]), array([  0, 180,   0]), array([  0, 270,   0]), array([90,  0,  0]), array([ 90, 180,   0]), array([35, 45,  0]), array([325,  45,   0]), array([ 35, 315,   0]), array([325, 315,   0]), array([ 35, 135,   0]), array([325, 135,   0]), array([ 35, 225,   0]), array([325, 225,   0])]
	 change every episode: None,
                          'object_names': ['mug', 'banana']},
 'eval_dataloader_class': <class 'tbp.monty.frameworks.environments.embodied_data.InformedEnvironmentDataLoader'>,
 'experiment_args': {'do_eval': False,
                     'do_train': True,
                     'max_eval_steps': 500,
                     'max_total_steps': 6000,
                     'max_train_steps': 1000,
                     'min_lms_match': 1,
                     'model_name_or_path': '',
                     'n_eval_epochs': 3,
                     'n_train_epochs': 14,
                     'seed': 42,
                     'show_sensor_output': False},
 'experiment_class': <class 'tbp.monty.frameworks.experiments.pretraining_experiments.MontySupervisedObjectPretrainingExperiment'>,
 'logging_config': {'log_parallel_wandb': False,
                    'monty_handlers': [],
                    'monty_log_level': 'SILENT',
                    'output_dir': '/users/sruiz10/data/sruiz10/tbp/results/monty/projects/FiveLMStackedMonty',
                    'python_log_level': 'WARNING',
                    'python_log_to_file': True,
                    'python_log_to_stdout': True,
                    'resume_wandb_run': False,
                    'run_name': 'FiveLMStackedMonty',
                    'wandb_group': 'debugging',
                    'wandb_handlers': [],
                    'wandb_id': 'a92ke6lr'},
 'monty_config': {'learning_module_configs': {'learning_module_0': {'learning_module_args': {'k': 5,
                                                                                             'match_attribute': 'displacement'},
                                                                    'learning_module_class': <class 'tbp.monty.frameworks.models.displacement_matching.DisplacementGraphLM'>},
                                              'learning_module_1': {'learning_module_args': {'k': 5,
                                                                                             'match_attribute': 'displacement'},
                                                                    'learning_module_class': <class 'tbp.monty.frameworks.models.displacement_matching.DisplacementGraphLM'>},
                                              'learning_module_2': {'learning_module_args': {'k': 5,
                                                                                             'match_attribute': 'displacement'},
                                                                    'learning_module_class': <class 'tbp.monty.frameworks.models.displacement_matching.DisplacementGraphLM'>},
                                              'learning_module_3': {'learning_module_args': {'k': 5,
                                                                                             'match_attribute': 'displacement'},
                                                                    'learning_module_class': <class 'tbp.monty.frameworks.models.displacement_matching.DisplacementGraphLM'>},
                                              'learning_module_4': {'learning_module_args': {'k': 5,
                                                                                             'match_attribute': 'displacement'},
                                                                    'learning_module_class': <class 'tbp.monty.frameworks.models.displacement_matching.DisplacementGraphLM'>}},
                  'lm_to_lm_matrix': [[], [0], [1], [2], [3]],
                  'lm_to_lm_vote_matrix': None,
                  'monty_args': {'max_total_steps': 2500,
                                 'min_eval_steps': 3,
                                 'min_train_steps': 3,
                                 'num_exploratory_steps': 500},
                  'monty_class': <class 'tbp.monty.frameworks.models.graph_matching.MontyForGraphMatching'>,
                  'motor_system_config': {'motor_system_args': {'policy_args': {'action_sampler_args': {'actions': [<class 'tbp.monty.frameworks.actions.actions.MoveForward'>,
                                                                                                                    <class 'tbp.monty.frameworks.actions.actions.MoveTangentially'>,
                                                                                                                    <class 'tbp.monty.frameworks.actions.actions.OrientHorizontal'>,
                                                                                                                    <class 'tbp.monty.frameworks.actions.actions.OrientVertical'>,
                                                                                                                    <class 'tbp.monty.frameworks.actions.actions.SetAgentPose'>,
                                                                                                                    <class 'tbp.monty.frameworks.actions.actions.SetSensorRotation'>]},
                                                                                'action_sampler_class': <class 'tbp.monty.frameworks.actions.action_samplers.ConstantSampler'>,
                                                                                'agent_id': 'agent_id_0',
                                                                                'alpha': 0.1,
                                                                                'desired_object_distance': 0.025,
                                                                                'file_name': None,
                                                                                'good_view_percentage': 0.5,
                                                                                'max_pc_bias_steps': 32,
                                                                                'min_general_steps': 8,
                                                                                'min_heading_steps': 12,
                                                                                'min_perc_on_obj': 0.25,
                                                                                'pc_alpha': 0.5,
                                                                                'switch_frequency': 1.0,
                                                                                'use_goal_state_driven_actions': False},
                                                                'policy_class': <class 'tbp.monty.frameworks.models.motor_policies.SurfacePolicyCurvatureInformed'>},
                                          'motor_system_class': <class 'tbp.monty.frameworks.models.motor_system.MotorSystem'>},
                  'sensor_module_configs': {'sensor_module_0': {'sensor_module_args': {'features': ['pose_vectors',
                                                                                                    'pose_fully_defined',
                                                                                                    'on_object',
                                                                                                    'object_coverage',
                                                                                                    'hsv',
                                                                                                    'principal_curvatures',
                                                                                                    'principal_curvatures_log',
                                                                                                    'gaussian_curvature',
                                                                                                    'mean_curvature',
                                                                                                    'gaussian_curvature_sc',
                                                                                                    'mean_curvature_sc'],
                                                                                       'save_raw_obs': True,
                                                                                       'sensor_module_id': 'patch_0'},
                                                                'sensor_module_class': <class 'tbp.monty.frameworks.models.sensor_modules.HabitatDistantPatchSM'>},
                                            'sensor_module_1': {'sensor_module_args': {'save_raw_obs': True,
                                                                                       'sensor_module_id': 'view_finder'},
                                                                'sensor_module_class': <class 'tbp.monty.frameworks.models.sensor_modules.DetailedLoggingSM'>}},
                  'sm_to_agent_dict': {'patch_0': 'agent_id_0',
                                       'view_finder': 'agent_id_0'},
                  'sm_to_lm_matrix': [[0], [], [], [], []]},
 'train_dataloader_args': {'object_init_sampler': PredefinedObjectInitializer with params: 
	 positions: [[0.0, 1.5, 0.0]]
	 rotations: [array([0, 0, 0]), array([ 0, 90,  0]), array([  0, 180,   0]), array([  0, 270,   0]), array([90,  0,  0]), array([ 90, 180,   0]), array([35, 45,  0]), array([325,  45,   0]), array([ 35, 315,   0]), array([325, 315,   0]), array([ 35, 135,   0]), array([325, 135,   0]), array([ 35, 225,   0]), array([325, 225,   0])]
	 change every episode: None,
                           'object_names': ['mug', 'banana']},
 'train_dataloader_class': <class 'tbp.monty.frameworks.environments.embodied_data.InformedEnvironmentDataLoader'>}
----------------------------------------------------------------------------------------------------
---------training---------
Traceback (most recent call last):
  File "run.py", line 44, in <module>
    main(all_configs=CONFIGS, experiments=cmd_args.experiments)
  File "/oscar/data/msherif/sruiz10/tbp.monty/src/tbp/monty/frameworks/run.py", line 112, in main
    run(exp_config)
  File "/oscar/data/msherif/sruiz10/tbp.monty/src/tbp/monty/frameworks/run.py", line 52, in run
    exp.train()
  File "/oscar/data/msherif/sruiz10/tbp.monty/src/tbp/monty/frameworks/experiments/pretraining_experiments.py", line 130, in train
    self.run_epoch()
  File "/oscar/data/msherif/sruiz10/tbp.monty/src/tbp/monty/frameworks/experiments/monty_experiment.py", line 507, in run_epoch
    self.run_episode()
  File "/oscar/data/msherif/sruiz10/tbp.monty/src/tbp/monty/frameworks/experiments/pretraining_experiments.py", line 83, in run_episode
    self.model.step(observation)
  File "/oscar/data/msherif/sruiz10/tbp.monty/src/tbp/monty/frameworks/models/monty_base.py", line 145, in step
    self._exploratory_step(observation)
  File "/oscar/data/msherif/sruiz10/tbp.monty/src/tbp/monty/frameworks/models/abstract_monty_classes.py", line 37, in _exploratory_step
    self.aggregate_sensory_inputs(observation)
  File "/oscar/data/msherif/sruiz10/tbp.monty/src/tbp/monty/frameworks/models/monty_base.py", line 153, in aggregate_sensory_inputs
    sensor_module.update_state(self.get_agent_state())
  File "/oscar/data/msherif/sruiz10/tbp.monty/src/tbp/monty/frameworks/models/sensor_modules.py", line 506, in update_state
    sensor_position = state["sensors"][self.sensor_module_id + ".rgba"]["position"]
KeyError: 'patch_0.rgba'

config_args.py

@dataclass
## 5 Five Learning Modules Monty Config
# Can we adapt this to 5 LMs stacked on top of each other? (ADAPTED VERSION)
class FiveLMStackedMontyConfig(MontyConfig):
    monty_class: Callable = MontyForGraphMatching
    learning_module_configs: Union[dataclass, Dict] = field(
        default_factory=lambda: dict(
            learning_module_0=dict(
                learning_module_class=DisplacementGraphLM,
                learning_module_args=dict(k=5, match_attribute="displacement"),
            ),
            learning_module_1=dict(
                learning_module_class=DisplacementGraphLM,
                learning_module_args=dict(k=5, match_attribute="displacement"),
            ),
            learning_module_2=dict(
                learning_module_class=DisplacementGraphLM,
                learning_module_args=dict(k=5, match_attribute="displacement"),
            ),
            learning_module_3=dict(
                learning_module_class=DisplacementGraphLM,
                learning_module_args=dict(k=5, match_attribute="displacement"),
            ),
            learning_module_4=dict(
                learning_module_class=DisplacementGraphLM,
                learning_module_args=dict(k=5, match_attribute="displacement"),
            ),
        )
    )
    sensor_module_configs: Union[dataclass, Dict] = field(
        default_factory=lambda: dict(
            sensor_module_0=dict(
                sensor_module_class=HabitatDistantPatchSM,
                sensor_module_args=dict(
                    sensor_module_id="patch_0",
                    features=[
                        # morphological features (nescessarry)
                        "pose_vectors",
                        "pose_fully_defined",
                        "on_object",
                        # non-morphological features (optional)
                        "object_coverage",
                        "hsv",
                        "principal_curvatures",
                        "principal_curvatures_log",
                        "gaussian_curvature",
                        "mean_curvature",
                        "gaussian_curvature_sc",
                        "mean_curvature_sc",
                    ],
                    save_raw_obs=True,
                ),
            ),
            sensor_module_1=dict(
                # No need to extract features from the view finder since it is not
                # connected to a learning module (just used at beginning of episode)
                sensor_module_class=DetailedLoggingSM,
                sensor_module_args=dict(
                    sensor_module_id="view_finder",
                    save_raw_obs=True,
                ),
            ),
        )
    )
    motor_system_config: Union[dataclass, Dict] = field(
        default_factory=MotorSystemConfigInformedNoTrans
    )
    sm_to_agent_dict: Dict = field(
        default_factory=lambda: dict(
            patch_0="agent_id_0",
            view_finder="agent_id_0",
        )
    )
    sm_to_lm_matrix: List = field(
        default_factory=lambda: [
            [0], [], [], [], []
        ],  # View finder (sm1) not connected to lm
    )
    # First LM only gets sensory input, second gets input from first + sensor
    lm_to_lm_matrix: Optional[List] = field(default_factory=lambda: [[], [0], [1], [2], [3]])
    lm_to_lm_vote_matrix: Optional[List] = None
    monty_args: Union[Dict, dataclass] = field(default_factory=MontyArgs)

my_experiments.py

import os
from dataclasses import asdict

from benchmarks.configs.names import MyExperiments
from tbp.monty.frameworks.config_utils.config_args import (
    MontyArgs,
    MotorSystemConfigCurvatureInformedSurface,
    FiveLMStackedMontyConfig,
    PretrainLoggingConfig,
    get_cube_face_and_corner_views_rotations,
)
from tbp.monty.frameworks.config_utils.make_dataset_configs import (
    EnvironmentDataloaderPerObjectArgs,
    ExperimentArgs,
    PredefinedObjectInitializer,
)
from tbp.monty.frameworks.environments import embodied_data as ED
from tbp.monty.frameworks.experiments import (
    MontySupervisedObjectPretrainingExperiment,
)
from tbp.monty.frameworks.models.displacement_matching import DisplacementGraphLM
from tbp.monty.frameworks.models.sensor_modules import (
    DetailedLoggingSM,
    HabitatSurfacePatchSM,
)
from tbp.monty.simulators.habitat.configs import (
    SurfaceViewFinderMountHabitatDatasetArgs,
)

"""
Basic setup
-----------
"""
# Specify directory where an output directory will be created.
project_dir = os.path.expanduser("~/data/sruiz10/tbp/results/monty/projects")

# Specify a name for the model.
model_name = "FiveLMStackedMonty"


"""
Training
----------------------------------------------------------------------------------------
"""
# Here we specify which objects to learn. 'mug' and 'banana' come from the YCB dataset.
# If you don't have the YCB dataset, replace with names from habitat (e.g.,
# 'capsule3DSolid', 'cubeSolid', etc.).
object_names = ["mug", "banana"]
# Get predefined object rotations that give good views of the object from 14 angles.
train_rotations = get_cube_face_and_corner_views_rotations()

# The config dictionary for the pretraining experiment.
FiveLM = dict(
    # Specify monty experiment and its args.
    # The MontySupervisedObjectPretrainingExperiment class will provide the model
    # with object and pose labels for supervised pretraining.
    experiment_class=MontySupervisedObjectPretrainingExperiment,
    experiment_args=ExperimentArgs(
        n_train_epochs=len(train_rotations),
        do_eval=False,
    ),
    # Specify logging config.
    logging_config=PretrainLoggingConfig(
        output_dir=project_dir,
        run_name=model_name,
        wandb_handlers=[],
    ),
    # Specify the Monty config.
    monty_config=FiveLMStackedMontyConfig(
        monty_args=MontyArgs(num_exploratory_steps=500),
        motor_system_config=MotorSystemConfigCurvatureInformedSurface(),
    ),

    # Set up the environment and agent
    dataset_class=ED.EnvironmentDataset,
    dataset_args=SurfaceViewFinderMountHabitatDatasetArgs(),
    train_dataloader_class=ED.InformedEnvironmentDataLoader,
    train_dataloader_args=EnvironmentDataloaderPerObjectArgs(
        object_names=object_names,
        object_init_sampler=PredefinedObjectInitializer(rotations=train_rotations),
    ),
    # For a complete config we need to specify an eval_dataloader but since we only train here, this is unused
    eval_dataloader_class=ED.InformedEnvironmentDataLoader,
    eval_dataloader_args=EnvironmentDataloaderPerObjectArgs(
        object_names=object_names,
        object_init_sampler=PredefinedObjectInitializer(rotations=train_rotations),
    ),
)

experiments = MyExperiments(
    FiveLM=FiveLM,
)
CONFIGS = asdict(experiments)

names.py

@dataclass
class MyExperiments:
    FiveLM: dict

Hi @sebastianruizsebas,

It looks like you’re quite close to getting it to work! I noticed a couple things that are missing, feel free to check out the five_lm_stacked branch on my fork and try it out directly. But in the meantime, let me walk you through the major changes.

The biggest problem is that you’re not using the correct dataset_args in the configs. You are using SurfaceViewFinderMountHabitatDatasetArgs which does not define enough sensor ids (patch_{0,1,2,3,4}). In a stacked LM experiment, the higher-level LMs still get direct sensory input from SMs, in addition to receiving the outputs of lower-level LMs. Note that higher-level LMs receive sensory information at larger receptive field, which is why I progressively reduced the zoom values of SM patches here.

# tbp/monty/simulators/habitat/configs.py

@dataclass
class FiveLMStackedDistantMountConfig:
    # Five sensor patches at the same location with different receptive field sizes
    agent_id: Union[str, None] = "agent_id_0"
    sensor_ids: Union[List[str], None] = field(
        default_factory=lambda: [
            "patch_0",
            "patch_1",
            "patch_2",
            "patch_3",
            "patch_4",
            "view_finder",
        ]
    )
    height: Union[float, None] = 0.0
    position: List[Union[int, float]] = field(default_factory=lambda: [0.0, 1.5, 0.2])
    resolutions: List[List[Union[int, float]]] = field(
        default_factory=lambda: [
            [64, 64],
            [64, 64],
            [64, 64],
            [64, 64],
            [64, 64],
            [64, 64],
        ]
    )
    positions: List[List[Union[int, float]]] = field(
        default_factory=lambda: [
            [0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0],
            [0.0, 0.0, 0.0],
        ]
    )
    rotations: List[List[Union[int, float]]] = field(
        default_factory=lambda: [
            [1.0, 0.0, 0.0, 0.0],
            [1.0, 0.0, 0.0, 0.0],
            [1.0, 0.0, 0.0, 0.0],
            [1.0, 0.0, 0.0, 0.0],
            [1.0, 0.0, 0.0, 0.0],
            [1.0, 0.0, 0.0, 0.0],
        ]
    )
    semantics: List[List[Union[int, float]]] = field(
        default_factory=lambda: [False, False, False, False, False, False]
    )
    zooms: List[float] = field(default_factory=lambda: [10.0, 8.0, 6.0, 4.0, 2.0, 1.0])


@dataclass
class EnvInitArgsFiveLMDistantStackedMount(EnvInitArgs):
    agents: List[AgentConfig] = field(
        default_factory=lambda: [
            AgentConfig(MultiSensorAgent, FiveLMStackedDistantMountConfig().__dict__)
        ]
    )


@dataclass
class FiveLMStackedDistantMountHabitatDatasetArgs(MultiLMMountHabitatDatasetArgs):
    env_init_args: Dict = field(
        default_factory=lambda: EnvInitArgsFiveLMDistantStackedMount().__dict__
    )

I’ve also modified the FiveLMStackedMontyConfig. The biggest changes are adding more SMs for the higher-level LMs and defining these connections in the sm_to_lm_matrix. We also mainly use EvidenceGraphLM now for all of our experiments and benchmarks. Feel free to tune the parameters to your liking.

# tbp/monty/frameworks/config_utils/make_dataset_configs.py

@dataclass
## 5 Five Learning Modules Monty Config
# Can we adapt this to 5 LMs stacked on top of each other? (ADAPTED VERSION)
class FiveLMStackedMontyConfig(MontyConfig):
    monty_class: Callable = MontyForGraphMatching
    learning_module_configs: Union[dataclass, Dict] = field(
        default_factory=lambda: dict(
            learning_module_0=dict(
                learning_module_class=EvidenceGraphLM,
                learning_module_args=dict(
                    max_match_distance=0.001,
                    tolerances={
                        "patch_0": {
                            "hsv": np.array([0.1, 1, 1]),
                            "principal_curvatures_log": np.ones(2),
                        }
                    },
                    feature_weights={},
                    max_graph_size=0.2,
                    num_model_voxels_per_dim=50,
                    max_nodes_per_graph=50,
                ),
            ),
            learning_module_1=dict(
                learning_module_class=EvidenceGraphLM,
                learning_module_args=dict(
                    max_match_distance=0.001,
                    tolerances={
                        "patch_1": {
                            "hsv": np.array([0.1, 1, 1]),
                            "principal_curvatures_log": np.ones(2),
                        },
                        "learning_module_0": {"object_id": 1},
                    },
                    feature_weights={"learning_module_0": {"object_id": 1}},
                    max_graph_size=0.3,
                    num_model_voxels_per_dim=50,
                    max_nodes_per_graph=50,
                ),
            ),
            learning_module_2=dict(
                learning_module_class=EvidenceGraphLM,
                learning_module_args=dict(
                    max_match_distance=0.001,
                    tolerances={
                        "patch_2": {
                            "hsv": np.array([0.1, 1, 1]),
                            "principal_curvatures_log": np.ones(2),
                        },
                        "learning_module_1": {"object_id": 1},
                    },
                    feature_weights={"learning_module_1": {"object_id": 1}},
                    max_graph_size=0.3,
                    num_model_voxels_per_dim=50,
                    max_nodes_per_graph=50,
                ),
            ),
            learning_module_3=dict(
                learning_module_class=EvidenceGraphLM,
                learning_module_args=dict(
                    max_match_distance=0.001,
                    tolerances={
                        "patch_3": {
                            "hsv": np.array([0.1, 1, 1]),
                            "principal_curvatures_log": np.ones(2),
                        },
                        "learning_module_2": {"object_id": 1},
                    },
                    feature_weights={"learning_module_2": {"object_id": 1}},
                    max_graph_size=0.3,
                    num_model_voxels_per_dim=50,
                    max_nodes_per_graph=50,
                ),
            ),
            learning_module_4=dict(
                learning_module_class=EvidenceGraphLM,
                learning_module_args=dict(
                    max_match_distance=0.001,
                    tolerances={
                        "patch_4": {
                            "hsv": np.array([0.1, 1, 1]),
                            "principal_curvatures_log": np.ones(2),
                        },
                        "learning_module_3": {"object_id": 1},
                    },
                    feature_weights={"learning_module_3": {"object_id": 1}},
                    max_graph_size=0.3,
                    num_model_voxels_per_dim=50,
                    max_nodes_per_graph=50,
                ),
            ),
        )
    )
    sensor_module_configs: Union[dataclass, Dict] = field(
        default_factory=lambda: dict(
            sensor_module_0=dict(
                sensor_module_class=HabitatDistantPatchSM,
                sensor_module_args=dict(
                    sensor_module_id="patch_0",
                    features=[
                        # morphological features (nescessarry)
                        "pose_vectors",
                        "pose_fully_defined",
                        "on_object",
                        # non-morphological features (optional)
                        "object_coverage",
                        "hsv",
                        "principal_curvatures",
                        "principal_curvatures_log",
                        "gaussian_curvature",
                        "mean_curvature",
                        "mean_depth",
                        "gaussian_curvature_sc",
                        "mean_curvature_sc",
                    ],
                    save_raw_obs=False,
                ),
            ),
            sensor_module_1=dict(
                sensor_module_class=HabitatDistantPatchSM,
                sensor_module_args=dict(
                    sensor_module_id="patch_1",
                    features=[
                        # morphological features (nescessarry)
                        "pose_vectors",
                        "pose_fully_defined",
                        "on_object",
                        # non-morphological features (optional)
                        "object_coverage",
                        "hsv",
                        "principal_curvatures",
                        "principal_curvatures_log",
                        "gaussian_curvature",
                        "mean_curvature",
                        "mean_depth",
                        "gaussian_curvature_sc",
                        "mean_curvature_sc",
                    ],
                    save_raw_obs=False,
                ),
            ),
            sensor_module_2=dict(
                sensor_module_class=HabitatDistantPatchSM,
                sensor_module_args=dict(
                    sensor_module_id="patch_2",
                    features=[
                        # morphological features (nescessarry)
                        "pose_vectors",
                        "pose_fully_defined",
                        "on_object",
                        # non-morphological features (optional)
                        "object_coverage",
                        "hsv",
                        "principal_curvatures",
                        "principal_curvatures_log",
                        "gaussian_curvature",
                        "mean_curvature",
                        "mean_depth",
                        "gaussian_curvature_sc",
                        "mean_curvature_sc",
                    ],
                    save_raw_obs=False,
                ),
            ),
            sensor_module_3=dict(
                sensor_module_class=HabitatDistantPatchSM,
                sensor_module_args=dict(
                    sensor_module_id="patch_3",
                    features=[
                        # morphological features (nescessarry)
                        "pose_vectors",
                        "pose_fully_defined",
                        "on_object",
                        # non-morphological features (optional)
                        "object_coverage",
                        "hsv",
                        "principal_curvatures",
                        "principal_curvatures_log",
                        "gaussian_curvature",
                        "mean_curvature",
                        "mean_depth",
                        "gaussian_curvature_sc",
                        "mean_curvature_sc",
                    ],
                    save_raw_obs=False,
                ),
            ),
            sensor_module_4=dict(
                sensor_module_class=HabitatDistantPatchSM,
                sensor_module_args=dict(
                    sensor_module_id="patch_4",
                    features=[
                        # morphological features (nescessarry)
                        "pose_vectors",
                        "pose_fully_defined",
                        "on_object",
                        # non-morphological features (optional)
                        "object_coverage",
                        "hsv",
                        "principal_curvatures",
                        "principal_curvatures_log",
                        "gaussian_curvature",
                        "mean_curvature",
                        "mean_depth",
                        "gaussian_curvature_sc",
                        "mean_curvature_sc",
                    ],
                    save_raw_obs=False,
                ),
            ),
            sensor_module_5=dict(
                # No need to extract features from the view finder since it is not
                # connected to a learning module (just used at beginning of episode)
                sensor_module_class=DetailedLoggingSM,
                sensor_module_args=dict(
                    sensor_module_id="view_finder",
                    save_raw_obs=True,
                ),
            ),
        )
    )
    motor_system_config: Union[dataclass, Dict] = field(
        default_factory=MotorSystemConfigInformedNoTrans
    )
    sm_to_agent_dict: Dict = field(
        default_factory=lambda: dict(
            patch_0="agent_id_0",
            patch_1="agent_id_0",
            patch_2="agent_id_0",
            patch_3="agent_id_0",
            patch_4="agent_id_0",
            view_finder="agent_id_0",
        )
    )
    sm_to_lm_matrix: List = field(
        default_factory=lambda: [
            [0],
            [1],
            [2],
            [3],
            [4],
        ],  # View finder (sm1) not connected to lm
    )
    # First LM only gets sensory input, second gets input from first + sensor
    lm_to_lm_matrix: Optional[List] = field(
        default_factory=lambda: [[], [0], [1], [2], [3]]
    )
    lm_to_lm_vote_matrix: Optional[List] = None
    monty_args: Union[Dict, dataclass] = field(default_factory=MontyArgs)

These stacked LM experiments are still a work in progress as we move toward modeling compositional objects. So support is a bit limited for now, but you’re on the right track!

For more context, check out our latest paper on “Hierarchy or Heterarchy?”, especially figure 5, which shows how different sensory modules with different receptive fields are wired into the hierarchy. Also highly recommend watching Viviane’s explainer video!

3 Likes

Thank you so much for your help @rmounir!! I am currently watching the Hierarchy/heterarchy video :wink:

3 Likes