Initial commit of Hearbit AI App
This commit is contained in:
390
src-tauri/src/lib.rs
Normal file
390
src-tauri/src/lib.rs
Normal file
@@ -0,0 +1,390 @@
|
||||
use tauri::State;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::process::Command;
|
||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
use std::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
|
||||
// State to hold the active recording stream
|
||||
struct AppState {
|
||||
recording_stream: Mutex<Option<cpal::Stream>>,
|
||||
recording_file_path: Mutex<Option<String>>,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct AudioDevice {
|
||||
id: String,
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
fn greet(name: &str) -> String {
|
||||
format!("Hello, {}! You've been greeted from Rust!", name)
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
fn get_input_devices() -> Result<Vec<AudioDevice>, String> {
|
||||
let host = cpal::default_host();
|
||||
let devices = host.input_devices().map_err(|e| e.to_string())?;
|
||||
|
||||
let mut result = Vec::new();
|
||||
for device in devices {
|
||||
#[allow(deprecated)]
|
||||
if let Ok(name) = device.name() {
|
||||
// macOS often produces weird names, but let's just use what we get
|
||||
result.push(AudioDevice {
|
||||
id: name.clone(), // Using name as ID for simplicity in this MVP
|
||||
name,
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
fn install_driver() -> Result<String, String> {
|
||||
let output = Command::new("brew")
|
||||
.args(["install", "blackhole-2ch"])
|
||||
.output()
|
||||
.map_err(|e| format!("Failed to execute command: {}", e))?;
|
||||
|
||||
if output.status.success() {
|
||||
Ok(String::from_utf8_lossy(&output.stdout).to_string())
|
||||
} else {
|
||||
Err(String::from_utf8_lossy(&output.stderr).to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
fn start_recording(state: State<'_, AppState>, device_id: String) -> Result<(), String> {
|
||||
let host = cpal::default_host();
|
||||
|
||||
// Find device by name (using name as ID)
|
||||
#[allow(deprecated)]
|
||||
let device = host.input_devices()
|
||||
.map_err(|e| e.to_string())?
|
||||
.find(|d| d.name().map(|n| n == device_id).unwrap_or(false))
|
||||
.or_else(|| host.default_input_device())
|
||||
.ok_or("No input device found")?;
|
||||
|
||||
let config = device.default_input_config().map_err(|e| e.to_string())?;
|
||||
let spec = hound::WavSpec {
|
||||
channels: config.channels(),
|
||||
sample_rate: config.sample_rate(),
|
||||
bits_per_sample: 16,
|
||||
sample_format: hound::SampleFormat::Int,
|
||||
};
|
||||
|
||||
// Create a temporary file
|
||||
let temp_dir = std::env::temp_dir();
|
||||
let file_path = temp_dir.join(format!("recording_{}.wav", std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs()));
|
||||
let file_path_str = file_path.to_string_lossy().to_string();
|
||||
|
||||
let writer = hound::WavWriter::create(&file_path, spec).map_err(|e| e.to_string())?;
|
||||
let writer = Arc::new(Mutex::new(writer));
|
||||
let writer_clone = writer.clone();
|
||||
|
||||
let err_fn = |err| eprintln!("an error occurred on stream: {}", err);
|
||||
|
||||
let stream = match config.sample_format() {
|
||||
cpal::SampleFormat::F32 => device.build_input_stream(
|
||||
&config.into(),
|
||||
move |data: &[f32], _: &_| {
|
||||
let mut guard = writer_clone.lock().unwrap();
|
||||
for &sample in data {
|
||||
let amplitude = i16::MAX as f32;
|
||||
guard.write_sample((sample * amplitude) as i16).ok();
|
||||
}
|
||||
},
|
||||
err_fn,
|
||||
None
|
||||
),
|
||||
cpal::SampleFormat::I16 => device.build_input_stream(
|
||||
&config.into(),
|
||||
move |data: &[i16], _: &_| {
|
||||
let mut guard = writer_clone.lock().unwrap();
|
||||
for &sample in data {
|
||||
guard.write_sample(sample).ok();
|
||||
}
|
||||
},
|
||||
err_fn,
|
||||
None
|
||||
),
|
||||
cpal::SampleFormat::U16 => device.build_input_stream(
|
||||
&config.into(),
|
||||
move |data: &[u16], _: &_| {
|
||||
let mut guard = writer_clone.lock().unwrap();
|
||||
for &sample in data {
|
||||
guard.write_sample((sample as i32 - 32768) as i16).ok();
|
||||
}
|
||||
},
|
||||
err_fn,
|
||||
None
|
||||
),
|
||||
_ => return Err("Unsupported sample format".to_string()),
|
||||
}.map_err(|e| e.to_string())?;
|
||||
|
||||
stream.play().map_err(|e| e.to_string())?;
|
||||
|
||||
// Store state
|
||||
*state.recording_stream.lock().unwrap() = Some(stream);
|
||||
*state.recording_file_path.lock().unwrap() = Some(file_path_str);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
fn stop_recording(state: State<'_, AppState>) -> Result<String, String> {
|
||||
// Drop stream to stop recording
|
||||
{
|
||||
let mut stream_guard = state.recording_stream.lock().unwrap();
|
||||
if stream_guard.is_none() {
|
||||
return Err("Not recording".to_string());
|
||||
}
|
||||
*stream_guard = None; // This drops the stream and stops recording
|
||||
}
|
||||
|
||||
// Return file path
|
||||
let mut path_guard = state.recording_file_path.lock().unwrap();
|
||||
path_guard.take().ok_or("No recording path found".to_string())
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct ModelListResponse {
|
||||
data: Vec<ModelData>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct ModelData {
|
||||
id: String,
|
||||
owned_by: Option<String>,
|
||||
}
|
||||
|
||||
// Structs for Infomaniak API responses
|
||||
#[derive(serde::Deserialize)]
|
||||
struct WhisperResponse {
|
||||
text: Option<String>,
|
||||
batch_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct ChatCompletionResponse {
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
#[derive(serde::Deserialize)]
|
||||
struct Choice {
|
||||
message: Message,
|
||||
}
|
||||
#[derive(serde::Deserialize)]
|
||||
struct Message {
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct ModelInfo {
|
||||
id: String,
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
async fn get_available_models(api_key: String, product_id: String) -> Result<Vec<ModelInfo>, String> {
|
||||
let client = reqwest::Client::new();
|
||||
// Use the v2/openai compliant endpoint as per docs
|
||||
let url = format!("https://api.infomaniak.com/2/ai/{}/openai/v1/models", product_id);
|
||||
|
||||
let res = client.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
if res.status().is_success() {
|
||||
let raw_body = res.text().await.map_err(|e| e.to_string())?;
|
||||
println!("Models Raw Response: {}", raw_body);
|
||||
let list: ModelListResponse = serde_json::from_str(&raw_body)
|
||||
.map_err(|e| format!("Failed to parse models: {}. Body: {}", e, raw_body))?;
|
||||
|
||||
let models = list.data.into_iter()
|
||||
.filter(|m| !m.id.to_lowercase().contains("mini_lm") && !m.id.to_lowercase().contains("bert") && !m.id.to_lowercase().contains("embedding"))
|
||||
.map(|m| ModelInfo {
|
||||
id: m.id.clone(),
|
||||
name: m.id, // Use ID as name for now, or fetch more details if available
|
||||
}).collect();
|
||||
|
||||
Ok(models)
|
||||
} else {
|
||||
// Fallback to v1 if v2 fails or try another common path?
|
||||
// For now just error out
|
||||
let err = res.text().await.unwrap_or_default();
|
||||
Err(format!("Failed to fetch models: {}", err))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
#[tauri::command]
|
||||
async fn transcribe_audio(file_path: String, api_key: String, product_id: String) -> Result<String, String> {
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Prepare file part
|
||||
let file_bytes = std::fs::read(&file_path).map_err(|e| e.to_string())?;
|
||||
// We must use a known file name for the part, Infomaniak might care, or not.
|
||||
let file_part = reqwest::multipart::Part::bytes(file_bytes)
|
||||
.file_name("recording.wav")
|
||||
.mime_str("audio/wav")
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
let form = reqwest::multipart::Form::new()
|
||||
.part("file", file_part)
|
||||
.text("model", "whisper");
|
||||
|
||||
let url = format!("https://api.infomaniak.com/1/ai/{}/openai/audio/transcriptions", product_id);
|
||||
|
||||
let res = client.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.multipart(form)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
if res.status().is_success() {
|
||||
let raw_body = res.text().await.map_err(|e| e.to_string())?;
|
||||
println!("Transcription Raw Response: {}", raw_body);
|
||||
|
||||
// Attempt to parse text or batch_id
|
||||
// Attempt to parse text or batch_id
|
||||
let response: WhisperResponse = serde_json::from_str(&raw_body)
|
||||
.map_err(|e| format!("Failed to decode JSON: {}. Body: {}", e, raw_body))?;
|
||||
|
||||
match (response.text, response.batch_id) {
|
||||
(Some(text), _) => Ok(text),
|
||||
(_, Some(batch_id)) => {
|
||||
// Need to poll
|
||||
poll_transcription(&client, &api_key, &product_id, &batch_id).await
|
||||
},
|
||||
_ => Err(format!("Response contained neither text nor batch_id. Body: {}", raw_body))
|
||||
}
|
||||
} else {
|
||||
let error_text = res.text().await.unwrap_or_default();
|
||||
Err(format!("Transcription failed: {}", error_text))
|
||||
}
|
||||
}
|
||||
|
||||
async fn poll_transcription(client: &reqwest::Client, api_key: &str, product_id: &str, batch_id: &str) -> Result<String, String> {
|
||||
// Polling URL: /1/ai/{product_id}/results/{batch_id} (or similar, verifying via trial)
|
||||
// If that fails, we can try /openai/audio/transcriptions/{batch_id} but documentation suggests results endpoint.
|
||||
// Let's assume the standard Infomaniak pattern for batches.
|
||||
let status_url = format!("https://api.infomaniak.com/1/ai/{}/results/{}", product_id, batch_id);
|
||||
|
||||
let mut attempts = 0;
|
||||
while attempts < 40 { // 40 * 2s = 80s timeout
|
||||
attempts += 1;
|
||||
sleep(Duration::from_secs(2)).await;
|
||||
|
||||
let res = client.get(&status_url)
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Polling error: {}", e))?;
|
||||
|
||||
if res.status().is_success() {
|
||||
let json: serde_json::Value = res.json().await.map_err(|e| e.to_string())?;
|
||||
// Check 'status'
|
||||
if let Some(status) = json.get("status").and_then(|s| s.as_str()) {
|
||||
if status == "success" {
|
||||
// Download the result
|
||||
let download_url = format!("https://api.infomaniak.com/1/ai/{}/results/{}/download", product_id, batch_id);
|
||||
let dl_res = client.get(&download_url)
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
if dl_res.status().is_success() {
|
||||
let content = dl_res.text().await.map_err(|e| e.to_string())?;
|
||||
|
||||
// Try to parse the content as JSON to see if it's { "text": "..." }
|
||||
if let Ok(json_val) = serde_json::from_str::<serde_json::Value>(&content) {
|
||||
if let Some(text_content) = json_val.get("text").and_then(|t| t.as_str()) {
|
||||
return Ok(text_content.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
// If not JSON or no text field, return raw content
|
||||
return Ok(content);
|
||||
} else {
|
||||
return Err(format!("Download failed: {}", dl_res.status()));
|
||||
}
|
||||
} else if status == "failed" || status == "error" {
|
||||
return Err(format!("Batch processing failed: {:?}", json));
|
||||
}
|
||||
// If 'processing' or 'pending', continue loop
|
||||
}
|
||||
}
|
||||
}
|
||||
Err("Transcription timed out".to_string())
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
async fn summarize_text(text: String, api_key: String, product_id: String, prompt: String, model: String) -> Result<String, String> {
|
||||
let client = reqwest::Client::new();
|
||||
let url = format!("https://api.infomaniak.com/2/ai/{}/openai/v1/chat/completions", product_id);
|
||||
|
||||
let messages = serde_json::json!([
|
||||
{ "role": "system", "content": prompt },
|
||||
{ "role": "user", "content": text }
|
||||
]);
|
||||
|
||||
let model_to_use = if model.is_empty() { "mixtral".to_string() } else { model };
|
||||
|
||||
let body = serde_json::json!({
|
||||
"model": model_to_use,
|
||||
"messages": messages
|
||||
});
|
||||
|
||||
let res = client.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", api_key))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| e.to_string())?;
|
||||
|
||||
if res.status().is_success() {
|
||||
let raw_body = res.text().await.map_err(|e| e.to_string())?;
|
||||
println!("Summarization Raw Response: {}", raw_body);
|
||||
|
||||
let response_body: ChatCompletionResponse = serde_json::from_str(&raw_body)
|
||||
.map_err(|e| format!("Failed to decode JSON: {}. Body: {}", e, raw_body))?;
|
||||
|
||||
if let Some(choice) = response_body.choices.first() {
|
||||
Ok(choice.message.content.clone())
|
||||
} else {
|
||||
Err("No summary generated".to_string())
|
||||
}
|
||||
} else {
|
||||
let error_text = res.text().await.unwrap_or_default();
|
||||
Err(format!("Summarization failed: {}", error_text))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(mobile, tauri::mobile_entry_point)]
|
||||
pub fn run() {
|
||||
tauri::Builder::default()
|
||||
.plugin(tauri_plugin_opener::init())
|
||||
.manage(AppState {
|
||||
recording_stream: Mutex::new(None),
|
||||
recording_file_path: Mutex::new(None),
|
||||
})
|
||||
.invoke_handler(tauri::generate_handler![
|
||||
greet,
|
||||
get_input_devices,
|
||||
install_driver,
|
||||
start_recording,
|
||||
stop_recording,
|
||||
transcribe_audio,
|
||||
summarize_text,
|
||||
get_available_models
|
||||
])
|
||||
.run(tauri::generate_context!())
|
||||
.expect("error while running tauri application");
|
||||
}
|
||||
Reference in New Issue
Block a user