Откакто разбрах за генеративните състезателни мрежи (GAN), бях очарован от тях. GAN е вид невронна мрежа, която е в състояние да генерира нови данни от нулата. Можете да му подадете малко произволен шум като вход и той може да създаде реалистични изображения на спални, птици или каквото е обучен да генерира.
Едно нещо, с което всички учени могат да се съгласят, е, че се нуждаем от повече данни.
GAN, които могат да се използват за създаване на нови данни в ситуации с ограничени данни, могат да се окажат наистина полезни. Данните понякога могат да бъдат трудни и скъпи и отнемащи много време за генериране. За да бъдат полезни обаче, новите данни трябва да бъдат достатъчно реалистични, че каквато и информация да получим от генерираните данни, все пак да се отнася за реални данни. Ако тренирате котка за лов на мишки и използвате фалшиви мишки, по-добре се уверете, че фалшивите мишки всъщност изглеждат като мишки.
Друг начин да се мисли за това е, че GAN откриват структура в данните, която им позволява да правят реалистични данни. Това може да бъде полезно, ако не можем сами да видим тази структура или не можем да я извадим с други методи.
В тази статия ще научите как GAN могат да се използват за генериране на нови данни. За да запазим този урок реалистично, ще използваме набор от данни за откриване на измами с кредитни карти от Kaggle.
В експериментите си се опитах да използвам този набор от данни, за да видя дали мога да получа GAN, за да създам достатъчно реалистични данни, които да ни помогнат да открием измамни случаи. Този набор от данни подчертава проблема с ограничените данни: От 285 000 транзакции само 492 са измами. 492 случая на измама не е голям набор от данни, за да се обучават, особено когато става въпрос за задачи за машинно обучение, където хората обичат да имат набори от данни с няколко порядъка по-големи. Въпреки че резултатите от моя експеримент не бяха невероятни, научих много за GAN по пътя, които с удоволствие споделям.
Преди да се задълбочим в това царство на GAN, ако искате бързо да се справите с машинното си обучение или с уменията си за дълбоко обучение, можете да разгледате тези две свързани публикации в блога:
Генеративните състезателни мрежи (GAN) са архитектура на невронна мрежа, която е показала впечатляващи подобрения спрямо предишните генеративни методи, като например вариационни автокодери или ограничени машини на болцман. GAN са успели да генерират по-реалистични изображения (напр. DCGAN ), активирайте прехвърлянето на стилове между изображения (вж тук и тук ), генериране на изображения от текстови описания ( StackGAN ) и се учат от по-малки набори от данни чрез полууправлявано обучение . Поради тези постижения те генерират много лихва както в академичния, така и в търговския сектор.
Директорът на AI Research във Facebook Ян ЛеКун дори им се обади най-вълнуващото развитие в машинното обучение през последното десетилетие.
Помислете как научавате. Опитвате нещо, получавате обратна връзка. Вие коригирате стратегията си и опитайте отново.
Отзивите могат да бъдат под формата на критика, болка или печалба. Може да дойде от вашата собствена преценка колко добре сте се справили. Често най-полезната обратна връзка е обратната връзка, която идва от друг човек, защото това не е просто число или усещане, а интелигентна оценка на това колко добре сте изпълнили задачата.
Когато компютърът е обучен за дадена задача, човек обикновено предоставя обратната връзка под формата на коригирани параметри или алгоритми. Това работи добре, когато задачата е добре дефинирана, като например да се научите да умножавате две числа. Можете лесно и точно да кажете на компютъра как е сгрешил.
С по-сложна задача, като създаване на образ на куче, става по-трудно да се осигури обратна връзка. Размазано ли е изображението, прилича ли по-скоро на котка или изобщо прилича на нещо? Може да се приложи сложна статистика, но би било трудно да се уловят всички детайли, които правят изображението да изглежда реално.
Човек може да даде известна оценка, тъй като имаме много опит в оценяването на визуалния вход, но ние сме относително бавни и нашите оценки могат да бъдат силно субективни. Вместо това бихме могли да обучим невронна мрежа, за да се научим да разграничаваме реалните и генерираните изображения.
След това, като оставят генератора на изображения (също невронна мрежа) и дискриминатора да се редуват да се учат един от друг, те могат да се подобрят с времето. Тези две мрежи, играещи тази игра, са генеративна състезателна мрежа.
Можеш чувам изобретателят на GAN, Ян Гудфелоу, говори за това как спорът в бар на тази тема доведе до трескава нощ на кодиране, която доведе до първия GAN. И да, той признава летвата в своята хартия . Можете да научите повече за GANs от Блогът на Ian Goodfellow по тази тема.
Има редица предизвикателства при работата с GAN. Обучението на една невронна мрежа може да бъде трудно поради броя на участващите избори: Архитектура, функции за активиране, метод за оптимизация, скорост на обучение и степен на отпадане, за да назовем само няколко.
GAN удвояват всички тези избори и добавят нови сложности. И генераторът, и дискриминаторът може да забравят трикове, които са използвали по-рано в обучението си. Това може да доведе до това двете мрежи да попаднат в стабилен цикъл от решения, които не се подобряват с времето. Едната мрежа може да надвие другата мрежа, така че нито една от тях вече не може да се научи. Или генераторът може да не изследва голяма част от възможното пространство за решения, а само достатъчно, за да намери реалистични решения. Тази последна ситуация е известна като колапс на режима.
Свиването на режима е, когато генераторът научи само малка част от възможните реалистични режими. Например, ако задачата е да генерира изображения на кучета, генераторът може да се научи да създава само изображения на малки кафяви кучета. Генераторът би пропуснал всички останали режими, състоящи се от кучета с други размери или цветове.
Много стратегии са приложени за справяне с това, включително нормализиране на партиди, добавяне на етикети в данните за обучение или чрез промяна на начина, по който дискриминаторът преценява генерираните данни.
Хората отбелязват, че добавянето на етикети към данните - тоест за разделянето им по категории, почти винаги подобрява производителността на GAN. Вместо да се научите да генерирате изображения на домашни любимци като цяло, трябва да е по-лесно да генерирате изображения на котки, кучета, риби и порове, например.
Може би най-значителните пробиви в развитието на GAN са постигнати по отношение на промяната на начина, по който дискриминаторът оценява данните, така че нека разгледаме по-отблизо това.
В оригиналната формулировка на ГАН през 2014 г. от Goodfellow et al. , дискриминаторът генерира оценка на вероятността дадено изображение да е било реално или генерирано. Дискриминаторът ще получи набор от изображения, които се състоят както от реални, така и от генерирани изображения, и ще генерира оценка за всеки от тези входове. След това грешката между изхода на дискриминатора и действителните етикети ще бъде измерена чрез кръстосана загуба на ентропия. Загубата на кръстосана ентропия може да бъде приравнена на метриката на разстоянието Jensen-Shannon и е показана в началото на 2017 г. от Arjovsky et al. че тази метрика ще се провали в някои случаи и няма да посочи правилната посока в други случаи. Тази група показа, че метриката на разстоянието Wasserstein (известна още като земно движещо се или EM разстояние) работи и работи по-добре в много повече случаи.
Загубата на кръстосана ентропия е мярка за това колко точно дискриминаторът е идентифицирал реални и генерирани изображения. Вместо това показателят Wasserstein разглежда разпределението на всяка променлива (т.е. всеки цвят на всеки пиксел) в реалните и генерираните изображения и определя колко далеч са разпределенията за реални и генерирани данни. Метриката на Wasserstein разглежда колко усилия, по отношение на масата по разстоянието, ще са необходими, за да изтласкат генерираното разпределение във формата на реалното разпределение, откъдето идва и алтернативното наименование „разстояние от движещия се човек“. Тъй като метриката Wasserstein вече не оценява дали дадено изображение е реално или не, а вместо това дава критика за това колко далеч са генерираните изображения от реалните изображения, мрежата „дискриминатор“ се нарича мрежата на „критиците“ в Wasserstein архитектура.
За малко по-цялостно изследване на GAN, в тази статия ще разгледаме четири различни архитектури:
Но нека погледнем първо нашия набор от данни.
Ще работим с набор от данни за откриване на измами с кредитни карти от Kaggle.
Наборът от данни се състои от ~ 285 000 транзакции, от които само 492 са измамни. Данните се състоят от 31 функции: „време“, „сума“, „клас“ и 28 допълнителни, анонимизирани функции. Характеристиката на класа е етикетът, показващ дали дадена транзакция е измамна или не, като 0 показва нормална и 1 показва измама. Всички данни са цифрови и непрекъснати (с изключение на етикета). Наборът от данни няма липсващи стойности. Наборът от данни вече е в доста добра форма, за да започнете, но ще направя малко повече почистване, най-вече просто да коригирам средствата на всички функции на нула и стандартните отклонения на едно. Описах процеса на почистване повече в бележника тук . Засега просто ще покажа краен резултат :
Човек може лесно да забележи разликите между нормалните данни и данните за измами в тези дистрибуции, но има и много припокривания. Можем да приложим един от по-бързите и мощни алгоритми за машинно обучение, за да идентифицираме най-полезните функции за идентифициране на измами. Този алгоритъм, xgboost , е алгоритъм на дървото за вземане на решения с усилен градиент. Ще го обучим на 70% от набора от данни и ще го тестваме на останалите 30%. Можем да настроим алгоритъма да продължи, докато не подобри изземването (частта от извадките, открити) на тестовия набор от данни. Това постига 76% изземване на тестовия комплект, което очевидно оставя място за подобрение. Постига точност от 94%, което означава, че само 6% от предвидените случаи на измами са всъщност нормални транзакции. От този анализ получаваме и списък с функции, сортирани по тяхната полезност при откриване на измами. Можем да използваме най-важните функции, за да помогнем да визуализираме резултатите си по-късно.
Отново, ако имахме повече данни за измами, може би ще успеем да ги открием по-добре. Тоест, бихме могли да постигнем по-високо изземване. Сега ще се опитаме да генерираме нови, реалистични данни за измами, използвайки GAN, за да ни помогнат да открием действителната измама.
За да приложа различни GAN архитектури към този набор от данни, ще използвам GAN-пясъчник , който има редица популярни GAN архитектури, внедрени в Python, използвайки библиотеката Keras и TensorFlow back-end. Всички мои резултати са достъпни като тетрадка на Jupyter тук . Всички необходими библиотеки са включени в образа на Kaggle / Python Docker, ако имате нужда от лесна настройка.
Примерите в GAN-Sandbox са настроени за обработка на изображения. Генераторът създава 2D изображение с 3 цветни канала за всеки пиксел и дискриминаторът / критикът е конфигуриран да оценява такива данни. Конволюционните трансформации се използват между слоевете на мрежите, за да се възползват от пространствената структура на данните за изображенията. Всеки неврон в конволюционния слой работи само с малка група входове и изходи (например съседни пиксели в изображение), за да позволи изучаване на пространствени отношения. Нашият набор от данни за кредитни карти липсва каквато и да е пространствена структура между променливите, така че превърнах конволюционните мрежи в мрежи с плътно свързани слоеве. Невроните в плътно свързани слоеве са свързани към всеки вход и изход на слоя, което позволява на мрежата да научи своите връзки между характеристиките. Ще използвам тази настройка за всяка от архитектурите.
Първият GAN, който ще оценя, сравнява генераторната мрежа с дискриминаторната мрежа, използвайки загубата на кръстосана ентропия от дискриминатора за обучение на мрежите. Това е оригиналната „ванилова“ GAN архитектура. Вторият GAN, който ще оценя, добавя етикети на класове към данните по начина на условен GAN (CGAN). Този GAN има още една променлива в данните, етикетът на класа. Третият GAN ще използва показателя за дистанция Wasserstein за обучение на мрежите (WGAN), а последният ще използва етикетите на класовете и показателя за разстояние от Wasserstein (WCGAN).
Ще обучим различните GAN, като използваме набор от данни за обучение, който се състои от всички 492 измамни транзакции. Можем да добавим класове към набора от данни за измами, за да улесним условните GAN архитектури. Изследвах няколко различни метода на клъстериране в тетрадката и отидох с класификация на KMeans, която сортира данните за измами в 2 класа.
Ще тренирам всеки GAN за 5000 кръга и ще проуча резултатите по пътя. На фигура 4 можем да видим действителните данни за измами и генерираните данни за измами от различните архитектури на GAN с напредването на обучението. Можем да видим действителните данни за измами, разделени на 2 класа KMeans, нанесени с 2 измерения, които най-добре различават тези два класа (характеристики V10 и V17 от PCA трансформирани функции). Двете GAN, които не използват информация за класа, GAN и WGAN, имат генерираните си резултати като един клас. Условните архитектури, CGAN и WCGAN, показват генерираните си данни по клас. На стъпка 0 всички показва генерирани данни нормалното разпределение на случайния вход, подаван към генераторите.
Можем да видим, че оригиналната GAN архитектура започва да научава формата и обхвата на действителните данни, но след това се срива към малко разпространение. Това е обсъждането на режима, обсъдено по-рано. Генераторът е научил малък набор от данни, които дискриминаторът трудно открива като фалшиви. Архитектурата на CGAN се справя малко по-добре, разпространявайки и приближавайки се до разпределенията на всеки клас данни за измами, но след това се включва колапс на режим, както може да се види на стъпка 5000.
WGAN не изпитва колапс на режима, показан от архитектурите GAN и CGAN. Дори и без информация за класа, той започва да приема ненормално разпространение на действителните данни за измами. Архитектурата на WCGAN работи по подобен начин и е в състояние да генерира отделните класове данни.
Можем да оценим колко реалистично изглеждат данните, използвайки същия алгоритъм xgboost, използван по-рано за откриване на измами. Той е бърз и мощен и работи на разположение без много настройки. Ще обучим класификатора xgboost, като използваме половината от действителните данни за измами (246 проби) и равен брой генерирани от GAN примери. След това ще тестваме класификатора xgboost, като използваме другата половина от действителните данни за измами и различен набор от примери, генерирани от 246 GAN. Този ортогонален метод (в експериментален смисъл) ще ни даде някаква индикация за това колко успешен е генераторът при генерирането на реалистични данни. С напълно реалистични генерирани данни, алгоритъмът xgboost трябва да постигне точност от 0,50 (50%) - с други думи, това не е по-добре от предположението.
Ние можем вижте точността на xgboost на GAN генерираните данни намаляват отначало, а след това се увеличават след тренировъчна стъпка 1000, когато се задава режим на свиване. Архитектурата на CGAN постига малко по-реалистични данни след 2000 стъпки, но след това режимът на колапс се включва и за тази мрежа. Архитектурите WGAN и WCGAN постигат по-реалистични данни по-бързо и продължават да се учат с напредването на обучението. Изглежда, че WCGAN няма много предимство пред WGAN, което предполага, че тези създадени класове може да не са полезни за архитектурите на Wasserstein GAN.
Можете да научите повече за архитектурата WGAN от тук и тук .
Мрежата на критиците в архитектурите WGAN и WCGAN се учи да изчислява разстоянието Wasserstein (Earth-mover, EM) между даден набор от данни и действителните данни за измами. В идеалния случай той ще измери разстояние, близко до нула, за извадка от действителни данни за измами. Критикът обаче е в процес на научаване как да извършва това изчисление. Докато измерва по-голямо разстояние за генерирани данни, отколкото за реални данни, мрежата може да се подобри. Можем да наблюдаваме как разликата между разстоянията на Wasserstein за генерирани и реални данни се променя по време на обучението. Ако това е плата, тогава по-нататъшното обучение може да не помогне. На фигура 6 виждаме, че там изглежда допълнително подобрение да се използва както за WGAN, така и за WCGAN на този набор от данни.
Сега можем да тестваме дали сме в състояние да генерираме нови данни за измами, достатъчно реалистични, за да ни помогнат да открием действителни данни за измами. Можем да вземем обучения генератор, който е постигнал най-ниския резултат за точност и да го използваме за генериране на данни. За нашия основен набор от обучения ще използваме 70% от данните за измами (199 020 случая) и 100 случая от данните за измама (~ 20% от данните за измами). След това ще се опитаме да добавим различни количества реални или генерирани данни за измами към този набор от обучения, до 344 случая (70% от данните за измами). За тестовия набор ще използваме останалите 30% от случаите за измама (85 295 случая) и случаите на измами (148 дела). Можем да опитаме да добавим генерирани данни от необучен GAN и от най-добре обучения GAN, за да тестваме дали генерираните данни са по-добри от случаен шум. От нашите тестове изглежда, че най-добрата ни архитектура е WCGAN в стъпка 4800 на обучение, където е постигнал xgboost точност от 70% (не забравяйте, че в идеалния случай точността ще бъде 50%). Така че ще използваме тази архитектура, за да генерираме нови данни за измами.
На фигура 7 можем да видим, че припомнянето (частта от действителните проби за измами, точно идентифицирани в тестовия набор) не се увеличава, тъй като използваме повече генерирани данни за измами за обучение. Класификаторът xgboost е в състояние да запази цялата информация, използвана за идентифициране на измама от 100-те реални случая, и да не се обърка от допълнително генерираните данни, дори когато ги избира от стотици хиляди нормални случаи. Генерираните данни от нетренирания WCGAN не помагат или нараняват, не е изненадващо. Но генерираните данни от обучения WCGAN също не помагат. Изглежда данните не са достатъчно реалистични. Виждаме на фигура 7, че когато се използват реални данни за измами за допълване на обучителния набор, изземването значително се увеличава. Ако WCGAN току-що се беше научил да дублира примери за обучение, без изобщо да бъде креативен, той би могъл да постигне по-високи нива на изземване, както ние вижте с реалните данни .
Въпреки че не успяхме да генерираме достатъчно реалистични данни за измами с кредитни карти, които да ни помогнат да открием действителни измами, ние едвам надраскахме повърхността с тези методи. Можем да тренираме по-дълго, с по-големи мрежи и да настроим параметри за архитектурите, които опитахме в тази статия. Тенденциите в точността на xgboost и загубата на дискриминатор предполагат, че повече обучение ще помогне на архитектурите WGAN и WCGAN. Друг вариант е да преразгледаме почистването на данните, което извършихме, може би да проектираме някои нови променливи или да променим дали и как се справяме с изкривяването на характеристиките. Може би различни схеми за класификация на данните за измами биха помогнали.
Можем да опитаме и други GAN архитектури. The ДРАГАН има теоретични и експериментални доказателства, които показват, че тренира по-бързо и по-стабилно от GAN на Wasserstein. Бихме могли да интегрираме методи, които използват полууправлявано обучение, което показа обещание при учене от ограничени учебни комплекти (вж. „ Подобрени техники за обучение на GAN ”). Можем да изпробваме архитектура, която ни дава човешки разбираеми модели, за да можем да разберем по-добре структурата на данните (вж. InfoGAN ).
Трябва също така да следим за нови разработки в тази област и не на последно място, можем да работим върху създаването на собствени иновации в това бързо развиващо се пространство.
Можете да намерите всички съответни кодове за тази статия в това Хранилище на GitHub .
Свързани: Многото приложения на градиентното спускане в TensorFlowGAN е алгоритъм за машинно обучение, при който една невронна мрежа генерира данни, докато друга определя дали изходът изглежда реален. Двете мрежи се съревновават една срещу друга, за да подобрят реализма на генерираните данни.