commit 2fda3d4e78eec14913aa71f5d009aba9bb246fba Author: Kevin Black Date: Fri Jun 23 19:25:54 2023 -0700 Initial commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c4f2c1c --- /dev/null +++ b/.gitignore @@ -0,0 +1,305 @@ +# Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,intellij+all,vim +# Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,python,intellij+all,vim + +### Intellij+all ### +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/usage.statistics.xml +.idea/**/dictionaries +.idea/**/shelf + +# AWS User-specific +.idea/**/aws.xml + +# Generated files +.idea/**/contentModel.xml + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# Gradle and Maven with auto-import +# When using Gradle or Maven with auto-import, you should exclude module files, +# since they will be recreated, and may cause churn. Uncomment if using +# auto-import. +# .idea/artifacts +# .idea/compiler.xml +# .idea/jarRepositories.xml +# .idea/modules.xml +# .idea/*.iml +# .idea/modules +# *.iml +# *.ipr + +# CMake +cmake-build-*/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# SonarLint plugin +.idea/sonarlint/ + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests + +# Android studio 3.1+ serialized cache file +.idea/caches/build_file_checksums.ser + +### Intellij+all Patch ### +# Ignore everything but code style settings and run configurations +# that are supposed to be shared within teams. + +.idea/* + +!.idea/codeStyles +!.idea/runConfigurations + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +### Vim ### +# Swap +[._]*.s[a-v][a-z] +!*.svg # comment out if you don't need vector files +[._]*.sw[a-p] +[._]s[a-rt-v][a-z] +[._]ss[a-gi-z] +[._]sw[a-p] + +# Session +Session.vim +Sessionx.vim + +# Temporary +.netrwhist +*~ +# Auto-generated tag files +tags +# Persistent undo +[._]*.un~ + +### VisualStudioCode ### +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix + +### VisualStudioCode Patch ### +# Ignore all local history of files +.history +.ionide + +# End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,python,intellij+all,vim + diff --git a/README.md b/README.md new file mode 100644 index 0000000..9040f06 --- /dev/null +++ b/README.md @@ -0,0 +1 @@ +# ddpo-pytorch \ No newline at end of file diff --git a/config/__pycache__/base.cpython-310.pyc b/config/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000..8b8f958 Binary files /dev/null and b/config/__pycache__/base.cpython-310.pyc differ diff --git a/config/base.py b/config/base.py new file mode 100644 index 0000000..e38e6bb --- /dev/null +++ b/config/base.py @@ -0,0 +1,56 @@ +import ml_collections + +def get_config(): + + config = ml_collections.ConfigDict() + + # misc + config.seed = 42 + config.logdir = "logs" + config.num_epochs = 100 + config.mixed_precision = "fp16" + config.allow_tf32 = True + + # pretrained model initialization + config.pretrained = pretrained = ml_collections.ConfigDict() + pretrained.model = "runwayml/stable-diffusion-v1-5" + pretrained.revision = "main" + + # training + config.train = train = ml_collections.ConfigDict() + train.mixed_precision = "fp16" + train.batch_size = 1 + train.use_8bit_adam = False + train.scale_lr = False + train.learning_rate = 1e-4 + train.adam_beta1 = 0.9 + train.adam_beta2 = 0.999 + train.adam_weight_decay = 1e-2 + train.adam_epsilon = 1e-8 + train.gradient_accumulation_steps = 1 + train.max_grad_norm = 1.0 + train.num_inner_epochs = 1 + train.cfg = True + train.adv_clip_max = 10 + train.clip_range = 1e-4 + + # sampling + config.sample = sample = ml_collections.ConfigDict() + sample.num_steps = 5 + sample.eta = 1.0 + sample.guidance_scale = 5.0 + sample.batch_size = 1 + sample.num_batches_per_epoch = 4 + + # prompting + config.prompt_fn = "imagenet_animals" + config.prompt_fn_kwargs = {} + + # rewards + config.reward_fn = "jpeg_compressibility" + + config.per_prompt_stat_tracking = ml_collections.ConfigDict() + config.per_prompt_stat_tracking.buffer_size = 128 + config.per_prompt_stat_tracking.min_count = 16 + + return config \ No newline at end of file diff --git a/ddpo_pytorch/__pycache__/prompts.cpython-310.pyc b/ddpo_pytorch/__pycache__/prompts.cpython-310.pyc new file mode 100644 index 0000000..440ec87 Binary files /dev/null and b/ddpo_pytorch/__pycache__/prompts.cpython-310.pyc differ diff --git a/ddpo_pytorch/__pycache__/rewards.cpython-310.pyc b/ddpo_pytorch/__pycache__/rewards.cpython-310.pyc new file mode 100644 index 0000000..7b3db0f Binary files /dev/null and b/ddpo_pytorch/__pycache__/rewards.cpython-310.pyc differ diff --git a/ddpo_pytorch/__pycache__/stat_tracking.cpython-310.pyc b/ddpo_pytorch/__pycache__/stat_tracking.cpython-310.pyc new file mode 100644 index 0000000..14e6708 Binary files /dev/null and b/ddpo_pytorch/__pycache__/stat_tracking.cpython-310.pyc differ diff --git a/ddpo_pytorch/assets/imagenet_classes.txt b/ddpo_pytorch/assets/imagenet_classes.txt new file mode 100644 index 0000000..722c984 --- /dev/null +++ b/ddpo_pytorch/assets/imagenet_classes.txt @@ -0,0 +1,1000 @@ +tench, Tinca tinca +goldfish, Carassius auratus +great white shark, white shark, man-eater, man-eating shark, Carcharodon carcharias +tiger shark, Galeocerdo cuvieri +hammerhead, hammerhead shark +electric ray, crampfish, numbfish, torpedo +stingray +cock +hen +ostrich, Struthio camelus +brambling, Fringilla montifringilla +goldfinch, Carduelis carduelis +house finch, linnet, Carpodacus mexicanus +junco, snowbird +indigo bunting, indigo finch, indigo bird, Passerina cyanea +robin, American robin, Turdus migratorius +bulbul +jay +magpie +chickadee +water ouzel, dipper +kite +bald eagle, American eagle, Haliaeetus leucocephalus +vulture +great grey owl, great gray owl, Strix nebulosa +European fire salamander, Salamandra salamandra +common newt, Triturus vulgaris +eft +spotted salamander, Ambystoma maculatum +axolotl, mud puppy, Ambystoma mexicanum +bullfrog, Rana catesbeiana +tree frog, tree-frog +tailed frog, bell toad, ribbed toad, tailed toad, Ascaphus trui +loggerhead, loggerhead turtle, Caretta caretta +leatherback turtle, leatherback, leathery turtle, Dermochelys coriacea +mud turtle +terrapin +box turtle, box tortoise +banded gecko +common iguana, iguana, Iguana iguana +American chameleon, anole, Anolis carolinensis +whiptail, whiptail lizard +agama +frilled lizard, Chlamydosaurus kingi +alligator lizard +Gila monster, Heloderma suspectum +green lizard, Lacerta viridis +African chameleon, Chamaeleo chamaeleon +Komodo dragon, Komodo lizard, dragon lizard, giant lizard, Varanus komodoensis +African crocodile, Nile crocodile, Crocodylus niloticus +American alligator, Alligator mississipiensis +triceratops +thunder snake, worm snake, Carphophis amoenus +ringneck snake, ring-necked snake, ring snake +hognose snake, puff adder, sand viper +green snake, grass snake +king snake, kingsnake +garter snake, grass snake +water snake +vine snake +night snake, Hypsiglena torquata +boa constrictor, Constrictor constrictor +rock python, rock snake, Python sebae +Indian cobra, Naja naja +green mamba +sea snake +horned viper, cerastes, sand viper, horned asp, Cerastes cornutus +diamondback, diamondback rattlesnake, Crotalus adamanteus +sidewinder, horned rattlesnake, Crotalus cerastes +trilobite +harvestman, daddy longlegs, Phalangium opilio +scorpion +black and gold garden spider, Argiope aurantia +barn spider, Araneus cavaticus +garden spider, Aranea diademata +black widow, Latrodectus mactans +tarantula +wolf spider, hunting spider +tick +centipede +black grouse +ptarmigan +ruffed grouse, partridge, Bonasa umbellus +prairie chicken, prairie grouse, prairie fowl +peacock +quail +partridge +African grey, African gray, Psittacus erithacus +macaw +sulphur-crested cockatoo, Kakatoe galerita, Cacatua galerita +lorikeet +coucal +bee eater +hornbill +hummingbird +jacamar +toucan +drake +red-breasted merganser, Mergus serrator +goose +black swan, Cygnus atratus +tusker +echidna, spiny anteater, anteater +platypus, duckbill, duckbilled platypus, duck-billed platypus, Ornithorhynchus anatinus +wallaby, brush kangaroo +koala, koala bear, kangaroo bear, native bear, Phascolarctos cinereus +wombat +jellyfish +sea anemone, anemone +brain coral +flatworm, platyhelminth +nematode, nematode worm, roundworm +conch +snail +slug +sea slug, nudibranch +chiton, coat-of-mail shell, sea cradle, polyplacophore +chambered nautilus, pearly nautilus, nautilus +Dungeness crab, Cancer magister +rock crab, Cancer irroratus +fiddler crab +king crab, Alaska crab, Alaskan king crab, Alaska king crab, Paralithodes camtschatica +American lobster, Northern lobster, Maine lobster, Homarus americanus +spiny lobster, langouste, rock lobster, crawfish, crayfish, sea crawfish +crayfish, crawfish, crawdad, crawdaddy +hermit crab +isopod +white stork, Ciconia ciconia +black stork, Ciconia nigra +spoonbill +flamingo +little blue heron, Egretta caerulea +American egret, great white heron, Egretta albus +bittern +crane +limpkin, Aramus pictus +European gallinule, Porphyrio porphyrio +American coot, marsh hen, mud hen, water hen, Fulica americana +bustard +ruddy turnstone, Arenaria interpres +red-backed sandpiper, dunlin, Erolia alpina +redshank, Tringa totanus +dowitcher +oystercatcher, oyster catcher +pelican +king penguin, Aptenodytes patagonica +albatross, mollymawk +grey whale, gray whale, devilfish, Eschrichtius gibbosus, Eschrichtius robustus +killer whale, killer, orca, grampus, sea wolf, Orcinus orca +dugong, Dugong dugon +sea lion +Chihuahua +Japanese spaniel +Maltese dog, Maltese terrier, Maltese +Pekinese, Pekingese, Peke +Shih-Tzu +Blenheim spaniel +papillon +toy terrier +Rhodesian ridgeback +Afghan hound, Afghan +basset, basset hound +beagle +bloodhound, sleuthhound +bluetick +black-and-tan coonhound +Walker hound, Walker foxhound +English foxhound +redbone +borzoi, Russian wolfhound +Irish wolfhound +Italian greyhound +whippet +Ibizan hound, Ibizan Podenco +Norwegian elkhound, elkhound +otterhound, otter hound +Saluki, gazelle hound +Scottish deerhound, deerhound +Weimaraner +Staffordshire bullterrier, Staffordshire bull terrier +American Staffordshire terrier, Staffordshire terrier, American pit bull terrier, pit bull terrier +Bedlington terrier +Border terrier +Kerry blue terrier +Irish terrier +Norfolk terrier +Norwich terrier +Yorkshire terrier +wire-haired fox terrier +Lakeland terrier +Sealyham terrier, Sealyham +Airedale, Airedale terrier +cairn, cairn terrier +Australian terrier +Dandie Dinmont, Dandie Dinmont terrier +Boston bull, Boston terrier +miniature schnauzer +giant schnauzer +standard schnauzer +Scotch terrier, Scottish terrier, Scottie +Tibetan terrier, chrysanthemum dog +silky terrier, Sydney silky +soft-coated wheaten terrier +West Highland white terrier +Lhasa, Lhasa apso +flat-coated retriever +curly-coated retriever +golden retriever +Labrador retriever +Chesapeake Bay retriever +German short-haired pointer +vizsla, Hungarian pointer +English setter +Irish setter, red setter +Gordon setter +Brittany spaniel +clumber, clumber spaniel +English springer, English springer spaniel +Welsh springer spaniel +cocker spaniel, English cocker spaniel, cocker +Sussex spaniel +Irish water spaniel +kuvasz +schipperke +groenendael +malinois +briard +kelpie +komondor +Old English sheepdog, bobtail +Shetland sheepdog, Shetland sheep dog, Shetland +collie +Border collie +Bouvier des Flandres, Bouviers des Flandres +Rottweiler +German shepherd, German shepherd dog, German police dog, alsatian +Doberman, Doberman pinscher +miniature pinscher +Greater Swiss Mountain dog +Bernese mountain dog +Appenzeller +EntleBucher +boxer +bull mastiff +Tibetan mastiff +French bulldog +Great Dane +Saint Bernard, St Bernard +Eskimo dog, husky +malamute, malemute, Alaskan malamute +Siberian husky +dalmatian, coach dog, carriage dog +affenpinscher, monkey pinscher, monkey dog +basenji +pug, pug-dog +Leonberg +Newfoundland, Newfoundland dog +Great Pyrenees +Samoyed, Samoyede +Pomeranian +chow, chow chow +keeshond +Brabancon griffon +Pembroke, Pembroke Welsh corgi +Cardigan, Cardigan Welsh corgi +toy poodle +miniature poodle +standard poodle +Mexican hairless +timber wolf, grey wolf, gray wolf, Canis lupus +white wolf, Arctic wolf, Canis lupus tundrarum +red wolf, maned wolf, Canis rufus, Canis niger +coyote, prairie wolf, brush wolf, Canis latrans +dingo, warrigal, warragal, Canis dingo +dhole, Cuon alpinus +African hunting dog, hyena dog, Cape hunting dog, Lycaon pictus +hyena, hyaena +red fox, Vulpes vulpes +kit fox, Vulpes macrotis +Arctic fox, white fox, Alopex lagopus +grey fox, gray fox, Urocyon cinereoargenteus +tabby, tabby cat +tiger cat +Persian cat +Siamese cat, Siamese +Egyptian cat +cougar, puma, catamount, mountain lion, painter, panther, Felis concolor +lynx, catamount +leopard, Panthera pardus +snow leopard, ounce, Panthera uncia +jaguar, panther, Panthera onca, Felis onca +lion, king of beasts, Panthera leo +tiger, Panthera tigris +cheetah, chetah, Acinonyx jubatus +brown bear, bruin, Ursus arctos +American black bear, black bear, Ursus americanus, Euarctos americanus +ice bear, polar bear, Ursus Maritimus, Thalarctos maritimus +sloth bear, Melursus ursinus, Ursus ursinus +mongoose +meerkat, mierkat +tiger beetle +ladybug, ladybeetle, lady beetle, ladybird, ladybird beetle +ground beetle, carabid beetle +long-horned beetle, longicorn, longicorn beetle +leaf beetle, chrysomelid +dung beetle +rhinoceros beetle +weevil +fly +bee +ant, emmet, pismire +grasshopper, hopper +cricket +walking stick, walkingstick, stick insect +cockroach, roach +mantis, mantid +cicada, cicala +leafhopper +lacewing, lacewing fly +dragonfly, darning needle, devil's darning needle, sewing needle, snake feeder, snake doctor, mosquito hawk, skeeter hawk +damselfly +admiral +ringlet, ringlet butterfly +monarch, monarch butterfly, milkweed butterfly, Danaus plexippus +cabbage butterfly +sulphur butterfly, sulfur butterfly +lycaenid, lycaenid butterfly +starfish, sea star +sea urchin +sea cucumber, holothurian +wood rabbit, cottontail, cottontail rabbit +hare +Angora, Angora rabbit +hamster +porcupine, hedgehog +fox squirrel, eastern fox squirrel, Sciurus niger +marmot +beaver +guinea pig, Cavia cobaya +sorrel +zebra +hog, pig, grunter, squealer, Sus scrofa +wild boar, boar, Sus scrofa +warthog +hippopotamus, hippo, river horse, Hippopotamus amphibius +ox +water buffalo, water ox, Asiatic buffalo, Bubalus bubalis +bison +ram, tup +bighorn, bighorn sheep, cimarron, Rocky Mountain bighorn, Rocky Mountain sheep, Ovis canadensis +ibex, Capra ibex +hartebeest +impala, Aepyceros melampus +gazelle +Arabian camel, dromedary, Camelus dromedarius +llama +weasel +mink +polecat, fitch, foulmart, foumart, Mustela putorius +black-footed ferret, ferret, Mustela nigripes +otter +skunk, polecat, wood pussy +badger +armadillo +three-toed sloth, ai, Bradypus tridactylus +orangutan, orang, orangutang, Pongo pygmaeus +gorilla, Gorilla gorilla +chimpanzee, chimp, Pan troglodytes +gibbon, Hylobates lar +siamang, Hylobates syndactylus, Symphalangus syndactylus +guenon, guenon monkey +patas, hussar monkey, Erythrocebus patas +baboon +macaque +langur +colobus, colobus monkey +proboscis monkey, Nasalis larvatus +marmoset +capuchin, ringtail, Cebus capucinus +howler monkey, howler +titi, titi monkey +spider monkey, Ateles geoffroyi +squirrel monkey, Saimiri sciureus +Madagascar cat, ring-tailed lemur, Lemur catta +indri, indris, Indri indri, Indri brevicaudatus +Indian elephant, Elephas maximus +African elephant, Loxodonta africana +lesser panda, red panda, panda, bear cat, cat bear, Ailurus fulgens +giant panda, panda, panda bear, coon bear, Ailuropoda melanoleuca +barracouta, snoek +eel +coho, cohoe, coho salmon, blue jack, silver salmon, Oncorhynchus kisutch +rock beauty, Holocanthus tricolor +anemone fish +sturgeon +gar, garfish, garpike, billfish, Lepisosteus osseus +lionfish +puffer, pufferfish, blowfish, globefish +abacus +abaya +academic gown, academic robe, judge's robe +accordion, piano accordion, squeeze box +acoustic guitar +aircraft carrier, carrier, flattop, attack aircraft carrier +airliner +airship, dirigible +altar +ambulance +amphibian, amphibious vehicle +analog clock +apiary, bee house +apron +ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, dustbin, trash barrel, trash bin +assault rifle, assault gun +backpack, back pack, knapsack, packsack, rucksack, haversack +bakery, bakeshop, bakehouse +balance beam, beam +balloon +ballpoint, ballpoint pen, ballpen, Biro +Band Aid +banjo +bannister, banister, balustrade, balusters, handrail +barbell +barber chair +barbershop +barn +barometer +barrel, cask +barrow, garden cart, lawn cart, wheelbarrow +baseball +basketball +bassinet +bassoon +bathing cap, swimming cap +bath towel +bathtub, bathing tub, bath, tub +beach wagon, station wagon, wagon, estate car, beach waggon, station waggon, waggon +beacon, lighthouse, beacon light, pharos +beaker +bearskin, busby, shako +beer bottle +beer glass +bell cote, bell cot +bib +bicycle-built-for-two, tandem bicycle, tandem +bikini, two-piece +binder, ring-binder +binoculars, field glasses, opera glasses +birdhouse +boathouse +bobsled, bobsleigh, bob +bolo tie, bolo, bola tie, bola +bonnet, poke bonnet +bookcase +bookshop, bookstore, bookstall +bottlecap +bow +bow tie, bow-tie, bowtie +brass, memorial tablet, plaque +brassiere, bra, bandeau +breakwater, groin, groyne, mole, bulwark, seawall, jetty +breastplate, aegis, egis +broom +bucket, pail +buckle +bulletproof vest +bullet train, bullet +butcher shop, meat market +cab, hack, taxi, taxicab +caldron, cauldron +candle, taper, wax light +cannon +canoe +can opener, tin opener +cardigan +car mirror +carousel, carrousel, merry-go-round, roundabout, whirligig +carpenter's kit, tool kit +carton +car wheel +cash machine, cash dispenser, automated teller machine, automatic teller machine, automated teller, automatic teller, ATM +cassette +cassette player +castle +catamaran +CD player +cello, violoncello +cellular telephone, cellular phone, cellphone, cell, mobile phone +chain +chainlink fence +chain mail, ring mail, mail, chain armor, chain armour, ring armor, ring armour +chain saw, chainsaw +chest +chiffonier, commode +chime, bell, gong +china cabinet, china closet +Christmas stocking +church, church building +cinema, movie theater, movie theatre, movie house, picture palace +cleaver, meat cleaver, chopper +cliff dwelling +cloak +clog, geta, patten, sabot +cocktail shaker +coffee mug +coffeepot +coil, spiral, volute, whorl, helix +combination lock +computer keyboard, keypad +confectionery, confectionary, candy store +container ship, containership, container vessel +convertible +corkscrew, bottle screw +cornet, horn, trumpet, trump +cowboy boot +cowboy hat, ten-gallon hat +cradle +crane +crash helmet +crate +crib, cot +Crock Pot +croquet ball +crutch +cuirass +dam, dike, dyke +desk +desktop computer +dial telephone, dial phone +diaper, nappy, napkin +digital clock +digital watch +dining table, board +dishrag, dishcloth +dishwasher, dish washer, dishwashing machine +disk brake, disc brake +dock, dockage, docking facility +dogsled, dog sled, dog sleigh +dome +doormat, welcome mat +drilling platform, offshore rig +drum, membranophone, tympan +drumstick +dumbbell +Dutch oven +electric fan, blower +electric guitar +electric locomotive +entertainment center +envelope +espresso maker +face powder +feather boa, boa +file, file cabinet, filing cabinet +fireboat +fire engine, fire truck +fire screen, fireguard +flagpole, flagstaff +flute, transverse flute +folding chair +football helmet +forklift +fountain +fountain pen +four-poster +freight car +French horn, horn +frying pan, frypan, skillet +fur coat +garbage truck, dustcart +gasmask, respirator, gas helmet +gas pump, gasoline pump, petrol pump, island dispenser +goblet +go-kart +golf ball +golfcart, golf cart +gondola +gong, tam-tam +gown +grand piano, grand +greenhouse, nursery, glasshouse +grille, radiator grille +grocery store, grocery, food market, market +guillotine +hair slide +hair spray +half track +hammer +hamper +hand blower, blow dryer, blow drier, hair dryer, hair drier +hand-held computer, hand-held microcomputer +handkerchief, hankie, hanky, hankey +hard disc, hard disk, fixed disk +harmonica, mouth organ, harp, mouth harp +harp +harvester, reaper +hatchet +holster +home theater, home theatre +honeycomb +hook, claw +hoopskirt, crinoline +horizontal bar, high bar +horse cart, horse-cart +hourglass +iPod +iron, smoothing iron +jack-o'-lantern +jean, blue jean, denim +jeep, landrover +jersey, T-shirt, tee shirt +jigsaw puzzle +jinrikisha, ricksha, rickshaw +joystick +kimono +knee pad +knot +lab coat, laboratory coat +ladle +lampshade, lamp shade +laptop, laptop computer +lawn mower, mower +lens cap, lens cover +letter opener, paper knife, paperknife +library +lifeboat +lighter, light, igniter, ignitor +limousine, limo +liner, ocean liner +lipstick, lip rouge +Loafer +lotion +loudspeaker, speaker, speaker unit, loudspeaker system, speaker system +loupe, jeweler's loupe +lumbermill, sawmill +magnetic compass +mailbag, postbag +mailbox, letter box +maillot +maillot, tank suit +manhole cover +maraca +marimba, xylophone +mask +matchstick +maypole +maze, labyrinth +measuring cup +medicine chest, medicine cabinet +megalith, megalithic structure +microphone, mike +microwave, microwave oven +military uniform +milk can +minibus +miniskirt, mini +minivan +missile +mitten +mixing bowl +mobile home, manufactured home +Model T +modem +monastery +monitor +moped +mortar +mortarboard +mosque +mosquito net +motor scooter, scooter +mountain bike, all-terrain bike, off-roader +mountain tent +mouse, computer mouse +mousetrap +moving van +muzzle +nail +neck brace +necklace +nipple +notebook, notebook computer +obelisk +oboe, hautboy, hautbois +ocarina, sweet potato +odometer, hodometer, mileometer, milometer +oil filter +organ, pipe organ +oscilloscope, scope, cathode-ray oscilloscope, CRO +overskirt +oxcart +oxygen mask +packet +paddle, boat paddle +paddlewheel, paddle wheel +padlock +paintbrush +pajama, pyjama, pj's, jammies +palace +panpipe, pandean pipe, syrinx +paper towel +parachute, chute +parallel bars, bars +park bench +parking meter +passenger car, coach, carriage +patio, terrace +pay-phone, pay-station +pedestal, plinth, footstall +pencil box, pencil case +pencil sharpener +perfume, essence +Petri dish +photocopier +pick, plectrum, plectron +pickelhaube +picket fence, paling +pickup, pickup truck +pier +piggy bank, penny bank +pill bottle +pillow +ping-pong ball +pinwheel +pirate, pirate ship +pitcher, ewer +plane, carpenter's plane, woodworking plane +planetarium +plastic bag +plate rack +plow, plough +plunger, plumber's helper +Polaroid camera, Polaroid Land camera +pole +police van, police wagon, paddy wagon, patrol wagon, wagon, black Maria +poncho +pool table, billiard table, snooker table +pop bottle, soda bottle +pot, flowerpot +potter's wheel +power drill +prayer rug, prayer mat +printer +prison, prison house +projectile, missile +projector +puck, hockey puck +punching bag, punch bag, punching ball, punchball +purse +quill, quill pen +quilt, comforter, comfort, puff +racer, race car, racing car +racket, racquet +radiator +radio, wireless +radio telescope, radio reflector +rain barrel +recreational vehicle, RV, R.V. +reel +reflex camera +refrigerator, icebox +remote control, remote +restaurant, eating house, eating place, eatery +revolver, six-gun, six-shooter +rifle +rocking chair, rocker +rotisserie +rubber eraser, rubber, pencil eraser +rugby ball +rule, ruler +running shoe +safe +safety pin +saltshaker, salt shaker +sandal +sarong +sax, saxophone +scabbard +scale, weighing machine +school bus +schooner +scoreboard +screen, CRT screen +screw +screwdriver +seat belt, seatbelt +sewing machine +shield, buckler +shoe shop, shoe-shop, shoe store +shoji +shopping basket +shopping cart +shovel +shower cap +shower curtain +ski +ski mask +sleeping bag +slide rule, slipstick +sliding door +slot, one-armed bandit +snorkel +snowmobile +snowplow, snowplough +soap dispenser +soccer ball +sock +solar dish, solar collector, solar furnace +sombrero +soup bowl +space bar +space heater +space shuttle +spatula +speedboat +spider web, spider's web +spindle +sports car, sport car +spotlight, spot +stage +steam locomotive +steel arch bridge +steel drum +stethoscope +stole +stone wall +stopwatch, stop watch +stove +strainer +streetcar, tram, tramcar, trolley, trolley car +stretcher +studio couch, day bed +stupa, tope +submarine, pigboat, sub, U-boat +suit, suit of clothes +sundial +sunglass +sunglasses, dark glasses, shades +sunscreen, sunblock, sun blocker +suspension bridge +swab, swob, mop +sweatshirt +swimming trunks, bathing trunks +swing +switch, electric switch, electrical switch +syringe +table lamp +tank, army tank, armored combat vehicle, armoured combat vehicle +tape player +teapot +teddy, teddy bear +television, television system +tennis ball +thatch, thatched roof +theater curtain, theatre curtain +thimble +thresher, thrasher, threshing machine +throne +tile roof +toaster +tobacco shop, tobacconist shop, tobacconist +toilet seat +torch +totem pole +tow truck, tow car, wrecker +toyshop +tractor +trailer truck, tractor trailer, trucking rig, rig, articulated lorry, semi +tray +trench coat +tricycle, trike, velocipede +trimaran +tripod +triumphal arch +trolleybus, trolley coach, trackless trolley +trombone +tub, vat +turnstile +typewriter keyboard +umbrella +unicycle, monocycle +upright, upright piano +vacuum, vacuum cleaner +vase +vault +velvet +vending machine +vestment +viaduct +violin, fiddle +volleyball +waffle iron +wall clock +wallet, billfold, notecase, pocketbook +wardrobe, closet, press +warplane, military plane +washbasin, handbasin, washbowl, lavabo, wash-hand basin +washer, automatic washer, washing machine +water bottle +water jug +water tower +whiskey jug +whistle +wig +window screen +window shade +Windsor tie +wine bottle +wing +wok +wooden spoon +wool, woolen, woollen +worm fence, snake fence, snake-rail fence, Virginia fence +wreck +yawl +yurt +web site, website, internet site, site +comic book +crossword puzzle, crossword +street sign +traffic light, traffic signal, stoplight +book jacket, dust cover, dust jacket, dust wrapper +menu +plate +guacamole +consomme +hot pot, hotpot +trifle +ice cream, icecream +ice lolly, lolly, lollipop, popsicle +French loaf +bagel, beigel +pretzel +cheeseburger +hotdog, hot dog, red hot +mashed potato +head cabbage +broccoli +cauliflower +zucchini, courgette +spaghetti squash +acorn squash +butternut squash +cucumber, cuke +artichoke, globe artichoke +bell pepper +cardoon +mushroom +Granny Smith +strawberry +orange +lemon +fig +pineapple, ananas +banana +jackfruit, jak, jack +custard apple +pomegranate +hay +carbonara +chocolate sauce, chocolate syrup +dough +meat loaf, meatloaf +pizza, pizza pie +potpie +burrito +red wine +espresso +cup +eggnog +alp +bubble +cliff, drop, drop-off +coral reef +geyser +lakeside, lakeshore +promontory, headland, head, foreland +sandbar, sand bar +seashore, coast, seacoast, sea-coast +valley, vale +volcano +ballplayer, baseball player +groom, bridegroom +scuba diver +rapeseed +daisy +yellow lady's slipper, yellow lady-slipper, Cypripedium calceolus, Cypripedium parviflorum +corn +acorn +hip, rose hip, rosehip +buckeye, horse chestnut, conker +coral fungus +agaric +gyromitra +stinkhorn, carrion fungus +earthstar +hen-of-the-woods, hen of the woods, Polyporus frondosus, Grifola frondosa +bolete +ear, spike, capitulum +toilet tissue, toilet paper, bathroom tissue \ No newline at end of file diff --git a/ddpo_pytorch/diffusers_patch/__pycache__/ddim_with_logprob.cpython-310.pyc b/ddpo_pytorch/diffusers_patch/__pycache__/ddim_with_logprob.cpython-310.pyc new file mode 100644 index 0000000..0ed6732 Binary files /dev/null and b/ddpo_pytorch/diffusers_patch/__pycache__/ddim_with_logprob.cpython-310.pyc differ diff --git a/ddpo_pytorch/diffusers_patch/__pycache__/pipeline_with_logprob.cpython-310.pyc b/ddpo_pytorch/diffusers_patch/__pycache__/pipeline_with_logprob.cpython-310.pyc new file mode 100644 index 0000000..a41be10 Binary files /dev/null and b/ddpo_pytorch/diffusers_patch/__pycache__/pipeline_with_logprob.cpython-310.pyc differ diff --git a/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py b/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py new file mode 100644 index 0000000..43b2fe8 --- /dev/null +++ b/ddpo_pytorch/diffusers_patch/ddim_with_logprob.py @@ -0,0 +1,143 @@ +# Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/schedulers/scheduling_ddim.py +# with the following modifications: +# - + +from typing import Optional, Tuple, Union + +import math +import torch + +from diffusers.utils import randn_tensor +from diffusers.schedulers.scheduling_ddim import DDIMSchedulerOutput, DDIMScheduler + + +def ddim_step_with_logprob( + self: DDIMScheduler, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + prev_sample: Optional[torch.FloatTensor] = None, +) -> Union[DDIMSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + eta (`float`): weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped + predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when + `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would + coincide with the one provided as input and `use_clipped_model_output` will have not effect. + generator: random number generator. + variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we + can directly provide the noise for the variance itself. This is useful for methods such as + CycleDiffusion. (https://arxiv.org/abs/2210.05559) + return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + assert isinstance(self, DDIMScheduler) + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + self.alphas_cumprod = self.alphas_cumprod.to(timestep.device) + self.final_alpha_cumprod = self.final_alpha_cumprod.to(timestep.device) + alpha_prod_t = self.alphas_cumprod.gather(0, timestep) + alpha_prod_t_prev = torch.where(prev_timestep >= 0, self.alphas_cumprod.gather(0, prev_timestep), self.final_alpha_cumprod) + print(timestep) + print(alpha_prod_t) + print(alpha_prod_t_prev) + print(prev_timestep) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # σ_t = sqrt((1 − α_t−1)/(1 − α_t)) * sqrt(1 − α_t/α_t−1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the pred_epsilon is always re-derived from the clipped x_0 in Glide + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample_mean = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if prev_sample is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and prev_sample. Please make sure that either `generator` or" + " `prev_sample` stays `None`." + ) + + if prev_sample is None: + variance_noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + prev_sample = prev_sample_mean + std_dev_t * variance_noise + + # log prob of prev_sample given prev_sample_mean and std_dev_t + log_prob = ( + -((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std_dev_t**2)) + - torch.log(std_dev_t) + - torch.log(torch.sqrt(2 * torch.as_tensor(math.pi))) + ) + # mean along all but batch dimension + log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim))) + + return prev_sample, log_prob diff --git a/ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py b/ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py new file mode 100644 index 0000000..09378c2 --- /dev/null +++ b/ddpo_pytorch/diffusers_patch/pipeline_with_logprob.py @@ -0,0 +1,225 @@ +# Copied from https://github.com/huggingface/diffusers/blob/fc6acb6b97e93d58cb22b5fee52d884d77ce84d8/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +# with the following modifications: +# - + +from typing import Any, Callable, Dict, List, Optional, Union + +import torch + +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( + StableDiffusionPipeline, + rescale_noise_cfg, +) +from diffusers.schedulers.scheduling_ddim import DDIMScheduler +from .ddim_with_logprob import ddim_step_with_logprob + + +@torch.no_grad() +def pipeline_with_logprob( + self: StableDiffusionPipeline, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, +): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + all_latents = [latents] + all_log_probs = [] + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents, log_prob = ddim_step_with_logprob(self.scheduler, noise_pred, t, latents, **extra_step_kwargs) + + all_latents.append(latents) + all_log_probs.append(log_prob) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + return image, has_nsfw_concept, all_latents, all_log_probs diff --git a/ddpo_pytorch/prompts.py b/ddpo_pytorch/prompts.py new file mode 100644 index 0000000..8cecf28 --- /dev/null +++ b/ddpo_pytorch/prompts.py @@ -0,0 +1,54 @@ +from importlib import resources +import functools +import random +import inflect + +IE = inflect.engine() +ASSETS_PATH = resources.files("ddpo_pytorch.assets") + + +@functools.cache +def load_lines(name): + with ASSETS_PATH.joinpath(name).open() as f: + return [line.strip() for line in f.readlines()] + + +def imagenet(low, high): + return random.choice(load_lines("imagenet_classes.txt")[low:high]), {} + + +def imagenet_all(): + return imagenet(0, 1000) + + +def imagenet_animals(): + return imagenet(0, 398) + + +def imagenet_dogs(): + return imagenet(151, 269) + + +def nouns_activities(nouns_file, activities_file): + nouns = load_lines(nouns_file) + activities = load_lines(activities_file) + return f"{IE.a(random.choice(nouns))} {random.choice(activities)}", {} + + +def counting(nouns_file, low, high): + nouns = load_lines(nouns_file) + number = IE.number_to_words(random.randint(low, high)) + noun = random.choice(nouns) + plural_noun = IE.plural(noun) + prompt = f"{number} {plural_noun}" + metadata = { + "questions": [ + f"How many {plural_noun} are there in this image?", + f"What animal is in this image?", + ], + "answers": [ + number, + noun, + ], + } + return prompt, metadata diff --git a/ddpo_pytorch/rewards.py b/ddpo_pytorch/rewards.py new file mode 100644 index 0000000..9ec6218 --- /dev/null +++ b/ddpo_pytorch/rewards.py @@ -0,0 +1,29 @@ +from PIL import Image +import io +import numpy as np +import torch + + +def jpeg_incompressibility(): + def _fn(images, prompts, metadata): + if isinstance(images, torch.Tensor): + images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy() + images = images.transpose(0, 2, 3, 1) # NCHW -> NHWC + images = [Image.fromarray(image) for image in images] + buffers = [io.BytesIO() for _ in images] + for image, buffer in zip(images, buffers): + image.save(buffer, format="JPEG", quality=95) + sizes = [buffer.tell() / 1000 for buffer in buffers] + return np.array(sizes), {} + + return _fn + + +def jpeg_compressibility(): + jpeg_fn = jpeg_incompressibility() + + def _fn(images, prompts, metadata): + rew, meta = jpeg_fn(images, prompts, metadata) + return -rew, meta + + return _fn diff --git a/ddpo_pytorch/stat_tracking.py b/ddpo_pytorch/stat_tracking.py new file mode 100644 index 0000000..4199ab9 --- /dev/null +++ b/ddpo_pytorch/stat_tracking.py @@ -0,0 +1,34 @@ +import numpy as np +from collections import deque + + +class PerPromptStatTracker: + def __init__(self, buffer_size, min_count): + self.buffer_size = buffer_size + self.min_count = min_count + self.stats = {} + + def update(self, prompts, rewards): + unique = np.unique(prompts) + advantages = np.empty_like(rewards) + for prompt in unique: + prompt_rewards = rewards[prompts == prompt] + if prompt not in self.stats: + self.stats[prompt] = deque(maxlen=self.buffer_size) + self.stats[prompt].extend(prompt_rewards) + + if len(self.stats[prompt]) < self.min_count: + mean = np.mean(rewards) + std = np.std(rewards) + 1e-6 + else: + mean = np.mean(self.stats[prompt]) + std = np.std(self.stats[prompt]) + 1e-6 + advantages[prompts == prompt] = (prompt_rewards - mean) / std + + return advantages + + def get_stats(self): + return { + k: {"mean": np.mean(v), "std": np.std(v), "count": len(v)} + for k, v in self.stats.items() + } diff --git a/scripts/train.py b/scripts/train.py new file mode 100644 index 0000000..c123ba7 --- /dev/null +++ b/scripts/train.py @@ -0,0 +1,341 @@ +from absl import app, flags, logging +from ml_collections import config_flags +from accelerate import Accelerator +from accelerate.utils import set_seed +from accelerate.logging import get_logger +from diffusers import StableDiffusionPipeline, DDIMScheduler +from diffusers.loaders import AttnProcsLayers +from diffusers.models.attention_processor import LoRAAttnProcessor +import ddpo_pytorch.prompts +import ddpo_pytorch.rewards +from ddpo_pytorch.stat_tracking import PerPromptStatTracker +from ddpo_pytorch.diffusers_patch.pipeline_with_logprob import pipeline_with_logprob +from ddpo_pytorch.diffusers_patch.ddim_with_logprob import ddim_step_with_logprob +import torch +import tqdm + + +FLAGS = flags.FLAGS +config_flags.DEFINE_config_file("config", "config/base.py", "Training configuration.") + +logger = get_logger(__name__) + + +def main(_): + # basic Accelerate and logging setup + config = FLAGS.config + accelerator = Accelerator( + log_with="all", + mixed_precision=config.mixed_precision, + project_dir=config.logdir, + ) + if accelerator.is_main_process: + accelerator.init_trackers(project_name="ddpo-pytorch", config=config) + logger.info(config) + + # set seed + set_seed(config.seed) + + # load scheduler, tokenizer and models. + pipeline = StableDiffusionPipeline.from_pretrained(config.pretrained.model, revision=config.pretrained.revision) + # freeze parameters of models to save more memory + pipeline.unet.requires_grad_(False) + pipeline.vae.requires_grad_(False) + pipeline.text_encoder.requires_grad_(False) + # disable safety checker + pipeline.safety_checker = None + # make the progress bar nicer + pipeline.set_progress_bar_config( + position=1, + disable=not accelerator.is_local_main_process, + leave=False, + ) + # switch to DDIM scheduler + pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) + + # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move unet, vae and text_encoder to device and cast to weight_dtype + pipeline.unet.to(accelerator.device, dtype=weight_dtype) + pipeline.vae.to(accelerator.device, dtype=weight_dtype) + pipeline.text_encoder.to(accelerator.device, dtype=weight_dtype) + + # Set correct lora layers + lora_attn_procs = {} + for name in pipeline.unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else pipeline.unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = pipeline.unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(pipeline.unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = pipeline.unet.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim) + + pipeline.unet.set_attn_processor(lora_attn_procs) + lora_layers = AttnProcsLayers(pipeline.unet.attn_processors) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if config.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if config.train.scale_lr: + config.train.learning_rate = ( + config.train.learning_rate + * config.train.gradient_accumulation_steps + * config.train.batch_size + * accelerator.num_processes + ) + + # Initialize the optimizer + if config.train.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" + ) + + optimizer_cls = bnb.optim.AdamW8bit + else: + optimizer_cls = torch.optim.AdamW + + optimizer = optimizer_cls( + lora_layers.parameters(), + lr=config.train.learning_rate, + betas=(config.train.adam_beta1, config.train.adam_beta2), + weight_decay=config.train.adam_weight_decay, + eps=config.train.adam_epsilon, + ) + + # prepare prompt and reward fn + prompt_fn = getattr(ddpo_pytorch.prompts, config.prompt_fn) + reward_fn = getattr(ddpo_pytorch.rewards, config.reward_fn)() + + # Prepare everything with our `accelerator`. + lora_layers, optimizer = accelerator.prepare(lora_layers, optimizer) + + # Train! + samples_per_epoch = config.sample.batch_size * accelerator.num_processes * config.sample.num_batches_per_epoch + total_train_batch_size = ( + config.train.batch_size * accelerator.num_processes * config.train.gradient_accumulation_steps + ) + + assert config.sample.batch_size % config.train.batch_size == 0 + assert samples_per_epoch % total_train_batch_size == 0 + + logger.info("***** Running training *****") + logger.info(f" Num Epochs = {config.num_epochs}") + logger.info(f" Sample batch size per device = {config.sample.batch_size}") + logger.info(f" Train batch size per device = {config.train.batch_size}") + logger.info(f" Gradient Accumulation steps = {config.train.gradient_accumulation_steps}") + logger.info("") + logger.info(f" Total number of samples per epoch = {samples_per_epoch}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") + logger.info(f" Number of gradient updates per inner epoch = {samples_per_epoch // total_train_batch_size}") + logger.info(f" Number of inner epochs = {config.train.num_inner_epochs}") + + neg_prompt_embed = pipeline.text_encoder( + pipeline.tokenizer( + [""], + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=pipeline.tokenizer.model_max_length, + ).input_ids.to(accelerator.device) + )[0] + sample_neg_prompt_embeds = neg_prompt_embed.repeat(config.sample.batch_size, 1, 1) + train_neg_prompt_embeds = neg_prompt_embed.repeat(config.train.batch_size, 1, 1) + + if config.per_prompt_stat_tracking: + stat_tracker = PerPromptStatTracker( + config.per_prompt_stat_tracking.buffer_size, + config.per_prompt_stat_tracking.min_count, + ) + + for epoch in range(config.num_epochs): + #################### SAMPLING #################### + samples = [] + prompts = [] + for i in tqdm.tqdm( + range(config.sample.num_batches_per_epoch), + desc=f"Epoch {epoch}: sampling", + disable=not accelerator.is_local_main_process, + position=0, + ): + # generate prompts + prompts, prompt_metadata = zip( + *[prompt_fn(**config.prompt_fn_kwargs) for _ in range(config.sample.batch_size)] + ) + + # encode prompts + prompt_ids = pipeline.tokenizer( + prompts, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=pipeline.tokenizer.model_max_length, + ).input_ids.to(accelerator.device) + prompt_embeds = pipeline.text_encoder(prompt_ids)[0] + + # sample + pipeline.unet.eval() + pipeline.vae.eval() + images, _, latents, log_probs = pipeline_with_logprob( + pipeline, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=sample_neg_prompt_embeds, + num_inference_steps=config.sample.num_steps, + guidance_scale=config.sample.guidance_scale, + eta=config.sample.eta, + output_type="pt", + ) + + latents = torch.stack(latents, dim=1) # (batch_size, num_steps + 1, 4, 64, 64) + log_probs = torch.stack(log_probs, dim=1) # (batch_size, num_steps, 1) + timesteps = pipeline.scheduler.timesteps.repeat(config.sample.batch_size, 1) # (batch_size, num_steps) + + # compute rewards + rewards, reward_metadata = reward_fn(images, prompts, prompt_metadata) + + samples.append( + { + "prompt_ids": prompt_ids, + "prompt_embeds": prompt_embeds, + "timesteps": timesteps, + "latents": latents[:, :-1], # each entry is the latent before timestep t + "next_latents": latents[:, 1:], # each entry is the latent after timestep t + "log_probs": log_probs, + "rewards": torch.as_tensor(rewards), + } + ) + + # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...) + samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()} + + # gather rewards across processes + rewards = accelerator.gather(samples["rewards"]).cpu().numpy() + + # per-prompt mean/std tracking + if config.per_prompt_stat_tracking: + # gather the prompts across processes + prompt_ids = accelerator.gather(samples["prompt_ids"]).cpu().numpy() + prompts = pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True) + advantages = stat_tracker.update(prompts, rewards) + else: + advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8) + + # ungather advantages; we only need to keep the entries corresponding to the samples on this process + samples["advantages"] = ( + torch.as_tensor(advantages) + .reshape(accelerator.num_processes, -1)[accelerator.process_index] + .to(accelerator.device) + ) + + del samples["rewards"] + del samples["prompt_ids"] + + total_batch_size, num_timesteps = samples["timesteps"].shape + assert total_batch_size == config.sample.batch_size * config.sample.num_batches_per_epoch + assert num_timesteps == config.sample.num_steps + + #################### TRAINING #################### + for inner_epoch in range(config.train.num_inner_epochs): + # shuffle samples along batch dimension + indices = torch.randperm(total_batch_size, device=accelerator.device) + samples = {k: v[indices] for k, v in samples.items()} + + # shuffle along time dimension, independently for each sample + for i in range(total_batch_size): + indices = torch.randperm(num_timesteps, device=accelerator.device) + for key in ["timesteps", "latents", "next_latents"]: + samples[key][i] = samples[key][i][indices] + + # rebatch for training + samples_batched = {k: v.reshape(-1, config.train.batch_size, *v.shape[1:]) for k, v in samples.items()} + + # dict of lists -> list of dicts for easier iteration + samples_batched = [dict(zip(samples_batched, x)) for x in zip(*samples_batched.values())] + + # train + for i, sample in tqdm.tqdm( + list(enumerate(samples_batched)), + desc=f"Outer epoch {epoch}, inner epoch {inner_epoch}: training", + position=0, + ): + if config.train.cfg: + # concat negative prompts to sample prompts to avoid two forward passes + embeds = torch.cat([train_neg_prompt_embeds, sample["prompt_embeds"]]) + else: + embeds = sample["prompt_embeds"] + + for j in tqdm.trange( + num_timesteps, + desc=f"Timestep", + position=1, + leave=False, + ): + with accelerator.accumulate(pipeline.unet): + if config.train.cfg: + noise_pred = pipeline.unet( + torch.cat([sample["latents"][:, j]] * 2), + torch.cat([sample["timesteps"][:, j]] * 2), + embeds, + ).sample + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + config.sample.guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + else: + noise_pred = pipeline.unet( + sample["latents"][:, j], sample["timesteps"][:, j], embeds + ).sample + _, log_prob = ddim_step_with_logprob( + pipeline.scheduler, + noise_pred, + sample["timesteps"][:, j], + sample["latents"][:, j], + eta=config.sample.eta, + prev_sample=sample["next_latents"][:, j], + ) + + # ppo logic + advantages = torch.clamp( + sample["advantages"][:, j], -config.train.adv_clip_max, config.train.adv_clip_max + ) + ratio = torch.exp(log_prob - sample["log_probs"][:, j]) + unclipped_loss = -advantages * ratio + clipped_loss = -advantages * torch.clamp( + ratio, 1.0 - config.train.clip_range, 1.0 + config.train.clip_range + ) + loss = torch.mean(torch.maximum(unclipped_loss, clipped_loss)) + + # debugging values + info = {} + # John Schulman says that (ratio - 1) - log(ratio) is a better + # estimator, but most existing code uses this so... + # http://joschu.net/blog/kl-approx.html + info["approx_kl"] = 0.5 * torch.mean((log_prob - sample["log_probs"][:, j]) ** 2) + info["clipfrac"] = torch.mean(torch.abs(ratio - 1.0) > config.train.clip_range) + info["loss"] = loss + + # backward pass + accelerator.backward(loss) + if accelerator.sync_gradients: + accelerator.clip_grad_norm_(lora_layers.parameters(), config.train.max_grad_norm) + optimizer.step() + optimizer.zero_grad() + + +if __name__ == "__main__": + app.run(main) diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..076cb2a --- /dev/null +++ b/setup.py @@ -0,0 +1,10 @@ +from setuptools import setup, find_packages + +setup( + name='ddpo-pytorch', + version='0.0.1', + packages=["ddpo_pytorch"], + install_requires=[ + "ml-collections", "absl-py" + ], +) \ No newline at end of file