diff --git a/finalise_from_context/src/main.rs b/finalise_from_context/src/main.rs index 290a3d5..1d239b2 100644 --- a/finalise_from_context/src/main.rs +++ b/finalise_from_context/src/main.rs @@ -14,7 +14,7 @@ use rust_bert::resources::RemoteResource; -#[derive(Deserialize, Serialize)] +#[derive(Deserialize, Serialize, Clone)] struct Human { firstName: String, lastName: String, @@ -38,34 +38,6 @@ fn main() { let mut i = 0; let mut l = &Humans.len().clone(); println!("there are {} humans to process", l - i); - - for mut human in &mut Humans { - //let (gender, age, country, job) = getHumanFromContext( - match getHumanFromContext(human.bio.clone(), human.firstName.clone()) { - Ok((gender, age, country, job)) => { - human.gender = gender; - human.age = age; - human.country = country; - human.job = job; - println!("just did {} at index {}", human.firstName.clone(), i); - }, - Err(e) => { - println!("skipping {} because of {}", human.firstName.clone(), e.to_string()) - } - } - println!("There are {} humans left to process", l - (i+1)); - i = i+1; - if (i > 100) { break; } - } - - let serialized: String = serde_json::to_string(&Humans).unwrap(); - let mut file = File::create(save_res.as_path()).unwrap(); - file.write_all(serialized.as_bytes()).expect("oopsie"); -} - - -fn getHumanFromContext(context: String, firstName: String) -> Result<(String, String, String, String), Box> { - //TODO use the other ai to get answers from a given context let bertconfig = QuestionAnsweringConfig::new( ModelType::Bert, RemoteResource::from_pretrained(BertModelResources::BERT_QA), @@ -76,7 +48,99 @@ fn getHumanFromContext(context: String, firstName: String) -> Result<(String, St false, None, ); - let mut model = QuestionAnsweringModel::new(bertconfig)?; + let mut model = QuestionAnsweringModel::new(bertconfig).unwrap(); + let mut questions: Vec = Vec::new(); + for human in &mut Humans { + questions.extend(getQuestions(human.bio.clone(), human.firstName.clone())); + //let (gender, age, country, job) = getHumanFromContext( + // match getHumanFromContext(&model, human.bio.clone(), human.firstName.clone()) { + // Ok((gender, age, country, job)) => { + // human.gender = gender; + // human.age = age; + // human.country = country; + // human.job = job; + // println!("just did {} at index {}", human.firstName.clone(), i); + // }, + // Err(e) => { + // println!("skipping {} because of {}", human.firstName.clone(), e.to_string()) + // } + // } + println!("There are {} humans left to process", l - (i+1)); + i = i+1; + } + println!("predicting {} questions:", questions.len()); + let mut answers = model.predict(&questions, 1, 32); + println!("answers gotten:"); + for i in &answers { + println!("{:?}", i); + } + + let mut looper = answers.iter(); + i = 0; + let mut finishedHumans: Vec = Vec::new(); + for mut human in &mut Humans { + println!("{}'s attributes ", human.firstName.clone()); + let mut h = human.clone(); + match looper.next().expect("euh").first() { + Ok(gender) => h.gender = gender.answer.clone(), + Err(e) => println!("failed to get gender for {}: {:?}", human.firstName.clone(), e) + } + match looper.next().expect("euh").first() { + Ok(age) => h.age = age.answer.clone(), + Err(e) => println!("failed to get age for {}: {:?}", human.firstName.clone(), e) + } + match looper.next().expect("euh").first() { + Ok(country) => h.country = country.answer.clone(), + Err(e) => println!("failed to get country for {}: {:?}", human.firstname.clone(), e) + } + match looper.next().expect("euh").first() { + Ok(job) => h.job = job.answer.clone(), + Err(e) => println!("failed to get job for {}: {:?}", human.firstName.clone(), e) + } + finishedHumans.push(h); + println!("{} to go", l - i); + i = i+1; + } + + //let iter: Vec<&[Answer]> = answers.first().unwrap().chunks(16).collect(); + //println!("{:?}", iter); + // for item in iter { + // for i in item { + // println!("{}", i.answer.to_string()) + // } + // //println!("{:?}", item) + // } + + let serialized: String = serde_json::to_string(&finishedHumans).unwrap(); + let mut file = File::create(save_res.as_path()).unwrap(); + file.write_all(serialized.as_bytes()).expect("oopsie"); +} +fn getQuestions(context: String, firstName: String) -> Vec { + let mut temp :Vec= Vec::new(); + temp.push(QaInput { + question: format!("What is {}'s gender?", firstName), + context: context.clone() + }); + temp.push(QaInput { + question: format!("What is {}'s age?", firstName), + context: context.clone() + }); + temp.push(QaInput { + question: format!("Where does {} live?", firstName), + context: context.clone() + }); + temp.push(QaInput { + question: format!("What is {}'s job?", firstName), + context: context.clone() + }); + return temp; + +} + + +fn getHumanFromContext(model: &QuestionAnsweringModel, context: String, firstName: String) -> Result<(String, String, String, String), Box> { + //TODO make this more efficient + let mut genderQuestion = QaInput { question: format!("What is {}'s gender?", firstName), context: context.clone() @@ -100,6 +164,7 @@ fn getHumanFromContext(context: String, firstName: String) -> Result<(String, St let mut age = looper.next().expect("euh").first(); let mut country= looper.next().expect("euh").first(); let mut job = looper.next().expect("euh").first(); + //Err("whatthefuck"); return Ok((gender.unwrap().answer.clone(), age.unwrap().answer.clone(), country.unwrap().answer.clone(), job.unwrap().answer.clone()))