From d21ec03b6f93839d7e642b49818398ead926288d Mon Sep 17 00:00:00 2001 From: bvanroll Date: Sun, 22 Jan 2023 17:51:56 +0100 Subject: [PATCH] interessant projectjen --- .gitignore | 3 +- data_prep/Cargo.toml | 12 +++++++ data_prep/src/main.rs | 42 +++++++++++++++++++++++ main/Cargo.toml | 4 ++- main/src/main.rs | 79 +++++++++++++++++++++---------------------- 5 files changed, 97 insertions(+), 43 deletions(-) create mode 100644 data_prep/Cargo.toml create mode 100644 data_prep/src/main.rs diff --git a/.gitignore b/.gitignore index d1db5a2..65530ee 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,5 @@ Cargo.lock # These are backup files generated by rustfmt **/*.rs.bk -secret/** \ No newline at end of file +secret/** +.idea/* \ No newline at end of file diff --git a/data_prep/Cargo.toml b/data_prep/Cargo.toml new file mode 100644 index 0000000..8d55952 --- /dev/null +++ b/data_prep/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "data_prep" +version = "0.1.0" +authors = ["bvanroll "] +edition = "2018" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +rfd = "0.10.0" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0.91" diff --git a/data_prep/src/main.rs b/data_prep/src/main.rs new file mode 100644 index 0000000..21d338a --- /dev/null +++ b/data_prep/src/main.rs @@ -0,0 +1,42 @@ +use std::fs::File; +use std::io::{BufRead, BufReader, Write}; +use serde::{Serialize}; + +#[derive(Serialize)] +struct MiniHuman { + name: String, + gender: String, +} + + +fn main() { + let mut miniHumans:Vec = Vec::new(); + let current_path = std::env::current_dir().unwrap(); + let res = rfd::FileDialog::new().set_directory(¤t_path).pick_file().unwrap(); + let file = File::open(res.as_path()).unwrap(); + let reader = BufReader::new(file); + for buffer in reader.lines() { + if let Ok(line) = buffer { + //# = comment + // let mut chars = line.chars(); + // let mut gender:String = String::from(chars.next().unwrap()); + // if gender != "?" && gender != "F" && gender != "M" { continue; } + // if gender == "?" { + // let preference = chars.next().unwrap(); + // gender = gender + &*(preference).to_string(); + // } + let mut parts = line.split(" "); + let gender = parts.next().unwrap().to_string(); + if gender.chars().next().unwrap() == '#' { continue; } + parts.next(); + let mut name = parts.next().unwrap().to_string(); + name = name.replace("+", " "); + miniHumans.push(MiniHuman{name: name, gender: gender}) + } + + } + let serialized: String = serde_json::to_string(&miniHumans).unwrap(); + let save_res = rfd::FileDialog::new().set_directory(¤t_path).save_file().unwrap(); + let mut file = File::create(save_res.as_path()).unwrap(); + file.write_all(serialized.as_bytes()).expect("oopsie"); +} diff --git a/main/Cargo.toml b/main/Cargo.toml index 8aa1ebc..c636283 100644 --- a/main/Cargo.toml +++ b/main/Cargo.toml @@ -10,5 +10,7 @@ edition = "2018" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0.91" rfd = "0.10.0" -chatgpt_rs = "0.6.0" +async-openai = "0.5.0" rand = "0.8" +tokio = { version = "1", features = ["full"] } +rust-bert = "0.20.0" diff --git a/main/src/main.rs b/main/src/main.rs index 6dfdf5b..668bb84 100644 --- a/main/src/main.rs +++ b/main/src/main.rs @@ -1,8 +1,11 @@ use std::fs::File; +use std::io; use std::io::{BufReader, Read, Write}; -use chatgpt::client::ChatGPT; +//use chatgpt::client::ChatGPT; +use async_openai::{Client, types::{CreateCompletionRequestArgs}}; use serde::{Serialize, Deserialize}; use rand::Rng; +use rust_bert::roberta::RobertaForQuestionAnswering; #[derive(Deserialize)] struct MiniHuman { @@ -19,7 +22,9 @@ struct Human { job: String, bio: String, } -fn main() { + +#[tokio::main] +async fn main() { let current_path = std::env::current_dir().unwrap(); let res = rfd::FileDialog::new().set_directory(¤t_path).pick_file().unwrap(); let mut file = File::open(res.as_path()).unwrap(); @@ -27,6 +32,8 @@ fn main() { file.read_to_string(&mut json_string).unwrap(); let mut MiniHumans: Vec = serde_json::from_str(&json_string).unwrap(); let mut Humans: Vec = Vec::new(); + let save_res = rfd::FileDialog::new().set_directory(¤t_path).save_file().unwrap(); + let mut client = Client::new(); while MiniHumans.len() > 1 { let (mut firstName, mut firstGender) = getRngName(&mut MiniHumans); let (mut lastName, mut lastGender) = getRngName(&mut MiniHumans); @@ -34,10 +41,13 @@ fn main() { if firstGender == "" {firstGender = lastGender.clone()} if lastGender == "" {lastGender = firstGender.clone()} let gender: String = decideGender(firstGender, lastGender); - Humans.push(getHuman(firstName, lastName, gender)) + let h = match getHuman(&mut client, firstName, lastName, gender).await { + Ok(h) => Humans.push(h), + Err(e) => println!("some err occured: {:?}", e.to_string()), + }; + break; } let serialized: String = serde_json::to_string(&Humans).unwrap(); - let save_res = rfd::FileDialog::new().set_directory(¤t_path).save_file().unwrap(); let mut file = File::create(save_res.as_path()).unwrap(); file.write_all(serialized.as_bytes()).expect("oopsie"); } @@ -54,48 +64,35 @@ fn decideGender(first: String, second: String) -> String { return first; } -async fn getHuman(firstName: String, lastName: String, gender: String) -> Human { - let mut token: String = std::env::var("SESSION_TOKEN").unwrap(); // obtain the session token. More on session tokens later. - let mut client = ChatGPT::new(token).unwrap(); - client.refresh_token().await.unwrap(); // it is recommended to refresh token after creating a client - let mut conversation = client.new_conversation(); - let mut finalGender = writeGender(gender); - let age: String = conversation.send_message(&client, "Can you give me the age of a hypothetical {} named {} {}").await.unwrap(); - let country: String = conversation.send_message(&client, "What country would this person be from?").await.unwrap(); - let job: String = conversation.send_message(&client, "What would this persons job title be?").await.unwrap(); - let bio: String = conversation.send_message(&client, "Could you write a short bio this person could use on their online profiles?").await.unwrap(); - return Human{ +async fn getHuman(client: &mut Client, firstName: String, lastName: String, gender: String) -> Result> { + let request = CreateCompletionRequestArgs::default() + .model("text-ada-001") + .prompt(format!("Write a short bio for a character called {} {} making sure to mention their age, gender, current country of residence and current jobtitle", firstName, lastName)) + .max_tokens(200_u16) + .build()?; + let res = client.completions().create(request).await; + let response = String::from(format!("{}", res?.choices.first().unwrap().text)); + + let (finalGender, age, country, job) = getHumanFromContext(response.clone()); + + return Ok(Human{ firstName: firstName, lastName: lastName, gender: finalGender, age: age, country: country, job: job, - bio: bio, - }; + bio: response, + }); } - -fn writeGender(gender: String) -> String { - if gender == "F" { - return "female".to_string() - } else if gender == "M" { - return "male".to_string() - } else if gender == "1M" { - if rand::thread_rng().gen_bool(1.0/9.0) { return "male".to_string() } - return "female".to_string() - } else if gender == "1F" { - if rand::thread_rng().gen_bool(1.0/9.0) { return "female".to_string() } - return "male".to_string() - } else if gender == "?" { - if rand::thread_rng().gen_bool(1.0/2.0) { return "male".to_string() } - return "female".to_string() - } else if gender == "?F" { - if rand::thread_rng().gen_bool(1.0/5.0) { return "male".to_string() } - return "female".to_string() - } else if gender == "?M" { - if rand::thread_rng().gen_bool(1.0/5.0) { return "female".to_string() } - return "male".to_string() - } - - return "male".to_string() +//returns in order: gender, age, country, job +fn getHumanFromContext(context: String, firstName: String) -> (String, String, String, String) { + //TODO use the other ai to get answers from a given context + let qa_model = QuestionAnsweringModel::new(Default::default())?; + let gender = String::from(format!("What is {}'s gender?", firstname)); + let age = String::from(format!("What is {}'s age?", firstName)); + let country= String::from(format!("Where does {} live?", firstName)); + let job = String::from(format!("What is {}'s job?", firstName)); + let answers = qa_model.predict(&[QaInput { question, context }], 1, 32); + return ("".to_string(), "".to_string(), "".to_string(), "".to_string()) } \ No newline at end of file