gtav-racebot / src / script.cpp
script.cpp
Raw
#include "script.h"
#include "keyboard.h"
#include <fstream>
#include <string>
#include <ctime>
#include <direct.h>
#include <vector>
#include <onnxruntime_cxx_api.h>

Ort::Env* env = nullptr;
Ort::SessionOptions* session_options = nullptr;
Ort::Session* session = nullptr;

const float SCALER_MEAN[8] = { 40.31f, 13.32f, 17.94f, 44.74f, 77.81f, 35.56f, 16.78f, 12.84f };
const float SCALER_STD[8] = { 5.48f, 9.25f, 9.72f, 24.72f, 32.81f, 18.14f, 4.20f, 5.41f };

const int TICK_RATE = 300;
const int NUM_RAYS = 7;
const float RAY_ANGLES[NUM_RAYS] = { -60.0f, -40.0f, -20.0f, 0.0f, 20.0f, 40.0f, 60.0f };
const float MAX_DISTANCE = 120.0f;
const std::string FILE_PATH = "C:\\Users\\sgemm\\Desktop\\Racebot\\gtav-racebot\\data\\"; // this needs to be refactored

enum State {
    Default,
    Debug, 
    AIMode, 
    Logging,
};

State state = Default;
std::ofstream logFile;
int sampleCount = 0;
int frameCounter = 0;

void showNotification(const char* text) {
    UI::_SET_NOTIFICATION_TEXT_ENTRY("STRING");
    UI::_ADD_TEXT_COMPONENT_STRING((char*)text);
    UI::_DRAW_NOTIFICATION(false, false);
}

Vector3 rotateVector(Vector3 vec, float angleDeg) {
    float rad = angleDeg * 3.14159265f / 180.0f;
    Vector3 result;
    result.x = vec.x * cos(rad) + vec.y * sin(rad);
    result.y = -vec.x * sin(rad) + vec.y * cos(rad);
    result.z = vec.z;
    return result;
}

void drawRaycastDebug(Vehicle vehicle) {
    Vector3 pos = ENTITY::GET_ENTITY_COORDS(vehicle, true);
    Vector3 forward = ENTITY::GET_ENTITY_FORWARD_VECTOR(vehicle);
    forward.z = 0.0f;
    pos.z += 0.5f;

    for (int i = 0; i < NUM_RAYS; i++) {
        Vector3 direction = rotateVector(forward, RAY_ANGLES[i]);

        Vector3 endPos;
        endPos.x = pos.x + direction.x * MAX_DISTANCE;
        endPos.y = pos.y + direction.y * MAX_DISTANCE;
        endPos.z = pos.z + direction.z * MAX_DISTANCE;

        int rayHandle = WORLDPROBE::_CAST_RAY_POINT_TO_POINT(
            pos.x, pos.y, pos.z, endPos.x, endPos.y, endPos.z,
            1 | 16, vehicle, 7
        );

        BOOL hit;
        Vector3 hitCoords, surfaceNormal;
        Entity hitEntity;
        WORLDPROBE::_GET_RAYCAST_RESULT(rayHandle, &hit, &hitCoords, &surfaceNormal, &hitEntity);

        int r = hit ? 0 : 255;
        int g = hit ? 255 : 0;
        Vector3 drawEnd = hit ? hitCoords : endPos;

        GRAPHICS::DRAW_LINE(pos.x, pos.y, pos.z, drawEnd.x, drawEnd.y, drawEnd.z, r, g, 0, 255);
    }
}

float castSingleRay(Vehicle vehicle, float angleDeg) {
    Vector3 pos = ENTITY::GET_ENTITY_COORDS(vehicle, true);
    Vector3 forward = ENTITY::GET_ENTITY_FORWARD_VECTOR(vehicle);
    forward.z = 0.0f;
    pos.z += 0.5f;

    Vector3 direction = rotateVector(forward, angleDeg);
    Vector3 endPos;
    endPos.x = pos.x + direction.x * MAX_DISTANCE;
    endPos.y = pos.y + direction.y * MAX_DISTANCE;
    endPos.z = pos.z + direction.z * MAX_DISTANCE;

    int rayHandle = WORLDPROBE::_CAST_RAY_POINT_TO_POINT(
        pos.x, pos.y, pos.z, endPos.x, endPos.y, endPos.z,
        1 | 16, vehicle, 7
    );

    BOOL hit;
    Vector3 hitCoords, surfaceNormal;
    Entity hitEntity;
    WORLDPROBE::_GET_RAYCAST_RESULT(rayHandle, &hit, &hitCoords, &surfaceNormal, &hitEntity);

    if (hit) {
        float dx = hitCoords.x - pos.x;
        float dy = hitCoords.y - pos.y;
        float dz = hitCoords.z - pos.z;
        return sqrt(dx * dx + dy * dy + dz * dz);
    }
    return MAX_DISTANCE;
}

void startLogging() {
    time_t now = time(0);
    tm ltm;
    localtime_s(&ltm, &now);

    char filename[512];
    sprintf_s(filename, "%sracing_data_%04d%02d%02d_%02d%02d%02d.csv",
        FILE_PATH.c_str(),
        1900 + ltm.tm_year, 1 + ltm.tm_mon, ltm.tm_mday,
        ltm.tm_hour, ltm.tm_min, ltm.tm_sec);

    logFile.open(filename);

    logFile << "speed,";
    logFile << "ray_0,ray_1,ray_2,ray_3,ray_4,ray_5,ray_6,";
    logFile << "steering,throttle\n";

    sampleCount = 0;

    char buffer[256];
    sprintf_s(buffer, "Logging started: %s", filename);
    showNotification(buffer);
}

void logData(Vehicle vehicle) {
    if (state != Logging || !logFile.is_open()) return;

    float speed = ENTITY::GET_ENTITY_SPEED(vehicle);

    float distances[NUM_RAYS];
    for (int i = 0; i < NUM_RAYS; i++) 
        distances[i] = castSingleRay(vehicle, RAY_ANGLES[i]);

    float steering = CONTROLS::GET_CONTROL_NORMAL(0, 59);
    float throttle = CONTROLS::GET_CONTROL_NORMAL(0, 71);

    logFile << speed << ",";
    for (int i = 0; i < NUM_RAYS; i++) 
        logFile << distances[i] << ",";
    logFile << steering << "," << throttle << "," << "\n";

    sampleCount++;
}

void stopLogging() {
    logFile.close();

    char buffer[128];
    sprintf_s(buffer, "Logging stopped. Samples: %d", sampleCount);
    showNotification(buffer);
}

void handleInput() {
    // F9: Default 
    if (IsKeyJustUp(VK_F9)) {
        switch (state) {
        case AIMode: 
            break;
        case Logging:
            stopLogging();
            break;
        case Debug:
            break;
        }
           
        state = Default;
        showNotification("DEFAULT MODE");
    }

    // F10: AI
    if (IsKeyJustUp(VK_F10)) {
        state = AIMode;
        showNotification("AI MODE");
    }

    // F11: Logging
    if (IsKeyJustUp(VK_F11)) {
        startLogging();
        state = Logging;
        showNotification("LOGGING MODE");
    }

    // F12: Debug
    if (IsKeyJustUp(VK_F12)) {
        state = Debug;
        showNotification("DEBUGGING MODE");
    }
}

void runAI(Vehicle vehicle) {
    if (!session || !vehicle) return;

    float speed = ENTITY::GET_ENTITY_SPEED(vehicle);
    std::vector<float> input_vector(8);
    input_vector[0] = speed;
    for (int i = 0; i < NUM_RAYS; i++)
        input_vector[i + 1] = castSingleRay(vehicle, RAY_ANGLES[i]);

    for (int i = 0; i < 8; i++) 
        input_vector[i] = (input_vector[i] - SCALER_MEAN[i]) / SCALER_STD[i];

    Ort::MemoryInfo memory_info = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);

    std::array<int64_t, 2> input_shape = { 1, 8 };
    Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
        memory_info,
        input_vector.data(),
        input_vector.size(),
        input_shape.data(),
        input_shape.size()
    );

    const char* input_names[] = { "input" };
    const char* output_names[] = { "sequential_1" };

    auto output_tensors = session->Run(
        Ort::RunOptions{ nullptr },
        input_names, &input_tensor, 1,
        output_names, 1
    );

    float* output = output_tensors[0].GetTensorMutableData<float>();
    float steering = tanh(output[0]);           // Apply tanh (matches training)
    float throttle = fmax(0.0f, fmin(1.0f, output[1]));  // Clip [0, 1] (matches training)

    float amplified_steering = steering * 2.0f;  // Double the steering signal
    amplified_steering = fmax(-1.0f, fmin(1.0f, amplified_steering));  // Clamp to [-1, 1]

    CONTROLS::_SET_CONTROL_NORMAL(0, 59, amplified_steering);
    CONTROLS::_SET_CONTROL_NORMAL(0, 71, throttle * 0.85f);
}

void ScriptMain() {
    const wchar_t* model_path = L"C:\\Program Files\\Rockstar Games\\Grand Theft Auto V\\racebot_model.onnx"; // this is shit

    env = new Ort::Env(ORT_LOGGING_LEVEL_WARNING, "GTAVRaceBot");
    session_options = new Ort::SessionOptions();
    session_options->SetIntraOpNumThreads(1);
    session_options->SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_BASIC);

    session = new Ort::Session(*env, model_path, *session_options);
    
    showNotification("Basic Controls: \nF9: Default \nF10: AI Mode \nF11: Start/Stop Logging \nF12: Debug");

    while (true) {
        WAIT(0);
        frameCounter++;

        // Disable traffic
        VEHICLE::SET_VEHICLE_DENSITY_MULTIPLIER_THIS_FRAME(0.0f);
        VEHICLE::SET_RANDOM_VEHICLE_DENSITY_MULTIPLIER_THIS_FRAME(0.0f);
        PED::SET_PED_DENSITY_MULTIPLIER_THIS_FRAME(0.0f);

        Ped player = PLAYER::PLAYER_PED_ID();

        if (!PED::IS_PED_IN_ANY_VEHICLE(player, false)) continue;
        Vehicle vehicle = PED::GET_VEHICLE_PED_IS_IN(player, false);

        handleInput();

        switch (state) {
        case Default: 
            break;
        case AIMode:
            runAI(vehicle);
            break;
        case Logging:
            logData(vehicle);
            break;
        case Debug:
            drawRaycastDebug(vehicle);

            if (frameCounter % TICK_RATE == 0) {
                float distances[NUM_RAYS];
                for (int i = 0; i < NUM_RAYS; i++) 
                    distances[i] = castSingleRay(vehicle, RAY_ANGLES[i]);

                float speed = ENTITY::GET_ENTITY_SPEED(vehicle);
                float steering = CONTROLS::GET_CONTROL_NORMAL(0, 59);
                float throttle = CONTROLS::GET_CONTROL_NORMAL(0, 71);

                char buffer[256];
                sprintf_s(buffer, "Speed:%.1f Throttle:%.1f Steer:%.2f~n~L:%.0f %.0f %.0f C:%.0f R:%.0f %.0f %.0f",
                    speed, throttle, steering,
                    distances[0], distances[1], distances[2],
                    distances[3],
                    distances[4], distances[5], distances[6]);
                showNotification(buffer);
            }
            break;
        }
    }
}