Skip to content

More pythonic interface #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
build
196 changes: 109 additions & 87 deletions ale_python_interface/ale_python_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,113 +8,135 @@
from numpy.ctypeslib import as_ctypes
import os

ale_lib = cdll.LoadLibrary(os.path.join(
os.path.dirname(__file__),'ale_c_wrapper.so'))
__all__ = ['ALEInterface']

ale_lib = cdll.LoadLibrary(os.path.join(os.path.dirname(__file__),
'ale_c_wrapper.so'))

# Properties taken from Arcade-Learning-Environment/src/common/Defaults.cpp
PROPS = {
'random_seed': str,
'game_controller': str,
'player_agent': str,
'max_num_episodes': int,
'max_num_frames': int,
'max_num_frames_per_episode': int,
'system_reset_steps': int,
'record_trajectory': bool,
'restricted_action_set': bool,
'use_starting_actions': bool,
'use_environment_distribution': bool,
'random_seed': str,
'disable_color_averaging': bool,
'send_rgb': bool,
'frame_skip': int,
'display_screen': bool,
}
GETTERS = {
str: ale_lib.getString,
int: ale_lib.getInt,
bool: ale_lib.getBool,
float: ale_lib.getFloat,
}
SETTERS = {
str: ale_lib.setString,
int: ale_lib.setInt,
bool: ale_lib.setBool,
float: ale_lib.setFloat,
}

class ALEInterface(object):
def __init__(self):
self.obj = ale_lib.ALE_new()

def getString(self,key):
return ale_lib.getString(self.obj,key)
def getInt(self,key):
return ale_lib.getInt(self.obj,key)
def getBool(self,key):
return ale_lib.getBool(self.obj,key)
def getFloat(self,key):
return ale_lib.getFloat(self.obj,key)

def set(self,key,value):
if(type(value) == str):
ale_lib.setString(self.obj,key,value)
elif(type(value) == int):
ale_lib.setInt(self.obj,key,value)
elif(type(value) == bool):
ale_lib.setBool(self.obj,key,value)
elif(type(value) == float):
ale_lib.setFloar(self.obj,key,value)

def loadROM(self,rom_file):
ale_lib.loadROM(self.obj,rom_file)

def act(self,action):
return ale_lib.act(self.obj,int(action))

def game_over(self):
return ale_lib.game_over(self.obj)

def reset_game(self):
ale_lib.reset_game(self.obj)
class ALEInterface(object):
def __init__(self, rom_file):
self._obj = ale_lib.ALE_new()
ale_lib.loadROM(self._obj, rom_file)

def getLegalActionSet(self):
act_size = ale_lib.getLegalActionSize(self.obj)
act = np.zeros((act_size),dtype=np.int32)
ale_lib.getLegalActionSet(self.obj,as_ctypes(act))
def __del__(self):
ale_lib.ALE_del(self._obj)

def __getitem__(self, key):
if key not in PROPS:
raise ValueError('Invalid key: %s' % key)
getter = GETTERS[PROPS[key]]
return getter(self._obj, key)

def __setitem__(self, key, value):
if key not in PROPS:
raise ValueError('Invalid key: %s' % key)
setter = SETTERS[PROPS[key]]
setter(self._obj, key, value)

@property
def legal_actions(self):
act_size = ale_lib.getLegalActionSize(self._obj)
act = np.zeros(act_size, dtype=np.int32)
ale_lib.getLegalActionSet(self._obj, as_ctypes(act))
return act

def getMinimalActionSet(self):
act_size = ale_lib.getMinimalActionSize(self.obj)
act = np.zeros((act_size),dtype=np.int32)
ale_lib.getMinimalActionSet(self.obj,as_ctypes(act))
@property
def minimal_actions(self):
act_size = ale_lib.getMinimalActionSize(self._obj)
act = np.zeros(act_size, dtype=np.int32)
ale_lib.getMinimalActionSet(self._obj, as_ctypes(act))
return act

def getFrameNumber(self):
return ale_lib.getFrameNumber(self.obj)
@property
def frame_number(self):
return ale_lib.getFrameNumber(self._obj)

def getEpisodeFrameNumber(self):
return ale_lib.getEpisodeFrameNumber(self.obj)
@property
def episode_frame_number(self):
return ale_lib.getEpisodeFrameNumber(self._obj)

def getScreenDims(self):
@property
def screen_dims(self):
"""returns a tuple that contains (screen_width,screen_height)
"""
width = ale_lib.getScreenWidth(self.obj)
height = ale_lib.getScreenHeight(self.obj)
return (width,height)
width = ale_lib.getScreenWidth(self._obj)
height = ale_lib.getScreenHeight(self._obj)
return width, height

@property
def ram_size(self):
return ale_lib.getRAMSize(self._obj)

@property
def is_game_over(self):
return ale_lib.game_over(self._obj)

def getScreen(self,screen_data=None):
def act(self,action):
return ale_lib.act(self._obj, int(action))

def reset_game(self):
ale_lib.reset_game(self._obj)

def fill_screen(self, screen_data=None):
"""This function fills screen_data with the RAW Pixel data
screen_data MUST be a numpy array of uint8/int8. This could be initialized like so:
screen_data = np.array(w*h,dtype=np.uint8)
Notice, it must be width*height in size also
If it is None, then this function will initialize it
Note: This is the raw pixel values from the atari, before any RGB palette transformation takes place
screen_data MUST be a numpy array of uint8.
Note: This is the raw pixel values from the atari,
before any RGB palette transformation takes place.
"""
if(screen_data is None):
width = ale_lib.getScreenWidth(self.obj)
height = ale_lib.getScreenWidth(self.obj)
screen_data = np.zeros(width*height,dtype=np.uint8)
ale_lib.getScreen(self.obj,as_ctypes(screen_data))
if screen_data is None:
size = np.prod(self.screen_dims)
screen_data = np.zeros(size, dtype=np.uint8)
ale_lib.getScreen(self._obj, as_ctypes(screen_data))
return screen_data

def getScreenRGB(self,screen_data=None):
def fill_screen_rgb(self, screen_data=None):
"""This function fills screen_data with the data
screen_data MUST be a numpy array of uint32/int32. This can be initialized like so:
screen_data = np.array(w*h,dtype=np.uint32)
Notice, it must be width*height in size also
If it is None, then this function will initialize it
screen_data MUST be a numpy array of uint32.
"""
if(screen_data is None):
width = ale_lib.getScreenWidth(self.obj)
height = ale_lib.getScreenWidth(self.obj)
screen_data = np.zeros(width*height,dtype=np.uint32)
ale_lib.getScreenRGB(self.obj,as_ctypes(screen_data))
if screen_data is None:
size = np.prod(self.screen_dims)
screen_data = np.zeros(size, dtype=np.uint32)
ale_lib.getScreenRGB(self._obj, as_ctypes(screen_data))
return screen_data

def getRAMSize(self):
return ale_lib.getRAMSize(self.obj)

def getRAM(self,ram=None):
def fill_ram(self, ram=None):
"""This function grabs the atari RAM.
ram MUST be a numpy array of uint8/int8. This can be initialized like so:
ram = np.array(ram_size,dtype=uint8)
Notice: It must be ram_size where ram_size can be retrieved via the getRAMSize function.
If it is None, then this function will initialize it.
ram MUST be a numpy array of uint8.
"""
if(ram is None):
ram_size = ale_lib.getRAMSize(self.obj)
ram = np.zeros(ram_size,dtype=np.uint8)
ale_lib.getRAM(self.obj,as_ctypes(ram))

def __del__(self):
ale_lib.ALE_del(self.obj)
if ram is None:
ram = np.zeros(self.ram_size, dtype=np.uint8)
ale_lib.getRAM(self._obj,as_ctypes(ram))
return ram
20 changes: 9 additions & 11 deletions examples/ale_python_test1.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,21 @@
print("Usage ./ale_python_test1.py <ROM_FILE_NAME>")
sys.exit()

ale = ALEInterface()
ale = ALEInterface(sys.argv[1])

max_frames_per_episode = ale.getInt("max_num_frames_per_episode");
ale.set("random_seed",123)
max_frames_per_episode = ale["max_num_frames_per_episode"]
ale["random_seed"] = 123

random_seed = ale.getInt("random_seed")
random_seed = ale["random_seed"]
print("random_seed: " + str(random_seed))

ale.loadROM(sys.argv[1])
legal_actions = ale.getLegalActionSet()
legal_actions = ale.legal_actions

for episode in range(10):
total_reward = 0.0
while not ale.game_over():
a = legal_actions[np.random.randint(legal_actions.size)]
reward = ale.act(a);
total_reward = 0.0
while not ale.is_game_over:
a = np.random.choice(legal_actions)
reward = ale.act(a)
total_reward += reward
print("Episode " + str(episode) + " ended with score: " + str(total_reward))
ale.reset_game()

30 changes: 14 additions & 16 deletions examples/ale_python_test2.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#!/usr/bin/env python

# ale_python_test1.py
# ale_python_test2.py
# Author: Ben Goodrich
#
# This modified ale_python_test1.py to do more extensive tests of the python interface
# by calling more functions
# This modified ale_python_test1.py to do more extensive tests of the
# python interface by calling more functions
import sys
from ale_python_interface import ALEInterface
import numpy as np
Expand All @@ -13,26 +13,24 @@
print("Usage ./ale_python_test2.py <ROM_FILE_NAME>")
sys.exit()

ale = ALEInterface()
ale = ALEInterface(sys.argv[1])

max_frames_per_episode = ale.getInt("max_num_frames_per_episode");
ale.set("random_seed",123)
max_frames_per_episode = ale["max_num_frames_per_episode"]
ale["random_seed"] = 123

random_seed = ale.getInt("random_seed")
random_seed = ale["random_seed"]
print("random_seed: " + str(random_seed))

ale.loadROM(sys.argv[1])
legal_actions = ale.getMinimalActionSet()
legal_actions = ale.minimal_actions

for episode in range(10):
total_reward = 0.0
while not ale.game_over():
a = legal_actions[np.random.randint(legal_actions.size)]
reward = ale.act(a);
total_reward = 0.0
while not ale.is_game_over:
a = np.random.choice(legal_actions)
reward = ale.act(a)
total_reward += reward
episode_frame_number = ale.getEpisodeFrameNumber()
frame_number = ale.getFrameNumber()
episode_frame_number = ale.episode_frame_number
frame_number = ale.frame_number
print("Frame Number: " + str(frame_number) + " Episode Frame Number: " + str(episode_frame_number))
print("Episode " + str(episode) + " ended with score: " + str(total_reward))
ale.reset_game()

48 changes: 23 additions & 25 deletions examples/ale_python_test_pygame.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,51 +13,49 @@
print("Usage ./ale_python_test_pygame.py <ROM_FILE_NAME>")
sys.exit()

ale = ALEInterface()
ale = ALEInterface(sys.argv[1])

max_frames_per_episode = ale.getInt("max_num_frames_per_episode");
ale.set("random_seed",123)
max_frames_per_episode = ale["max_num_frames_per_episode"]
ale["random_seed"] = 123

random_seed = ale.getInt("random_seed")
random_seed = ale["random_seed"]
print("random_seed: " + str(random_seed))

ale.loadROM(sys.argv[1])
legal_actions = ale.getMinimalActionSet()
legal_actions = ale.minimal_actions

(screen_width,screen_height) = ale.getScreenDims()
print("width/height: " +str(screen_width) + "/" + str(screen_height))
print("width/height: %s/%s" % ale.screen_dims)

#init pygame
pygame.init()
screen = pygame.display.set_mode((screen_width,screen_height))
screen = pygame.display.set_mode(ale.screen_dims)
pygame.display.set_caption("Arcade Learning Environment Random Agent Display")

pygame.display.flip()

episode = 0
total_reward = 0.0
while(episode < 10):
exit=False
total_reward = 0.0
while episode < 10:
exit = False
for event in pygame.event.get():
if event.type == pygame.QUIT:
exit=True
break;
if(exit):
break
if exit:
break

a = legal_actions[np.random.randint(legal_actions.size)]
reward = ale.act(a);
a = np.random.choice(legal_actions)
reward = ale.act(a)
total_reward += reward

numpy_surface = np.frombuffer(screen.get_buffer(),dtype=np.int32)
ale.getScreenRGB(numpy_surface)
ale.fill_screen_rgb(numpy_surface)
pygame.display.flip()
if(ale.game_over()):
episode_frame_number = ale.getEpisodeFrameNumber()
frame_number = ale.getFrameNumber()
print("Frame Number: " + str(frame_number) + " Episode Frame Number: " + str(episode_frame_number))
print("Episode " + str(episode) + " ended with score: " + str(total_reward))
if ale.is_game_over:
episode_frame_number = ale.episode_frame_number
frame_number = ale.frame_number
print("Frame Number: %d Episode Frame Number: %d" % (
frame_number, episode_frame_number))
print("Episode %d ended with score: %g" % (episode, total_reward))
ale.reset_game()
total_reward = 0.0
episode = episode + 1

total_reward = 0.0
episode += 1
Loading