Compare commits

...

1 commit

Author SHA1 Message Date
a279965fed Add dalle
All checks were successful
continuous-integration/drone/push Build is passing
2022-07-10 19:47:06 +02:00
11 changed files with 350 additions and 246 deletions

View file

@ -1,3 +1,5 @@
from typing import List
import uvicorn
from fastapi import FastAPI
@ -5,6 +7,7 @@ from fastapi.middleware.cors import CORSMiddleware
from evaluators import dialog_gpt
from evaluators import roberta
from evaluators.dalle import dalle
from pydantic import BaseModel
@ -14,7 +17,7 @@ from pathlib import Path
app = FastAPI()
origins = ["*"]
origins = ["https://warp.friedl.net"]
app.add_middleware(
CORSMiddleware,
@ -83,23 +86,30 @@ def get_continuation(text: str) -> str:
return dialog_gpt.dialogGPT.eval(text)
# @app.get("/dalle/generate")
# def get_image(text: str):
# text_prompt = text
# generated_imgs = dalle.dallE.eval(text_prompt, 2)
class DalleResponse(BaseModel):
imagePaths: List[str]
# returned_generated_images = []
# dir_name = os.path.join("/home/armin/Desktop/dalle", f"{time.strftime('%Y-%m-%d_%H:%M:%S')}_{text_prompt}")
# Path(dir_name).mkdir(parents=True, exist_ok=True)
# for idx, img in enumerate(generated_imgs):
# img.save(os.path.join(dir_name, f'{idx}.png'), format="png")
@app.get("/dalle/generate", response_model=DalleResponse)
def get_image(text: str, count: int) -> DalleResponse:
text_prompt = text
generated_imgs = dalle.dallE.eval(text_prompt, count)
# print(f"Created {2} images from text prompt [{text_prompt}]")
base_dir = "/home/armin/dev/incubator/alas/web/public/"
sub_dir = f"{time.strftime('%Y-%m-%d_%H:%M:%S')}_{text_prompt}"
dir_name = os.path.join(base_dir, sub_dir)
Path(dir_name).mkdir(parents=True, exist_ok=True)
# response = {'generatedImgs': returned_generated_images,
# 'generatedImgsFormat': "img"}
# return response
img_paths = []
for idx, img in enumerate(generated_imgs):
f_name = f'{idx}.png'
img_path = os.path.join(dir_name, f_name)
img.save(img_path, format="png")
img_paths.append(os.path.join(sub_dir, f_name))
return {
'imagePaths': img_paths
}
if __name__ == "__main__":

214
web/components/Artsy.tsx Normal file
View file

@ -0,0 +1,214 @@
import {
Configuration, DalleResponse,
DefaultApi,
RobertaEmotionResponse,
RobertaHateResponse,
RobertaIronyResponse,
RobertaOffensiveResponse,
RobertaSentimentResponse
} from "../service/openapi";
import {
Accordion,
Button,
Card,
Col,
Container,
Dropdown,
DropdownButton,
Form,
FormControl,
FormGroup, FormLabel, Image, InputGroup,
Row, Spinner,
Table
} from "react-bootstrap";
import styles from "../styles/TextInput.module.scss";
import {createRef, FormEvent, forwardRef, useCallback, useEffect, useState} from "react";
import _ from "lodash";
let debounce = _.debounce(async (fn) => await fn(), 10000)
export default function Artsy(props) {
const [text, setText] = useState("");
const [calculating, setCalculating] = useState(false)
const [model, setModel] = useState("1");
const [count, setCount] = useState("5");
const [prompt, setPrompt] = useState("")
const [artery, setArtery] = useState<Array<string>>(Array.of());
interface History {
model: string,
count: string,
prompt: string
artery: Array<string>
}
const [history, setHistory] = useState<Array<History>>(Array.of())
const configuration = new Configuration({
basePath: 'https://warp2.friedl.net'
});
const api = new DefaultApi(configuration);
let handleSubmit = async (e: FormEvent) => {
e.preventDefault();
if (!text) {
setText(undefined);
return;
}
setCalculating(true);
setPrompt("");
setArtery(Array.of());
try {
const r: DalleResponse = await api.getImageDalleGenerateGet({text: text, count: parseInt(count)});
setPrompt(text);
setArtery(r.imagePaths);
history.push({ model: model, count: count, prompt: prompt, artery: [...artery]});
setHistory(history)
} catch (error) {
console.error(error)
}
setCalculating(false);
return false;
}
let renderCards = () => {
let cards = []
for(const p of artery) {
cards.push(
<Col key={p} className={"flex-grow-0"}>
<Card style={{width: '18rem', height: '100%'}}>
<Card.Body>
<Image src={p} />
</Card.Body>
</Card>
</Col>
)
}
return (
<>{cards}</>
);
}
let renderHistory = () => {
let hist = []
let idx = 0
for(const h of history) {
let cards = []
for(const p of h.artery) {
cards.push(
<Col key={p} className={"flex-grow-0"}>
<Card style={{width: '18rem', height: '100%'}}>
<Card.Body>
<Image src={p} />
</Card.Body>
</Card>
</Col>
)
}
hist.push(
<>
<Row className={"mb-3 justify-content-center"} key={`${idx}-prompt`}>
<h5 style={{width: "auto"}} className={"text-muted"}>{h.prompt}</h5>
</Row>
<Row className={"justify-content-center gap-2 mb-5"} key={`${idx}-card`}>
{cards}
</Row>
</>
)
idx++;
}
return(
<>
{hist}
</>
)
}
return (
<>
<Container fluid={true} className={"bg-primary pb-5"}>
<Container className={"pb-5 pt-2"}>
<Row xs={1} xl={2} className={"justify-content-center"}>
<Col>
<h1 className={`text-light`}>ALAS</h1>
</Col>
</Row>
</Container>
</Container>
<Container>
<Row xs={1} xl={2} className={`pb-5 justify-content-center`}>
<Form className={styles.inputForm} onSubmit={handleSubmit}>
<FormControl value={text}
placeholder={"Prompt..."}
onChange={(e) => setText(e.target.value)}
className={`${styles.textBox} mb-3`} as="textarea"
disabled={calculating}/>
<InputGroup className={"mb-3"}>
<InputGroup.Text id="model-addon">Model</InputGroup.Text>
<Form.Select aria-label="Model selection" className={"me-3"} disabled={calculating} value={model} onChange={(e) => setModel(e.currentTarget.value)}>
<option value="1">MINI</option>
<option value="2" disabled={true}>MEGA</option>
<option value="3" disabled={true}>MEGA FULL</option>
</Form.Select>
<InputGroup.Text id="number-addon"># of Images</InputGroup.Text>
<Form.Select aria-label="Number of generated images" className={"me-3"} onChange={(e) => setCount(e.currentTarget.value)} value={count} disabled={calculating}>
<option value="1">1</option>
<option value="2">2</option>
<option value="3">3</option>
<option value="4">4</option>
<option value="5">5</option>
<option value="6">6</option>
<option value="7">7</option>
<option value="8">8</option>
<option value="9">9</option>
<option value="10">10</option>
</Form.Select>
<Button className={"float-end"} type={"submit"} disabled={calculating}>Submit</Button>
</InputGroup>
</Form>
</Row>
<Row className={"justify-content-center mb-3"}>
<h1 style={{width: "auto"}} className={"text-primary"}>Art<sup>e<sub>r</sub>y</sup></h1>
</Row>
<Row className={"mb-3 justify-content-center"}>
<h5 style={{width: "auto"}} className={"text-muted"}>{prompt}</h5>
</Row>
<Row className={"justify-content-center gap-2"}>
{
calculating ? <Spinner animation={"border"}/> : renderCards()
}
</Row>
<Row className={"mt-5"} />
{
history.length > 1 &&
<Row className={"mt-5 justify-content-center"}>
<h2 style={{width: "auto"}} className={"text-muted"}>History</h2>
</Row>
}
<Row className={"justify-content-center gap-2"}>
{
history.length > 1 && renderHistory()
}
</Row>
</Container>
</>
)
}

View file

@ -1,53 +0,0 @@
import {Col, Container, Row, Spinner} from "react-bootstrap";
import {Configuration, DefaultApi} from "../service/openapi";
import {useEffect, useState} from "react";
export default function Continuation(props) {
let [continuation, setContinuation] = useState<string>("")
let [computing, setComputing] = useState<boolean>(false)
const configuration = new Configuration({
basePath: 'https://api.alas.friedl.net'
});
const api = new DefaultApi(configuration);
const fetchContinuation = async () => {
try {
setComputing(true)
let resp = await api.getContinuationDialogContinuationGet({text: props.text});
setComputing(false)
return resp;
} catch (error) {
console.error(error);
}
}
useEffect(() => {
fetchContinuation()
.then(r => setContinuation(r));
}, [props])
return (
<Container className={"mt-5"}>
<Row className={"justify-content-center mb-3"}>
<h4 style={{width: "auto"}}>Effective Response</h4>
</Row>
<Row className={"justify-content-center"}>
{computing
? <Spinner animation="border" />
: continuation && continuation !== "\"I\"" && continuation !== '""'
? <figure style={{width: "auto"}}>
<blockquote className="blockquote">
<p>{continuation}</p>
</blockquote>
<figcaption className={"blockquote-footer"}>
use at own risk
</figcaption>
</figure>
: <p className={"text-muted"} style={{width: "auto"}}>Does not compute.</p>
}
</Row>
</Container>
);
}

View file

@ -1,168 +0,0 @@
import { Configuration, DefaultApi, RobertaEmotionResponse, RobertaHateResponse, RobertaIronyResponse, RobertaOffensiveResponse, RobertaSentimentResponse } from "../service/openapi";
import {Button, Card, Col, Container, Form, FormControl, FormGroup, Row, Table} from "react-bootstrap";
import styles from "../styles/TextInput.module.scss";
import {createRef, FormEvent, forwardRef, useCallback, useEffect, useState} from "react";
import _ from "lodash";
let debounce = _.debounce(async (fn) => await fn(), 400)
export default function TextInput(props) {
const [text, setText] = useState("");
const [emotions, setEmotions] = useState<RobertaEmotionResponse|undefined>(undefined);
const [hate, setHate] = useState<RobertaHateResponse|undefined>(undefined);
const [irony, setIrony] = useState<RobertaIronyResponse|undefined>(undefined);
const [offensive, setOffensive] = useState<RobertaOffensiveResponse|undefined>(undefined);
const [sentiment, setSentiment] = useState<RobertaSentimentResponse|undefined>(undefined);
useEffect(() => { debounce(async () => {
props.textFn(text);
await evaluateInput();
}) }, [text])
const configuration = new Configuration({
basePath: 'https://api.alas.friedl.net'
});
const api = new DefaultApi(configuration);
let evaluateInput = async () => {
if(!text) {
setEmotions(undefined);
setHate(undefined);
setIrony(undefined);
setOffensive(undefined);
setSentiment(undefined);
return;
}
try{
const [emotions, hate, irony, offensive, sentiment] = await Promise.all([
api.getEmotionsRobertaEmotionGet({text: text}),
api.getHateRobertaHateGet({ text: text }),
api.getIronyRobertaIronyGet({ text: text }),
api.getOffensiveRobertaOffensiveGet({ text: text }),
api.getSentimentRobertaSentimentGet({ text: text })
])
setEmotions(emotions);
setHate(hate);
setIrony(irony);
setOffensive(offensive);
setSentiment(sentiment);
} catch(error) {
console.error(error)
}
}
return (
<>
<Container fluid={true} className={"bg-primary pb-5"}>
<Container className={"pb-5 pt-2"}>
<Row xs={1} xl={2} className={"justify-content-center"}>
<Col>
<h1 className={`text-light`}>ALAS</h1>
</Col>
</Row>
</Container>
</Container>
<Container>
<Row xs={1} xl={2} className={`pb-5 justify-content-center`}>
<Form className={styles.inputForm}>
<FormGroup controlId="formText">
<FormControl value={text}
placeholder={"Start typing..."}
onChange={(e) => setText(e.target.value)}
className={styles.textBox} as="textarea" />
</FormGroup>
</Form>
</Row>
<Row className={"justify-content-center mb-3"}>
<h4 style={{width: "auto"}}>Analysis</h4>
</Row>
<Row className={"justify-content-center gap-3"}>
<Col className={"flex-grow-0"}>
<Card style={{width: '18rem', height: '100%'}}>
<Card.Body>
<Card.Title className={"border-bottom border-2"}>
Emotion
</Card.Title>
<Table size="sm" borderless={true}>
<tbody>
<tr className={"mt-2"}>
<td>Joy</td>
<td>{emotions?.joy.toFixed(3)}</td>
</tr>
<tr>
<td>Sadness</td>
<td>{emotions?.sadness.toFixed(3)}</td>
</tr>
<tr>
<td>Anger</td>
<td>{emotions?.anger.toFixed(3)}</td>
</tr>
<tr>
<td>Optimism</td>
<td>{emotions?.optimism.toFixed(3)}</td>
</tr>
</tbody>
</Table>
</Card.Body>
</Card>
</Col>
<Col className={"flex-grow-0"}>
<Card style={{width: '18rem', height: '100%'}}>
<Card.Body>
<Card.Title className={"border-bottom border-2"}>
Temper
</Card.Title>
<Table size="sm" borderless={true}>
<tbody>
<tr>
<td>Hate</td>
<td>{hate?.hate.toFixed(3)}</td>
</tr>
<tr>
<td>Irony</td>
<td>{irony?.irony.toFixed(3)}</td>
</tr>
<tr>
<td>Offense</td>
<td>{offensive?.offensive.toFixed(3)}</td>
</tr>
</tbody>
</Table>
</Card.Body>
</Card>
</Col>
<Col className={"flex-grow-0"}>
<Card style={{width: '18rem', height: '100%'}}>
<Card.Body>
<Card.Title className={"border-bottom border-2"}>
Sentiment
</Card.Title>
<Table size="sm" borderless={true}>
<tbody>
<tr>
<td>Negative</td>
<td>{sentiment?.negative.toFixed(3)}</td>
</tr>
<tr>
<td>Neutral</td>
<td>{sentiment?.neutral.toFixed(3)}</td>
</tr>
<tr>
<td>Positive</td>
<td>{sentiment?.positive.toFixed(3)}</td>
</tr>
</tbody>
</Table>
</Card.Body>
</Card>
</Col>
</Row>
</Container>
</>
)
}

View file

@ -1,13 +1,8 @@
import Head from 'next/head'
import TextInput from "../components/TextInput";
import Continuation from '../components/Continuation';
import {useRef, useState} from "react";
import Artsy from "../components/Artsy";
export default function Home() {
const [text, setText] = useState<string>("");
// @ts-ignore
return (
<div>
<Head>
@ -16,9 +11,7 @@ export default function Home() {
<link rel="icon" href="/favicon.ico"/>
</Head>
<TextInput textFn={setText} />
<Continuation text={text}/>
<Artsy />
</div>
)
}

View file

@ -2,6 +2,7 @@
apis/DefaultApi.ts
apis/index.ts
index.ts
models/DalleResponse.ts
models/HTTPValidationError.ts
models/LocationInner.ts
models/RobertaEmotionResponse.ts

View file

@ -15,6 +15,9 @@
import * as runtime from '../runtime';
import {
DalleResponse,
DalleResponseFromJSON,
DalleResponseToJSON,
HTTPValidationError,
HTTPValidationErrorFromJSON,
HTTPValidationErrorToJSON,
@ -47,6 +50,11 @@ export interface GetHateRobertaHateGetRequest {
text: string;
}
export interface GetImageDalleGenerateGetRequest {
text: string;
count: number;
}
export interface GetIronyRobertaIronyGetRequest {
text: string;
}
@ -166,6 +174,48 @@ export class DefaultApi extends runtime.BaseAPI {
return await response.value();
}
/**
* Get Image
*/
async getImageDalleGenerateGetRaw(requestParameters: GetImageDalleGenerateGetRequest, initOverrides?: RequestInit | runtime.InitOverideFunction): Promise<runtime.ApiResponse<DalleResponse>> {
if (requestParameters.text === null || requestParameters.text === undefined) {
throw new runtime.RequiredError('text','Required parameter requestParameters.text was null or undefined when calling getImageDalleGenerateGet.');
}
if (requestParameters.count === null || requestParameters.count === undefined) {
throw new runtime.RequiredError('count','Required parameter requestParameters.count was null or undefined when calling getImageDalleGenerateGet.');
}
const queryParameters: any = {};
if (requestParameters.text !== undefined) {
queryParameters['text'] = requestParameters.text;
}
if (requestParameters.count !== undefined) {
queryParameters['count'] = requestParameters.count;
}
const headerParameters: runtime.HTTPHeaders = {};
const response = await this.request({
path: `/dalle/generate`,
method: 'GET',
headers: headerParameters,
query: queryParameters,
}, initOverrides);
return new runtime.JSONApiResponse(response, (jsonValue) => DalleResponseFromJSON(jsonValue));
}
/**
* Get Image
*/
async getImageDalleGenerateGet(requestParameters: GetImageDalleGenerateGetRequest, initOverrides?: RequestInit | runtime.InitOverideFunction): Promise<DalleResponse> {
const response = await this.getImageDalleGenerateGetRaw(requestParameters, initOverrides);
return await response.value();
}
/**
* Get Irony
*/

View file

@ -0,0 +1,56 @@
/* tslint:disable */
/* eslint-disable */
/**
* FastAPI
* No description provided (generated by Openapi Generator https://github.com/openapitools/openapi-generator)
*
* The version of the OpenAPI document: 0.1.0
*
*
* NOTE: This class is auto generated by OpenAPI Generator (https://openapi-generator.tech).
* https://openapi-generator.tech
* Do not edit the class manually.
*/
import { exists, mapValues } from '../runtime';
/**
*
* @export
* @interface DalleResponse
*/
export interface DalleResponse {
/**
*
* @type {Array<string>}
* @memberof DalleResponse
*/
imagePaths: Array<string>;
}
export function DalleResponseFromJSON(json: any): DalleResponse {
return DalleResponseFromJSONTyped(json, false);
}
export function DalleResponseFromJSONTyped(json: any, ignoreDiscriminator: boolean): DalleResponse {
if ((json === undefined) || (json === null)) {
return json;
}
return {
'imagePaths': json['imagePaths'],
};
}
export function DalleResponseToJSON(value?: DalleResponse | null): any {
if (value === undefined) {
return undefined;
}
if (value === null) {
return null;
}
return {
'imagePaths': value.imagePaths,
};
}

View file

@ -1,5 +1,6 @@
/* tslint:disable */
/* eslint-disable */
export * from './DalleResponse';
export * from './HTTPValidationError';
export * from './LocationInner';
export * from './RobertaEmotionResponse';

View file

@ -3,7 +3,7 @@
}
.textBox {
height: 200px;
height: 150px;
font-family: Nunito;
}