Skip to content

Commit

Permalink
Merge pull request #850 from morgen-peschke/expand-testing-for-SlfjLo…
Browse files Browse the repository at this point in the history
…gger

Fix MDC handling in Slf4jLogger
  • Loading branch information
rossabaker authored Feb 13, 2025
2 parents d51d1ec + 61ee3e9 commit fb7d847
Show file tree
Hide file tree
Showing 3 changed files with 643 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@

package org.typelevel.log4cats.slf4j.internal

import org.typelevel.log4cats._
import cats.syntax.all._
import cats.effect._
import org.slf4j.{Logger => JLogger}
import org.typelevel.log4cats.*
import cats.syntax.all.*
import cats.effect.*
import org.slf4j.Logger as JLogger
import org.slf4j.MDC

import scala.annotation.nowarn

private[slf4j] object Slf4jLoggerInternal {

final val singletonsByName = true
Expand All @@ -34,21 +36,36 @@ private[slf4j] object Slf4jLoggerInternal {
def apply(t: Throwable)(msg: => String): F[Unit]
}

// Need this to make sure MDC is correctly cleared before logging
private[this] def noContextLog[F[_]](isEnabled: F[Boolean], logging: () => Unit)(implicit
F: Sync[F]
): F[Unit] =
contextLog[F](isEnabled, Map.empty, logging)

private[this] def contextLog[F[_]](
isEnabled: F[Boolean],
ctx: Map[String, String],
logging: () => Unit
)(implicit F: Sync[F]): F[Unit] = {

val ifEnabled = F.delay {
val backup = MDC.getCopyOfContextMap()
val backup =
try MDC.getCopyOfContextMap()
catch {
case e: IllegalStateException =>
// MDCAdapter is missing, no point in doing anything with
// the MDC, so just hope the logging backend can salvage
// something.
logging()
throw e
}

for {
(k, v) <- ctx
} MDC.put(k, v)

try logging()
finally
try {
// Once 2.12 is no longer supported, change this to MDC.setContextMap(ctx.asJava)
MDC.clear()
ctx.foreach { case (k, v) => MDC.put(k, v) }
logging()
} finally
if (backup eq null) MDC.clear()
else MDC.setContextMap(backup)
}
Expand All @@ -59,6 +76,7 @@ private[slf4j] object Slf4jLoggerInternal {
)
}

@nowarn("msg=used")
final class Slf4jLogger[F[_]](val logger: JLogger, sync: Sync.Type = Sync.Type.Delay)(implicit
F: Sync[F]
) extends SelfAwareStructuredLogger[F] {
Expand All @@ -76,53 +94,47 @@ private[slf4j] object Slf4jLoggerInternal {
override def isErrorEnabled: F[Boolean] = F.delay(logger.isErrorEnabled)

override def trace(t: Throwable)(msg: => String): F[Unit] =
isTraceEnabled
.ifM(F.suspend(sync)(logger.trace(msg, t)), F.unit)
noContextLog(isTraceEnabled, () => logger.trace(msg, t))
override def trace(msg: => String): F[Unit] =
isTraceEnabled
.ifM(F.suspend(sync)(logger.trace(msg)), F.unit)
noContextLog(isTraceEnabled, () => logger.trace(msg))
override def trace(ctx: Map[String, String])(msg: => String): F[Unit] =
contextLog(isTraceEnabled, ctx, () => logger.trace(msg))
override def trace(ctx: Map[String, String], t: Throwable)(msg: => String): F[Unit] =
contextLog(isTraceEnabled, ctx, () => logger.trace(msg, t))

override def debug(t: Throwable)(msg: => String): F[Unit] =
isDebugEnabled
.ifM(F.suspend(sync)(logger.debug(msg, t)), F.unit)
noContextLog(isDebugEnabled, () => logger.debug(msg, t))
override def debug(msg: => String): F[Unit] =
isDebugEnabled
.ifM(F.suspend(sync)(logger.debug(msg)), F.unit)
noContextLog(isDebugEnabled, () => logger.debug(msg))
override def debug(ctx: Map[String, String])(msg: => String): F[Unit] =
contextLog(isDebugEnabled, ctx, () => logger.debug(msg))
override def debug(ctx: Map[String, String], t: Throwable)(msg: => String): F[Unit] =
contextLog(isDebugEnabled, ctx, () => logger.debug(msg, t))

override def info(t: Throwable)(msg: => String): F[Unit] =
isInfoEnabled
.ifM(F.suspend(sync)(logger.info(msg, t)), F.unit)
noContextLog(isInfoEnabled, () => logger.info(msg, t))
override def info(msg: => String): F[Unit] =
isInfoEnabled
.ifM(F.suspend(sync)(logger.info(msg)), F.unit)
noContextLog(isInfoEnabled, () => logger.info(msg))
override def info(ctx: Map[String, String])(msg: => String): F[Unit] =
contextLog(isInfoEnabled, ctx, () => logger.info(msg))
override def info(ctx: Map[String, String], t: Throwable)(msg: => String): F[Unit] =
contextLog(isInfoEnabled, ctx, () => logger.info(msg, t))

override def warn(t: Throwable)(msg: => String): F[Unit] =
isWarnEnabled
.ifM(F.suspend(sync)(logger.warn(msg, t)), F.unit)
noContextLog(isWarnEnabled, () => logger.warn(msg, t))
override def warn(msg: => String): F[Unit] =
isWarnEnabled
.ifM(F.suspend(sync)(logger.warn(msg)), F.unit)
noContextLog(isWarnEnabled, () => logger.warn(msg))
override def warn(ctx: Map[String, String])(msg: => String): F[Unit] =
contextLog(isWarnEnabled, ctx, () => logger.warn(msg))
override def warn(ctx: Map[String, String], t: Throwable)(msg: => String): F[Unit] =
contextLog(isWarnEnabled, ctx, () => logger.warn(msg, t))

override def error(t: Throwable)(msg: => String): F[Unit] =
isErrorEnabled
.ifM(F.suspend(sync)(logger.error(msg, t)), F.unit)
noContextLog(isErrorEnabled, () => logger.error(msg, t))
override def error(msg: => String): F[Unit] =
isErrorEnabled
.ifM(F.suspend(sync)(logger.error(msg)), F.unit)
noContextLog(isErrorEnabled, () => logger.error(msg))
override def error(ctx: Map[String, String])(msg: => String): F[Unit] =
contextLog(isErrorEnabled, ctx, () => logger.error(msg))
override def trace(ctx: Map[String, String], t: Throwable)(msg: => String): F[Unit] =
contextLog(isTraceEnabled, ctx, () => logger.trace(msg, t))
override def debug(ctx: Map[String, String], t: Throwable)(msg: => String): F[Unit] =
contextLog(isDebugEnabled, ctx, () => logger.debug(msg, t))
override def info(ctx: Map[String, String], t: Throwable)(msg: => String): F[Unit] =
contextLog(isInfoEnabled, ctx, () => logger.info(msg, t))
override def warn(ctx: Map[String, String], t: Throwable)(msg: => String): F[Unit] =
contextLog(isWarnEnabled, ctx, () => logger.warn(msg, t))
override def error(ctx: Map[String, String], t: Throwable)(msg: => String): F[Unit] =
contextLog(isErrorEnabled, ctx, () => logger.error(msg, t))
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
/*
* Copyright 2018 Typelevel
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.typelevel.log4cats.slf4j.internal;

import org.slf4j.Logger;
import org.slf4j.MDC;
import org.slf4j.Marker;
import org.typelevel.log4cats.extras.LogLevel;
import scala.Option;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Supplier;

public class JTestLogger implements Logger {
// Java -> Scala compat helpers

private static final scala.Option<Throwable> none = scala.Option$.MODULE$.empty();
private static scala.Option<Throwable> some(Throwable t) { return scala.Option$.MODULE$.apply(t); }
private static final LogLevel.Trace$ Trace = LogLevel.Trace$.MODULE$;
private static final LogLevel.Debug$ Debug = LogLevel.Debug$.MODULE$;
private static final LogLevel.Info$ Info = LogLevel.Info$.MODULE$;
private static final LogLevel.Warn$ Warn = LogLevel.Warn$.MODULE$;
private static final LogLevel.Error$ Error = LogLevel.Error$.MODULE$;

private Map<String, String> captureContext () {
java.util.Map<String, String> mdc = MDC.getCopyOfContextMap();
if (mdc == null) {
return new HashMap<>();
}
return MDC.getCopyOfContextMap();
}

public static class TestLogMessage {
public final LogLevel logLevel;
public final java.util.Map<String, String> context;
public final Option<Throwable> throwableOpt;
public final Supplier<String> message;

public TestLogMessage(LogLevel logLevel,
java.util.Map<String, String> context,
Option<Throwable> throwableOpt,
Supplier<String> message) {
this.logLevel = logLevel;
this.context = context;
this.throwableOpt = throwableOpt;
this.message = message;
}

@Override
public String toString() {
return new StringBuilder()
.append("TestLogMessage(")
.append("logLevel=").append(logLevel)
.append(", ")
.append("context=").append(context)
.append(", ")
.append("throwableOpt=").append(throwableOpt)
.append(", ")
.append("message=").append(message.get())
.append(')')
.toString();
}

static TestLogMessage of(LogLevel logLevel,
java.util.Map<String, String> context,
Throwable throwable,
Supplier<String> message) {
return new TestLogMessage(logLevel, context, some(throwable), message);
}

static TestLogMessage of(LogLevel logLevel,
java.util.Map<String, String> context,
Supplier<String> message) {
return new TestLogMessage(logLevel, context, none, message);
}
}

private final String loggerName;
private final boolean traceEnabled;
private final boolean debugEnabled;
private final boolean infoEnabled;
private final boolean warnEnabled;
private final boolean errorEnabled;
private final AtomicReference<List<TestLogMessage>> loggedMessages;


public JTestLogger(String loggerName,
boolean traceEnabled,
boolean debugEnabled,
boolean infoEnabled,
boolean warnEnabled,
boolean errorEnabled) {
this.loggerName = loggerName;
this.traceEnabled = traceEnabled;
this.debugEnabled = debugEnabled;
this.infoEnabled = infoEnabled;
this.warnEnabled = warnEnabled;
this.errorEnabled = errorEnabled;
loggedMessages = new AtomicReference<>(new ArrayList<TestLogMessage>());
}

private void save(Function<Map<String, String>, TestLogMessage> mkLogMessage) {
loggedMessages.updateAndGet(ll -> {
ll.add(mkLogMessage.apply(captureContext()));
return ll;
});
}

public List<TestLogMessage> logs() { return loggedMessages.get(); }
public void reset() { loggedMessages.set(new ArrayList<>()); }

@Override public String getName() { return loggerName;}

@Override public boolean isTraceEnabled() { return traceEnabled; }
@Override public boolean isDebugEnabled() { return debugEnabled; }
@Override public boolean isInfoEnabled() { return infoEnabled; }
@Override public boolean isWarnEnabled() { return warnEnabled; }
@Override public boolean isErrorEnabled() { return errorEnabled; }

// We don't use them, so we're going to ignore Markers
@Override public boolean isTraceEnabled(Marker marker) { return traceEnabled; }
@Override public boolean isDebugEnabled(Marker marker) { return debugEnabled; }
@Override public boolean isInfoEnabled(Marker marker) { return infoEnabled; }
@Override public boolean isWarnEnabled(Marker marker) { return warnEnabled; }
@Override public boolean isErrorEnabled(Marker marker) { return errorEnabled; }

@Override public void trace(String msg) { save(ctx -> TestLogMessage.of(Trace, ctx, () -> msg)); }
@Override public void trace(String msg, Throwable t) { save(ctx -> TestLogMessage.of(Trace, ctx, t, () -> msg)); }

@Override public void debug(String msg) { save(ctx -> TestLogMessage.of(Debug, ctx, () -> msg)); }
@Override public void debug(String msg, Throwable t) { save(ctx -> TestLogMessage.of(Debug, ctx, t, () -> msg)); }

@Override public void info(String msg) { save(ctx -> TestLogMessage.of(Info, ctx, () -> msg)); }
@Override public void info(String msg, Throwable t) { save(ctx -> TestLogMessage.of(Info, ctx, t, () -> msg)); }

@Override public void warn(String msg) { save(ctx -> TestLogMessage.of(Warn, ctx, () -> msg)); }
@Override public void warn(String msg, Throwable t) { save(ctx -> TestLogMessage.of(Warn, ctx, t, () -> msg)); }

@Override public void error(String msg) { save(ctx -> TestLogMessage.of(Error, ctx, () -> msg)); }
@Override public void error(String msg, Throwable t) { save(ctx -> TestLogMessage.of(Error, ctx, t, () -> msg)); }

// We shouldn't need these for our tests, so we're treating these variants as if they were the standard method

@Override public void trace(String format, Object arg) { trace(format); }
@Override public void trace(String format, Object arg1, Object arg2) { trace(format); }
@Override public void trace(String format, Object... arguments) { trace(format); }
@Override public void trace(Marker marker, String msg) { trace(msg); }
@Override public void trace(Marker marker, String format, Object arg) { trace(format); }
@Override public void trace(Marker marker, String format, Object arg1, Object arg2) { trace(format); }
@Override public void trace(Marker marker, String format, Object... argArray) { trace(format); }
@Override public void trace(Marker marker, String msg, Throwable t) { trace(msg, t); }

@Override public void debug(String format, Object arg) { debug(format); }
@Override public void debug(String format, Object arg1, Object arg2) { debug(format); }
@Override public void debug(String format, Object... arguments) { debug(format); }
@Override public void debug(Marker marker, String msg) { debug(msg); }
@Override public void debug(Marker marker, String format, Object arg) { debug(format); }
@Override public void debug(Marker marker, String format, Object arg1, Object arg2) { debug(format); }
@Override public void debug(Marker marker, String format, Object... arguments) { debug(format); }
@Override public void debug(Marker marker, String msg, Throwable t) { debug(msg, t); }

@Override public void info(String format, Object arg) { info(format); }
@Override public void info(String format, Object arg1, Object arg2) { info(format); }
@Override public void info(String format, Object... arguments) { info(format); }
@Override public void info(Marker marker, String msg) { info(msg); }
@Override public void info(Marker marker, String format, Object arg) { info(format); }
@Override public void info(Marker marker, String format, Object arg1, Object arg2) { info(format); }
@Override public void info(Marker marker, String format, Object... arguments) { info(format); }
@Override public void info(Marker marker, String msg, Throwable t) { info(msg, t); }

@Override public void warn(String format, Object arg) { warn(format); }
@Override public void warn(String format, Object... arguments) { warn(format); }
@Override public void warn(String format, Object arg1, Object arg2) { warn(format); }
@Override public void warn(Marker marker, String msg) { warn(msg); }
@Override public void warn(Marker marker, String format, Object arg) { warn(format); }
@Override public void warn(Marker marker, String format, Object arg1, Object arg2) { warn(format); }
@Override public void warn(Marker marker, String format, Object... arguments) { warn(format); }
@Override public void warn(Marker marker, String msg, Throwable t) { warn(msg, t); }

@Override public void error(String format, Object arg) { error(format); }
@Override public void error(String format, Object arg1, Object arg2) { error(format); }
@Override public void error(String format, Object... arguments) { error(format); }
@Override public void error(Marker marker, String msg) { error(msg); }
@Override public void error(Marker marker, String format, Object arg) { error(format); }
@Override public void error(Marker marker, String format, Object arg1, Object arg2) { error(format); }
@Override public void error(Marker marker, String format, Object... arguments) { error(format); }
@Override public void error(Marker marker, String msg, Throwable t) { error(msg, t); }
}
Loading

0 comments on commit fb7d847

Please sign in to comment.