Minor changes, add assets

This commit is contained in:
Kevin Black
2023-06-27 10:20:03 -07:00
parent 4c5322ca85
commit 8cab96dea4
5 changed files with 73 additions and 13 deletions

View File

@@ -1,4 +1,5 @@
from importlib import resources
import os
import functools
import random
import inflect
@@ -8,35 +9,45 @@ ASSETS_PATH = resources.files("ddpo_pytorch.assets")
@functools.cache
def load_lines(name):
with ASSETS_PATH.joinpath(name).open() as f:
def _load_lines(path):
"""
Load lines from a file. First tries to load from `path` directly, and if that doesn't exist, searches the
`ddpo_pytorch/assets` directory for a file named `path`.
"""
if not os.path.exists(path):
newpath = ASSETS_PATH.joinpath(path)
if not os.path.exists(newpath):
raise FileNotFoundError(f"Could not find {path} or ddpo_pytorch.assets/{path}")
path = newpath
with open(path, "r") as f:
return [line.strip() for line in f.readlines()]
def imagenet(low, high):
return random.choice(load_lines("imagenet_classes.txt")[low:high]), {}
def from_file(path, low=None, high=None):
prompts = _load_lines(path)[low:high]
return random.choice(prompts), {}
def imagenet_all():
return imagenet(0, 1000)
return from_file("imagenet_classes.txt")
def imagenet_animals():
return imagenet(0, 398)
return from_file("imagenet_classes.txt", 0, 398)
def imagenet_dogs():
return imagenet(151, 269)
return from_file("imagenet_classes.txt", 151, 269)
def nouns_activities(nouns_file, activities_file):
nouns = load_lines(nouns_file)
activities = load_lines(activities_file)
nouns = _load_lines(nouns_file)
activities = _load_lines(activities_file)
return f"{IE.a(random.choice(nouns))} {random.choice(activities)}", {}
def counting(nouns_file, low, high):
nouns = load_lines(nouns_file)
nouns = _load_lines(nouns_file)
number = IE.number_to_words(random.randint(low, high))
noun = random.choice(nouns)
plural_noun = IE.plural(noun)