Skip to content

[Bug]: Impossible to load model to use it for training #49

@edofazza

Description

@edofazza

🐛 Bug

I'm trying to load a trained model to use it for testing, but I am facing with an error.
Thank you.

To Reproduce

import torch as th
import os
from rllte.xplore.reward import RND, Disagreement, RIDE
from rllte.env import make_mario_env
from rllte.agent import PPO, DDPG

if __name__ == '__main__':
    n_steps: int = 2048 * 16
    device = 'cuda' if th.cuda.is_available() else 'cpu'
    envs = make_mario_env('SuperMarioBros-1-1-v0', device=device, num_envs=1,
                          asynchronous=False, frame_stack=4, gray_scale=True)
    print(device, envs.observation_space, envs.action_space)
    # create the intrinsic reward module
    #irs = Disagreement(envs, device=device)
    # create the PPO agent
    agent = PPO(envs,
                device=device,
                batch_size=512,
                n_epochs=10,
                num_steps=n_steps//8,
                pretraining=True)
    agent.policy.load_state_dict(th.load("ride_1_1_1507328.pth", map_location=th.device('cpu')),)
    agent.eval(100)

Relevant log output / Error message

/Users/edoardofazzari/miniconda3/envs/got-it-memorized/lib/python3.11/site-packages/gym/envs/registration.py:555: UserWarning: WARN: The environment SuperMarioBros-1-1-v0 is out of date. You should consider upgrading to version `v3`.
  logger.warn(
/Users/edoardofazzari/miniconda3/envs/got-it-memorized/lib/python3.11/site-packages/gym/envs/registration.py:627: UserWarning: WARN: The environment creator metadata doesn't include `render_modes`, contains: ['render.modes', 'video.frames_per_second']
  logger.warn(
/Users/edoardofazzari/miniconda3/envs/got-it-memorized/lib/python3.11/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.metadata to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.metadata` for environment variables or `env.get_wrapper_attr('metadata')` that will search the reminding wrappers.
  logger.warn(
/Users/edoardofazzari/miniconda3/envs/got-it-memorized/lib/python3.11/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.single_observation_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.single_observation_space` for environment variables or `env.get_wrapper_attr('single_observation_space')` that will search the reminding wrappers.
  logger.warn(
/Users/edoardofazzari/miniconda3/envs/got-it-memorized/lib/python3.11/site-packages/gymnasium/core.py:311: UserWarning: WARN: env.single_action_space to get variables from other wrappers is deprecated and will be removed in v1.0, to get this variable you can do `env.unwrapped.single_action_space` for environment variables or `env.get_wrapper_attr('single_action_space')` that will search the reminding wrappers.
  logger.warn(
cpu Box(0, 255, (4, 84, 84), uint8) Discrete(7)
Traceback (most recent call last):
  File "/Users/edoardofazzari/Documents/GitHub/got-it-memorized/src/tests.py", line 22, in <module>
    agent.policy.load_state_dict(th.load("/Users/edoardofazzari/Documents/GitHub/got-it-memorized/src/ride_1_1_1507328.pth",
  File "/Users/edoardofazzari/miniconda3/envs/got-it-memorized/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2103, in load_state_dict
    raise TypeError(f"Expected state_dict to be dict-like, got {type(state_dict)}.")
TypeError: Expected state_dict to be dict-like, got <class 'rllte.common.utils.ExportModel'>.

System Info

No response

Checklist

  • I have checked that there is no similar issue in the repo
  • I have read the documentation
  • I have provided a minimal working example to reproduce the bug
  • I've used the markdown code blocks for both code and stack traces.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions