definetly run on pc instead of laptop sheesh

This commit is contained in:
2023-01-24 22:05:03 +01:00
parent 4c39b063f5
commit 376e8c7321

View File

@@ -14,7 +14,7 @@ use rust_bert::resources::RemoteResource;
#[derive(Deserialize, Serialize)] #[derive(Deserialize, Serialize, Clone)]
struct Human { struct Human {
firstName: String, firstName: String,
lastName: String, lastName: String,
@@ -38,34 +38,6 @@ fn main() {
let mut i = 0; let mut i = 0;
let mut l = &Humans.len().clone(); let mut l = &Humans.len().clone();
println!("there are {} humans to process", l - i); println!("there are {} humans to process", l - i);
for mut human in &mut Humans {
//let (gender, age, country, job) = getHumanFromContext(
match getHumanFromContext(human.bio.clone(), human.firstName.clone()) {
Ok((gender, age, country, job)) => {
human.gender = gender;
human.age = age;
human.country = country;
human.job = job;
println!("just did {} at index {}", human.firstName.clone(), i);
},
Err(e) => {
println!("skipping {} because of {}", human.firstName.clone(), e.to_string())
}
}
println!("There are {} humans left to process", l - (i+1));
i = i+1;
if (i > 100) { break; }
}
let serialized: String = serde_json::to_string(&Humans).unwrap();
let mut file = File::create(save_res.as_path()).unwrap();
file.write_all(serialized.as_bytes()).expect("oopsie");
}
fn getHumanFromContext(context: String, firstName: String) -> Result<(String, String, String, String), Box<dyn std::error::Error>> {
//TODO use the other ai to get answers from a given context
let bertconfig = QuestionAnsweringConfig::new( let bertconfig = QuestionAnsweringConfig::new(
ModelType::Bert, ModelType::Bert,
RemoteResource::from_pretrained(BertModelResources::BERT_QA), RemoteResource::from_pretrained(BertModelResources::BERT_QA),
@@ -76,7 +48,99 @@ fn getHumanFromContext(context: String, firstName: String) -> Result<(String, St
false, false,
None, None,
); );
let mut model = QuestionAnsweringModel::new(bertconfig)?; let mut model = QuestionAnsweringModel::new(bertconfig).unwrap();
let mut questions: Vec<QaInput> = Vec::new();
for human in &mut Humans {
questions.extend(getQuestions(human.bio.clone(), human.firstName.clone()));
//let (gender, age, country, job) = getHumanFromContext(
// match getHumanFromContext(&model, human.bio.clone(), human.firstName.clone()) {
// Ok((gender, age, country, job)) => {
// human.gender = gender;
// human.age = age;
// human.country = country;
// human.job = job;
// println!("just did {} at index {}", human.firstName.clone(), i);
// },
// Err(e) => {
// println!("skipping {} because of {}", human.firstName.clone(), e.to_string())
// }
// }
println!("There are {} humans left to process", l - (i+1));
i = i+1;
}
println!("predicting {} questions:", questions.len());
let mut answers = model.predict(&questions, 1, 32);
println!("answers gotten:");
for i in &answers {
println!("{:?}", i);
}
let mut looper = answers.iter();
i = 0;
let mut finishedHumans: Vec<Human> = Vec::new();
for mut human in &mut Humans {
println!("{}'s attributes ", human.firstName.clone());
let mut h = human.clone();
match looper.next().expect("euh").first() {
Ok(gender) => h.gender = gender.answer.clone(),
Err(e) => println!("failed to get gender for {}: {:?}", human.firstName.clone(), e)
}
match looper.next().expect("euh").first() {
Ok(age) => h.age = age.answer.clone(),
Err(e) => println!("failed to get age for {}: {:?}", human.firstName.clone(), e)
}
match looper.next().expect("euh").first() {
Ok(country) => h.country = country.answer.clone(),
Err(e) => println!("failed to get country for {}: {:?}", human.firstname.clone(), e)
}
match looper.next().expect("euh").first() {
Ok(job) => h.job = job.answer.clone(),
Err(e) => println!("failed to get job for {}: {:?}", human.firstName.clone(), e)
}
finishedHumans.push(h);
println!("{} to go", l - i);
i = i+1;
}
//let iter: Vec<&[Answer]> = answers.first().unwrap().chunks(16).collect();
//println!("{:?}", iter);
// for item in iter {
// for i in item {
// println!("{}", i.answer.to_string())
// }
// //println!("{:?}", item)
// }
let serialized: String = serde_json::to_string(&finishedHumans).unwrap();
let mut file = File::create(save_res.as_path()).unwrap();
file.write_all(serialized.as_bytes()).expect("oopsie");
}
fn getQuestions(context: String, firstName: String) -> Vec<QaInput> {
let mut temp :Vec<QaInput>= Vec::new();
temp.push(QaInput {
question: format!("What is {}'s gender?", firstName),
context: context.clone()
});
temp.push(QaInput {
question: format!("What is {}'s age?", firstName),
context: context.clone()
});
temp.push(QaInput {
question: format!("Where does {} live?", firstName),
context: context.clone()
});
temp.push(QaInput {
question: format!("What is {}'s job?", firstName),
context: context.clone()
});
return temp;
}
fn getHumanFromContext(model: &QuestionAnsweringModel, context: String, firstName: String) -> Result<(String, String, String, String), Box<dyn std::error::Error>> {
//TODO make this more efficient
let mut genderQuestion = QaInput { let mut genderQuestion = QaInput {
question: format!("What is {}'s gender?", firstName), question: format!("What is {}'s gender?", firstName),
context: context.clone() context: context.clone()
@@ -100,6 +164,7 @@ fn getHumanFromContext(context: String, firstName: String) -> Result<(String, St
let mut age = looper.next().expect("euh").first(); let mut age = looper.next().expect("euh").first();
let mut country= looper.next().expect("euh").first(); let mut country= looper.next().expect("euh").first();
let mut job = looper.next().expect("euh").first(); let mut job = looper.next().expect("euh").first();
//Err("whatthefuck");
return Ok((gender.unwrap().answer.clone(), age.unwrap().answer.clone(), country.unwrap().answer.clone(), job.unwrap().answer.clone())) return Ok((gender.unwrap().answer.clone(), age.unwrap().answer.clone(), country.unwrap().answer.clone(), job.unwrap().answer.clone()))