[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can you provide a test demo, reproduce the optimal policy and save the video? #100

Open
whyy812 opened this issue Jan 17, 2022 · 2 comments
Assignees
Labels
Suggestion New feature or request

Comments

@whyy812
Copy link
whyy812 commented Jan 17, 2022

Can you provide a test demo, reproduce the optimal policy and save the video? (e.g. for LunarLanderContinuous-v2 or BipedalWalker-v3)

@YangletLiu YangletLiu added the Suggestion New feature or request label Jan 17, 2022
@whyy812
Copy link
Author
whyy812 commented Jan 20, 2022

I have written a test demo that may help. However, the function def save_or_load_agent(self, cwd: str, if_save: bool) in AgentBase.py has to be modified a little:
def save_or_load_agent(self, cwd: str, if_save: bool):
...
if if_save:
for name, obj in name_obj_list:
save_path = f"{cwd}/{name}.pth"
torch.save(obj.state_dict(), save_path)
else:
for name, obj in name_obj_list:
save_path = f"{cwd}/{name}.pth"
load_torch_file(obj, save_path) if os.path.isfile(save_path) else None
return self.act,self.act_target,self.act_optim,self.cri,self.cri_target, self.cri_optim

import torch
from elegantrl.train.utils import init_agent
from elegantrl.train.config import build_env

import gym
from elegantrl.agents.AgentSAC import AgentSAC, AgentModSAC
from elegantrl.envs.Gym import get_gym_env_args
from elegantrl.train.config import Arguments

get_gym_env_args(gym.make('LunarLanderContinuous-v2'), if_print=True)

env_func = gym.make
env_args = {
'env_num': 1,
'env_name': 'LunarLanderContinuous-v2',
'max_step': 1000,
'state_dim': 8,
'action_dim': 2,
'if_discrete': False,
'target_return': 200,
'id': 'LunarLanderContinuous-v2'
}

args = Arguments(agent=AgentModSAC(), env_func=env_func, env_args=env_args)
args.init_before_training() # necessary!
learner_gpu = args.learner_gpus[0]
env = build_env(env=args.env, env_func=args.env_func, env_args=args.env_args, gpu_id=learner_gpu)
agent = init_agent(args, gpu_id=learner_gpu, env=env)
cwd = args.cwd

act,b,c,d,e,f = agent.save_or_load_agent(cwd, False)

act.load_state_dict(torch.load("actor.pth"))

s=env.reset()
print(s,s.shape)
for i in range(1000):
action=act.get_action(torch.tensor(s))
# agent.train()
next_state,reward,done,_=env.step(action.detach().numpy())
if done:
s=env.reset()
state=next_state
env.render()

@YangletLiu
Copy link
Contributor

Thanks! We are looking into your codes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Suggestion New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants