[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Additional functionality #127

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .clang-format
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ UseTab: Never
IndentWidth: 4
TabWidth: 4
AllowShortIfStatementsOnASingleLine: false
IndentCaseLabels: false
ColumnLimit: 0
AccessModifierOffset: -4
NamespaceIndentation: All
Expand Down
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ test/
.vscode/
.cache/
*.swp
.vscode/
*.bat
*.bin
*.exe
Expand Down
5 changes: 3 additions & 2 deletions examples/cli/main.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <stdio.h>
#include <cstdio>
#include <ctime>
#include <random>

#include "ggml/ggml.h"
#include "stable-diffusion.h"
#include "util.h"
Expand Down Expand Up @@ -427,7 +428,7 @@ int main(int argc, const char* argv[]) {
if (params.verbose) {
print_params(params);
printf("%s", sd_get_system_info().c_str());
set_sd_log_level(SDLogLevel::DEBUG);
set_sd_log_level(SDLogLevel::SD_LOG_LEVEL_DEBUG);
}

bool vae_decode_only = true;
Expand Down
11 changes: 5 additions & 6 deletions stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,11 @@ void pretty_progress(int step, int steps, float time) {
}
}
progress += "|";
printf(time > 1.0f ? "\r%s %i/%i - %.2fs/it" : "\r%s %i/%i - %.2fit/s",
progress.c_str(), step, steps,
time > 1.0f || time == 0 ? time : (1.0f / time));
fflush(stdout); // for linux
LOG_DEFAULT(time > 1.0f ? "\r%s %i/%i - %.2fs/it" : "\r%s %i/%i - %.2fit/s",
progress.c_str(), step, steps,
time > 1.0f || time == 0 ? time : (1.0f / time));
if (step == steps) {
printf("\n");
LOG_DEFAULT("\n");
}
}

Expand Down Expand Up @@ -1749,7 +1748,7 @@ struct SpatialTransformer {
#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS)
struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, h * w, d_head]
#else
struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, h * w, max_position]
struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, h * w, max_position]
// kq = ggml_diag_mask_inf_inplace(ctx, kq, 0);
kq = ggml_soft_max_inplace(ctx, kq);

Expand Down
93 changes: 70 additions & 23 deletions util.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#include "util.h"

#include <stdarg.h>
#include <codecvt>
#include <cstdarg>
#include <fstream>
#include <locale>
#include <thread>
Expand Down Expand Up @@ -164,40 +164,87 @@ std::string path_join(const std::string& p1, const std::string& p2) {
return p1 + "/" + p2;
}

static SDLogLevel log_level = SDLogLevel::INFO;
static SDLogLevel log_level = SDLogLevel::SD_LOG_LEVEL_INFO;

void set_sd_log_level(SDLogLevel level) {
log_level = level;
}

void log_printf(SDLogLevel level, const char* file, int line, const char* format, ...) {
void default_sd_logger(SDLogLevel level, const char* text) {
if (level == SDLogLevel::SD_LOG_LEVEL_ERROR) {
fputs(text, stderr);
fflush(stderr);
} else {
fputs(text, stdout);
fflush(stdout);
}
}

static sd_logger_function_t sd_logger = &default_sd_logger;

// Ref: https://stackoverflow.com/questions/2342162/stdstring-formatting-like-sprintf
template <typename... Args>
std::string string_format(const std::string& format, Args... args) {
int size_s = std::snprintf(nullptr, 0, format.c_str(), args...) + 1; // Extra space for '\0'
auto size = static_cast<size_t>(size_s);
std::unique_ptr<char[]> buf(new char[size]);
std::snprintf(buf.get(), size, format.c_str(), args...);
return {buf.get(), buf.get() + size - 1}; // We don't want the '\0' inside
}

std::string log_prefix(SDLogLevel level, const char* file, int line) {
std::string format;
switch (level) {
case SDLogLevel::SD_LOG_LEVEL_DEBUG:
format = "[DEBUG] %s:%-4d - ";
break;
case SDLogLevel::SD_LOG_LEVEL_INFO:
format = "[INFO] %s:%-4d - ";
break;
case SDLogLevel::SD_LOG_LEVEL_WARN:
format = "[WARN] %s:%-4d - ";
break;
case SDLogLevel::SD_LOG_LEVEL_ERROR:
format = "[ERROR] %s:%-4d - ";
break;
}
return string_format(format, basename(file).c_str(), line);
}

void log_printf(SDLogLevel level, bool enable_log_tag, bool enable_log_newline, const char* file, int line, const char* format, ...) {
if (level < log_level) {
return;
}

va_list args;
va_start(args, format);
std::string log_prefix_str;
if (enable_log_tag) {
log_prefix_str = log_prefix(level, file, line);
}

if (level == SDLogLevel::DEBUG) {
printf("[DEBUG] %s:%-4d - ", basename(file).c_str(), line);
vprintf(format, args);
printf("\n");
fflush(stdout);
} else if (level == SDLogLevel::INFO) {
printf("[INFO] %s:%-4d - ", basename(file).c_str(), line);
vprintf(format, args);
printf("\n");
fflush(stdout);
} else if (level == SDLogLevel::WARN) {
fprintf(stdout, "[WARN] %s:%-4d - ", basename(file).c_str(), line);
vfprintf(stdout, format, args);
fprintf(stdout, "\n");
fflush(stdout);
char buffer[128];
const int len = std::vsnprintf(buffer, sizeof(buffer), format, args);
if (len < sizeof(buffer)) {
std::string log_message = log_prefix_str + std::string(buffer);
if (enable_log_newline) {
log_message += "\n";
}
sd_logger(level, log_message.c_str());
} else {
fprintf(stderr, "[ERROR] %s:%-4d - ", basename(file).c_str(), line);
vfprintf(stderr, format, args);
fprintf(stderr, "\n");
fflush(stderr);
char* buffer2 = new char[len + 2];
std::vsnprintf(buffer2, len + 1, format, args);
buffer2[len + 1] = 0;
std::string log_message = log_prefix_str + std::string(buffer2);
if (enable_log_newline) {
log_message += "\n";
}
sd_logger(level, log_message.c_str());
delete[] buffer2;
}

va_end(args);
}

void set_sd_logger(const sd_logger_function_t& sd_logger_function) {
sd_logger = sd_logger_function;
}
26 changes: 16 additions & 10 deletions util.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#ifndef __UTIL_H__
#define __UTIL_H__

#include <string>
#include <cstdint>
#include <functional>
#include <string>

bool ends_with(const std::string& str, const std::string& ending);
bool starts_with(const std::string& str, const std::string& start);
Expand All @@ -25,18 +26,23 @@ std::string path_join(const std::string& p1, const std::string& p2);
int32_t get_num_physical_cores();

enum SDLogLevel {
DEBUG,
INFO,
WARN,
ERROR
SD_LOG_LEVEL_DEBUG,
SD_LOG_LEVEL_INFO,
SD_LOG_LEVEL_WARN,
SD_LOG_LEVEL_ERROR
};

void set_sd_log_level(SDLogLevel level);

void log_printf(SDLogLevel level, const char* file, int line, const char* format, ...);
void log_printf(SDLogLevel level, bool enable_log_tag, bool enable_log_newline, const char* file, int line, const char* format, ...);

typedef std::function<void(SDLogLevel level, const char* text)> sd_logger_function_t;

void set_sd_logger(const sd_logger_function_t& sd_logger_function);

#define LOG_DEBUG(format, ...) log_printf(SDLogLevel::DEBUG, __FILE__, __LINE__, format, ##__VA_ARGS__)
#define LOG_INFO(format, ...) log_printf(SDLogLevel::INFO, __FILE__, __LINE__, format, ##__VA_ARGS__)
#define LOG_WARN(format, ...) log_printf(SDLogLevel::WARN, __FILE__, __LINE__, format, ##__VA_ARGS__)
#define LOG_ERROR(format, ...) log_printf(SDLogLevel::ERROR, __FILE__, __LINE__, format, ##__VA_ARGS__)
#define LOG_DEFAULT(format, ...) log_printf(SDLogLevel::SD_LOG_LEVEL_INFO, false, false, __FILE__, __LINE__, format, ##__VA_ARGS__)
#define LOG_DEBUG(format, ...) log_printf(SDLogLevel::SD_LOG_LEVEL_DEBUG, true, true, __FILE__, __LINE__, format, ##__VA_ARGS__)
#define LOG_INFO(format, ...) log_printf(SDLogLevel::SD_LOG_LEVEL_INFO, true, true, __FILE__, __LINE__, format, ##__VA_ARGS__)
#define LOG_WARN(format, ...) log_printf(SDLogLevel::SD_LOG_LEVEL_WARN, true, true, __FILE__, __LINE__, format, ##__VA_ARGS__)
#define LOG_ERROR(format, ...) log_printf(SDLogLevel::SD_LOG_LEVEL_ERROR, true, true, __FILE__, __LINE__, format, ##__VA_ARGS__)
#endif // __UTIL_H__
Loading