From a279965fed5dfac3c908c44b476898f6992bb95c Mon Sep 17 00:00:00 2001 From: Armin Friedl Date: Sun, 10 Jul 2022 19:47:06 +0200 Subject: [PATCH] Add dalle --- api/alas/evaluators/dalle/dalle.py | 2 +- api/alas/main.py | 38 ++-- web/components/Artsy.tsx | 214 +++++++++++++++++++ web/components/Continuation.tsx | 53 ----- web/components/TextInput.tsx | 168 --------------- web/pages/index.tsx | 11 +- web/service/openapi/.openapi-generator/FILES | 1 + web/service/openapi/apis/DefaultApi.ts | 50 +++++ web/service/openapi/models/DalleResponse.ts | 56 +++++ web/service/openapi/models/index.ts | 1 + web/styles/TextInput.module.scss | 2 +- 11 files changed, 350 insertions(+), 246 deletions(-) create mode 100644 web/components/Artsy.tsx delete mode 100644 web/components/Continuation.tsx delete mode 100644 web/components/TextInput.tsx create mode 100644 web/service/openapi/models/DalleResponse.ts diff --git a/api/alas/evaluators/dalle/dalle.py b/api/alas/evaluators/dalle/dalle.py index 6bec8e3..7ba0d9f 100644 --- a/api/alas/evaluators/dalle/dalle.py +++ b/api/alas/evaluators/dalle/dalle.py @@ -112,4 +112,4 @@ class DalleModel: return images -dallE = DalleModel(ModelSize.MINI) \ No newline at end of file +dallE = DalleModel(ModelSize.MINI) diff --git a/api/alas/main.py b/api/alas/main.py index 2848063..bea5ac8 100644 --- a/api/alas/main.py +++ b/api/alas/main.py @@ -1,3 +1,5 @@ +from typing import List + import uvicorn from fastapi import FastAPI @@ -5,6 +7,7 @@ from fastapi.middleware.cors import CORSMiddleware from evaluators import dialog_gpt from evaluators import roberta +from evaluators.dalle import dalle from pydantic import BaseModel @@ -14,7 +17,7 @@ from pathlib import Path app = FastAPI() -origins = ["*"] +origins = ["https://warp.friedl.net"] app.add_middleware( CORSMiddleware, @@ -83,23 +86,30 @@ def get_continuation(text: str) -> str: return dialog_gpt.dialogGPT.eval(text) -# @app.get("/dalle/generate") -# def get_image(text: str): -# text_prompt = text -# generated_imgs = dalle.dallE.eval(text_prompt, 2) +class DalleResponse(BaseModel): + imagePaths: List[str] -# returned_generated_images = [] -# dir_name = os.path.join("/home/armin/Desktop/dalle", f"{time.strftime('%Y-%m-%d_%H:%M:%S')}_{text_prompt}") -# Path(dir_name).mkdir(parents=True, exist_ok=True) -# for idx, img in enumerate(generated_imgs): -# img.save(os.path.join(dir_name, f'{idx}.png'), format="png") +@app.get("/dalle/generate", response_model=DalleResponse) +def get_image(text: str, count: int) -> DalleResponse: + text_prompt = text + generated_imgs = dalle.dallE.eval(text_prompt, count) -# print(f"Created {2} images from text prompt [{text_prompt}]") + base_dir = "/home/armin/dev/incubator/alas/web/public/" + sub_dir = f"{time.strftime('%Y-%m-%d_%H:%M:%S')}_{text_prompt}" + dir_name = os.path.join(base_dir, sub_dir) + Path(dir_name).mkdir(parents=True, exist_ok=True) -# response = {'generatedImgs': returned_generated_images, -# 'generatedImgsFormat': "img"} -# return response + img_paths = [] + for idx, img in enumerate(generated_imgs): + f_name = f'{idx}.png' + img_path = os.path.join(dir_name, f_name) + img.save(img_path, format="png") + img_paths.append(os.path.join(sub_dir, f_name)) + + return { + 'imagePaths': img_paths + } if __name__ == "__main__": diff --git a/web/components/Artsy.tsx b/web/components/Artsy.tsx new file mode 100644 index 0000000..77ddc9c --- /dev/null +++ b/web/components/Artsy.tsx @@ -0,0 +1,214 @@ +import { + Configuration, DalleResponse, + DefaultApi, + RobertaEmotionResponse, + RobertaHateResponse, + RobertaIronyResponse, + RobertaOffensiveResponse, + RobertaSentimentResponse +} from "../service/openapi"; +import { + Accordion, + Button, + Card, + Col, + Container, + Dropdown, + DropdownButton, + Form, + FormControl, + FormGroup, FormLabel, Image, InputGroup, + Row, Spinner, + Table +} from "react-bootstrap"; +import styles from "../styles/TextInput.module.scss"; +import {createRef, FormEvent, forwardRef, useCallback, useEffect, useState} from "react"; +import _ from "lodash"; + +let debounce = _.debounce(async (fn) => await fn(), 10000) + +export default function Artsy(props) { + const [text, setText] = useState(""); + + const [calculating, setCalculating] = useState(false) + const [model, setModel] = useState("1"); + const [count, setCount] = useState("5"); + const [prompt, setPrompt] = useState("") + const [artery, setArtery] = useState>(Array.of()); + + interface History { + model: string, + count: string, + prompt: string + artery: Array + } + const [history, setHistory] = useState>(Array.of()) + + const configuration = new Configuration({ + basePath: 'https://warp2.friedl.net' + }); + + const api = new DefaultApi(configuration); + + let handleSubmit = async (e: FormEvent) => { + e.preventDefault(); + if (!text) { + setText(undefined); + return; + } + + setCalculating(true); + setPrompt(""); + setArtery(Array.of()); + + try { + const r: DalleResponse = await api.getImageDalleGenerateGet({text: text, count: parseInt(count)}); + setPrompt(text); + setArtery(r.imagePaths); + + history.push({ model: model, count: count, prompt: prompt, artery: [...artery]}); + setHistory(history) + + } catch (error) { + console.error(error) + } + + setCalculating(false); + return false; + } + + let renderCards = () => { + + let cards = [] + for(const p of artery) { + cards.push( + + + + + + + + ) + } + + return ( + <>{cards} + ); + } + + let renderHistory = () => { + let hist = [] + + let idx = 0 + for(const h of history) { + + let cards = [] + for(const p of h.artery) { + cards.push( + + + + + + + + ) + } + + hist.push( + <> + +
{h.prompt}
+
+ + {cards} + + + ) + + idx++; + } + + return( + <> + {hist} + + ) + } + + return ( + <> + + + + +

ALAS

+ +
+
+
+ + +
+ setText(e.target.value)} + className={`${styles.textBox} mb-3`} as="textarea" + disabled={calculating}/> + + + Model + setModel(e.currentTarget.value)}> + + + + + + # of Images + setCount(e.currentTarget.value)} value={count} disabled={calculating}> + + + + + + + + + + + + + + + + +
+ +

Artery

+
+ +
{prompt}
+
+ + { + calculating ? : renderCards() + } + + + + { + history.length > 1 && + +

History

+
+ } + + { + history.length > 1 && renderHistory() + } + +
+ + ) +} diff --git a/web/components/Continuation.tsx b/web/components/Continuation.tsx deleted file mode 100644 index d8616b3..0000000 --- a/web/components/Continuation.tsx +++ /dev/null @@ -1,53 +0,0 @@ -import {Col, Container, Row, Spinner} from "react-bootstrap"; -import {Configuration, DefaultApi} from "../service/openapi"; -import {useEffect, useState} from "react"; - -export default function Continuation(props) { - let [continuation, setContinuation] = useState("") - let [computing, setComputing] = useState(false) - - - const configuration = new Configuration({ - basePath: 'https://api.alas.friedl.net' - }); - - const api = new DefaultApi(configuration); - const fetchContinuation = async () => { - try { - setComputing(true) - let resp = await api.getContinuationDialogContinuationGet({text: props.text}); - setComputing(false) - return resp; - } catch (error) { - console.error(error); - } - } - - useEffect(() => { - fetchContinuation() - .then(r => setContinuation(r)); - }, [props]) - - return ( - - -

Effective Response

-
- - {computing - ? - : continuation && continuation !== "\"I\"" && continuation !== '""' - ?
-
-

{continuation}

-
-
- use at own risk -
-
- :

Does not compute.

- } -
-
- ); -} diff --git a/web/components/TextInput.tsx b/web/components/TextInput.tsx deleted file mode 100644 index 5a91c0f..0000000 --- a/web/components/TextInput.tsx +++ /dev/null @@ -1,168 +0,0 @@ -import { Configuration, DefaultApi, RobertaEmotionResponse, RobertaHateResponse, RobertaIronyResponse, RobertaOffensiveResponse, RobertaSentimentResponse } from "../service/openapi"; -import {Button, Card, Col, Container, Form, FormControl, FormGroup, Row, Table} from "react-bootstrap"; -import styles from "../styles/TextInput.module.scss"; -import {createRef, FormEvent, forwardRef, useCallback, useEffect, useState} from "react"; -import _ from "lodash"; - - -let debounce = _.debounce(async (fn) => await fn(), 400) - -export default function TextInput(props) { - const [text, setText] = useState(""); - const [emotions, setEmotions] = useState(undefined); - const [hate, setHate] = useState(undefined); - const [irony, setIrony] = useState(undefined); - const [offensive, setOffensive] = useState(undefined); - const [sentiment, setSentiment] = useState(undefined); - - useEffect(() => { debounce(async () => { - props.textFn(text); - await evaluateInput(); - }) }, [text]) - - const configuration = new Configuration({ - basePath: 'https://api.alas.friedl.net' - }); - - const api = new DefaultApi(configuration); - - let evaluateInput = async () => { - if(!text) { - setEmotions(undefined); - setHate(undefined); - setIrony(undefined); - setOffensive(undefined); - setSentiment(undefined); - return; - } - - try{ - const [emotions, hate, irony, offensive, sentiment] = await Promise.all([ - api.getEmotionsRobertaEmotionGet({text: text}), - api.getHateRobertaHateGet({ text: text }), - api.getIronyRobertaIronyGet({ text: text }), - api.getOffensiveRobertaOffensiveGet({ text: text }), - api.getSentimentRobertaSentimentGet({ text: text }) - ]) - - setEmotions(emotions); - setHate(hate); - setIrony(irony); - setOffensive(offensive); - setSentiment(sentiment); - - } catch(error) { - console.error(error) - } - } - - return ( - <> - - - - -

ALAS

- -
-
-
- - -
- - setText(e.target.value)} - className={styles.textBox} as="textarea" /> - -
-
- -

Analysis

-
- - - - - - Emotion - - - - - - - - - - - - - - - - - - - - -
Joy{emotions?.joy.toFixed(3)}
Sadness{emotions?.sadness.toFixed(3)}
Anger{emotions?.anger.toFixed(3)}
Optimism{emotions?.optimism.toFixed(3)}
-
-
- - - - - - Temper - - - - - - - - - - - - - - - - -
Hate{hate?.hate.toFixed(3)}
Irony{irony?.irony.toFixed(3)}
Offense{offensive?.offensive.toFixed(3)}
-
-
- - - - - - Sentiment - - - - - - - - - - - - - - - - -
Negative{sentiment?.negative.toFixed(3)}
Neutral{sentiment?.neutral.toFixed(3)}
Positive{sentiment?.positive.toFixed(3)}
-
-
- -
-
- - ) -} diff --git a/web/pages/index.tsx b/web/pages/index.tsx index 534951d..ad555c8 100644 --- a/web/pages/index.tsx +++ b/web/pages/index.tsx @@ -1,13 +1,8 @@ import Head from 'next/head' -import TextInput from "../components/TextInput"; -import Continuation from '../components/Continuation'; -import {useRef, useState} from "react"; +import Artsy from "../components/Artsy"; export default function Home() { - const [text, setText] = useState(""); - - // @ts-ignore return (
@@ -16,9 +11,7 @@ export default function Home() { - - - +
) } diff --git a/web/service/openapi/.openapi-generator/FILES b/web/service/openapi/.openapi-generator/FILES index aefa923..fb5a5ec 100644 --- a/web/service/openapi/.openapi-generator/FILES +++ b/web/service/openapi/.openapi-generator/FILES @@ -2,6 +2,7 @@ apis/DefaultApi.ts apis/index.ts index.ts +models/DalleResponse.ts models/HTTPValidationError.ts models/LocationInner.ts models/RobertaEmotionResponse.ts diff --git a/web/service/openapi/apis/DefaultApi.ts b/web/service/openapi/apis/DefaultApi.ts index f696354..1f23f37 100644 --- a/web/service/openapi/apis/DefaultApi.ts +++ b/web/service/openapi/apis/DefaultApi.ts @@ -15,6 +15,9 @@ import * as runtime from '../runtime'; import { + DalleResponse, + DalleResponseFromJSON, + DalleResponseToJSON, HTTPValidationError, HTTPValidationErrorFromJSON, HTTPValidationErrorToJSON, @@ -47,6 +50,11 @@ export interface GetHateRobertaHateGetRequest { text: string; } +export interface GetImageDalleGenerateGetRequest { + text: string; + count: number; +} + export interface GetIronyRobertaIronyGetRequest { text: string; } @@ -166,6 +174,48 @@ export class DefaultApi extends runtime.BaseAPI { return await response.value(); } + /** + * Get Image + */ + async getImageDalleGenerateGetRaw(requestParameters: GetImageDalleGenerateGetRequest, initOverrides?: RequestInit | runtime.InitOverideFunction): Promise> { + if (requestParameters.text === null || requestParameters.text === undefined) { + throw new runtime.RequiredError('text','Required parameter requestParameters.text was null or undefined when calling getImageDalleGenerateGet.'); + } + + if (requestParameters.count === null || requestParameters.count === undefined) { + throw new runtime.RequiredError('count','Required parameter requestParameters.count was null or undefined when calling getImageDalleGenerateGet.'); + } + + const queryParameters: any = {}; + + if (requestParameters.text !== undefined) { + queryParameters['text'] = requestParameters.text; + } + + if (requestParameters.count !== undefined) { + queryParameters['count'] = requestParameters.count; + } + + const headerParameters: runtime.HTTPHeaders = {}; + + const response = await this.request({ + path: `/dalle/generate`, + method: 'GET', + headers: headerParameters, + query: queryParameters, + }, initOverrides); + + return new runtime.JSONApiResponse(response, (jsonValue) => DalleResponseFromJSON(jsonValue)); + } + + /** + * Get Image + */ + async getImageDalleGenerateGet(requestParameters: GetImageDalleGenerateGetRequest, initOverrides?: RequestInit | runtime.InitOverideFunction): Promise { + const response = await this.getImageDalleGenerateGetRaw(requestParameters, initOverrides); + return await response.value(); + } + /** * Get Irony */ diff --git a/web/service/openapi/models/DalleResponse.ts b/web/service/openapi/models/DalleResponse.ts new file mode 100644 index 0000000..41095a1 --- /dev/null +++ b/web/service/openapi/models/DalleResponse.ts @@ -0,0 +1,56 @@ +/* tslint:disable */ +/* eslint-disable */ +/** + * FastAPI + * No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator) + * + * The version of the OpenAPI document: 0.1.0 + * + * + * NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech). + * https://openapi-generator.tech + * Do not edit the class manually. + */ + +import { exists, mapValues } from '../runtime'; +/** + * + * @export + * @interface DalleResponse + */ +export interface DalleResponse { + /** + * + * @type {Array} + * @memberof DalleResponse + */ + imagePaths: Array; +} + +export function DalleResponseFromJSON(json: any): DalleResponse { + return DalleResponseFromJSONTyped(json, false); +} + +export function DalleResponseFromJSONTyped(json: any, ignoreDiscriminator: boolean): DalleResponse { + if ((json === undefined) || (json === null)) { + return json; + } + return { + + 'imagePaths': json['imagePaths'], + }; +} + +export function DalleResponseToJSON(value?: DalleResponse | null): any { + if (value === undefined) { + return undefined; + } + if (value === null) { + return null; + } + return { + + 'imagePaths': value.imagePaths, + }; +} + diff --git a/web/service/openapi/models/index.ts b/web/service/openapi/models/index.ts index d98e7db..1ad7621 100644 --- a/web/service/openapi/models/index.ts +++ b/web/service/openapi/models/index.ts @@ -1,5 +1,6 @@ /* tslint:disable */ /* eslint-disable */ +export * from './DalleResponse'; export * from './HTTPValidationError'; export * from './LocationInner'; export * from './RobertaEmotionResponse'; diff --git a/web/styles/TextInput.module.scss b/web/styles/TextInput.module.scss index 2d650ec..925a40c 100644 --- a/web/styles/TextInput.module.scss +++ b/web/styles/TextInput.module.scss @@ -3,7 +3,7 @@ } .textBox { - height: 200px; + height: 150px; font-family: Nunito; }