1 package org.oxerr.spring.cache.redis.scored;
2
3 import java.time.Duration;
4 import java.util.Optional;
5 import java.util.function.Function;
6
7 import org.springframework.data.redis.cache.CacheStatistics;
8 import org.springframework.data.redis.cache.CacheStatisticsCollector;
9 import org.springframework.data.redis.cache.RedisCacheWriter;
10 import org.springframework.data.redis.connection.RedisConnection;
11 import org.springframework.data.redis.connection.RedisConnectionFactory;
12 import org.springframework.data.redis.connection.RedisZSetCommands.Limit;
13 import org.springframework.data.redis.connection.RedisZSetCommands.Range;
14 import org.springframework.data.redis.connection.RedisZSetCommands.ZAddArgs;
15 import org.springframework.lang.NonNull;
16 import org.springframework.lang.Nullable;
17 import org.springframework.util.Assert;
18
19
20
21
22 public class ScoredRedisCacheWriter implements RedisCacheWriter {
23
24 private static final String MUST_NOT_BE_NULL = " must not be null!";
25
26 private static final String NAME_NOT_NULL = "Name" + MUST_NOT_BE_NULL;
27 private static final String KEY_NOT_NULL = "Key" + MUST_NOT_BE_NULL;
28 private static final String VALUE_NOT_NULL = "Value" + MUST_NOT_BE_NULL;
29
30 private static final String OK = "OK";
31
32 private static final Double DEFAULT_SCORE = Double.valueOf(0d);
33
34 private final RedisConnectionFactory connectionFactory;
35 private final CacheStatisticsCollector statistics;
36 private final ScoreHolder scoreHolder;
37 private final RedisCacheWriter cacheWriter;
38
39 public ScoredRedisCacheWriter(
40 @NonNull RedisConnectionFactory connectionFactory
41 ) {
42 this(connectionFactory, CacheStatisticsCollector.none());
43 }
44
45 public ScoredRedisCacheWriter(
46 @NonNull RedisConnectionFactory connectionFactory,
47 @NonNull CacheStatisticsCollector cacheStatisticsCollector
48 ) {
49 this(
50 connectionFactory,
51 cacheStatisticsCollector,
52 new InheritableThreadLocalScoreHolder()
53 );
54 }
55
56 public ScoredRedisCacheWriter(
57 @NonNull RedisConnectionFactory connectionFactory,
58 @NonNull CacheStatisticsCollector cacheStatisticsCollector,
59 @NonNull ScoreHolder scoreHolder
60 ) {
61 this(
62 connectionFactory,
63 cacheStatisticsCollector,
64 scoreHolder,
65 RedisCacheWriter
66 .nonLockingRedisCacheWriter(connectionFactory)
67 .withStatisticsCollector(cacheStatisticsCollector)
68 );
69 }
70
71 public ScoredRedisCacheWriter(
72 @NonNull RedisConnectionFactory connectionFactory,
73 @NonNull CacheStatisticsCollector cacheStatisticsCollector,
74 @NonNull ScoreHolder scoreHolder,
75 @NonNull RedisCacheWriter cacheWriter
76 ) {
77 this.connectionFactory = connectionFactory;
78 this.statistics = cacheStatisticsCollector;
79 this.scoreHolder = scoreHolder;
80 this.cacheWriter = cacheWriter;
81 }
82
83 @Override
84 public CacheStatistics getCacheStatistics(String cacheName) {
85 return statistics.getCacheStatistics(cacheName);
86 }
87
88 @Override
89 public void put(String name, byte[] key, byte[] value, Duration ttl) {
90 Assert.notNull(name, NAME_NOT_NULL);
91 Assert.notNull(key, KEY_NOT_NULL);
92 Assert.notNull(value, VALUE_NOT_NULL);
93
94 final double score = getScore();
95
96 final long millis = ttl.toMillis();
97 final Range range = Range.range().lt(score);
98
99 execute(name, connection -> {
100
101 connection.zAdd(key, score, value);
102 connection.zRemRangeByScore(key, range);
103
104 if (shouldExpireWithin(ttl)) {
105 connection.pExpire(key, millis);
106 }
107
108 return OK;
109 });
110
111 statistics.incPuts(name);
112 }
113
114 @Override
115 public byte[] get(String name, byte[] key) {
116 Assert.notNull(name, NAME_NOT_NULL);
117 Assert.notNull(key, KEY_NOT_NULL);
118
119 final Range range = Range.unbounded();
120 final Limit limit = Limit.limit().count(1);
121
122 byte[] result = execute(name, connection -> connection.zRevRangeByScore(key, range, limit))
123 .stream()
124 .findFirst()
125 .orElse(null);
126
127 statistics.incGets(name);
128
129 if (result != null) {
130 statistics.incHits(name);
131 } else {
132 statistics.incMisses(name);
133 }
134
135 return result;
136 }
137
138 @Override
139 public byte[] putIfAbsent(String name, byte[] key, byte[] value, Duration ttl) {
140 Assert.notNull(name, NAME_NOT_NULL);
141 Assert.notNull(key, KEY_NOT_NULL);
142 Assert.notNull(value, VALUE_NOT_NULL);
143
144 final double score = getScore();
145
146 final long millis = ttl.toMillis();
147 final Range range = Range.range().lt(score);
148 final Limit limit = Limit.limit().count(1);
149
150 return execute(name, connection -> {
151 final Boolean added = connection.zAdd(key, score, value, ZAddArgs.ifNotExists());
152 connection.zRemRangeByScore(key, range);
153
154 if (shouldExpireWithin(ttl)) {
155 connection.pExpire(key, millis);
156 }
157
158 if (Boolean.TRUE.equals(added)) {
159 statistics.incPuts(name);
160 return null;
161 }
162
163 return connection.zRevRangeByScore(key, range, limit)
164 .stream()
165 .findFirst()
166 .orElse(null);
167 });
168 }
169
170 @Override
171 public void remove(String name, byte[] key) {
172 this.cacheWriter.remove(name, key);
173 }
174
175 @Override
176 public void clean(String name, byte[] pattern) {
177 this.cacheWriter.clean(name, pattern);
178 }
179
180 @Override
181 public void clearStatistics(String name) {
182 statistics.reset(name);
183 }
184
185 @Override
186 public ScoredRedisCacheWriter withStatisticsCollector(CacheStatisticsCollector cacheStatisticsCollector) {
187 return new ScoredRedisCacheWriter(connectionFactory, cacheStatisticsCollector);
188 }
189
190 private <T> T execute(String name, Function<RedisConnection, T> callback) {
191 try (RedisConnection connection = connectionFactory.getConnection()) {
192 return callback.apply(connection);
193 }
194 }
195
196 private static boolean shouldExpireWithin(@Nullable Duration ttl) {
197 return ttl != null && !ttl.isZero() && !ttl.isNegative();
198 }
199
200 private double getScore() {
201 return Optional.ofNullable(this.scoreHolder.get()).orElse(DEFAULT_SCORE);
202 }
203
204 }