Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error when stacking RGB frames using UniZero #316

Open
fjparrado opened this issue Jan 8, 2025 · 2 comments
Open

Error when stacking RGB frames using UniZero #316

fjparrado opened this issue Jan 8, 2025 · 2 comments
Labels
bug Something isn't working config New or improved configuration

Comments

@fjparrado
Copy link

Hi, I got very good results using UniZero with just 1 RGB frame, better than MZ and EZ. Now, I would like to stack 3 frames during the training to check if the results are even better. This is my config dictionary:

titan_unizero_stack3_config = dict(
    exp_name=f'data_unizero/{env_id}_unizero',
    env=dict(
        stop_value=int(1e6),
        observation_shape=(9, 96, 96),
        image_channel=3,
        frame_stack_num=3,
        gray_scale=False,
        collector_env_num=8,
        evaluator_env_num=1,
        n_evaluator_episode=5,
        manager=dict(shared_memory=False, ),
        # TODO: only for debug
        # collect_max_episode_steps=int(50),
        # eval_max_episode_steps=int(50),
    ),
    policy=dict(
        model=dict(
            observation_shape=(9, 96, 96),
            image_channel=3,
            frame_stack_num=3,
            gray_scale=False,
            action_space_size=action_space_size,
            world_model_cfg=dict(
                max_blocks=num_unroll_steps,
                max_tokens=2 * num_unroll_steps,  # NOTE: each timestep has 2 tokens: obs and action
                context_length=2 * infer_context_length,
                device='cuda',
                # device='cpu',
                action_space_size=action_space_size,
                num_layers=4,
                num_heads=8,
                embed_dim=768,
                obs_type='image',
                env_num=max(collector_env_num, evaluator_env_num),
            ),
        ),
        # (str) The path of the pretrained model. If None, the model will be initialized by the default model.
        model_path=None,
        action_type = 'fixed_action_space',
        num_unroll_steps=num_unroll_steps,
        update_per_collect=False,
        replay_ratio=0.25,
        batch_size=64,
        optim_type='AdamW',
        num_simulations=50,
        reanalyze_ratio=0.,
        n_episode=8,
        replay_buffer_size=int(1e6),
        collector_env_num=8,
        evaluator_env_num=1,
    ),
)

However, it looks like there is a problem. A similar configuration for stacking 3 RGB frames is working well with MZ and EZ configurations, but not with UZ. This is the error I am getting:

File "/LightZero/lzero/model/unizero_world_models/world_model.py", line 563, in _process_obs_act_combined
obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: shape '[64, 10, 1, -1]' is invalid for input of size 196608

I would appreciate it if I could get any insight on how to fix the issue. Thanks in advance.

@puyuan1996 puyuan1996 added config New or improved configuration bug Something isn't working labels Jan 8, 2025
@puyuan1996
Copy link
Collaborator

Thank you for your attention! UniZero has previously only been tested with stack=1 and stack=4 scenarios. You can first adjust the parameter to stack=4 to test the results. As for the adaptation for stack=3, we will analyze it in the following days. We truly appreciate your patience. If you have any other questions, feel free to bring them up at any time!

@fjparrado
Copy link
Author

fjparrado commented Jan 9, 2025

Thank you for your answer! Small update. As you mentioned, I just made the following change:

From:

        observation_shape=(9, 96, 96),
        image_channel=3,
        frame_stack_num=3,

To:

        observation_shape=(12, 96, 96),
        image_channel=3,
        frame_stack_num=4,

And now, there are no errors.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working config New or improved configuration
Projects
None yet
Development

No branches or pull requests

2 participants