#include "script.h"
#include "keyboard.h"
#include <fstream>
#include <string>
#include <ctime>
#include <direct.h>
#include <vector>
#include <onnxruntime_cxx_api.h>
Ort::Env* env = nullptr;
Ort::SessionOptions* session_options = nullptr;
Ort::Session* session = nullptr;
const float SCALER_MEAN[8] = { 40.31f, 13.32f, 17.94f, 44.74f, 77.81f, 35.56f, 16.78f, 12.84f };
const float SCALER_STD[8] = { 5.48f, 9.25f, 9.72f, 24.72f, 32.81f, 18.14f, 4.20f, 5.41f };
const int TICK_RATE = 300;
const int NUM_RAYS = 7;
const float RAY_ANGLES[NUM_RAYS] = { -60.0f, -40.0f, -20.0f, 0.0f, 20.0f, 40.0f, 60.0f };
const float MAX_DISTANCE = 120.0f;
const std::string FILE_PATH = "C:\\Users\\sgemm\\Desktop\\Racebot\\gtav-racebot\\data\\"; // this needs to be refactored
enum State {
Default,
Debug,
AIMode,
Logging,
};
State state = Default;
std::ofstream logFile;
int sampleCount = 0;
int frameCounter = 0;
void showNotification(const char* text) {
UI::_SET_NOTIFICATION_TEXT_ENTRY("STRING");
UI::_ADD_TEXT_COMPONENT_STRING((char*)text);
UI::_DRAW_NOTIFICATION(false, false);
}
Vector3 rotateVector(Vector3 vec, float angleDeg) {
float rad = angleDeg * 3.14159265f / 180.0f;
Vector3 result;
result.x = vec.x * cos(rad) + vec.y * sin(rad);
result.y = -vec.x * sin(rad) + vec.y * cos(rad);
result.z = vec.z;
return result;
}
void drawRaycastDebug(Vehicle vehicle) {
Vector3 pos = ENTITY::GET_ENTITY_COORDS(vehicle, true);
Vector3 forward = ENTITY::GET_ENTITY_FORWARD_VECTOR(vehicle);
forward.z = 0.0f;
pos.z += 0.5f;
for (int i = 0; i < NUM_RAYS; i++) {
Vector3 direction = rotateVector(forward, RAY_ANGLES[i]);
Vector3 endPos;
endPos.x = pos.x + direction.x * MAX_DISTANCE;
endPos.y = pos.y + direction.y * MAX_DISTANCE;
endPos.z = pos.z + direction.z * MAX_DISTANCE;
int rayHandle = WORLDPROBE::_CAST_RAY_POINT_TO_POINT(
pos.x, pos.y, pos.z, endPos.x, endPos.y, endPos.z,
1 | 16, vehicle, 7
);
BOOL hit;
Vector3 hitCoords, surfaceNormal;
Entity hitEntity;
WORLDPROBE::_GET_RAYCAST_RESULT(rayHandle, &hit, &hitCoords, &surfaceNormal, &hitEntity);
int r = hit ? 0 : 255;
int g = hit ? 255 : 0;
Vector3 drawEnd = hit ? hitCoords : endPos;
GRAPHICS::DRAW_LINE(pos.x, pos.y, pos.z, drawEnd.x, drawEnd.y, drawEnd.z, r, g, 0, 255);
}
}
float castSingleRay(Vehicle vehicle, float angleDeg) {
Vector3 pos = ENTITY::GET_ENTITY_COORDS(vehicle, true);
Vector3 forward = ENTITY::GET_ENTITY_FORWARD_VECTOR(vehicle);
forward.z = 0.0f;
pos.z += 0.5f;
Vector3 direction = rotateVector(forward, angleDeg);
Vector3 endPos;
endPos.x = pos.x + direction.x * MAX_DISTANCE;
endPos.y = pos.y + direction.y * MAX_DISTANCE;
endPos.z = pos.z + direction.z * MAX_DISTANCE;
int rayHandle = WORLDPROBE::_CAST_RAY_POINT_TO_POINT(
pos.x, pos.y, pos.z, endPos.x, endPos.y, endPos.z,
1 | 16, vehicle, 7
);
BOOL hit;
Vector3 hitCoords, surfaceNormal;
Entity hitEntity;
WORLDPROBE::_GET_RAYCAST_RESULT(rayHandle, &hit, &hitCoords, &surfaceNormal, &hitEntity);
if (hit) {
float dx = hitCoords.x - pos.x;
float dy = hitCoords.y - pos.y;
float dz = hitCoords.z - pos.z;
return sqrt(dx * dx + dy * dy + dz * dz);
}
return MAX_DISTANCE;
}
void startLogging() {
time_t now = time(0);
tm ltm;
localtime_s(<m, &now);
char filename[512];
sprintf_s(filename, "%sracing_data_%04d%02d%02d_%02d%02d%02d.csv",
FILE_PATH.c_str(),
1900 + ltm.tm_year, 1 + ltm.tm_mon, ltm.tm_mday,
ltm.tm_hour, ltm.tm_min, ltm.tm_sec);
logFile.open(filename);
logFile << "speed,";
logFile << "ray_0,ray_1,ray_2,ray_3,ray_4,ray_5,ray_6,";
logFile << "steering,throttle\n";
sampleCount = 0;
char buffer[256];
sprintf_s(buffer, "Logging started: %s", filename);
showNotification(buffer);
}
void logData(Vehicle vehicle) {
if (state != Logging || !logFile.is_open()) return;
float speed = ENTITY::GET_ENTITY_SPEED(vehicle);
float distances[NUM_RAYS];
for (int i = 0; i < NUM_RAYS; i++)
distances[i] = castSingleRay(vehicle, RAY_ANGLES[i]);
float steering = CONTROLS::GET_CONTROL_NORMAL(0, 59);
float throttle = CONTROLS::GET_CONTROL_NORMAL(0, 71);
logFile << speed << ",";
for (int i = 0; i < NUM_RAYS; i++)
logFile << distances[i] << ",";
logFile << steering << "," << throttle << "," << "\n";
sampleCount++;
}
void stopLogging() {
logFile.close();
char buffer[128];
sprintf_s(buffer, "Logging stopped. Samples: %d", sampleCount);
showNotification(buffer);
}
void handleInput() {
// F9: Default
if (IsKeyJustUp(VK_F9)) {
switch (state) {
case AIMode:
break;
case Logging:
stopLogging();
break;
case Debug:
break;
}
state = Default;
showNotification("DEFAULT MODE");
}
// F10: AI
if (IsKeyJustUp(VK_F10)) {
state = AIMode;
showNotification("AI MODE");
}
// F11: Logging
if (IsKeyJustUp(VK_F11)) {
startLogging();
state = Logging;
showNotification("LOGGING MODE");
}
// F12: Debug
if (IsKeyJustUp(VK_F12)) {
state = Debug;
showNotification("DEBUGGING MODE");
}
}
void runAI(Vehicle vehicle) {
if (!session || !vehicle) return;
float speed = ENTITY::GET_ENTITY_SPEED(vehicle);
std::vector<float> input_vector(8);
input_vector[0] = speed;
for (int i = 0; i < NUM_RAYS; i++)
input_vector[i + 1] = castSingleRay(vehicle, RAY_ANGLES[i]);
for (int i = 0; i < 8; i++)
input_vector[i] = (input_vector[i] - SCALER_MEAN[i]) / SCALER_STD[i];
Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
std::array<int64_t, 2> input_shape = { 1, 8 };
Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
memory_info,
input_vector.data(),
input_vector.size(),
input_shape.data(),
input_shape.size()
);
const char* input_names[] = { "input" };
const char* output_names[] = { "sequential_1" };
auto output_tensors = session->Run(
Ort::RunOptions{ nullptr },
input_names, &input_tensor, 1,
output_names, 1
);
float* output = output_tensors[0].GetTensorMutableData<float>();
float steering = tanh(output[0]); // Apply tanh (matches training)
float throttle = fmax(0.0f, fmin(1.0f, output[1])); // Clip [0, 1] (matches training)
float amplified_steering = steering * 2.0f; // Double the steering signal
amplified_steering = fmax(-1.0f, fmin(1.0f, amplified_steering)); // Clamp to [-1, 1]
CONTROLS::_SET_CONTROL_NORMAL(0, 59, amplified_steering);
CONTROLS::_SET_CONTROL_NORMAL(0, 71, throttle * 0.85f);
}
void ScriptMain() {
const wchar_t* model_path = L"C:\\Program Files\\Rockstar Games\\Grand Theft Auto V\\racebot_model.onnx"; // this is shit
env = new Ort::Env(ORT_LOGGING_LEVEL_WARNING, "GTAVRaceBot");
session_options = new Ort::SessionOptions();
session_options->SetIntraOpNumThreads(1);
session_options->SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_BASIC);
session = new Ort::Session(*env, model_path, *session_options);
showNotification("Basic Controls: \nF9: Default \nF10: AI Mode \nF11: Start/Stop Logging \nF12: Debug");
while (true) {
WAIT(0);
frameCounter++;
// Disable traffic
VEHICLE::SET_VEHICLE_DENSITY_MULTIPLIER_THIS_FRAME(0.0f);
VEHICLE::SET_RANDOM_VEHICLE_DENSITY_MULTIPLIER_THIS_FRAME(0.0f);
PED::SET_PED_DENSITY_MULTIPLIER_THIS_FRAME(0.0f);
Ped player = PLAYER::PLAYER_PED_ID();
if (!PED::IS_PED_IN_ANY_VEHICLE(player, false)) continue;
Vehicle vehicle = PED::GET_VEHICLE_PED_IS_IN(player, false);
handleInput();
switch (state) {
case Default:
break;
case AIMode:
runAI(vehicle);
break;
case Logging:
logData(vehicle);
break;
case Debug:
drawRaycastDebug(vehicle);
if (frameCounter % TICK_RATE == 0) {
float distances[NUM_RAYS];
for (int i = 0; i < NUM_RAYS; i++)
distances[i] = castSingleRay(vehicle, RAY_ANGLES[i]);
float speed = ENTITY::GET_ENTITY_SPEED(vehicle);
float steering = CONTROLS::GET_CONTROL_NORMAL(0, 59);
float throttle = CONTROLS::GET_CONTROL_NORMAL(0, 71);
char buffer[256];
sprintf_s(buffer, "Speed:%.1f Throttle:%.1f Steer:%.2f~n~L:%.0f %.0f %.0f C:%.0f R:%.0f %.0f %.0f",
speed, throttle, steering,
distances[0], distances[1], distances[2],
distances[3],
distances[4], distances[5], distances[6]);
showNotification(buffer);
}
break;
}
}
}